Block-Structured AMR Software Framework
 
Loading...
Searching...
No Matches
AMReX_FFT_Helper.H
Go to the documentation of this file.
1#ifndef AMREX_FFT_HELPER_H_
2#define AMREX_FFT_HELPER_H_
3#include <AMReX_Config.H>
4
5#include <AMReX.H>
6#include <AMReX_BLProfiler.H>
9#include <AMReX_Enum.H>
10#include <AMReX_FabArray.H>
11#include <AMReX_Gpu.H>
12#include <AMReX_GpuComplex.H>
13#include <AMReX_Math.H>
14#include <AMReX_Periodicity.H>
15
16#if defined(AMREX_USE_CUDA)
17# include <cufft.h>
18# include <cuComplex.h>
19#elif defined(AMREX_USE_HIP)
20# if __has_include(<rocfft/rocfft.h>) // ROCm 5.3+
21# include <rocfft/rocfft.h>
22# else
23# include <rocfft.h>
24# endif
25# include <hip/hip_complex.h>
26#elif defined(AMREX_USE_SYCL)
27# if __has_include(<oneapi/mkl/dft.hpp>) // oneAPI 2025.0
28# include <oneapi/mkl/dft.hpp>
29#else
30# define AMREX_USE_MKL_DFTI_2024 1
31# include <oneapi/mkl/dfti.hpp>
32# endif
33#else
34# include <fftw3.h>
35#endif
36
37#include <algorithm>
38#include <complex>
39#include <limits>
40#include <memory>
41#include <tuple>
42#include <utility>
43#include <variant>
44
45namespace amrex::FFT
46{
47
48enum struct Direction { forward, backward, both, none };
49
51
53
56
57struct Info
58{
61
65
69 bool twod_mode = false;
70
72 bool oned_mode = false;
73
75 int batch_size = 1;
76
78 int nprocs = std::numeric_limits<int>::max();
79
81 Info& setPencilThreshold (int t) { pencil_threshold = t; return *this; }
82 Info& setTwoDMode (bool x) { twod_mode = x; return *this; }
83 Info& setOneDMode (bool x) { oned_mode = x; return *this; }
84 Info& setBatchSize (int bsize) { batch_size = bsize; return *this; }
85 Info& setNumProcs (int n) { nprocs = n; return *this; }
86};
87
88#ifdef AMREX_USE_HIP
90namespace detail { void hip_execute (rocfft_plan plan, void **in, void **out); }
92#endif
93
94#ifdef AMREX_USE_SYCL
96namespace detail
97{
98template <typename T, Direction direction, typename P, typename TI, typename TO>
99void sycl_execute (P* plan, TI* in, TO* out)
100{
101#ifndef AMREX_USE_MKL_DFTI_2024
102 std::int64_t workspaceSize = 0;
103#else
104 std::size_t workspaceSize = 0;
105#endif
106 plan->get_value(oneapi::mkl::dft::config_param::WORKSPACE_BYTES,
107 &workspaceSize);
108 auto* buffer = (T*)amrex::The_Arena()->alloc(workspaceSize);
109 plan->set_workspace(buffer);
110 sycl::event r;
111 if (std::is_same_v<TI,TO>) {
113 if constexpr (direction == Direction::forward) {
114 r = oneapi::mkl::dft::compute_forward(*plan, out);
115 } else {
116 r = oneapi::mkl::dft::compute_backward(*plan, out);
117 }
118 } else {
119 if constexpr (direction == Direction::forward) {
120 r = oneapi::mkl::dft::compute_forward(*plan, in, out);
121 } else {
122 r = oneapi::mkl::dft::compute_backward(*plan, in, out);
123 }
124 }
125 r.wait();
126 amrex::The_Arena()->free(buffer);
127}
128}
130#endif
131
132template <typename T>
133struct Plan
134{
135#if defined(AMREX_USE_CUDA)
136 using VendorPlan = cufftHandle;
137 using VendorComplex = std::conditional_t<std::is_same_v<float,T>,
138 cuComplex, cuDoubleComplex>;
139#elif defined(AMREX_USE_HIP)
140 using VendorPlan = rocfft_plan;
141 using VendorComplex = std::conditional_t<std::is_same_v<float,T>,
142 float2, double2>;
143#elif defined(AMREX_USE_SYCL)
144 using mkl_desc_r = oneapi::mkl::dft::descriptor<std::is_same_v<float,T>
145 ? oneapi::mkl::dft::precision::SINGLE
146 : oneapi::mkl::dft::precision::DOUBLE,
147 oneapi::mkl::dft::domain::REAL>;
148 using mkl_desc_c = oneapi::mkl::dft::descriptor<std::is_same_v<float,T>
149 ? oneapi::mkl::dft::precision::SINGLE
150 : oneapi::mkl::dft::precision::DOUBLE,
151 oneapi::mkl::dft::domain::COMPLEX>;
152 using VendorPlan = std::variant<mkl_desc_r*,mkl_desc_c*>;
153 using VendorComplex = std::complex<T>;
154#else
155 using VendorPlan = std::conditional_t<std::is_same_v<float,T>,
156 fftwf_plan, fftw_plan>;
157 using VendorComplex = std::conditional_t<std::is_same_v<float,T>,
158 fftwf_complex, fftw_complex>;
159#endif
160
161 int n = 0;
162 int howmany = 0;
165 bool defined = false;
166 bool defined2 = false;
169 void* pf = nullptr;
170 void* pb = nullptr;
171
172#ifdef AMREX_USE_GPU
173 void set_ptrs (void* p0, void* p1) {
174 pf = p0;
175 pb = p1;
176 }
177#endif
178
179 void destroy ()
180 {
181 if (defined) {
183 defined = false;
184 }
185#if !defined(AMREX_USE_GPU)
186 if (defined2) {
188 defined2 = false;
189 }
190#endif
191 }
192
193 template <Direction D>
194 void init_r2c (Box const& box, T* pr, VendorComplex* pc, bool is_2d_transform = false, int ncomp = 1)
195 {
196 static_assert(D == Direction::forward || D == Direction::backward);
197
198 int rank = is_2d_transform ? 2 : 1;
199
201 defined = true;
202 pf = (void*)pr;
203 pb = (void*)pc;
204
205 int len[2] = {};
206 if (rank == 1) {
207 len[0] = box.length(0);
208 len[1] = box.length(0); // Not used except for HIP. Yes it's `(0)`.
209 } else {
210 len[0] = box.length(1); // Most FFT libraries assume row-major ordering
211 len[1] = box.length(0); // except for rocfft
212 }
213 int nr = (rank == 1) ? len[0] : len[0]*len[1];
214 n = nr;
215 int nc = (rank == 1) ? (len[0]/2+1) : (len[1]/2+1)*len[0];
216#if (AMREX_SPACEDIM == 1)
217 howmany = 1;
218#else
219 howmany = (rank == 1) ? AMREX_D_TERM(1, *box.length(1), *box.length(2))
220 : AMREX_D_TERM(1, *1 , *box.length(2));
221#endif
222 howmany *= ncomp;
223
225
226#if defined(AMREX_USE_CUDA)
227
228 AMREX_CUFFT_SAFE_CALL(cufftCreate(&plan));
229 AMREX_CUFFT_SAFE_CALL(cufftSetAutoAllocation(plan, 0));
230 std::size_t work_size;
231 if constexpr (D == Direction::forward) {
232 cufftType fwd_type = std::is_same_v<float,T> ? CUFFT_R2C : CUFFT_D2Z;
234 (cufftMakePlanMany(plan, rank, len, nullptr, 1, nr, nullptr, 1, nc, fwd_type, howmany, &work_size));
235 } else {
236 cufftType bwd_type = std::is_same_v<float,T> ? CUFFT_C2R : CUFFT_Z2D;
238 (cufftMakePlanMany(plan, rank, len, nullptr, 1, nc, nullptr, 1, nr, bwd_type, howmany, &work_size));
239 }
240
241#elif defined(AMREX_USE_HIP)
242
243 auto prec = std::is_same_v<float,T> ? rocfft_precision_single : rocfft_precision_double;
244 // switch to column-major ordering
245 std::size_t length[2] = {std::size_t(len[1]), std::size_t(len[0])};
246 if constexpr (D == Direction::forward) {
247 AMREX_ROCFFT_SAFE_CALL
248 (rocfft_plan_create(&plan, rocfft_placement_notinplace,
249 rocfft_transform_type_real_forward, prec, rank,
250 length, howmany, nullptr));
251 } else {
252 AMREX_ROCFFT_SAFE_CALL
253 (rocfft_plan_create(&plan, rocfft_placement_notinplace,
254 rocfft_transform_type_real_inverse, prec, rank,
255 length, howmany, nullptr));
256 }
257
258#elif defined(AMREX_USE_SYCL)
259
260 mkl_desc_r* pp;
261 if (rank == 1) {
262 pp = new mkl_desc_r(len[0]);
263 } else {
264 pp = new mkl_desc_r({std::int64_t(len[0]), std::int64_t(len[1])});
265 }
266#ifndef AMREX_USE_MKL_DFTI_2024
267 pp->set_value(oneapi::mkl::dft::config_param::PLACEMENT,
268 oneapi::mkl::dft::config_value::NOT_INPLACE);
269#else
270 pp->set_value(oneapi::mkl::dft::config_param::PLACEMENT, DFTI_NOT_INPLACE);
271#endif
272 pp->set_value(oneapi::mkl::dft::config_param::NUMBER_OF_TRANSFORMS, howmany);
273 pp->set_value(oneapi::mkl::dft::config_param::FWD_DISTANCE, nr);
274 pp->set_value(oneapi::mkl::dft::config_param::BWD_DISTANCE, nc);
275 std::vector<std::int64_t> strides;
276 strides.push_back(0);
277 if (rank == 2) { strides.push_back(len[1]); }
278 strides.push_back(1);
279#ifndef AMREX_USE_MKL_DFTI_2024
280 pp->set_value(oneapi::mkl::dft::config_param::FWD_STRIDES, strides);
281 // Do not set BWD_STRIDES
282#else
283 pp->set_value(oneapi::mkl::dft::config_param::FWD_STRIDES, strides.data());
284 // Do not set BWD_STRIDES
285#endif
286 pp->set_value(oneapi::mkl::dft::config_param::WORKSPACE,
287 oneapi::mkl::dft::config_value::WORKSPACE_EXTERNAL);
288 pp->commit(amrex::Gpu::Device::streamQueue());
289 plan = pp;
290
291#else /* FFTW */
292
293 if constexpr (std::is_same_v<float,T>) {
294 if constexpr (D == Direction::forward) {
295 plan = fftwf_plan_many_dft_r2c
296 (rank, len, howmany, pr, nullptr, 1, nr, pc, nullptr, 1, nc,
297 FFTW_ESTIMATE | FFTW_DESTROY_INPUT);
298 } else {
299 plan = fftwf_plan_many_dft_c2r
300 (rank, len, howmany, pc, nullptr, 1, nc, pr, nullptr, 1, nr,
301 FFTW_ESTIMATE | FFTW_DESTROY_INPUT);
302 }
303 } else {
304 if constexpr (D == Direction::forward) {
305 plan = fftw_plan_many_dft_r2c
306 (rank, len, howmany, pr, nullptr, 1, nr, pc, nullptr, 1, nc,
307 FFTW_ESTIMATE | FFTW_DESTROY_INPUT);
308 } else {
309 plan = fftw_plan_many_dft_c2r
310 (rank, len, howmany, pc, nullptr, 1, nc, pr, nullptr, 1, nr,
311 FFTW_ESTIMATE | FFTW_DESTROY_INPUT);
312 }
313 }
314#endif
315 }
316
317 template <Direction D, int M>
318 void init_r2c (IntVectND<M> const& fft_size, void*, void*, bool cache, int ncomp = 1);
319
320 template <Direction D>
321 void init_c2c (Box const& box, VendorComplex* p, int ncomp = 1, int ndims = 1)
322 {
323 static_assert(D == Direction::forward || D == Direction::backward);
324
326 defined = true;
327 pf = (void*)p;
328 pb = (void*)p;
329
330 int len[3] = {};
331
332 if (ndims == 1) {
333 n = box.length(0);
334 howmany = AMREX_D_TERM(1, *box.length(1), *box.length(2));
335 howmany *= ncomp;
336 len[0] = box.length(0);
337 }
338#if (AMREX_SPACEDIM >= 2)
339 else if (ndims == 2) {
340 n = box.length(0) * box.length(1);
341#if (AMREX_SPACEDIM == 2)
342 howmany = ncomp;
343#else
344 howmany = box.length(2) * ncomp;
345#endif
346 len[0] = box.length(1);
347 len[1] = box.length(0);
348 }
349#if (AMREX_SPACEDIM == 3)
350 else if (ndims == 3) {
351 n = box.length(0) * box.length(1) * box.length(2);
352 howmany = ncomp;
353 len[0] = box.length(2);
354 len[1] = box.length(1);
355 len[2] = box.length(0);
356 }
357#endif
358#endif
359
360#if defined(AMREX_USE_CUDA)
361 AMREX_CUFFT_SAFE_CALL(cufftCreate(&plan));
362 AMREX_CUFFT_SAFE_CALL(cufftSetAutoAllocation(plan, 0));
363
364 cufftType t = std::is_same_v<float,T> ? CUFFT_C2C : CUFFT_Z2Z;
365 std::size_t work_size;
367 (cufftMakePlanMany(plan, ndims, len, nullptr, 1, n, nullptr, 1, n, t, howmany, &work_size));
368
369#elif defined(AMREX_USE_HIP)
370
371 auto prec = std::is_same_v<float,T> ? rocfft_precision_single
372 : rocfft_precision_double;
373 auto dir= (D == Direction::forward) ? rocfft_transform_type_complex_forward
374 : rocfft_transform_type_complex_inverse;
375 std::size_t length[3];
376 if (ndims == 1) {
377 length[0] = len[0];
378 } else if (ndims == 2) {
379 length[0] = len[1];
380 length[1] = len[0];
381 } else {
382 length[0] = len[2];
383 length[1] = len[1];
384 length[2] = len[0];
385 }
386 AMREX_ROCFFT_SAFE_CALL
387 (rocfft_plan_create(&plan, rocfft_placement_inplace, dir, prec, ndims,
388 length, howmany, nullptr));
389
390#elif defined(AMREX_USE_SYCL)
391
392 mkl_desc_c* pp;
393 if (ndims == 1) {
394 pp = new mkl_desc_c(n);
395 } else if (ndims == 2) {
396 pp = new mkl_desc_c({std::int64_t(len[0]), std::int64_t(len[1])});
397 } else {
398 pp = new mkl_desc_c({std::int64_t(len[0]), std::int64_t(len[1]), std::int64_t(len[2])});
399 }
400#ifndef AMREX_USE_MKL_DFTI_2024
401 pp->set_value(oneapi::mkl::dft::config_param::PLACEMENT,
402 oneapi::mkl::dft::config_value::INPLACE);
403#else
404 pp->set_value(oneapi::mkl::dft::config_param::PLACEMENT, DFTI_INPLACE);
405#endif
406 pp->set_value(oneapi::mkl::dft::config_param::NUMBER_OF_TRANSFORMS, howmany);
407 pp->set_value(oneapi::mkl::dft::config_param::FWD_DISTANCE, n);
408 pp->set_value(oneapi::mkl::dft::config_param::BWD_DISTANCE, n);
409 std::vector<std::int64_t> strides(ndims+1);
410 strides[0] = 0;
411 strides[ndims] = 1;
412 for (int i = ndims-1; i >= 1; --i) {
413 strides[i] = strides[i+1] * len[i];
414 }
415#ifndef AMREX_USE_MKL_DFTI_2024
416 pp->set_value(oneapi::mkl::dft::config_param::FWD_STRIDES, strides);
417 pp->set_value(oneapi::mkl::dft::config_param::BWD_STRIDES, strides);
418#else
419 pp->set_value(oneapi::mkl::dft::config_param::FWD_STRIDES, strides.data());
420 pp->set_value(oneapi::mkl::dft::config_param::BWD_STRIDES, strides.data());
421#endif
422 pp->set_value(oneapi::mkl::dft::config_param::WORKSPACE,
423 oneapi::mkl::dft::config_value::WORKSPACE_EXTERNAL);
424 pp->commit(amrex::Gpu::Device::streamQueue());
425 plan = pp;
426
427#else /* FFTW */
428
429 if constexpr (std::is_same_v<float,T>) {
430 if constexpr (D == Direction::forward) {
431 plan = fftwf_plan_many_dft
432 (ndims, len, howmany, p, nullptr, 1, n, p, nullptr, 1, n, -1,
433 FFTW_ESTIMATE);
434 } else {
435 plan = fftwf_plan_many_dft
436 (ndims, len, howmany, p, nullptr, 1, n, p, nullptr, 1, n, +1,
437 FFTW_ESTIMATE);
438 }
439 } else {
440 if constexpr (D == Direction::forward) {
441 plan = fftw_plan_many_dft
442 (ndims, len, howmany, p, nullptr, 1, n, p, nullptr, 1, n, -1,
443 FFTW_ESTIMATE);
444 } else {
445 plan = fftw_plan_many_dft
446 (ndims, len, howmany, p, nullptr, 1, n, p, nullptr, 1, n, +1,
447 FFTW_ESTIMATE);
448 }
449 }
450#endif
451 }
452
453#ifndef AMREX_USE_GPU
454 template <Direction D>
455 fftw_r2r_kind get_fftw_kind (std::pair<Boundary,Boundary> const& bc)
456 {
457 if (bc.first == Boundary::even && bc.second == Boundary::even)
458 {
459 return (D == Direction::forward) ? FFTW_REDFT10 : FFTW_REDFT01;
460 }
461 else if (bc.first == Boundary::even && bc.second == Boundary::odd)
462 {
463 return FFTW_REDFT11;
464 }
465 else if (bc.first == Boundary::odd && bc.second == Boundary::even)
466 {
467 return FFTW_RODFT11;
468 }
469 else if (bc.first == Boundary::odd && bc.second == Boundary::odd)
470 {
471 return (D == Direction::forward) ? FFTW_RODFT10 : FFTW_RODFT01;
472 }
473 else {
474 amrex::Abort("FFT: unsupported BC");
475 return fftw_r2r_kind{};
476 }
477
478 }
479#endif
480
481 template <Direction D>
482 Kind get_r2r_kind (std::pair<Boundary,Boundary> const& bc)
483 {
484 if (bc.first == Boundary::even && bc.second == Boundary::even)
485 {
487 }
488 else if (bc.first == Boundary::even && bc.second == Boundary::odd)
489 {
490 return Kind::r2r_eo;
491 }
492 else if (bc.first == Boundary::odd && bc.second == Boundary::even)
493 {
494 return Kind::r2r_oe;
495 }
496 else if (bc.first == Boundary::odd && bc.second == Boundary::odd)
497 {
499 }
500 else {
501 amrex::Abort("FFT: unsupported BC");
502 return Kind::none;
503 }
504
505 }
506
507 template <Direction D>
508 void init_r2r (Box const& box, T* p, std::pair<Boundary,Boundary> const& bc,
509 int howmany_initval = 1)
510 {
511 static_assert(D == Direction::forward || D == Direction::backward);
512
513 kind = get_r2r_kind<D>(bc);
514 defined = true;
515 pf = (void*)p;
516 pb = (void*)p;
517
518 n = box.length(0);
519 howmany = AMREX_D_TERM(howmany_initval, *box.length(1), *box.length(2));
520
521#if defined(AMREX_USE_GPU)
522 int nex=0;
523 if (bc.first == Boundary::odd && bc.second == Boundary::odd &&
524 Direction::forward == D) {
525 nex = 2*n;
526 } else if (bc.first == Boundary::odd && bc.second == Boundary::odd &&
527 Direction::backward == D) {
528 nex = 4*n;
529 } else if (bc.first == Boundary::even && bc.second == Boundary::even &&
530 Direction::forward == D) {
531 nex = 2*n;
532 } else if (bc.first == Boundary::even && bc.second == Boundary::even &&
533 Direction::backward == D) {
534 nex = 4*n;
535 } else if ((bc.first == Boundary::even && bc.second == Boundary::odd) ||
536 (bc.first == Boundary::odd && bc.second == Boundary::even)) {
537 nex = 4*n;
538 } else {
539 amrex::Abort("FFT: unsupported BC");
540 }
541 int nc = (nex/2) + 1;
542
543#if defined (AMREX_USE_CUDA)
544
545 AMREX_CUFFT_SAFE_CALL(cufftCreate(&plan));
546 AMREX_CUFFT_SAFE_CALL(cufftSetAutoAllocation(plan, 0));
547 cufftType fwd_type = std::is_same_v<float,T> ? CUFFT_R2C : CUFFT_D2Z;
548 std::size_t work_size;
550 (cufftMakePlanMany(plan, 1, &nex, nullptr, 1, nc*2, nullptr, 1, nc, fwd_type, howmany, &work_size));
551
552#elif defined(AMREX_USE_HIP)
553
555 auto prec = std::is_same_v<float,T> ? rocfft_precision_single : rocfft_precision_double;
556 const std::size_t length = nex;
557 AMREX_ROCFFT_SAFE_CALL
558 (rocfft_plan_create(&plan, rocfft_placement_inplace,
559 rocfft_transform_type_real_forward, prec, 1,
560 &length, howmany, nullptr));
561
562#elif defined(AMREX_USE_SYCL)
563
564 auto* pp = new mkl_desc_r(nex);
565#ifndef AMREX_USE_MKL_DFTI_2024
566 pp->set_value(oneapi::mkl::dft::config_param::PLACEMENT,
567 oneapi::mkl::dft::config_value::INPLACE);
568#else
569 pp->set_value(oneapi::mkl::dft::config_param::PLACEMENT, DFTI_INPLACE);
570#endif
571 pp->set_value(oneapi::mkl::dft::config_param::NUMBER_OF_TRANSFORMS, howmany);
572 pp->set_value(oneapi::mkl::dft::config_param::FWD_DISTANCE, nc*2);
573 pp->set_value(oneapi::mkl::dft::config_param::BWD_DISTANCE, nc);
574 std::vector<std::int64_t> strides = {0,1};
575#ifndef AMREX_USE_MKL_DFTI_2024
576 pp->set_value(oneapi::mkl::dft::config_param::FWD_STRIDES, strides);
577 pp->set_value(oneapi::mkl::dft::config_param::BWD_STRIDES, strides);
578#else
579 pp->set_value(oneapi::mkl::dft::config_param::FWD_STRIDES, strides.data());
580 pp->set_value(oneapi::mkl::dft::config_param::BWD_STRIDES, strides.data());
581#endif
582 pp->set_value(oneapi::mkl::dft::config_param::WORKSPACE,
583 oneapi::mkl::dft::config_value::WORKSPACE_EXTERNAL);
584 pp->commit(amrex::Gpu::Device::streamQueue());
585 plan = pp;
586
587#endif
588
589#else /* FFTW */
590 auto fftw_kind = get_fftw_kind<D>(bc);
591 if constexpr (std::is_same_v<float,T>) {
592 plan = fftwf_plan_many_r2r
593 (1, &n, howmany, p, nullptr, 1, n, p, nullptr, 1, n, &fftw_kind,
594 FFTW_ESTIMATE);
595 } else {
596 plan = fftw_plan_many_r2r
597 (1, &n, howmany, p, nullptr, 1, n, p, nullptr, 1, n, &fftw_kind,
598 FFTW_ESTIMATE);
599 }
600#endif
601 }
602
603 template <Direction D>
604 void init_r2r (Box const& box, VendorComplex* pc,
605 std::pair<Boundary,Boundary> const& bc)
606 {
607 static_assert(D == Direction::forward || D == Direction::backward);
608
609 auto* p = (T*)pc;
610
611#if defined(AMREX_USE_GPU)
612
613 init_r2r<D>(box, p, bc, 2);
614 r2r_data_is_complex = true;
615
616#else
617
618 kind = get_r2r_kind<D>(bc);
619 defined = true;
620 pf = (void*)p;
621 pb = (void*)p;
622
623 n = box.length(0);
624 howmany = AMREX_D_TERM(1, *box.length(1), *box.length(2));
625
626 defined2 = true;
627 auto fftw_kind = get_fftw_kind<D>(bc);
628 if constexpr (std::is_same_v<float,T>) {
629 plan = fftwf_plan_many_r2r
630 (1, &n, howmany, p, nullptr, 2, n*2, p, nullptr, 2, n*2, &fftw_kind,
631 FFTW_ESTIMATE);
632 plan2 = fftwf_plan_many_r2r
633 (1, &n, howmany, p+1, nullptr, 2, n*2, p+1, nullptr, 2, n*2, &fftw_kind,
634 FFTW_ESTIMATE);
635 } else {
636 plan = fftw_plan_many_r2r
637 (1, &n, howmany, p, nullptr, 2, n*2, p, nullptr, 2, n*2, &fftw_kind,
638 FFTW_ESTIMATE);
639 plan2 = fftw_plan_many_r2r
640 (1, &n, howmany, p+1, nullptr, 2, n*2, p+1, nullptr, 2, n*2, &fftw_kind,
641 FFTW_ESTIMATE);
642 }
643#endif
644 }
645
646 template <Direction D>
648 {
649 static_assert(D == Direction::forward || D == Direction::backward);
650 if (!defined) { return; }
651
652 using TI = std::conditional_t<(D == Direction::forward), T, VendorComplex>;
653 using TO = std::conditional_t<(D == Direction::backward), T, VendorComplex>;
654 auto* pi = (TI*)((D == Direction::forward) ? pf : pb);
655 auto* po = (TO*)((D == Direction::forward) ? pb : pf);
656
657#if defined(AMREX_USE_CUDA)
658 AMREX_CUFFT_SAFE_CALL(cufftSetStream(plan, Gpu::gpuStream()));
659
660 std::size_t work_size = 0;
661 AMREX_CUFFT_SAFE_CALL(cufftGetSize(plan, &work_size));
662
663 auto* work_area = The_Arena()->alloc(work_size);
664 AMREX_CUFFT_SAFE_CALL(cufftSetWorkArea(plan, work_area));
665
666 if constexpr (D == Direction::forward) {
667 if constexpr (std::is_same_v<float,T>) {
668 AMREX_CUFFT_SAFE_CALL(cufftExecR2C(plan, pi, po));
669 } else {
670 AMREX_CUFFT_SAFE_CALL(cufftExecD2Z(plan, pi, po));
671 }
672 } else {
673 if constexpr (std::is_same_v<float,T>) {
674 AMREX_CUFFT_SAFE_CALL(cufftExecC2R(plan, pi, po));
675 } else {
676 AMREX_CUFFT_SAFE_CALL(cufftExecZ2D(plan, pi, po));
677 }
678 }
680 The_Arena()->free(work_area);
681#elif defined(AMREX_USE_HIP)
682 detail::hip_execute(plan, (void**)&pi, (void**)&po);
683#elif defined(AMREX_USE_SYCL)
684 detail::sycl_execute<T,D>(std::get<0>(plan), pi, po);
685#else
687 if constexpr (std::is_same_v<float,T>) {
688 fftwf_execute(plan);
689 } else {
690 fftw_execute(plan);
691 }
692#endif
693 }
694
695 template <Direction D>
697 {
698 static_assert(D == Direction::forward || D == Direction::backward);
699 if (!defined) { return; }
700
701 auto* p = (VendorComplex*)pf;
702
703#if defined(AMREX_USE_CUDA)
704 AMREX_CUFFT_SAFE_CALL(cufftSetStream(plan, Gpu::gpuStream()));
705
706 std::size_t work_size = 0;
707 AMREX_CUFFT_SAFE_CALL(cufftGetSize(plan, &work_size));
708
709 auto* work_area = The_Arena()->alloc(work_size);
710 AMREX_CUFFT_SAFE_CALL(cufftSetWorkArea(plan, work_area));
711
712 auto dir = (D == Direction::forward) ? CUFFT_FORWARD : CUFFT_INVERSE;
713 if constexpr (std::is_same_v<float,T>) {
714 AMREX_CUFFT_SAFE_CALL(cufftExecC2C(plan, p, p, dir));
715 } else {
716 AMREX_CUFFT_SAFE_CALL(cufftExecZ2Z(plan, p, p, dir));
717 }
719 The_Arena()->free(work_area);
720#elif defined(AMREX_USE_HIP)
721 detail::hip_execute(plan, (void**)&p, (void**)&p);
722#elif defined(AMREX_USE_SYCL)
723 detail::sycl_execute<T,D>(std::get<1>(plan), p, p);
724#else
726 if constexpr (std::is_same_v<float,T>) {
727 fftwf_execute(plan);
728 } else {
729 fftw_execute(plan);
730 }
731#endif
732 }
733
734#ifdef AMREX_USE_GPU
735 [[nodiscard]] void* alloc_scratch_space () const
736 {
737 int nc = 0;
738 if (kind == Kind::r2r_oo_f || kind == Kind::r2r_ee_f) {
739 nc = n + 1;
740 } else if (kind == Kind::r2r_oo_b || kind == Kind::r2r_ee_b ||
742 nc = 2*n+1;
743 } else {
744 amrex::Abort("FFT: alloc_scratch_space: unsupported kind");
745 }
746 return The_Arena()->alloc(sizeof(GpuComplex<T>)*nc*howmany);
747 }
748
749 static void free_scratch_space (void* p) { The_Arena()->free(p); }
750
751 void pack_r2r_buffer (void* pbuf, T const* psrc) const
752 {
753 auto* pdst = (T*) pbuf;
754 if (kind == Kind::r2r_oo_f || kind == Kind::r2r_ee_f) {
755 T sign = (kind == Kind::r2r_oo_f) ? T(-1) : T(1);
756 int ostride = (n+1)*2;
757 int istride = n;
758 int nex = 2*n;
759 int norig = n;
760 Long nelems = Long(nex)*howmany;
762 ParallelForOMP(nelems/2, [=] AMREX_GPU_DEVICE (Long ielem)
763 {
764 auto batch = ielem / Long(nex);
765 auto i = int(ielem - batch*nex);
766 for (int ir = 0; ir < 2; ++ir) {
767 auto* po = pdst + (2*batch+ir)*ostride + i;
768 auto const* pi = psrc + 2*batch*istride + ir;
769 if (i < norig) {
770 *po = pi[i*2];
771 } else {
772 *po = sign * pi[(2*norig-1-i)*2];
773 }
774 }
775 });
776 } else {
777 ParallelForOMP(nelems, [=] AMREX_GPU_DEVICE (Long ielem)
778 {
779 auto batch = ielem / Long(nex);
780 auto i = int(ielem - batch*nex);
781 auto* po = pdst + batch*ostride + i;
782 auto const* pi = psrc + batch*istride;
783 if (i < norig) {
784 *po = pi[i];
785 } else {
786 *po = sign * pi[2*norig-1-i];
787 }
788 });
789 }
790 } else if (kind == Kind::r2r_oo_b) {
791 int ostride = (2*n+1)*2;
792 int istride = n;
793 int nex = 4*n;
794 int norig = n;
795 Long nelems = Long(nex)*howmany;
797 ParallelForOMP(nelems/2, [=] AMREX_GPU_DEVICE (Long ielem)
798 {
799 auto batch = ielem / Long(nex);
800 auto i = int(ielem - batch*nex);
801 for (int ir = 0; ir < 2; ++ir) {
802 auto* po = pdst + (2*batch+ir)*ostride + i;
803 auto const* pi = psrc + 2*batch*istride + ir;
804 if (i < norig) {
805 *po = pi[i*2];
806 } else if (i < (2*norig-1)) {
807 *po = pi[(2*norig-2-i)*2];
808 } else if (i == (2*norig-1)) {
809 *po = T(0);
810 } else if (i < (3*norig)) {
811 *po = -pi[(i-2*norig)*2];
812 } else if (i < (4*norig-1)) {
813 *po = -pi[(4*norig-2-i)*2];
814 } else {
815 *po = T(0);
816 }
817 }
818 });
819 } else {
820 ParallelForOMP(nelems, [=] AMREX_GPU_DEVICE (Long ielem)
821 {
822 auto batch = ielem / Long(nex);
823 auto i = int(ielem - batch*nex);
824 auto* po = pdst + batch*ostride + i;
825 auto const* pi = psrc + batch*istride;
826 if (i < norig) {
827 *po = pi[i];
828 } else if (i < (2*norig-1)) {
829 *po = pi[2*norig-2-i];
830 } else if (i == (2*norig-1)) {
831 *po = T(0);
832 } else if (i < (3*norig)) {
833 *po = -pi[i-2*norig];
834 } else if (i < (4*norig-1)) {
835 *po = -pi[4*norig-2-i];
836 } else {
837 *po = T(0);
838 }
839 });
840 }
841 } else if (kind == Kind::r2r_ee_b) {
842 int ostride = (2*n+1)*2;
843 int istride = n;
844 int nex = 4*n;
845 int norig = n;
846 Long nelems = Long(nex)*howmany;
848 ParallelForOMP(nelems/2, [=] AMREX_GPU_DEVICE (Long ielem)
849 {
850 auto batch = ielem / Long(nex);
851 auto i = int(ielem - batch*nex);
852 for (int ir = 0; ir < 2; ++ir) {
853 auto* po = pdst + (2*batch+ir)*ostride + i;
854 auto const* pi = psrc + 2*batch*istride + ir;
855 if (i < norig) {
856 *po = pi[i*2];
857 } else if (i == norig) {
858 *po = T(0);
859 } else if (i < (2*norig+1)) {
860 *po = -pi[(2*norig-i)*2];
861 } else if (i < (3*norig)) {
862 *po = -pi[(i-2*norig)*2];
863 } else if (i == 3*norig) {
864 *po = T(0);
865 } else {
866 *po = pi[(4*norig-i)*2];
867 }
868 }
869 });
870 } else {
871 ParallelForOMP(nelems, [=] AMREX_GPU_DEVICE (Long ielem)
872 {
873 auto batch = ielem / Long(nex);
874 auto i = int(ielem - batch*nex);
875 auto* po = pdst + batch*ostride + i;
876 auto const* pi = psrc + batch*istride;
877 if (i < norig) {
878 *po = pi[i];
879 } else if (i == norig) {
880 *po = T(0);
881 } else if (i < (2*norig+1)) {
882 *po = -pi[2*norig-i];
883 } else if (i < (3*norig)) {
884 *po = -pi[i-2*norig];
885 } else if (i == 3*norig) {
886 *po = T(0);
887 } else {
888 *po = pi[4*norig-i];
889 }
890 });
891 }
892 } else if (kind == Kind::r2r_eo) {
893 int ostride = (2*n+1)*2;
894 int istride = n;
895 int nex = 4*n;
896 int norig = n;
897 Long nelems = Long(nex)*howmany;
899 ParallelForOMP(nelems/2, [=] AMREX_GPU_DEVICE (Long ielem)
900 {
901 auto batch = ielem / Long(nex);
902 auto i = int(ielem - batch*nex);
903 for (int ir = 0; ir < 2; ++ir) {
904 auto* po = pdst + (2*batch+ir)*ostride + i;
905 auto const* pi = psrc + 2*batch*istride + ir;
906 if (i < norig) {
907 *po = pi[i*2];
908 } else if (i < (2*norig)) {
909 *po = -pi[(2*norig-1-i)*2];
910 } else if (i < (3*norig)) {
911 *po = -pi[(i-2*norig)*2];
912 } else {
913 *po = pi[(4*norig-1-i)*2];
914 }
915 }
916 });
917 } else {
918 ParallelForOMP(nelems, [=] AMREX_GPU_DEVICE (Long ielem)
919 {
920 auto batch = ielem / Long(nex);
921 auto i = int(ielem - batch*nex);
922 auto* po = pdst + batch*ostride + i;
923 auto const* pi = psrc + batch*istride;
924 if (i < norig) {
925 *po = pi[i];
926 } else if (i < (2*norig)) {
927 *po = -pi[2*norig-1-i];
928 } else if (i < (3*norig)) {
929 *po = -pi[i-2*norig];
930 } else {
931 *po = pi[4*norig-1-i];
932 }
933 });
934 }
935 } else if (kind == Kind::r2r_oe) {
936 int ostride = (2*n+1)*2;
937 int istride = n;
938 int nex = 4*n;
939 int norig = n;
940 Long nelems = Long(nex)*howmany;
942 ParallelForOMP(nelems/2, [=] AMREX_GPU_DEVICE (Long ielem)
943 {
944 auto batch = ielem / Long(nex);
945 auto i = int(ielem - batch*nex);
946 for (int ir = 0; ir < 2; ++ir) {
947 auto* po = pdst + (2*batch+ir)*ostride + i;
948 auto const* pi = psrc + 2*batch*istride + ir;
949 if (i < norig) {
950 *po = pi[i*2];
951 } else if (i < (2*norig)) {
952 *po = pi[(2*norig-1-i)*2];
953 } else if (i < (3*norig)) {
954 *po = -pi[(i-2*norig)*2];
955 } else {
956 *po = -pi[(4*norig-1-i)*2];
957 }
958 }
959 });
960 } else {
961 ParallelForOMP(nelems, [=] AMREX_GPU_DEVICE (Long ielem)
962 {
963 auto batch = ielem / Long(nex);
964 auto i = int(ielem - batch*nex);
965 auto* po = pdst + batch*ostride + i;
966 auto const* pi = psrc + batch*istride;
967 if (i < norig) {
968 *po = pi[i];
969 } else if (i < (2*norig)) {
970 *po = pi[2*norig-1-i];
971 } else if (i < (3*norig)) {
972 *po = -pi[i-2*norig];
973 } else {
974 *po = -pi[4*norig-1-i];
975 }
976 });
977 }
978 } else {
979 amrex::Abort("FFT: pack_r2r_buffer: unsupported kind");
980 }
981 }
982
983 void unpack_r2r_buffer (T* pdst, void const* pbuf) const
984 {
985 auto const* psrc = (GpuComplex<T> const*) pbuf;
986 int norig = n;
987 Long nelems = Long(norig)*howmany;
988 int ostride = n;
989
990 if (kind == Kind::r2r_oo_f) {
991 int istride = n+1;
993 ParallelForOMP(nelems/2, [=] AMREX_GPU_DEVICE (Long ielem)
994 {
995 auto batch = ielem / Long(norig);
996 auto k = int(ielem - batch*norig);
997 auto [s, c] = Math::sincospi(T(k+1)/T(2*norig));
998 for (int ir = 0; ir < 2; ++ir) {
999 auto const& yk = psrc[(2*batch+ir)*istride+k+1];
1000 pdst[2*batch*ostride+ir+k*2] = s * yk.real() - c * yk.imag();
1001 }
1002 });
1003 } else {
1004 ParallelForOMP(nelems, [=] AMREX_GPU_DEVICE (Long ielem)
1005 {
1006 auto batch = ielem / Long(norig);
1007 auto k = int(ielem - batch*norig);
1008 auto [s, c] = Math::sincospi(T(k+1)/T(2*norig));
1009 auto const& yk = psrc[batch*istride+k+1];
1010 pdst[batch*ostride+k] = s * yk.real() - c * yk.imag();
1011 });
1012 }
1013 } else if (kind == Kind::r2r_oo_b) {
1014 int istride = 2*n+1;
1015 if (r2r_data_is_complex) {
1016 ParallelForOMP(nelems/2, [=] AMREX_GPU_DEVICE (Long ielem)
1017 {
1018 auto batch = ielem / Long(norig);
1019 auto k = int(ielem - batch*norig);
1020 auto [s, c] = Math::sincospi(T(2*k+1)/T(2*norig));
1021 for (int ir = 0; ir < 2; ++ir) {
1022 auto const& yk = psrc[(2*batch+ir)*istride+2*k+1];
1023 pdst[2*batch*ostride+ir+k*2] = T(0.5)*(s * yk.real() - c * yk.imag());
1024 }
1025 });
1026 } else {
1027 ParallelForOMP(nelems, [=] AMREX_GPU_DEVICE (Long ielem)
1028 {
1029 auto batch = ielem / Long(norig);
1030 auto k = int(ielem - batch*norig);
1031 auto [s, c] = Math::sincospi(T(2*k+1)/T(2*norig));
1032 auto const& yk = psrc[batch*istride+2*k+1];
1033 pdst[batch*ostride+k] = T(0.5)*(s * yk.real() - c * yk.imag());
1034 });
1035 }
1036 } else if (kind == Kind::r2r_ee_f) {
1037 int istride = n+1;
1038 if (r2r_data_is_complex) {
1039 ParallelForOMP(nelems/2, [=] AMREX_GPU_DEVICE (Long ielem)
1040 {
1041 auto batch = ielem / Long(norig);
1042 auto k = int(ielem - batch*norig);
1043 auto [s, c] = Math::sincospi(T(k)/T(2*norig));
1044 for (int ir = 0; ir < 2; ++ir) {
1045 auto const& yk = psrc[(2*batch+ir)*istride+k];
1046 pdst[2*batch*ostride+ir+k*2] = c * yk.real() + s * yk.imag();
1047 }
1048 });
1049 } else {
1050 ParallelForOMP(nelems, [=] AMREX_GPU_DEVICE (Long ielem)
1051 {
1052 auto batch = ielem / Long(norig);
1053 auto k = int(ielem - batch*norig);
1054 auto [s, c] = Math::sincospi(T(k)/T(2*norig));
1055 auto const& yk = psrc[batch*istride+k];
1056 pdst[batch*ostride+k] = c * yk.real() + s * yk.imag();
1057 });
1058 }
1059 } else if (kind == Kind::r2r_ee_b) {
1060 int istride = 2*n+1;
1061 if (r2r_data_is_complex) {
1062 ParallelForOMP(nelems/2, [=] AMREX_GPU_DEVICE (Long ielem)
1063 {
1064 auto batch = ielem / Long(norig);
1065 auto k = int(ielem - batch*norig);
1066 for (int ir = 0; ir < 2; ++ir) {
1067 auto const& yk = psrc[(2*batch+ir)*istride+2*k+1];
1068 pdst[2*batch*ostride+ir+k*2] = T(0.5) * yk.real();
1069 }
1070 });
1071 } else {
1072 ParallelForOMP(nelems, [=] AMREX_GPU_DEVICE (Long ielem)
1073 {
1074 auto batch = ielem / Long(norig);
1075 auto k = int(ielem - batch*norig);
1076 auto const& yk = psrc[batch*istride+2*k+1];
1077 pdst[batch*ostride+k] = T(0.5) * yk.real();
1078 });
1079 }
1080 } else if (kind == Kind::r2r_eo) {
1081 int istride = 2*n+1;
1082 if (r2r_data_is_complex) {
1083 ParallelForOMP(nelems/2, [=] AMREX_GPU_DEVICE (Long ielem)
1084 {
1085 auto batch = ielem / Long(norig);
1086 auto k = int(ielem - batch*norig);
1087 auto [s, c] = Math::sincospi((k+T(0.5))/T(2*norig));
1088 for (int ir = 0; ir < 2; ++ir) {
1089 auto const& yk = psrc[(2*batch+ir)*istride+2*k+1];
1090 pdst[2*batch*ostride+ir+k*2] = T(0.5) * (c * yk.real() + s * yk.imag());
1091 }
1092 });
1093 } else {
1094 ParallelForOMP(nelems, [=] AMREX_GPU_DEVICE (Long ielem)
1095 {
1096 auto batch = ielem / Long(norig);
1097 auto k = int(ielem - batch*norig);
1098 auto [s, c] = Math::sincospi((k+T(0.5))/T(2*norig));
1099 auto const& yk = psrc[batch*istride+2*k+1];
1100 pdst[batch*ostride+k] = T(0.5) * (c * yk.real() + s * yk.imag());
1101 });
1102 }
1103 } else if (kind == Kind::r2r_oe) {
1104 int istride = 2*n+1;
1105 if (r2r_data_is_complex) {
1106 ParallelForOMP(nelems/2, [=] AMREX_GPU_DEVICE (Long ielem)
1107 {
1108 auto batch = ielem / Long(norig);
1109 auto k = int(ielem - batch*norig);
1110 auto [s, c] = Math::sincospi((k+T(0.5))/T(2*norig));
1111 for (int ir = 0; ir < 2; ++ir) {
1112 auto const& yk = psrc[(2*batch+ir)*istride+2*k+1];
1113 pdst[2*batch*ostride+ir+k*2] = T(0.5) * (s * yk.real() - c * yk.imag());
1114 }
1115 });
1116 } else {
1117 ParallelForOMP(nelems, [=] AMREX_GPU_DEVICE (Long ielem)
1118 {
1119 auto batch = ielem / Long(norig);
1120 auto k = int(ielem - batch*norig);
1121 auto [s, c] = Math::sincospi((k+T(0.5))/T(2*norig));
1122 auto const& yk = psrc[batch*istride+2*k+1];
1123 pdst[batch*ostride+k] = T(0.5) * (s * yk.real() - c * yk.imag());
1124 });
1125 }
1126 } else {
1127 amrex::Abort("FFT: unpack_r2r_buffer: unsupported kind");
1128 }
1129 }
1130#endif
1131
1132 template <Direction D>
1134 {
1135 static_assert(D == Direction::forward || D == Direction::backward);
1136 if (!defined) { return; }
1137
1138#if defined(AMREX_USE_GPU)
1139
1140 auto* pscratch = alloc_scratch_space();
1141
1142 pack_r2r_buffer(pscratch, (T*)((D == Direction::forward) ? pf : pb));
1143
1144#if defined(AMREX_USE_CUDA)
1145
1146 AMREX_CUFFT_SAFE_CALL(cufftSetStream(plan, Gpu::gpuStream()));
1147
1148 std::size_t work_size = 0;
1149 AMREX_CUFFT_SAFE_CALL(cufftGetSize(plan, &work_size));
1150
1151 auto* work_area = The_Arena()->alloc(work_size);
1152 AMREX_CUFFT_SAFE_CALL(cufftSetWorkArea(plan, work_area));
1153
1154 if constexpr (std::is_same_v<float,T>) {
1155 AMREX_CUFFT_SAFE_CALL(cufftExecR2C(plan, (T*)pscratch, (VendorComplex*)pscratch));
1156 } else {
1157 AMREX_CUFFT_SAFE_CALL(cufftExecD2Z(plan, (T*)pscratch, (VendorComplex*)pscratch));
1158 }
1159
1160#elif defined(AMREX_USE_HIP)
1161 detail::hip_execute(plan, (void**)&pscratch, (void**)&pscratch);
1162#elif defined(AMREX_USE_SYCL)
1163 detail::sycl_execute<T,Direction::forward>(std::get<0>(plan), (T*)pscratch, (VendorComplex*)pscratch);
1164#endif
1165
1166 unpack_r2r_buffer((T*)((D == Direction::forward) ? pb : pf), pscratch);
1167
1169 free_scratch_space(pscratch);
1170#if defined(AMREX_USE_CUDA)
1171 The_Arena()->free(work_area);
1172#endif
1173
1174#else /* FFTW */
1175
1176 if constexpr (std::is_same_v<float,T>) {
1177 fftwf_execute(plan);
1178 if (defined2) { fftwf_execute(plan2); }
1179 } else {
1180 fftw_execute(plan);
1181 if (defined2) { fftw_execute(plan2); }
1182 }
1183
1184#endif
1185 }
1186
1188 {
1189#if defined(AMREX_USE_CUDA)
1190 AMREX_CUFFT_SAFE_CALL(cufftDestroy(plan));
1191#elif defined(AMREX_USE_HIP)
1192 AMREX_ROCFFT_SAFE_CALL(rocfft_plan_destroy(plan));
1193#elif defined(AMREX_USE_SYCL)
1194 std::visit([](auto&& p) { delete p; }, plan);
1195#else
1196 if constexpr (std::is_same_v<float,T>) {
1197 fftwf_destroy_plan(plan);
1198 } else {
1199 fftw_destroy_plan(plan);
1200 }
1201#endif
1202 }
1203};
1204
1206
1207using Key = std::tuple<IntVectND<3>,int,Direction,Kind>;
1208using PlanD = typename Plan<double>::VendorPlan;
1209using PlanF = typename Plan<float>::VendorPlan;
1210
1211namespace detail {
1212 PlanD* get_vendor_plan_d (Key const& key);
1213 PlanF* get_vendor_plan_f (Key const& key);
1214
1215 void add_vendor_plan_d (Key const& key, PlanD plan);
1216 void add_vendor_plan_f (Key const& key, PlanF plan);
1217}
1219
1220template <typename T>
1221template <Direction D, int M>
1222void Plan<T>::init_r2c (IntVectND<M> const& fft_size, void* pbf, void* pbb, bool cache, int ncomp)
1223{
1224 static_assert(D == Direction::forward || D == Direction::backward);
1225
1226 kind = (D == Direction::forward) ? Kind::r2c_f : Kind::r2c_b;
1227 defined = true;
1228 pf = pbf;
1229 pb = pbb;
1230
1231 n = 1;
1232 for (auto s : fft_size) { n *= s; }
1233 howmany = ncomp;
1234
1235#if defined(AMREX_USE_GPU)
1236 Key key = {fft_size.template expand<3>(), ncomp, D, kind};
1237 if (cache) {
1238 VendorPlan* cached_plan = nullptr;
1239 if constexpr (std::is_same_v<float,T>) {
1240 cached_plan = detail::get_vendor_plan_f(key);
1241 } else {
1242 cached_plan = detail::get_vendor_plan_d(key);
1243 }
1244 if (cached_plan) {
1245 plan = *cached_plan;
1246 return;
1247 }
1248 }
1249#else
1250 amrex::ignore_unused(cache);
1251#endif
1252
1253 int len[M];
1254 for (int i = 0; i < M; ++i) {
1255 len[i] = fft_size[M-1-i];
1256 }
1257
1258 int nc = fft_size[0]/2+1;
1259 for (int i = 1; i < M; ++i) {
1260 nc *= fft_size[i];
1261 }
1262
1263#if defined(AMREX_USE_CUDA)
1264
1265 AMREX_CUFFT_SAFE_CALL(cufftCreate(&plan));
1266 AMREX_CUFFT_SAFE_CALL(cufftSetAutoAllocation(plan, 0));
1267 cufftType type;
1268 int n_in, n_out;
1269 if constexpr (D == Direction::forward) {
1270 type = std::is_same_v<float,T> ? CUFFT_R2C : CUFFT_D2Z;
1271 n_in = n;
1272 n_out = nc;
1273 } else {
1274 type = std::is_same_v<float,T> ? CUFFT_C2R : CUFFT_Z2D;
1275 n_in = nc;
1276 n_out = n;
1277 }
1278 std::size_t work_size;
1280 (cufftMakePlanMany(plan, M, len, nullptr, 1, n_in, nullptr, 1, n_out, type, howmany, &work_size));
1281
1282#elif defined(AMREX_USE_HIP)
1283
1284 auto prec = std::is_same_v<float,T> ? rocfft_precision_single : rocfft_precision_double;
1285 std::size_t length[M];
1286 for (int idim = 0; idim < M; ++idim) { length[idim] = fft_size[idim]; }
1287 if constexpr (D == Direction::forward) {
1288 AMREX_ROCFFT_SAFE_CALL
1289 (rocfft_plan_create(&plan, rocfft_placement_notinplace,
1290 rocfft_transform_type_real_forward, prec, M,
1291 length, howmany, nullptr));
1292 } else {
1293 AMREX_ROCFFT_SAFE_CALL
1294 (rocfft_plan_create(&plan, rocfft_placement_notinplace,
1295 rocfft_transform_type_real_inverse, prec, M,
1296 length, howmany, nullptr));
1297 }
1298
1299#elif defined(AMREX_USE_SYCL)
1300
1301 mkl_desc_r* pp;
1302 if (M == 1) {
1303 pp = new mkl_desc_r(fft_size[0]);
1304 } else {
1305 std::vector<std::int64_t> len64(M);
1306 for (int idim = 0; idim < M; ++idim) {
1307 len64[idim] = len[idim];
1308 }
1309 pp = new mkl_desc_r(len64);
1310 }
1311#ifndef AMREX_USE_MKL_DFTI_2024
1312 pp->set_value(oneapi::mkl::dft::config_param::PLACEMENT,
1313 oneapi::mkl::dft::config_value::NOT_INPLACE);
1314#else
1315 pp->set_value(oneapi::mkl::dft::config_param::PLACEMENT, DFTI_NOT_INPLACE);
1316#endif
1317 pp->set_value(oneapi::mkl::dft::config_param::NUMBER_OF_TRANSFORMS, howmany);
1318 pp->set_value(oneapi::mkl::dft::config_param::FWD_DISTANCE, n);
1319 pp->set_value(oneapi::mkl::dft::config_param::BWD_DISTANCE, nc);
1320 std::vector<std::int64_t> strides(M+1);
1321 strides[0] = 0;
1322 strides[M] = 1;
1323 for (int i = M-1; i >= 1; --i) {
1324 strides[i] = strides[i+1] * fft_size[M-1-i];
1325 }
1326
1327#ifndef AMREX_USE_MKL_DFTI_2024
1328 pp->set_value(oneapi::mkl::dft::config_param::FWD_STRIDES, strides);
1329 // Do not set BWD_STRIDES
1330#else
1331 pp->set_value(oneapi::mkl::dft::config_param::FWD_STRIDES, strides.data());
1332 // Do not set BWD_STRIDES
1333#endif
1334 pp->set_value(oneapi::mkl::dft::config_param::WORKSPACE,
1335 oneapi::mkl::dft::config_value::WORKSPACE_EXTERNAL);
1336 pp->commit(amrex::Gpu::Device::streamQueue());
1337 plan = pp;
1338
1339#else /* FFTW */
1340
1341 if (pf == nullptr || pb == nullptr) {
1342 defined = false;
1343 return;
1344 }
1345
1346 if constexpr (std::is_same_v<float,T>) {
1347 if constexpr (D == Direction::forward) {
1348 plan = fftwf_plan_many_dft_r2c
1349 (M, len, howmany, (float*)pf, nullptr, 1, n, (fftwf_complex*)pb, nullptr, 1, nc,
1350 FFTW_ESTIMATE);
1351 } else {
1352 plan = fftwf_plan_many_dft_c2r
1353 (M, len, howmany, (fftwf_complex*)pb, nullptr, 1, nc, (float*)pf, nullptr, 1, n,
1354 FFTW_ESTIMATE);
1355 }
1356 } else {
1357 if constexpr (D == Direction::forward) {
1358 plan = fftw_plan_many_dft_r2c
1359 (M, len, howmany, (double*)pf, nullptr, 1, n, (fftw_complex*)pb, nullptr, 1, nc,
1360 FFTW_ESTIMATE);
1361 } else {
1362 plan = fftw_plan_many_dft_c2r
1363 (M, len, howmany, (fftw_complex*)pb, nullptr, 1, nc, (double*)pf, nullptr, 1, n,
1364 FFTW_ESTIMATE);
1365 }
1366 }
1367#endif
1368
1369#if defined(AMREX_USE_GPU)
1370 if (cache) {
1371 if constexpr (std::is_same_v<float,T>) {
1372 detail::add_vendor_plan_f(key, plan);
1373 } else {
1374 detail::add_vendor_plan_d(key, plan);
1375 }
1376 }
1377#endif
1378}
1379
1381namespace detail
1382{
1383 DistributionMapping make_iota_distromap (Long n);
1384
1385 template <typename FA>
1386 typename FA::FABType::value_type * get_fab (FA& fa)
1387 {
1388 auto myproc = ParallelContext::MyProcSub();
1389 if (myproc < fa.size()) {
1390 return fa.fabPtr(myproc);
1391 } else {
1392 return nullptr;
1393 }
1394 }
1395
1396 template <typename FA1, typename FA2>
1397 std::unique_ptr<char,DataDeleter> make_mfs_share (FA1& fa1, FA2& fa2)
1398 {
1399 bool not_same_fa = true;
1400 if constexpr (std::is_same_v<FA1,FA2>) {
1401 not_same_fa = (&fa1 != &fa2);
1402 }
1403 using FAB1 = typename FA1::FABType::value_type;
1404 using FAB2 = typename FA2::FABType::value_type;
1405 using T1 = typename FAB1::value_type;
1406 using T2 = typename FAB2::value_type;
1407 auto myproc = ParallelContext::MyProcSub();
1408 bool alloc_1 = (myproc < fa1.size());
1409 bool alloc_2 = (myproc < fa2.size()) && not_same_fa;
1410 void* p = nullptr;
1411 if (alloc_1 && alloc_2) {
1412 Box const& box1 = fa1.fabbox(myproc);
1413 Box const& box2 = fa2.fabbox(myproc);
1414 int ncomp1 = fa1.nComp();
1415 int ncomp2 = fa2.nComp();
1416 p = The_Arena()->alloc(std::max(sizeof(T1)*box1.numPts()*ncomp1,
1417 sizeof(T2)*box2.numPts()*ncomp2));
1418 fa1.setFab(myproc, FAB1(box1, ncomp1, (T1*)p));
1419 fa2.setFab(myproc, FAB2(box2, ncomp2, (T2*)p));
1420 } else if (alloc_1) {
1421 Box const& box1 = fa1.fabbox(myproc);
1422 int ncomp1 = fa1.nComp();
1423 p = The_Arena()->alloc(sizeof(T1)*box1.numPts()*ncomp1);
1424 fa1.setFab(myproc, FAB1(box1, ncomp1, (T1*)p));
1425 } else if (alloc_2) {
1426 Box const& box2 = fa2.fabbox(myproc);
1427 int ncomp2 = fa2.nComp();
1428 p = The_Arena()->alloc(sizeof(T2)*box2.numPts()*ncomp2);
1429 fa2.setFab(myproc, FAB2(box2, ncomp2, (T2*)p));
1430 } else {
1431 return nullptr;
1432 }
1433 return std::unique_ptr<char,DataDeleter>((char*)p, DataDeleter{The_Arena()});
1434 }
1435}
1437
1439
1440struct Swap01
1441{
1442 [[nodiscard]] constexpr Dim3 operator() (Dim3 i) const noexcept
1443 {
1444 return {i.y, i.x, i.z};
1445 }
1446
1447 static constexpr Dim3 Inverse (Dim3 i)
1448 {
1449 return {i.y, i.x, i.z};
1450 }
1451
1452 [[nodiscard]] constexpr IndexType operator() (IndexType it) const noexcept
1453 {
1454 return it;
1455 }
1456
1457 static constexpr IndexType Inverse (IndexType it)
1458 {
1459 return it;
1460 }
1461};
1462
1463struct Swap02
1464{
1465 [[nodiscard]] constexpr Dim3 operator() (Dim3 i) const noexcept
1466 {
1467 return {i.z, i.y, i.x};
1468 }
1469
1470 static constexpr Dim3 Inverse (Dim3 i)
1471 {
1472 return {i.z, i.y, i.x};
1473 }
1474
1475 [[nodiscard]] constexpr IndexType operator() (IndexType it) const noexcept
1476 {
1477 return it;
1478 }
1479
1480 static constexpr IndexType Inverse (IndexType it)
1481 {
1482 return it;
1483 }
1484};
1485
1486struct RotateFwd
1487{
1488 // dest -> src: (x,y,z) -> (y,z,x)
1489 [[nodiscard]] constexpr Dim3 operator() (Dim3 i) const noexcept
1490 {
1491 return {i.y, i.z, i.x};
1492 }
1493
1494 // src -> dest: (x,y,z) -> (z,x,y)
1495 static constexpr Dim3 Inverse (Dim3 i)
1496 {
1497 return {i.z, i.x, i.y};
1498 }
1499
1500 [[nodiscard]] constexpr IndexType operator() (IndexType it) const noexcept
1501 {
1502 return it;
1503 }
1504
1505 static constexpr IndexType Inverse (IndexType it)
1506 {
1507 return it;
1508 }
1509};
1510
1511struct RotateBwd
1512{
1513 // dest -> src: (x,y,z) -> (z,x,y)
1514 [[nodiscard]] constexpr Dim3 operator() (Dim3 i) const noexcept
1515 {
1516 return {i.z, i.x, i.y};
1517 }
1518
1519 // src -> dest: (x,y,z) -> (y,z,x)
1520 static constexpr Dim3 Inverse (Dim3 i)
1521 {
1522 return {i.y, i.z, i.x};
1523 }
1524
1525 [[nodiscard]] constexpr IndexType operator() (IndexType it) const noexcept
1526 {
1527 return it;
1528 }
1529
1530 static constexpr IndexType Inverse (IndexType it)
1531 {
1532 return it;
1533 }
1534};
1535
1537
1539namespace detail
1540{
1541 struct SubHelper
1542 {
1543 explicit SubHelper (Box const& domain);
1544
1545 [[nodiscard]] Box make_box (Box const& box) const;
1546
1547 [[nodiscard]] Periodicity make_periodicity (Periodicity const& period) const;
1548
1549 [[nodiscard]] bool ghost_safe (IntVect const& ng) const;
1550
1551 // This rearranges the order.
1552 [[nodiscard]] IntVect make_iv (IntVect const& iv) const;
1553
1554 // This keeps the order, but zero out the values in the hidden dimension.
1555 [[nodiscard]] IntVect make_safe_ghost (IntVect const& ng) const;
1556
1557 [[nodiscard]] BoxArray inverse_boxarray (BoxArray const& ba) const;
1558
1559 [[nodiscard]] IntVect inverse_order (IntVect const& order) const;
1560
1561 template <typename T>
1562 [[nodiscard]] T make_array (T const& a) const
1563 {
1564#if (AMREX_SPACEDIM == 1)
1566 return a;
1567#elif (AMREX_SPACEDIM == 2)
1568 if (m_case == case_1n) {
1569 return T{a[1],a[0]};
1570 } else {
1571 return a;
1572 }
1573#else
1574 if (m_case == case_11n) {
1575 return T{a[2],a[0],a[1]};
1576 } else if (m_case == case_1n1) {
1577 return T{a[1],a[0],a[2]};
1578 } else if (m_case == case_1nn) {
1579 return T{a[1],a[2],a[0]};
1580 } else if (m_case == case_n1n) {
1581 return T{a[0],a[2],a[1]};
1582 } else {
1583 return a;
1584 }
1585#endif
1586 }
1587
1588 [[nodiscard]] GpuArray<int,3> xyz_order () const;
1589
1590 template <typename FA>
1591 FA make_alias_mf (FA const& mf)
1592 {
1593 BoxList bl = mf.boxArray().boxList();
1594 for (auto& b : bl) {
1595 b = make_box(b);
1596 }
1597 auto const& ng = make_iv(mf.nGrowVect());
1598 FA submf(BoxArray(std::move(bl)), mf.DistributionMap(), mf.nComp(), ng, MFInfo{}.SetAlloc(false));
1599 using FAB = typename FA::fab_type;
1600 for (MFIter mfi(submf, MFItInfo().DisableDeviceSync()); mfi.isValid(); ++mfi) {
1601 submf.setFab(mfi, FAB(mfi.fabbox(), mf.nComp(), mf[mfi].dataPtr()));
1602 }
1603 return submf;
1604 }
1605
1606#if (AMREX_SPACEDIM == 2)
1607 enum Case { case_1n, case_other };
1608 int m_case = case_other;
1609#elif (AMREX_SPACEDIM == 3)
1610 enum Case { case_11n, case_1n1, case_1nn, case_n1n, case_other };
1611 int m_case = case_other;
1612#endif
1613 };
1614}
1616
1617}
1618
1619#endif
#define AMREX_ENUM(CLASS,...)
Definition AMReX_Enum.H:208
#define AMREX_CUFFT_SAFE_CALL(call)
Definition AMReX_GpuError.H:92
#define AMREX_GPU_DEVICE
Definition AMReX_GpuQualifiers.H:18
amrex::ParmParse pp
Input file parser instance for the given namespace.
Definition AMReX_HypreIJIface.cpp:15
Real * pdst
Definition AMReX_HypreMLABecLap.cpp:1090
#define AMREX_D_TERM(a, b, c)
Definition AMReX_SPACE.H:172
virtual void free(void *pt)=0
A pure virtual function for deleting the arena pointed to by pt.
virtual void * alloc(std::size_t sz)=0
__host__ __device__ IntVectND< dim > length() const noexcept
Return the length of the BoxND.
Definition AMReX_Box.H:154
Calculates the distribution of FABs to MPI processes.
Definition AMReX_DistributionMapping.H:43
An Integer Vector in dim-Dimensional Space.
Definition AMReX_IntVect.H:57
amrex_long Long
Definition AMReX_INT.H:30
void ParallelForOMP(T n, L const &f) noexcept
Performance-portable kernel launch function with optional OpenMP threading.
Definition AMReX_GpuLaunch.H:243
Definition AMReX_FFT_Helper.H:46
Direction
Definition AMReX_FFT_Helper.H:48
Boundary
Definition AMReX_FFT_Helper.H:52
DomainStrategy
Definition AMReX_FFT_Helper.H:50
Kind
Definition AMReX_FFT_Helper.H:54
void streamSynchronize() noexcept
Definition AMReX_GpuDevice.H:263
gpuStream_t gpuStream() noexcept
Definition AMReX_GpuDevice.H:244
__host__ __device__ std::pair< double, double > sincospi(double x)
Return sin(pi*x) and cos(pi*x) given x.
Definition AMReX_Math.H:204
__host__ __device__ void ignore_unused(const Ts &...)
This shuts up the compiler about unused variables.
Definition AMReX.H:138
__host__ __device__ Dim3 length(Array4< T > const &a) noexcept
Definition AMReX_Array4.H:326
BoxND< 3 > Box
Box is an alias for amrex::BoxND instantiated with AMREX_SPACEDIM.
Definition AMReX_BaseFwd.H:27
IndexTypeND< 3 > IndexType
IndexType is an alias for amrex::IndexTypeND instantiated with AMREX_SPACEDIM.
Definition AMReX_BaseFwd.H:33
IntVectND< 3 > IntVect
IntVect is an alias for amrex::IntVectND instantiated with AMREX_SPACEDIM.
Definition AMReX_BaseFwd.H:30
void Abort(const std::string &msg)
Print out message to cerr and exit via abort().
Definition AMReX.cpp:230
Arena * The_Arena()
Definition AMReX_Arena.cpp:783
Definition AMReX_FFT_Helper.H:58
bool twod_mode
Definition AMReX_FFT_Helper.H:69
Info & setNumProcs(int n)
Definition AMReX_FFT_Helper.H:85
bool oned_mode
We might have a special twod_mode: nx or ny == 1 && nz > 1.
Definition AMReX_FFT_Helper.H:72
int batch_size
Batched FFT size. Only support in R2C, not R2X.
Definition AMReX_FFT_Helper.H:75
Info & setDomainStrategy(DomainStrategy s)
Definition AMReX_FFT_Helper.H:80
DomainStrategy domain_strategy
Domain composition strategy.
Definition AMReX_FFT_Helper.H:60
int nprocs
Max number of processes to use.
Definition AMReX_FFT_Helper.H:78
int pencil_threshold
Definition AMReX_FFT_Helper.H:64
Info & setOneDMode(bool x)
Definition AMReX_FFT_Helper.H:83
Info & setBatchSize(int bsize)
Definition AMReX_FFT_Helper.H:84
Info & setPencilThreshold(int t)
Definition AMReX_FFT_Helper.H:81
Info & setTwoDMode(bool x)
Definition AMReX_FFT_Helper.H:82
Definition AMReX_FFT_Helper.H:134
void * pf
Definition AMReX_FFT_Helper.H:169
void unpack_r2r_buffer(T *pdst, void const *pbuf) const
Definition AMReX_FFT_Helper.H:983
std::conditional_t< std::is_same_v< float, T >, cuComplex, cuDoubleComplex > VendorComplex
Definition AMReX_FFT_Helper.H:138
VendorPlan plan2
Definition AMReX_FFT_Helper.H:168
int n
Definition AMReX_FFT_Helper.H:161
void destroy()
Definition AMReX_FFT_Helper.H:179
bool defined2
Definition AMReX_FFT_Helper.H:166
void init_r2r(Box const &box, VendorComplex *pc, std::pair< Boundary, Boundary > const &bc)
Definition AMReX_FFT_Helper.H:604
void pack_r2r_buffer(void *pbuf, T const *psrc) const
Definition AMReX_FFT_Helper.H:751
static void free_scratch_space(void *p)
Definition AMReX_FFT_Helper.H:749
static void destroy_vendor_plan(VendorPlan plan)
Definition AMReX_FFT_Helper.H:1187
Kind get_r2r_kind(std::pair< Boundary, Boundary > const &bc)
Definition AMReX_FFT_Helper.H:482
cufftHandle VendorPlan
Definition AMReX_FFT_Helper.H:136
Kind kind
Definition AMReX_FFT_Helper.H:163
void init_c2c(Box const &box, VendorComplex *p, int ncomp=1, int ndims=1)
Definition AMReX_FFT_Helper.H:321
int howmany
Definition AMReX_FFT_Helper.H:162
void * pb
Definition AMReX_FFT_Helper.H:170
void init_r2c(Box const &box, T *pr, VendorComplex *pc, bool is_2d_transform=false, int ncomp=1)
Definition AMReX_FFT_Helper.H:194
void compute_r2r()
Definition AMReX_FFT_Helper.H:1133
void compute_c2c()
Definition AMReX_FFT_Helper.H:696
bool r2r_data_is_complex
Definition AMReX_FFT_Helper.H:164
void * alloc_scratch_space() const
Definition AMReX_FFT_Helper.H:735
VendorPlan plan
Definition AMReX_FFT_Helper.H:167
void compute_r2c()
Definition AMReX_FFT_Helper.H:647
bool defined
Definition AMReX_FFT_Helper.H:165
void set_ptrs(void *p0, void *p1)
Definition AMReX_FFT_Helper.H:173
void init_r2r(Box const &box, T *p, std::pair< Boundary, Boundary > const &bc, int howmany_initval=1)
Definition AMReX_FFT_Helper.H:508
void init_r2c(IntVectND< M > const &fft_size, void *, void *, bool cache, int ncomp=1)
Definition AMReX_FFT_Helper.H:1222
A host / device complex number type, because std::complex doesn't work in device code with Cuda yet.
Definition AMReX_GpuComplex.H:30