Block-Structured AMR Software Framework
 
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages
Loading...
Searching...
No Matches
AMReX_FFT_Helper.H
Go to the documentation of this file.
1#ifndef AMREX_FFT_HELPER_H_
2#define AMREX_FFT_HELPER_H_
3#include <AMReX_Config.H>
4
5#include <AMReX.H>
6#include <AMReX_BLProfiler.H>
9#include <AMReX_Enum.H>
10#include <AMReX_FabArray.H>
11#include <AMReX_Gpu.H>
12#include <AMReX_GpuComplex.H>
13#include <AMReX_Math.H>
14#include <AMReX_Periodicity.H>
15
16#if defined(AMREX_USE_CUDA)
17# include <cufft.h>
18# include <cuComplex.h>
19#elif defined(AMREX_USE_HIP)
20# if __has_include(<rocfft/rocfft.h>) // ROCm 5.3+
21# include <rocfft/rocfft.h>
22# else
23# include <rocfft.h>
24# endif
25# include <hip/hip_complex.h>
26#elif defined(AMREX_USE_SYCL)
27# if __has_include(<oneapi/mkl/dft.hpp>) // oneAPI 2025.0
28# include <oneapi/mkl/dft.hpp>
29#else
30# define AMREX_USE_MKL_DFTI_2024 1
31# include <oneapi/mkl/dfti.hpp>
32# endif
33#else
34# include <fftw3.h>
35#endif
36
37#include <algorithm>
38#include <complex>
39#include <limits>
40#include <memory>
41#include <tuple>
42#include <utility>
43#include <variant>
44
45namespace amrex::FFT
46{
47
48enum struct Direction { forward, backward, both, none };
49
51
52AMREX_ENUM( Boundary, periodic, even, odd );
53
56
57struct Info
58{
61
65
69 bool twod_mode = false;
70
72 int batch_size = 1;
73
75 int nprocs = std::numeric_limits<int>::max();
76
78 Info& setPencilThreshold (int t) { pencil_threshold = t; return *this; }
79 Info& setTwoDMode (bool x) { twod_mode = x; return *this; }
80 Info& setBatchSize (int bsize) { batch_size = bsize; return *this; }
81 Info& setNumProcs (int n) { nprocs = n; return *this; }
82};
83
84#ifdef AMREX_USE_HIP
85namespace detail { void hip_execute (rocfft_plan plan, void **in, void **out); }
86#endif
87
88#ifdef AMREX_USE_SYCL
89namespace detail
90{
91template <typename T, Direction direction, typename P, typename TI, typename TO>
92void sycl_execute (P* plan, TI* in, TO* out)
93{
94#ifndef AMREX_USE_MKL_DFTI_2024
95 std::int64_t workspaceSize = 0;
96#else
97 std::size_t workspaceSize = 0;
98#endif
99 plan->get_value(oneapi::mkl::dft::config_param::WORKSPACE_BYTES,
100 &workspaceSize);
101 auto* buffer = (T*)amrex::The_Arena()->alloc(workspaceSize);
102 plan->set_workspace(buffer);
103 sycl::event r;
104 if (std::is_same_v<TI,TO>) {
106 if constexpr (direction == Direction::forward) {
107 r = oneapi::mkl::dft::compute_forward(*plan, out);
108 } else {
109 r = oneapi::mkl::dft::compute_backward(*plan, out);
110 }
111 } else {
112 if constexpr (direction == Direction::forward) {
113 r = oneapi::mkl::dft::compute_forward(*plan, in, out);
114 } else {
115 r = oneapi::mkl::dft::compute_backward(*plan, in, out);
116 }
117 }
118 r.wait();
119 amrex::The_Arena()->free(buffer);
120}
121}
122#endif
123
124template <typename T>
125struct Plan
126{
127#if defined(AMREX_USE_CUDA)
128 using VendorPlan = cufftHandle;
129 using VendorComplex = std::conditional_t<std::is_same_v<float,T>,
130 cuComplex, cuDoubleComplex>;
131#elif defined(AMREX_USE_HIP)
132 using VendorPlan = rocfft_plan;
133 using VendorComplex = std::conditional_t<std::is_same_v<float,T>,
134 float2, double2>;
135#elif defined(AMREX_USE_SYCL)
136 using mkl_desc_r = oneapi::mkl::dft::descriptor<std::is_same_v<float,T>
137 ? oneapi::mkl::dft::precision::SINGLE
138 : oneapi::mkl::dft::precision::DOUBLE,
139 oneapi::mkl::dft::domain::REAL>;
140 using mkl_desc_c = oneapi::mkl::dft::descriptor<std::is_same_v<float,T>
141 ? oneapi::mkl::dft::precision::SINGLE
142 : oneapi::mkl::dft::precision::DOUBLE,
143 oneapi::mkl::dft::domain::COMPLEX>;
144 using VendorPlan = std::variant<mkl_desc_r*,mkl_desc_c*>;
145 using VendorComplex = std::complex<T>;
146#else
147 using VendorPlan = std::conditional_t<std::is_same_v<float,T>,
148 fftwf_plan, fftw_plan>;
149 using VendorComplex = std::conditional_t<std::is_same_v<float,T>,
150 fftwf_complex, fftw_complex>;
151#endif
152
153 int n = 0;
154 int howmany = 0;
157 bool defined = false;
158 bool defined2 = false;
161 void* pf = nullptr;
162 void* pb = nullptr;
163
164#ifdef AMREX_USE_GPU
165 void set_ptrs (void* p0, void* p1) {
166 pf = p0;
167 pb = p1;
168 }
169#endif
170
171 void destroy ()
172 {
173 if (defined) {
175 defined = false;
176 }
177#if !defined(AMREX_USE_GPU)
178 if (defined2) {
180 defined2 = false;
181 }
182#endif
183 }
184
185 template <Direction D>
186 void init_r2c (Box const& box, T* pr, VendorComplex* pc, bool is_2d_transform = false, int ncomp = 1)
187 {
188 static_assert(D == Direction::forward || D == Direction::backward);
189
190 int rank = is_2d_transform ? 2 : 1;
191
193 defined = true;
194 pf = (void*)pr;
195 pb = (void*)pc;
196
197 int len[2] = {};
198 if (rank == 1) {
199 len[0] = box.length(0);
200 len[1] = box.length(0); // Not used except for HIP. Yes it's `(0)`.
201 } else {
202 len[0] = box.length(1); // Most FFT libraries assume row-major ordering
203 len[1] = box.length(0); // except for rocfft
204 }
205 int nr = (rank == 1) ? len[0] : len[0]*len[1];
206 n = nr;
207 int nc = (rank == 1) ? (len[0]/2+1) : (len[1]/2+1)*len[0];
208#if (AMREX_SPACEDIM == 1)
209 howmany = 1;
210#else
211 howmany = (rank == 1) ? AMREX_D_TERM(1, *box.length(1), *box.length(2))
212 : AMREX_D_TERM(1, *1 , *box.length(2));
213#endif
214 howmany *= ncomp;
215
217
218#if defined(AMREX_USE_CUDA)
219
220 AMREX_CUFFT_SAFE_CALL(cufftCreate(&plan));
221 AMREX_CUFFT_SAFE_CALL(cufftSetAutoAllocation(plan, 0));
222 std::size_t work_size;
223 if constexpr (D == Direction::forward) {
224 cufftType fwd_type = std::is_same_v<float,T> ? CUFFT_R2C : CUFFT_D2Z;
226 (cufftMakePlanMany(plan, rank, len, nullptr, 1, nr, nullptr, 1, nc, fwd_type, howmany, &work_size));
227 } else {
228 cufftType bwd_type = std::is_same_v<float,T> ? CUFFT_C2R : CUFFT_Z2D;
230 (cufftMakePlanMany(plan, rank, len, nullptr, 1, nc, nullptr, 1, nr, bwd_type, howmany, &work_size));
231 }
232
233#elif defined(AMREX_USE_HIP)
234
235 auto prec = std::is_same_v<float,T> ? rocfft_precision_single : rocfft_precision_double;
236 // switch to column-major ordering
237 std::size_t length[2] = {std::size_t(len[1]), std::size_t(len[0])};
238 if constexpr (D == Direction::forward) {
239 AMREX_ROCFFT_SAFE_CALL
240 (rocfft_plan_create(&plan, rocfft_placement_notinplace,
241 rocfft_transform_type_real_forward, prec, rank,
242 length, howmany, nullptr));
243 } else {
244 AMREX_ROCFFT_SAFE_CALL
245 (rocfft_plan_create(&plan, rocfft_placement_notinplace,
246 rocfft_transform_type_real_inverse, prec, rank,
247 length, howmany, nullptr));
248 }
249
250#elif defined(AMREX_USE_SYCL)
251
252 mkl_desc_r* pp;
253 if (rank == 1) {
254 pp = new mkl_desc_r(len[0]);
255 } else {
256 pp = new mkl_desc_r({std::int64_t(len[0]), std::int64_t(len[1])});
257 }
258#ifndef AMREX_USE_MKL_DFTI_2024
259 pp->set_value(oneapi::mkl::dft::config_param::PLACEMENT,
260 oneapi::mkl::dft::config_value::NOT_INPLACE);
261#else
262 pp->set_value(oneapi::mkl::dft::config_param::PLACEMENT, DFTI_NOT_INPLACE);
263#endif
264 pp->set_value(oneapi::mkl::dft::config_param::NUMBER_OF_TRANSFORMS, howmany);
265 pp->set_value(oneapi::mkl::dft::config_param::FWD_DISTANCE, nr);
266 pp->set_value(oneapi::mkl::dft::config_param::BWD_DISTANCE, nc);
267 std::vector<std::int64_t> strides;
268 strides.push_back(0);
269 if (rank == 2) { strides.push_back(len[1]); }
270 strides.push_back(1);
271#ifndef AMREX_USE_MKL_DFTI_2024
272 pp->set_value(oneapi::mkl::dft::config_param::FWD_STRIDES, strides);
273 // Do not set BWD_STRIDES
274#else
275 pp->set_value(oneapi::mkl::dft::config_param::FWD_STRIDES, strides.data());
276 // Do not set BWD_STRIDES
277#endif
278 pp->set_value(oneapi::mkl::dft::config_param::WORKSPACE,
279 oneapi::mkl::dft::config_value::WORKSPACE_EXTERNAL);
280 pp->commit(amrex::Gpu::Device::streamQueue());
281 plan = pp;
282
283#else /* FFTW */
284
285 if constexpr (std::is_same_v<float,T>) {
286 if constexpr (D == Direction::forward) {
287 plan = fftwf_plan_many_dft_r2c
288 (rank, len, howmany, pr, nullptr, 1, nr, pc, nullptr, 1, nc,
289 FFTW_ESTIMATE | FFTW_DESTROY_INPUT);
290 } else {
291 plan = fftwf_plan_many_dft_c2r
292 (rank, len, howmany, pc, nullptr, 1, nc, pr, nullptr, 1, nr,
293 FFTW_ESTIMATE | FFTW_DESTROY_INPUT);
294 }
295 } else {
296 if constexpr (D == Direction::forward) {
297 plan = fftw_plan_many_dft_r2c
298 (rank, len, howmany, pr, nullptr, 1, nr, pc, nullptr, 1, nc,
299 FFTW_ESTIMATE | FFTW_DESTROY_INPUT);
300 } else {
301 plan = fftw_plan_many_dft_c2r
302 (rank, len, howmany, pc, nullptr, 1, nc, pr, nullptr, 1, nr,
303 FFTW_ESTIMATE | FFTW_DESTROY_INPUT);
304 }
305 }
306#endif
307 }
308
309 template <Direction D, int M>
310 void init_r2c (IntVectND<M> const& fft_size, void*, void*, bool cache, int ncomp = 1);
311
312 template <Direction D>
313 void init_c2c (Box const& box, VendorComplex* p, int ncomp = 1, int ndims = 1)
314 {
315 static_assert(D == Direction::forward || D == Direction::backward);
316
318 defined = true;
319 pf = (void*)p;
320 pb = (void*)p;
321
322 int len[3] = {};
323
324 if (ndims == 1) {
325 n = box.length(0);
326 howmany = AMREX_D_TERM(1, *box.length(1), *box.length(2));
327 howmany *= ncomp;
328 len[0] = box.length(0);
329 }
330#if (AMREX_SPACEDIM >= 2)
331 else if (ndims == 2) {
332 n = box.length(0) * box.length(1);
333#if (AMREX_SPACEDIM == 2)
334 howmany = ncomp;
335#else
336 howmany = box.length(2) * ncomp;
337#endif
338 len[0] = box.length(1);
339 len[1] = box.length(0);
340 }
341#if (AMREX_SPACEDIM == 3)
342 else if (ndims == 3) {
343 n = box.length(0) * box.length(1) * box.length(2);
344 howmany = ncomp;
345 len[0] = box.length(2);
346 len[1] = box.length(1);
347 len[2] = box.length(0);
348 }
349#endif
350#endif
351
352#if defined(AMREX_USE_CUDA)
353 AMREX_CUFFT_SAFE_CALL(cufftCreate(&plan));
354 AMREX_CUFFT_SAFE_CALL(cufftSetAutoAllocation(plan, 0));
355
356 cufftType t = std::is_same_v<float,T> ? CUFFT_C2C : CUFFT_Z2Z;
357 std::size_t work_size;
359 (cufftMakePlanMany(plan, ndims, len, nullptr, 1, n, nullptr, 1, n, t, howmany, &work_size));
360
361#elif defined(AMREX_USE_HIP)
362
363 auto prec = std::is_same_v<float,T> ? rocfft_precision_single
364 : rocfft_precision_double;
365 auto dir= (D == Direction::forward) ? rocfft_transform_type_complex_forward
366 : rocfft_transform_type_complex_inverse;
367 std::size_t length[3];
368 if (ndims == 1) {
369 length[0] = len[0];
370 } else if (ndims == 2) {
371 length[0] = len[1];
372 length[1] = len[0];
373 } else {
374 length[0] = len[2];
375 length[1] = len[1];
376 length[2] = len[0];
377 }
378 AMREX_ROCFFT_SAFE_CALL
379 (rocfft_plan_create(&plan, rocfft_placement_inplace, dir, prec, ndims,
380 length, howmany, nullptr));
381
382#elif defined(AMREX_USE_SYCL)
383
384 mkl_desc_c* pp;
385 if (ndims == 1) {
386 pp = new mkl_desc_c(n);
387 } else if (ndims == 2) {
388 pp = new mkl_desc_c({std::int64_t(len[0]), std::int64_t(len[1])});
389 } else {
390 pp = new mkl_desc_c({std::int64_t(len[0]), std::int64_t(len[1]), std::int64_t(len[2])});
391 }
392#ifndef AMREX_USE_MKL_DFTI_2024
393 pp->set_value(oneapi::mkl::dft::config_param::PLACEMENT,
394 oneapi::mkl::dft::config_value::INPLACE);
395#else
396 pp->set_value(oneapi::mkl::dft::config_param::PLACEMENT, DFTI_INPLACE);
397#endif
398 pp->set_value(oneapi::mkl::dft::config_param::NUMBER_OF_TRANSFORMS, howmany);
399 pp->set_value(oneapi::mkl::dft::config_param::FWD_DISTANCE, n);
400 pp->set_value(oneapi::mkl::dft::config_param::BWD_DISTANCE, n);
401 std::vector<std::int64_t> strides(ndims+1);
402 strides[0] = 0;
403 strides[ndims] = 1;
404 for (int i = ndims-1; i >= 1; --i) {
405 strides[i] = strides[i+1] * len[ndims-1-i];
406 }
407#ifndef AMREX_USE_MKL_DFTI_2024
408 pp->set_value(oneapi::mkl::dft::config_param::FWD_STRIDES, strides);
409 pp->set_value(oneapi::mkl::dft::config_param::BWD_STRIDES, strides);
410#else
411 pp->set_value(oneapi::mkl::dft::config_param::FWD_STRIDES, strides.data());
412 pp->set_value(oneapi::mkl::dft::config_param::BWD_STRIDES, strides.data());
413#endif
414 pp->set_value(oneapi::mkl::dft::config_param::WORKSPACE,
415 oneapi::mkl::dft::config_value::WORKSPACE_EXTERNAL);
416 pp->commit(amrex::Gpu::Device::streamQueue());
417 plan = pp;
418
419#else /* FFTW */
420
421 if constexpr (std::is_same_v<float,T>) {
422 if constexpr (D == Direction::forward) {
423 plan = fftwf_plan_many_dft
424 (ndims, len, howmany, p, nullptr, 1, n, p, nullptr, 1, n, -1,
425 FFTW_ESTIMATE);
426 } else {
427 plan = fftwf_plan_many_dft
428 (ndims, len, howmany, p, nullptr, 1, n, p, nullptr, 1, n, +1,
429 FFTW_ESTIMATE);
430 }
431 } else {
432 if constexpr (D == Direction::forward) {
433 plan = fftw_plan_many_dft
434 (ndims, len, howmany, p, nullptr, 1, n, p, nullptr, 1, n, -1,
435 FFTW_ESTIMATE);
436 } else {
437 plan = fftw_plan_many_dft
438 (ndims, len, howmany, p, nullptr, 1, n, p, nullptr, 1, n, +1,
439 FFTW_ESTIMATE);
440 }
441 }
442#endif
443 }
444
445#ifndef AMREX_USE_GPU
446 template <Direction D>
447 fftw_r2r_kind get_fftw_kind (std::pair<Boundary,Boundary> const& bc)
448 {
449 if (bc.first == Boundary::even && bc.second == Boundary::even)
450 {
451 return (D == Direction::forward) ? FFTW_REDFT10 : FFTW_REDFT01;
452 }
453 else if (bc.first == Boundary::even && bc.second == Boundary::odd)
454 {
455 return FFTW_REDFT11;
456 }
457 else if (bc.first == Boundary::odd && bc.second == Boundary::even)
458 {
459 return FFTW_RODFT11;
460 }
461 else if (bc.first == Boundary::odd && bc.second == Boundary::odd)
462 {
463 return (D == Direction::forward) ? FFTW_RODFT10 : FFTW_RODFT01;
464 }
465 else {
466 amrex::Abort("FFT: unsupported BC");
467 return fftw_r2r_kind{};
468 }
469
470 }
471#endif
472
473 template <Direction D>
474 Kind get_r2r_kind (std::pair<Boundary,Boundary> const& bc)
475 {
476 if (bc.first == Boundary::even && bc.second == Boundary::even)
477 {
479 }
480 else if (bc.first == Boundary::even && bc.second == Boundary::odd)
481 {
482 return Kind::r2r_eo;
483 }
484 else if (bc.first == Boundary::odd && bc.second == Boundary::even)
485 {
486 return Kind::r2r_oe;
487 }
488 else if (bc.first == Boundary::odd && bc.second == Boundary::odd)
489 {
491 }
492 else {
493 amrex::Abort("FFT: unsupported BC");
494 return Kind::none;
495 }
496
497 }
498
499 template <Direction D>
500 void init_r2r (Box const& box, T* p, std::pair<Boundary,Boundary> const& bc,
501 int howmany_initval = 1)
502 {
503 static_assert(D == Direction::forward || D == Direction::backward);
504
505 kind = get_r2r_kind<D>(bc);
506 defined = true;
507 pf = (void*)p;
508 pb = (void*)p;
509
510 n = box.length(0);
511 howmany = AMREX_D_TERM(howmany_initval, *box.length(1), *box.length(2));
512
513#if defined(AMREX_USE_GPU)
514 int nex=0;
515 if (bc.first == Boundary::odd && bc.second == Boundary::odd &&
516 Direction::forward == D) {
517 nex = 2*n;
518 } else if (bc.first == Boundary::odd && bc.second == Boundary::odd &&
519 Direction::backward == D) {
520 nex = 4*n;
521 } else if (bc.first == Boundary::even && bc.second == Boundary::even &&
522 Direction::forward == D) {
523 nex = 2*n;
524 } else if (bc.first == Boundary::even && bc.second == Boundary::even &&
525 Direction::backward == D) {
526 nex = 4*n;
527 } else if ((bc.first == Boundary::even && bc.second == Boundary::odd) ||
528 (bc.first == Boundary::odd && bc.second == Boundary::even)) {
529 nex = 4*n;
530 } else {
531 amrex::Abort("FFT: unsupported BC");
532 }
533 int nc = (nex/2) + 1;
534
535#if defined (AMREX_USE_CUDA)
536
537 AMREX_CUFFT_SAFE_CALL(cufftCreate(&plan));
538 AMREX_CUFFT_SAFE_CALL(cufftSetAutoAllocation(plan, 0));
539 cufftType fwd_type = std::is_same_v<float,T> ? CUFFT_R2C : CUFFT_D2Z;
540 std::size_t work_size;
542 (cufftMakePlanMany(plan, 1, &nex, nullptr, 1, nc*2, nullptr, 1, nc, fwd_type, howmany, &work_size));
543
544#elif defined(AMREX_USE_HIP)
545
547 auto prec = std::is_same_v<float,T> ? rocfft_precision_single : rocfft_precision_double;
548 const std::size_t length = nex;
549 AMREX_ROCFFT_SAFE_CALL
550 (rocfft_plan_create(&plan, rocfft_placement_inplace,
551 rocfft_transform_type_real_forward, prec, 1,
552 &length, howmany, nullptr));
553
554#elif defined(AMREX_USE_SYCL)
555
556 auto* pp = new mkl_desc_r(nex);
557#ifndef AMREX_USE_MKL_DFTI_2024
558 pp->set_value(oneapi::mkl::dft::config_param::PLACEMENT,
559 oneapi::mkl::dft::config_value::INPLACE);
560#else
561 pp->set_value(oneapi::mkl::dft::config_param::PLACEMENT, DFTI_INPLACE);
562#endif
563 pp->set_value(oneapi::mkl::dft::config_param::NUMBER_OF_TRANSFORMS, howmany);
564 pp->set_value(oneapi::mkl::dft::config_param::FWD_DISTANCE, nc*2);
565 pp->set_value(oneapi::mkl::dft::config_param::BWD_DISTANCE, nc);
566 std::vector<std::int64_t> strides = {0,1};
567#ifndef AMREX_USE_MKL_DFTI_2024
568 pp->set_value(oneapi::mkl::dft::config_param::FWD_STRIDES, strides);
569 pp->set_value(oneapi::mkl::dft::config_param::BWD_STRIDES, strides);
570#else
571 pp->set_value(oneapi::mkl::dft::config_param::FWD_STRIDES, strides.data());
572 pp->set_value(oneapi::mkl::dft::config_param::BWD_STRIDES, strides.data());
573#endif
574 pp->set_value(oneapi::mkl::dft::config_param::WORKSPACE,
575 oneapi::mkl::dft::config_value::WORKSPACE_EXTERNAL);
576 pp->commit(amrex::Gpu::Device::streamQueue());
577 plan = pp;
578
579#endif
580
581#else /* FFTW */
582 auto fftw_kind = get_fftw_kind<D>(bc);
583 if constexpr (std::is_same_v<float,T>) {
584 plan = fftwf_plan_many_r2r
585 (1, &n, howmany, p, nullptr, 1, n, p, nullptr, 1, n, &fftw_kind,
586 FFTW_ESTIMATE);
587 } else {
588 plan = fftw_plan_many_r2r
589 (1, &n, howmany, p, nullptr, 1, n, p, nullptr, 1, n, &fftw_kind,
590 FFTW_ESTIMATE);
591 }
592#endif
593 }
594
595 template <Direction D>
596 void init_r2r (Box const& box, VendorComplex* pc,
597 std::pair<Boundary,Boundary> const& bc)
598 {
599 static_assert(D == Direction::forward || D == Direction::backward);
600
601 auto* p = (T*)pc;
602
603#if defined(AMREX_USE_GPU)
604
605 init_r2r<D>(box, p, bc, 2);
606 r2r_data_is_complex = true;
607
608#else
609
610 kind = get_r2r_kind<D>(bc);
611 defined = true;
612 pf = (void*)p;
613 pb = (void*)p;
614
615 n = box.length(0);
616 howmany = AMREX_D_TERM(1, *box.length(1), *box.length(2));
617
618 defined2 = true;
619 auto fftw_kind = get_fftw_kind<D>(bc);
620 if constexpr (std::is_same_v<float,T>) {
621 plan = fftwf_plan_many_r2r
622 (1, &n, howmany, p, nullptr, 2, n*2, p, nullptr, 2, n*2, &fftw_kind,
623 FFTW_ESTIMATE);
624 plan2 = fftwf_plan_many_r2r
625 (1, &n, howmany, p+1, nullptr, 2, n*2, p+1, nullptr, 2, n*2, &fftw_kind,
626 FFTW_ESTIMATE);
627 } else {
628 plan = fftw_plan_many_r2r
629 (1, &n, howmany, p, nullptr, 2, n*2, p, nullptr, 2, n*2, &fftw_kind,
630 FFTW_ESTIMATE);
631 plan2 = fftw_plan_many_r2r
632 (1, &n, howmany, p+1, nullptr, 2, n*2, p+1, nullptr, 2, n*2, &fftw_kind,
633 FFTW_ESTIMATE);
634 }
635#endif
636 }
637
638 template <Direction D>
640 {
641 static_assert(D == Direction::forward || D == Direction::backward);
642 if (!defined) { return; }
643
644 using TI = std::conditional_t<(D == Direction::forward), T, VendorComplex>;
645 using TO = std::conditional_t<(D == Direction::backward), T, VendorComplex>;
646 auto* pi = (TI*)((D == Direction::forward) ? pf : pb);
647 auto* po = (TO*)((D == Direction::forward) ? pb : pf);
648
649#if defined(AMREX_USE_CUDA)
650 AMREX_CUFFT_SAFE_CALL(cufftSetStream(plan, Gpu::gpuStream()));
651
652 std::size_t work_size = 0;
653 AMREX_CUFFT_SAFE_CALL(cufftGetSize(plan, &work_size));
654
655 auto* work_area = The_Arena()->alloc(work_size);
656 AMREX_CUFFT_SAFE_CALL(cufftSetWorkArea(plan, work_area));
657
658 if constexpr (D == Direction::forward) {
659 if constexpr (std::is_same_v<float,T>) {
660 AMREX_CUFFT_SAFE_CALL(cufftExecR2C(plan, pi, po));
661 } else {
662 AMREX_CUFFT_SAFE_CALL(cufftExecD2Z(plan, pi, po));
663 }
664 } else {
665 if constexpr (std::is_same_v<float,T>) {
666 AMREX_CUFFT_SAFE_CALL(cufftExecC2R(plan, pi, po));
667 } else {
668 AMREX_CUFFT_SAFE_CALL(cufftExecZ2D(plan, pi, po));
669 }
670 }
672 The_Arena()->free(work_area);
673#elif defined(AMREX_USE_HIP)
674 detail::hip_execute(plan, (void**)&pi, (void**)&po);
675#elif defined(AMREX_USE_SYCL)
676 detail::sycl_execute<T,D>(std::get<0>(plan), pi, po);
677#else
679 if constexpr (std::is_same_v<float,T>) {
680 fftwf_execute(plan);
681 } else {
682 fftw_execute(plan);
683 }
684#endif
685 }
686
687 template <Direction D>
689 {
690 static_assert(D == Direction::forward || D == Direction::backward);
691 if (!defined) { return; }
692
693 auto* p = (VendorComplex*)pf;
694
695#if defined(AMREX_USE_CUDA)
696 AMREX_CUFFT_SAFE_CALL(cufftSetStream(plan, Gpu::gpuStream()));
697
698 std::size_t work_size = 0;
699 AMREX_CUFFT_SAFE_CALL(cufftGetSize(plan, &work_size));
700
701 auto* work_area = The_Arena()->alloc(work_size);
702 AMREX_CUFFT_SAFE_CALL(cufftSetWorkArea(plan, work_area));
703
704 auto dir = (D == Direction::forward) ? CUFFT_FORWARD : CUFFT_INVERSE;
705 if constexpr (std::is_same_v<float,T>) {
706 AMREX_CUFFT_SAFE_CALL(cufftExecC2C(plan, p, p, dir));
707 } else {
708 AMREX_CUFFT_SAFE_CALL(cufftExecZ2Z(plan, p, p, dir));
709 }
711 The_Arena()->free(work_area);
712#elif defined(AMREX_USE_HIP)
713 detail::hip_execute(plan, (void**)&p, (void**)&p);
714#elif defined(AMREX_USE_SYCL)
715 detail::sycl_execute<T,D>(std::get<1>(plan), p, p);
716#else
718 if constexpr (std::is_same_v<float,T>) {
719 fftwf_execute(plan);
720 } else {
721 fftw_execute(plan);
722 }
723#endif
724 }
725
726#ifdef AMREX_USE_GPU
727 [[nodiscard]] void* alloc_scratch_space () const
728 {
729 int nc = 0;
730 if (kind == Kind::r2r_oo_f || kind == Kind::r2r_ee_f) {
731 nc = n + 1;
732 } else if (kind == Kind::r2r_oo_b || kind == Kind::r2r_ee_b ||
734 nc = 2*n+1;
735 } else {
736 amrex::Abort("FFT: alloc_scratch_space: unsupported kind");
737 }
738 return The_Arena()->alloc(sizeof(GpuComplex<T>)*nc*howmany);
739 }
740
741 static void free_scratch_space (void* p) { The_Arena()->free(p); }
742
743 void pack_r2r_buffer (void* pbuf, T const* psrc) const
744 {
745 auto* pdst = (T*) pbuf;
746 if (kind == Kind::r2r_oo_f || kind == Kind::r2r_ee_f) {
747 T sign = (kind == Kind::r2r_oo_f) ? T(-1) : T(1);
748 int ostride = (n+1)*2;
749 int istride = n;
750 int nex = 2*n;
751 int norig = n;
752 Long nelems = Long(nex)*howmany;
754 ParallelFor(nelems/2, [=] AMREX_GPU_DEVICE (Long ielem)
755 {
756 auto batch = ielem / Long(nex);
757 auto i = int(ielem - batch*nex);
758 for (int ir = 0; ir < 2; ++ir) {
759 auto* po = pdst + (2*batch+ir)*ostride + i;
760 auto const* pi = psrc + 2*batch*istride + ir;
761 if (i < norig) {
762 *po = pi[i*2];
763 } else {
764 *po = sign * pi[(2*norig-1-i)*2];
765 }
766 }
767 });
768 } else {
769 ParallelFor(nelems, [=] AMREX_GPU_DEVICE (Long ielem)
770 {
771 auto batch = ielem / Long(nex);
772 auto i = int(ielem - batch*nex);
773 auto* po = pdst + batch*ostride + i;
774 auto const* pi = psrc + batch*istride;
775 if (i < norig) {
776 *po = pi[i];
777 } else {
778 *po = sign * pi[2*norig-1-i];
779 }
780 });
781 }
782 } else if (kind == Kind::r2r_oo_b) {
783 int ostride = (2*n+1)*2;
784 int istride = n;
785 int nex = 4*n;
786 int norig = n;
787 Long nelems = Long(nex)*howmany;
789 ParallelFor(nelems/2, [=] AMREX_GPU_DEVICE (Long ielem)
790 {
791 auto batch = ielem / Long(nex);
792 auto i = int(ielem - batch*nex);
793 for (int ir = 0; ir < 2; ++ir) {
794 auto* po = pdst + (2*batch+ir)*ostride + i;
795 auto const* pi = psrc + 2*batch*istride + ir;
796 if (i < norig) {
797 *po = pi[i*2];
798 } else if (i < (2*norig-1)) {
799 *po = pi[(2*norig-2-i)*2];
800 } else if (i == (2*norig-1)) {
801 *po = T(0);
802 } else if (i < (3*norig)) {
803 *po = -pi[(i-2*norig)*2];
804 } else if (i < (4*norig-1)) {
805 *po = -pi[(4*norig-2-i)*2];
806 } else {
807 *po = T(0);
808 }
809 }
810 });
811 } else {
812 ParallelFor(nelems, [=] AMREX_GPU_DEVICE (Long ielem)
813 {
814 auto batch = ielem / Long(nex);
815 auto i = int(ielem - batch*nex);
816 auto* po = pdst + batch*ostride + i;
817 auto const* pi = psrc + batch*istride;
818 if (i < norig) {
819 *po = pi[i];
820 } else if (i < (2*norig-1)) {
821 *po = pi[2*norig-2-i];
822 } else if (i == (2*norig-1)) {
823 *po = T(0);
824 } else if (i < (3*norig)) {
825 *po = -pi[i-2*norig];
826 } else if (i < (4*norig-1)) {
827 *po = -pi[4*norig-2-i];
828 } else {
829 *po = T(0);
830 }
831 });
832 }
833 } else if (kind == Kind::r2r_ee_b) {
834 int ostride = (2*n+1)*2;
835 int istride = n;
836 int nex = 4*n;
837 int norig = n;
838 Long nelems = Long(nex)*howmany;
840 ParallelFor(nelems/2, [=] AMREX_GPU_DEVICE (Long ielem)
841 {
842 auto batch = ielem / Long(nex);
843 auto i = int(ielem - batch*nex);
844 for (int ir = 0; ir < 2; ++ir) {
845 auto* po = pdst + (2*batch+ir)*ostride + i;
846 auto const* pi = psrc + 2*batch*istride + ir;
847 if (i < norig) {
848 *po = pi[i*2];
849 } else if (i == norig) {
850 *po = T(0);
851 } else if (i < (2*norig+1)) {
852 *po = -pi[(2*norig-i)*2];
853 } else if (i < (3*norig)) {
854 *po = -pi[(i-2*norig)*2];
855 } else if (i == 3*norig) {
856 *po = T(0);
857 } else {
858 *po = pi[(4*norig-i)*2];
859 }
860 }
861 });
862 } else {
863 ParallelFor(nelems, [=] AMREX_GPU_DEVICE (Long ielem)
864 {
865 auto batch = ielem / Long(nex);
866 auto i = int(ielem - batch*nex);
867 auto* po = pdst + batch*ostride + i;
868 auto const* pi = psrc + batch*istride;
869 if (i < norig) {
870 *po = pi[i];
871 } else if (i == norig) {
872 *po = T(0);
873 } else if (i < (2*norig+1)) {
874 *po = -pi[2*norig-i];
875 } else if (i < (3*norig)) {
876 *po = -pi[i-2*norig];
877 } else if (i == 3*norig) {
878 *po = T(0);
879 } else {
880 *po = pi[4*norig-i];
881 }
882 });
883 }
884 } else if (kind == Kind::r2r_eo) {
885 int ostride = (2*n+1)*2;
886 int istride = n;
887 int nex = 4*n;
888 int norig = n;
889 Long nelems = Long(nex)*howmany;
891 ParallelFor(nelems/2, [=] AMREX_GPU_DEVICE (Long ielem)
892 {
893 auto batch = ielem / Long(nex);
894 auto i = int(ielem - batch*nex);
895 for (int ir = 0; ir < 2; ++ir) {
896 auto* po = pdst + (2*batch+ir)*ostride + i;
897 auto const* pi = psrc + 2*batch*istride + ir;
898 if (i < norig) {
899 *po = pi[i*2];
900 } else if (i < (2*norig)) {
901 *po = -pi[(2*norig-1-i)*2];
902 } else if (i < (3*norig)) {
903 *po = -pi[(i-2*norig)*2];
904 } else {
905 *po = pi[(4*norig-1-i)*2];
906 }
907 }
908 });
909 } else {
910 ParallelFor(nelems, [=] AMREX_GPU_DEVICE (Long ielem)
911 {
912 auto batch = ielem / Long(nex);
913 auto i = int(ielem - batch*nex);
914 auto* po = pdst + batch*ostride + i;
915 auto const* pi = psrc + batch*istride;
916 if (i < norig) {
917 *po = pi[i];
918 } else if (i < (2*norig)) {
919 *po = -pi[2*norig-1-i];
920 } else if (i < (3*norig)) {
921 *po = -pi[i-2*norig];
922 } else {
923 *po = pi[4*norig-1-i];
924 }
925 });
926 }
927 } else if (kind == Kind::r2r_oe) {
928 int ostride = (2*n+1)*2;
929 int istride = n;
930 int nex = 4*n;
931 int norig = n;
932 Long nelems = Long(nex)*howmany;
934 ParallelFor(nelems/2, [=] AMREX_GPU_DEVICE (Long ielem)
935 {
936 auto batch = ielem / Long(nex);
937 auto i = int(ielem - batch*nex);
938 for (int ir = 0; ir < 2; ++ir) {
939 auto* po = pdst + (2*batch+ir)*ostride + i;
940 auto const* pi = psrc + 2*batch*istride + ir;
941 if (i < norig) {
942 *po = pi[i*2];
943 } else if (i < (2*norig)) {
944 *po = pi[(2*norig-1-i)*2];
945 } else if (i < (3*norig)) {
946 *po = -pi[(i-2*norig)*2];
947 } else {
948 *po = -pi[(4*norig-1-i)*2];
949 }
950 }
951 });
952 } else {
953 ParallelFor(nelems, [=] AMREX_GPU_DEVICE (Long ielem)
954 {
955 auto batch = ielem / Long(nex);
956 auto i = int(ielem - batch*nex);
957 auto* po = pdst + batch*ostride + i;
958 auto const* pi = psrc + batch*istride;
959 if (i < norig) {
960 *po = pi[i];
961 } else if (i < (2*norig)) {
962 *po = pi[2*norig-1-i];
963 } else if (i < (3*norig)) {
964 *po = -pi[i-2*norig];
965 } else {
966 *po = -pi[4*norig-1-i];
967 }
968 });
969 }
970 } else {
971 amrex::Abort("FFT: pack_r2r_buffer: unsupported kind");
972 }
973 }
974
975 void unpack_r2r_buffer (T* pdst, void const* pbuf) const
976 {
977 auto const* psrc = (GpuComplex<T> const*) pbuf;
978 int norig = n;
979 Long nelems = Long(norig)*howmany;
980 int ostride = n;
981
982 if (kind == Kind::r2r_oo_f) {
983 int istride = n+1;
985 ParallelFor(nelems/2, [=] AMREX_GPU_DEVICE (Long ielem)
986 {
987 auto batch = ielem / Long(norig);
988 auto k = int(ielem - batch*norig);
989 auto [s, c] = Math::sincospi(T(k+1)/T(2*norig));
990 for (int ir = 0; ir < 2; ++ir) {
991 auto const& yk = psrc[(2*batch+ir)*istride+k+1];
992 pdst[2*batch*ostride+ir+k*2] = s * yk.real() - c * yk.imag();
993 }
994 });
995 } else {
996 ParallelFor(nelems, [=] AMREX_GPU_DEVICE (Long ielem)
997 {
998 auto batch = ielem / Long(norig);
999 auto k = int(ielem - batch*norig);
1000 auto [s, c] = Math::sincospi(T(k+1)/T(2*norig));
1001 auto const& yk = psrc[batch*istride+k+1];
1002 pdst[batch*ostride+k] = s * yk.real() - c * yk.imag();
1003 });
1004 }
1005 } else if (kind == Kind::r2r_oo_b) {
1006 int istride = 2*n+1;
1007 if (r2r_data_is_complex) {
1008 ParallelFor(nelems/2, [=] AMREX_GPU_DEVICE (Long ielem)
1009 {
1010 auto batch = ielem / Long(norig);
1011 auto k = int(ielem - batch*norig);
1012 auto [s, c] = Math::sincospi(T(2*k+1)/T(2*norig));
1013 for (int ir = 0; ir < 2; ++ir) {
1014 auto const& yk = psrc[(2*batch+ir)*istride+2*k+1];
1015 pdst[2*batch*ostride+ir+k*2] = T(0.5)*(s * yk.real() - c * yk.imag());
1016 }
1017 });
1018 } else {
1019 ParallelFor(nelems, [=] AMREX_GPU_DEVICE (Long ielem)
1020 {
1021 auto batch = ielem / Long(norig);
1022 auto k = int(ielem - batch*norig);
1023 auto [s, c] = Math::sincospi(T(2*k+1)/T(2*norig));
1024 auto const& yk = psrc[batch*istride+2*k+1];
1025 pdst[batch*ostride+k] = T(0.5)*(s * yk.real() - c * yk.imag());
1026 });
1027 }
1028 } else if (kind == Kind::r2r_ee_f) {
1029 int istride = n+1;
1030 if (r2r_data_is_complex) {
1031 ParallelFor(nelems/2, [=] AMREX_GPU_DEVICE (Long ielem)
1032 {
1033 auto batch = ielem / Long(norig);
1034 auto k = int(ielem - batch*norig);
1035 auto [s, c] = Math::sincospi(T(k)/T(2*norig));
1036 for (int ir = 0; ir < 2; ++ir) {
1037 auto const& yk = psrc[(2*batch+ir)*istride+k];
1038 pdst[2*batch*ostride+ir+k*2] = c * yk.real() + s * yk.imag();
1039 }
1040 });
1041 } else {
1042 ParallelFor(nelems, [=] AMREX_GPU_DEVICE (Long ielem)
1043 {
1044 auto batch = ielem / Long(norig);
1045 auto k = int(ielem - batch*norig);
1046 auto [s, c] = Math::sincospi(T(k)/T(2*norig));
1047 auto const& yk = psrc[batch*istride+k];
1048 pdst[batch*ostride+k] = c * yk.real() + s * yk.imag();
1049 });
1050 }
1051 } else if (kind == Kind::r2r_ee_b) {
1052 int istride = 2*n+1;
1053 if (r2r_data_is_complex) {
1054 ParallelFor(nelems/2, [=] AMREX_GPU_DEVICE (Long ielem)
1055 {
1056 auto batch = ielem / Long(norig);
1057 auto k = int(ielem - batch*norig);
1058 for (int ir = 0; ir < 2; ++ir) {
1059 auto const& yk = psrc[(2*batch+ir)*istride+2*k+1];
1060 pdst[2*batch*ostride+ir+k*2] = T(0.5) * yk.real();
1061 }
1062 });
1063 } else {
1064 ParallelFor(nelems, [=] AMREX_GPU_DEVICE (Long ielem)
1065 {
1066 auto batch = ielem / Long(norig);
1067 auto k = int(ielem - batch*norig);
1068 auto const& yk = psrc[batch*istride+2*k+1];
1069 pdst[batch*ostride+k] = T(0.5) * yk.real();
1070 });
1071 }
1072 } else if (kind == Kind::r2r_eo) {
1073 int istride = 2*n+1;
1074 if (r2r_data_is_complex) {
1075 ParallelFor(nelems/2, [=] AMREX_GPU_DEVICE (Long ielem)
1076 {
1077 auto batch = ielem / Long(norig);
1078 auto k = int(ielem - batch*norig);
1079 auto [s, c] = Math::sincospi((k+T(0.5))/T(2*norig));
1080 for (int ir = 0; ir < 2; ++ir) {
1081 auto const& yk = psrc[(2*batch+ir)*istride+2*k+1];
1082 pdst[2*batch*ostride+ir+k*2] = T(0.5) * (c * yk.real() + s * yk.imag());
1083 }
1084 });
1085 } else {
1086 ParallelFor(nelems, [=] AMREX_GPU_DEVICE (Long ielem)
1087 {
1088 auto batch = ielem / Long(norig);
1089 auto k = int(ielem - batch*norig);
1090 auto [s, c] = Math::sincospi((k+T(0.5))/T(2*norig));
1091 auto const& yk = psrc[batch*istride+2*k+1];
1092 pdst[batch*ostride+k] = T(0.5) * (c * yk.real() + s * yk.imag());
1093 });
1094 }
1095 } else if (kind == Kind::r2r_oe) {
1096 int istride = 2*n+1;
1097 if (r2r_data_is_complex) {
1098 ParallelFor(nelems/2, [=] AMREX_GPU_DEVICE (Long ielem)
1099 {
1100 auto batch = ielem / Long(norig);
1101 auto k = int(ielem - batch*norig);
1102 auto [s, c] = Math::sincospi((k+T(0.5))/T(2*norig));
1103 for (int ir = 0; ir < 2; ++ir) {
1104 auto const& yk = psrc[(2*batch+ir)*istride+2*k+1];
1105 pdst[2*batch*ostride+ir+k*2] = T(0.5) * (s * yk.real() - c * yk.imag());
1106 }
1107 });
1108 } else {
1109 ParallelFor(nelems, [=] AMREX_GPU_DEVICE (Long ielem)
1110 {
1111 auto batch = ielem / Long(norig);
1112 auto k = int(ielem - batch*norig);
1113 auto [s, c] = Math::sincospi((k+T(0.5))/T(2*norig));
1114 auto const& yk = psrc[batch*istride+2*k+1];
1115 pdst[batch*ostride+k] = T(0.5) * (s * yk.real() - c * yk.imag());
1116 });
1117 }
1118 } else {
1119 amrex::Abort("FFT: unpack_r2r_buffer: unsupported kind");
1120 }
1121 }
1122#endif
1123
1124 template <Direction D>
1126 {
1127 static_assert(D == Direction::forward || D == Direction::backward);
1128 if (!defined) { return; }
1129
1130#if defined(AMREX_USE_GPU)
1131
1132 auto* pscratch = alloc_scratch_space();
1133
1134 pack_r2r_buffer(pscratch, (T*)((D == Direction::forward) ? pf : pb));
1135
1136#if defined(AMREX_USE_CUDA)
1137
1138 AMREX_CUFFT_SAFE_CALL(cufftSetStream(plan, Gpu::gpuStream()));
1139
1140 std::size_t work_size = 0;
1141 AMREX_CUFFT_SAFE_CALL(cufftGetSize(plan, &work_size));
1142
1143 auto* work_area = The_Arena()->alloc(work_size);
1144 AMREX_CUFFT_SAFE_CALL(cufftSetWorkArea(plan, work_area));
1145
1146 if constexpr (std::is_same_v<float,T>) {
1147 AMREX_CUFFT_SAFE_CALL(cufftExecR2C(plan, (T*)pscratch, (VendorComplex*)pscratch));
1148 } else {
1149 AMREX_CUFFT_SAFE_CALL(cufftExecD2Z(plan, (T*)pscratch, (VendorComplex*)pscratch));
1150 }
1151
1152#elif defined(AMREX_USE_HIP)
1153 detail::hip_execute(plan, (void**)&pscratch, (void**)&pscratch);
1154#elif defined(AMREX_USE_SYCL)
1155 detail::sycl_execute<T,Direction::forward>(std::get<0>(plan), (T*)pscratch, (VendorComplex*)pscratch);
1156#endif
1157
1158 unpack_r2r_buffer((T*)((D == Direction::forward) ? pb : pf), pscratch);
1159
1161 free_scratch_space(pscratch);
1162#if defined(AMREX_USE_CUDA)
1163 The_Arena()->free(work_area);
1164#endif
1165
1166#else /* FFTW */
1167
1168 if constexpr (std::is_same_v<float,T>) {
1169 fftwf_execute(plan);
1170 if (defined2) { fftwf_execute(plan2); }
1171 } else {
1172 fftw_execute(plan);
1173 if (defined2) { fftw_execute(plan2); }
1174 }
1175
1176#endif
1177 }
1178
1180 {
1181#if defined(AMREX_USE_CUDA)
1182 AMREX_CUFFT_SAFE_CALL(cufftDestroy(plan));
1183#elif defined(AMREX_USE_HIP)
1184 AMREX_ROCFFT_SAFE_CALL(rocfft_plan_destroy(plan));
1185#elif defined(AMREX_USE_SYCL)
1186 std::visit([](auto&& p) { delete p; }, plan);
1187#else
1188 if constexpr (std::is_same_v<float,T>) {
1189 fftwf_destroy_plan(plan);
1190 } else {
1191 fftw_destroy_plan(plan);
1192 }
1193#endif
1194 }
1195};
1196
1197using Key = std::tuple<IntVectND<3>,int,Direction,Kind>;
1200
1201PlanD* get_vendor_plan_d (Key const& key);
1202PlanF* get_vendor_plan_f (Key const& key);
1203
1204void add_vendor_plan_d (Key const& key, PlanD plan);
1205void add_vendor_plan_f (Key const& key, PlanF plan);
1206
1207template <typename T>
1208template <Direction D, int M>
1209void Plan<T>::init_r2c (IntVectND<M> const& fft_size, void* pbf, void* pbb, bool cache, int ncomp)
1210{
1211 static_assert(D == Direction::forward || D == Direction::backward);
1212
1213 kind = (D == Direction::forward) ? Kind::r2c_f : Kind::r2c_b;
1214 defined = true;
1215 pf = pbf;
1216 pb = pbb;
1217
1218 n = 1;
1219 for (auto s : fft_size) { n *= s; }
1220 howmany = ncomp;
1221
1222#if defined(AMREX_USE_GPU)
1223 Key key = {fft_size.template expand<3>(), ncomp, D, kind};
1224 if (cache) {
1225 VendorPlan* cached_plan = nullptr;
1226 if constexpr (std::is_same_v<float,T>) {
1227 cached_plan = get_vendor_plan_f(key);
1228 } else {
1229 cached_plan = get_vendor_plan_d(key);
1230 }
1231 if (cached_plan) {
1232 plan = *cached_plan;
1233 return;
1234 }
1235 }
1236#else
1237 amrex::ignore_unused(cache);
1238#endif
1239
1240 int len[M];
1241 for (int i = 0; i < M; ++i) {
1242 len[i] = fft_size[M-1-i];
1243 }
1244
1245 int nc = fft_size[0]/2+1;
1246 for (int i = 1; i < M; ++i) {
1247 nc *= fft_size[i];
1248 }
1249
1250#if defined(AMREX_USE_CUDA)
1251
1252 AMREX_CUFFT_SAFE_CALL(cufftCreate(&plan));
1253 AMREX_CUFFT_SAFE_CALL(cufftSetAutoAllocation(plan, 0));
1254 cufftType type;
1255 int n_in, n_out;
1256 if constexpr (D == Direction::forward) {
1257 type = std::is_same_v<float,T> ? CUFFT_R2C : CUFFT_D2Z;
1258 n_in = n;
1259 n_out = nc;
1260 } else {
1261 type = std::is_same_v<float,T> ? CUFFT_C2R : CUFFT_Z2D;
1262 n_in = nc;
1263 n_out = n;
1264 }
1265 std::size_t work_size;
1267 (cufftMakePlanMany(plan, M, len, nullptr, 1, n_in, nullptr, 1, n_out, type, howmany, &work_size));
1268
1269#elif defined(AMREX_USE_HIP)
1270
1271 auto prec = std::is_same_v<float,T> ? rocfft_precision_single : rocfft_precision_double;
1272 std::size_t length[M];
1273 for (int idim = 0; idim < M; ++idim) { length[idim] = fft_size[idim]; }
1274 if constexpr (D == Direction::forward) {
1275 AMREX_ROCFFT_SAFE_CALL
1276 (rocfft_plan_create(&plan, rocfft_placement_notinplace,
1277 rocfft_transform_type_real_forward, prec, M,
1278 length, howmany, nullptr));
1279 } else {
1280 AMREX_ROCFFT_SAFE_CALL
1281 (rocfft_plan_create(&plan, rocfft_placement_notinplace,
1282 rocfft_transform_type_real_inverse, prec, M,
1283 length, howmany, nullptr));
1284 }
1285
1286#elif defined(AMREX_USE_SYCL)
1287
1288 mkl_desc_r* pp;
1289 if (M == 1) {
1290 pp = new mkl_desc_r(fft_size[0]);
1291 } else {
1292 std::vector<std::int64_t> len64(M);
1293 for (int idim = 0; idim < M; ++idim) {
1294 len64[idim] = len[idim];
1295 }
1296 pp = new mkl_desc_r(len64);
1297 }
1298#ifndef AMREX_USE_MKL_DFTI_2024
1299 pp->set_value(oneapi::mkl::dft::config_param::PLACEMENT,
1300 oneapi::mkl::dft::config_value::NOT_INPLACE);
1301#else
1302 pp->set_value(oneapi::mkl::dft::config_param::PLACEMENT, DFTI_NOT_INPLACE);
1303#endif
1304 pp->set_value(oneapi::mkl::dft::config_param::NUMBER_OF_TRANSFORMS, howmany);
1305 pp->set_value(oneapi::mkl::dft::config_param::FWD_DISTANCE, n);
1306 pp->set_value(oneapi::mkl::dft::config_param::BWD_DISTANCE, nc);
1307 std::vector<std::int64_t> strides(M+1);
1308 strides[0] = 0;
1309 strides[M] = 1;
1310 for (int i = M-1; i >= 1; --i) {
1311 strides[i] = strides[i+1] * fft_size[M-1-i];
1312 }
1313
1314#ifndef AMREX_USE_MKL_DFTI_2024
1315 pp->set_value(oneapi::mkl::dft::config_param::FWD_STRIDES, strides);
1316 // Do not set BWD_STRIDES
1317#else
1318 pp->set_value(oneapi::mkl::dft::config_param::FWD_STRIDES, strides.data());
1319 // Do not set BWD_STRIDES
1320#endif
1321 pp->set_value(oneapi::mkl::dft::config_param::WORKSPACE,
1322 oneapi::mkl::dft::config_value::WORKSPACE_EXTERNAL);
1323 pp->commit(amrex::Gpu::Device::streamQueue());
1324 plan = pp;
1325
1326#else /* FFTW */
1327
1328 if (pf == nullptr || pb == nullptr) {
1329 defined = false;
1330 return;
1331 }
1332
1333 if constexpr (std::is_same_v<float,T>) {
1334 if constexpr (D == Direction::forward) {
1335 plan = fftwf_plan_many_dft_r2c
1336 (M, len, howmany, (float*)pf, nullptr, 1, n, (fftwf_complex*)pb, nullptr, 1, nc,
1337 FFTW_ESTIMATE);
1338 } else {
1339 plan = fftwf_plan_many_dft_c2r
1340 (M, len, howmany, (fftwf_complex*)pb, nullptr, 1, nc, (float*)pf, nullptr, 1, n,
1341 FFTW_ESTIMATE);
1342 }
1343 } else {
1344 if constexpr (D == Direction::forward) {
1345 plan = fftw_plan_many_dft_r2c
1346 (M, len, howmany, (double*)pf, nullptr, 1, n, (fftw_complex*)pb, nullptr, 1, nc,
1347 FFTW_ESTIMATE);
1348 } else {
1349 plan = fftw_plan_many_dft_c2r
1350 (M, len, howmany, (fftw_complex*)pb, nullptr, 1, nc, (double*)pf, nullptr, 1, n,
1351 FFTW_ESTIMATE);
1352 }
1353 }
1354#endif
1355
1356#if defined(AMREX_USE_GPU)
1357 if (cache) {
1358 if constexpr (std::is_same_v<float,T>) {
1359 add_vendor_plan_f(key, plan);
1360 } else {
1361 add_vendor_plan_d(key, plan);
1362 }
1363 }
1364#endif
1365}
1366
1367namespace detail
1368{
1370
1371 template <typename FA>
1372 typename FA::FABType::value_type * get_fab (FA& fa)
1373 {
1374 auto myproc = ParallelContext::MyProcSub();
1375 if (myproc < fa.size()) {
1376 return fa.fabPtr(myproc);
1377 } else {
1378 return nullptr;
1379 }
1380 }
1381
1382 template <typename FA1, typename FA2>
1383 std::unique_ptr<char,DataDeleter> make_mfs_share (FA1& fa1, FA2& fa2)
1384 {
1385 bool not_same_fa = true;
1386 if constexpr (std::is_same_v<FA1,FA2>) {
1387 not_same_fa = (&fa1 != &fa2);
1388 }
1389 using FAB1 = typename FA1::FABType::value_type;
1390 using FAB2 = typename FA2::FABType::value_type;
1391 using T1 = typename FAB1::value_type;
1392 using T2 = typename FAB2::value_type;
1393 auto myproc = ParallelContext::MyProcSub();
1394 bool alloc_1 = (myproc < fa1.size());
1395 bool alloc_2 = (myproc < fa2.size()) && not_same_fa;
1396 void* p = nullptr;
1397 if (alloc_1 && alloc_2) {
1398 Box const& box1 = fa1.fabbox(myproc);
1399 Box const& box2 = fa2.fabbox(myproc);
1400 int ncomp1 = fa1.nComp();
1401 int ncomp2 = fa2.nComp();
1402 p = The_Arena()->alloc(std::max(sizeof(T1)*box1.numPts()*ncomp1,
1403 sizeof(T2)*box2.numPts()*ncomp2));
1404 fa1.setFab(myproc, FAB1(box1, ncomp1, (T1*)p));
1405 fa2.setFab(myproc, FAB2(box2, ncomp2, (T2*)p));
1406 } else if (alloc_1) {
1407 Box const& box1 = fa1.fabbox(myproc);
1408 int ncomp1 = fa1.nComp();
1409 p = The_Arena()->alloc(sizeof(T1)*box1.numPts()*ncomp1);
1410 fa1.setFab(myproc, FAB1(box1, ncomp1, (T1*)p));
1411 } else if (alloc_2) {
1412 Box const& box2 = fa2.fabbox(myproc);
1413 int ncomp2 = fa2.nComp();
1414 p = The_Arena()->alloc(sizeof(T2)*box2.numPts()*ncomp2);
1415 fa2.setFab(myproc, FAB2(box2, ncomp2, (T2*)p));
1416 } else {
1417 return nullptr;
1418 }
1419 return std::unique_ptr<char,DataDeleter>((char*)p, DataDeleter{The_Arena()});
1420 }
1421}
1422
1424{
1425 [[nodiscard]] constexpr Dim3 operator() (Dim3 i) const noexcept
1426 {
1427 return {i.y, i.x, i.z};
1428 }
1429
1430 static constexpr Dim3 Inverse (Dim3 i)
1431 {
1432 return {i.y, i.x, i.z};
1433 }
1434
1435 [[nodiscard]] constexpr IndexType operator() (IndexType it) const noexcept
1436 {
1437 return it;
1438 }
1439
1440 static constexpr IndexType Inverse (IndexType it)
1441 {
1442 return it;
1443 }
1444};
1445
1447{
1448 [[nodiscard]] constexpr Dim3 operator() (Dim3 i) const noexcept
1449 {
1450 return {i.z, i.y, i.x};
1451 }
1452
1453 static constexpr Dim3 Inverse (Dim3 i)
1454 {
1455 return {i.z, i.y, i.x};
1456 }
1457
1458 [[nodiscard]] constexpr IndexType operator() (IndexType it) const noexcept
1459 {
1460 return it;
1461 }
1462
1463 static constexpr IndexType Inverse (IndexType it)
1464 {
1465 return it;
1466 }
1467};
1468
1470{
1471 // dest -> src: (x,y,z) -> (y,z,x)
1472 [[nodiscard]] constexpr Dim3 operator() (Dim3 i) const noexcept
1473 {
1474 return {i.y, i.z, i.x};
1475 }
1476
1477 // src -> dest: (x,y,z) -> (z,x,y)
1478 static constexpr Dim3 Inverse (Dim3 i)
1479 {
1480 return {i.z, i.x, i.y};
1481 }
1482
1483 [[nodiscard]] constexpr IndexType operator() (IndexType it) const noexcept
1484 {
1485 return it;
1486 }
1487
1488 static constexpr IndexType Inverse (IndexType it)
1489 {
1490 return it;
1491 }
1492};
1493
1495{
1496 // dest -> src: (x,y,z) -> (z,x,y)
1497 [[nodiscard]] constexpr Dim3 operator() (Dim3 i) const noexcept
1498 {
1499 return {i.z, i.x, i.y};
1500 }
1501
1502 // src -> dest: (x,y,z) -> (y,z,x)
1503 static constexpr Dim3 Inverse (Dim3 i)
1504 {
1505 return {i.y, i.z, i.x};
1506 }
1507
1508 [[nodiscard]] constexpr IndexType operator() (IndexType it) const noexcept
1509 {
1510 return it;
1511 }
1512
1513 static constexpr IndexType Inverse (IndexType it)
1514 {
1515 return it;
1516 }
1517};
1518
1519namespace detail
1520{
1522 {
1523 explicit SubHelper (Box const& domain);
1524
1525 [[nodiscard]] Box make_box (Box const& box) const;
1526
1527 [[nodiscard]] Periodicity make_periodicity (Periodicity const& period) const;
1528
1529 [[nodiscard]] bool ghost_safe (IntVect const& ng) const;
1530
1531 // This rearranges the order.
1532 [[nodiscard]] IntVect make_iv (IntVect const& iv) const;
1533
1534 // This keeps the order, but zero out the values in the hidden dimension.
1535 [[nodiscard]] IntVect make_safe_ghost (IntVect const& ng) const;
1536
1537 [[nodiscard]] BoxArray inverse_boxarray (BoxArray const& ba) const;
1538
1539 [[nodiscard]] IntVect inverse_order (IntVect const& order) const;
1540
1541 template <typename T>
1542 [[nodiscard]] T make_array (T const& a) const
1543 {
1544#if (AMREX_SPACEDIM == 1)
1546 return a;
1547#elif (AMREX_SPACEDIM == 2)
1548 if (m_case == case_1n) {
1549 return T{a[1],a[0]};
1550 } else {
1551 return a;
1552 }
1553#else
1554 if (m_case == case_11n) {
1555 return T{a[2],a[0],a[1]};
1556 } else if (m_case == case_1n1) {
1557 return T{a[1],a[0],a[2]};
1558 } else if (m_case == case_1nn) {
1559 return T{a[1],a[2],a[0]};
1560 } else if (m_case == case_n1n) {
1561 return T{a[0],a[2],a[1]};
1562 } else {
1563 return a;
1564 }
1565#endif
1566 }
1567
1568 [[nodiscard]] GpuArray<int,3> xyz_order () const;
1569
1570 template <typename FA>
1571 FA make_alias_mf (FA const& mf)
1572 {
1573 BoxList bl = mf.boxArray().boxList();
1574 for (auto& b : bl) {
1575 b = make_box(b);
1576 }
1577 auto const& ng = make_iv(mf.nGrowVect());
1578 FA submf(BoxArray(std::move(bl)), mf.DistributionMap(), mf.nComp(), ng, MFInfo{}.SetAlloc(false));
1579 using FAB = typename FA::fab_type;
1580 for (MFIter mfi(submf, MFItInfo().DisableDeviceSync()); mfi.isValid(); ++mfi) {
1581 submf.setFab(mfi, FAB(mfi.fabbox(), mf.nComp(), mf[mfi].dataPtr()));
1582 }
1583 return submf;
1584 }
1585
1586#if (AMREX_SPACEDIM == 2)
1587 enum Case { case_1n, case_other };
1588 int m_case = case_other;
1589#elif (AMREX_SPACEDIM == 3)
1590 enum Case { case_11n, case_1n1, case_1nn, case_n1n, case_other };
1591 int m_case = case_other;
1592#endif
1593 };
1594}
1595
1596}
1597
1598#endif
#define AMREX_ENUM(CLASS,...)
Definition AMReX_Enum.H:133
#define AMREX_CUFFT_SAFE_CALL(call)
Definition AMReX_GpuError.H:92
#define AMREX_GPU_DEVICE
Definition AMReX_GpuQualifiers.H:18
amrex::ParmParse pp
Input file parser instance for the given namespace.
Definition AMReX_HypreIJIface.cpp:15
Real * pdst
Definition AMReX_HypreMLABecLap.cpp:1090
#define AMREX_D_TERM(a, b, c)
Definition AMReX_SPACE.H:129
virtual void free(void *pt)=0
A pure virtual function for deleting the arena pointed to by pt.
virtual void * alloc(std::size_t sz)=0
A collection of Boxes stored in an Array.
Definition AMReX_BoxArray.H:550
A class for managing a List of Boxes that share a common IndexType. This class implements operations ...
Definition AMReX_BoxList.H:52
AMREX_GPU_HOST_DEVICE IntVectND< dim > length() const noexcept
Return the length of the BoxND.
Definition AMReX_Box.H:146
AMREX_GPU_HOST_DEVICE Long numPts() const noexcept
Returns the number of points contained in the BoxND.
Definition AMReX_Box.H:346
Calculates the distribution of FABs to MPI processes.
Definition AMReX_DistributionMapping.H:41
Definition AMReX_IntVect.H:48
Definition AMReX_MFIter.H:57
bool isValid() const noexcept
Is the iterator valid i.e. is it associated with a FAB?
Definition AMReX_MFIter.H:141
This provides length of period for periodic domains. 0 means it is not periodic in that direction....
Definition AMReX_Periodicity.H:17
std::unique_ptr< char, DataDeleter > make_mfs_share(FA1 &fa1, FA2 &fa2)
Definition AMReX_FFT_Helper.H:1383
FA::FABType::value_type * get_fab(FA &fa)
Definition AMReX_FFT_Helper.H:1372
DistributionMapping make_iota_distromap(Long n)
Definition AMReX_FFT.cpp:88
Definition AMReX_FFT.cpp:7
Direction
Definition AMReX_FFT_Helper.H:48
void add_vendor_plan_f(Key const &key, PlanF plan)
Definition AMReX_FFT.cpp:78
DomainStrategy
Definition AMReX_FFT_Helper.H:50
typename Plan< float >::VendorPlan PlanF
Definition AMReX_FFT_Helper.H:1199
typename Plan< double >::VendorPlan PlanD
Definition AMReX_FFT_Helper.H:1198
void add_vendor_plan_d(Key const &key, PlanD plan)
Definition AMReX_FFT.cpp:73
Kind
Definition AMReX_FFT_Helper.H:54
PlanF * get_vendor_plan_f(Key const &key)
Definition AMReX_FFT.cpp:64
PlanD * get_vendor_plan_d(Key const &key)
Definition AMReX_FFT.cpp:55
std::tuple< IntVectND< 3 >, int, Direction, Kind > Key
Definition AMReX_FFT_Helper.H:1197
void streamSynchronize() noexcept
Definition AMReX_GpuDevice.H:237
gpuStream_t gpuStream() noexcept
Definition AMReX_GpuDevice.H:218
AMREX_GPU_HOST_DEVICE AMREX_FORCE_INLINE std::pair< double, double > sincospi(double x)
Return sin(pi*x) and cos(pi*x) given x.
Definition AMReX_Math.H:165
int MyProcSub() noexcept
my sub-rank in current frame
Definition AMReX_ParallelContext.H:76
std::enable_if_t< std::is_integral_v< T > > ParallelFor(TypeList< CTOs... > ctos, std::array< int, sizeof...(CTOs)> const &runtime_options, T N, F &&f)
Definition AMReX_CTOParallelForImpl.H:191
AMREX_GPU_HOST_DEVICE AMREX_FORCE_INLINE void ignore_unused(const Ts &...)
This shuts up the compiler about unused variables.
Definition AMReX.H:127
AMREX_GPU_HOST_DEVICE AMREX_FORCE_INLINE Dim3 length(Array4< T > const &a) noexcept
Definition AMReX_Array4.H:322
void Abort(const std::string &msg)
Print out message to cerr and exit via abort().
Definition AMReX.cpp:230
const int[]
Definition AMReX_BLProfiler.cpp:1664
Arena * The_Arena()
Definition AMReX_Arena.cpp:616
Definition AMReX_FabArrayCommI.H:896
Definition AMReX_DataAllocator.H:29
Definition AMReX_Dim3.H:12
int x
Definition AMReX_Dim3.H:12
int z
Definition AMReX_Dim3.H:12
int y
Definition AMReX_Dim3.H:12
Definition AMReX_FFT_Helper.H:58
bool twod_mode
Definition AMReX_FFT_Helper.H:69
Info & setNumProcs(int n)
Definition AMReX_FFT_Helper.H:81
int batch_size
Batched FFT size. Only support in R2C, not R2X.
Definition AMReX_FFT_Helper.H:72
Info & setDomainStrategy(DomainStrategy s)
Definition AMReX_FFT_Helper.H:77
DomainStrategy domain_strategy
Domain composition strategy.
Definition AMReX_FFT_Helper.H:60
int nprocs
Max number of processes to use.
Definition AMReX_FFT_Helper.H:75
int pencil_threshold
Definition AMReX_FFT_Helper.H:64
Info & setBatchSize(int bsize)
Definition AMReX_FFT_Helper.H:80
Info & setPencilThreshold(int t)
Definition AMReX_FFT_Helper.H:78
Info & setTwoDMode(bool x)
Definition AMReX_FFT_Helper.H:79
Definition AMReX_FFT_Helper.H:126
void * pf
Definition AMReX_FFT_Helper.H:161
void unpack_r2r_buffer(T *pdst, void const *pbuf) const
Definition AMReX_FFT_Helper.H:975
std::conditional_t< std::is_same_v< float, T >, cuComplex, cuDoubleComplex > VendorComplex
Definition AMReX_FFT_Helper.H:130
VendorPlan plan2
Definition AMReX_FFT_Helper.H:160
int n
Definition AMReX_FFT_Helper.H:153
void destroy()
Definition AMReX_FFT_Helper.H:171
bool defined2
Definition AMReX_FFT_Helper.H:158
void init_r2r(Box const &box, VendorComplex *pc, std::pair< Boundary, Boundary > const &bc)
Definition AMReX_FFT_Helper.H:596
void pack_r2r_buffer(void *pbuf, T const *psrc) const
Definition AMReX_FFT_Helper.H:743
static void free_scratch_space(void *p)
Definition AMReX_FFT_Helper.H:741
static void destroy_vendor_plan(VendorPlan plan)
Definition AMReX_FFT_Helper.H:1179
Kind get_r2r_kind(std::pair< Boundary, Boundary > const &bc)
Definition AMReX_FFT_Helper.H:474
cufftHandle VendorPlan
Definition AMReX_FFT_Helper.H:128
Kind kind
Definition AMReX_FFT_Helper.H:155
void init_c2c(Box const &box, VendorComplex *p, int ncomp=1, int ndims=1)
Definition AMReX_FFT_Helper.H:313
int howmany
Definition AMReX_FFT_Helper.H:154
void * pb
Definition AMReX_FFT_Helper.H:162
void init_r2c(Box const &box, T *pr, VendorComplex *pc, bool is_2d_transform=false, int ncomp=1)
Definition AMReX_FFT_Helper.H:186
void compute_r2r()
Definition AMReX_FFT_Helper.H:1125
void compute_c2c()
Definition AMReX_FFT_Helper.H:688
bool r2r_data_is_complex
Definition AMReX_FFT_Helper.H:156
void * alloc_scratch_space() const
Definition AMReX_FFT_Helper.H:727
VendorPlan plan
Definition AMReX_FFT_Helper.H:159
void compute_r2c()
Definition AMReX_FFT_Helper.H:639
bool defined
Definition AMReX_FFT_Helper.H:157
void set_ptrs(void *p0, void *p1)
Definition AMReX_FFT_Helper.H:165
void init_r2r(Box const &box, T *p, std::pair< Boundary, Boundary > const &bc, int howmany_initval=1)
Definition AMReX_FFT_Helper.H:500
void init_r2c(IntVectND< M > const &fft_size, void *, void *, bool cache, int ncomp=1)
Definition AMReX_FFT_Helper.H:1209
Definition AMReX_FFT_Helper.H:1495
constexpr Dim3 operator()(Dim3 i) const noexcept
Definition AMReX_FFT_Helper.H:1497
static constexpr IndexType Inverse(IndexType it)
Definition AMReX_FFT_Helper.H:1513
static constexpr Dim3 Inverse(Dim3 i)
Definition AMReX_FFT_Helper.H:1503
Definition AMReX_FFT_Helper.H:1470
static constexpr IndexType Inverse(IndexType it)
Definition AMReX_FFT_Helper.H:1488
constexpr Dim3 operator()(Dim3 i) const noexcept
Definition AMReX_FFT_Helper.H:1472
static constexpr Dim3 Inverse(Dim3 i)
Definition AMReX_FFT_Helper.H:1478
Definition AMReX_FFT_Helper.H:1424
constexpr Dim3 operator()(Dim3 i) const noexcept
Definition AMReX_FFT_Helper.H:1425
static constexpr IndexType Inverse(IndexType it)
Definition AMReX_FFT_Helper.H:1440
static constexpr Dim3 Inverse(Dim3 i)
Definition AMReX_FFT_Helper.H:1430
Definition AMReX_FFT_Helper.H:1447
static constexpr Dim3 Inverse(Dim3 i)
Definition AMReX_FFT_Helper.H:1453
static constexpr IndexType Inverse(IndexType it)
Definition AMReX_FFT_Helper.H:1463
constexpr Dim3 operator()(Dim3 i) const noexcept
Definition AMReX_FFT_Helper.H:1448
Definition AMReX_FFT_Helper.H:1522
T make_array(T const &a) const
Definition AMReX_FFT_Helper.H:1542
Box make_box(Box const &box) const
Definition AMReX_FFT.cpp:142
BoxArray inverse_boxarray(BoxArray const &ba) const
Definition AMReX_FFT.cpp:209
bool ghost_safe(IntVect const &ng) const
Definition AMReX_FFT.cpp:152
GpuArray< int, 3 > xyz_order() const
Definition AMReX_FFT.cpp:326
IntVect inverse_order(IntVect const &order) const
Definition AMReX_FFT.cpp:266
IntVect make_iv(IntVect const &iv) const
Definition AMReX_FFT.cpp:178
FA make_alias_mf(FA const &mf)
Definition AMReX_FFT_Helper.H:1571
Periodicity make_periodicity(Periodicity const &period) const
Definition AMReX_FFT.cpp:147
IntVect make_safe_ghost(IntVect const &ng) const
Definition AMReX_FFT.cpp:183
Definition AMReX_Array.H:34
A host / device complex number type, because std::complex doesn't work in device code with Cuda yet.
Definition AMReX_GpuComplex.H:29
FabArray memory allocation information.
Definition AMReX_FabArray.H:66
MFInfo & SetAlloc(bool a) noexcept
Definition AMReX_FabArray.H:73
Definition AMReX_MFIter.H:20