Block-Structured AMR Software Framework
 
Loading...
Searching...
No Matches
AMReX_FFT_Helper.H
Go to the documentation of this file.
1#ifndef AMREX_FFT_HELPER_H_
2#define AMREX_FFT_HELPER_H_
3#include <AMReX_Config.H>
4
5#include <AMReX.H>
6#include <AMReX_BLProfiler.H>
9#include <AMReX_Enum.H>
10#include <AMReX_FabArray.H>
11#include <AMReX_Gpu.H>
12#include <AMReX_GpuComplex.H>
13#include <AMReX_Math.H>
14#include <AMReX_Periodicity.H>
15
16#if defined(AMREX_USE_CUDA)
17# include <cufft.h>
18# include <cuComplex.h>
19#elif defined(AMREX_USE_HIP)
20# if __has_include(<rocfft/rocfft.h>) // ROCm 5.3+
21# include <rocfft/rocfft.h>
22# else
23# include <rocfft.h>
24# endif
25# include <hip/hip_complex.h>
26#elif defined(AMREX_USE_SYCL)
27# if __has_include(<oneapi/mkl/dft.hpp>) // oneAPI 2025.0
28# include <oneapi/mkl/dft.hpp>
29#else
30# define AMREX_USE_MKL_DFTI_2024 1
31# include <oneapi/mkl/dfti.hpp>
32# endif
33#else
34# include <fftw3.h>
35#endif
36
37#include <algorithm>
38#include <complex>
39#include <limits>
40#include <memory>
41#include <tuple>
42#include <utility>
43#include <variant>
44
45namespace amrex::FFT
46{
47
48enum struct Direction { forward, backward, both, none };
49
51
53
56
57struct Info
58{
61
65
69 bool twod_mode = false;
70
72 bool oned_mode = false;
73
75 int batch_size = 1;
76
78 int nprocs = std::numeric_limits<int>::max();
79
81 Info& setPencilThreshold (int t) { pencil_threshold = t; return *this; }
82 Info& setTwoDMode (bool x) { twod_mode = x; return *this; }
83 Info& setOneDMode (bool x) { oned_mode = x; return *this; }
84 Info& setBatchSize (int bsize) { batch_size = bsize; return *this; }
85 Info& setNumProcs (int n) { nprocs = n; return *this; }
86};
87
88#ifdef AMREX_USE_HIP
89namespace detail { void hip_execute (rocfft_plan plan, void **in, void **out); }
90#endif
91
92#ifdef AMREX_USE_SYCL
93namespace detail
94{
95template <typename T, Direction direction, typename P, typename TI, typename TO>
96void sycl_execute (P* plan, TI* in, TO* out)
97{
98#ifndef AMREX_USE_MKL_DFTI_2024
99 std::int64_t workspaceSize = 0;
100#else
101 std::size_t workspaceSize = 0;
102#endif
103 plan->get_value(oneapi::mkl::dft::config_param::WORKSPACE_BYTES,
104 &workspaceSize);
105 auto* buffer = (T*)amrex::The_Arena()->alloc(workspaceSize);
106 plan->set_workspace(buffer);
107 sycl::event r;
108 if (std::is_same_v<TI,TO>) {
110 if constexpr (direction == Direction::forward) {
111 r = oneapi::mkl::dft::compute_forward(*plan, out);
112 } else {
113 r = oneapi::mkl::dft::compute_backward(*plan, out);
114 }
115 } else {
116 if constexpr (direction == Direction::forward) {
117 r = oneapi::mkl::dft::compute_forward(*plan, in, out);
118 } else {
119 r = oneapi::mkl::dft::compute_backward(*plan, in, out);
120 }
121 }
122 r.wait();
123 amrex::The_Arena()->free(buffer);
124}
125}
126#endif
127
128template <typename T>
129struct Plan
130{
131#if defined(AMREX_USE_CUDA)
132 using VendorPlan = cufftHandle;
133 using VendorComplex = std::conditional_t<std::is_same_v<float,T>,
134 cuComplex, cuDoubleComplex>;
135#elif defined(AMREX_USE_HIP)
136 using VendorPlan = rocfft_plan;
137 using VendorComplex = std::conditional_t<std::is_same_v<float,T>,
138 float2, double2>;
139#elif defined(AMREX_USE_SYCL)
140 using mkl_desc_r = oneapi::mkl::dft::descriptor<std::is_same_v<float,T>
141 ? oneapi::mkl::dft::precision::SINGLE
142 : oneapi::mkl::dft::precision::DOUBLE,
143 oneapi::mkl::dft::domain::REAL>;
144 using mkl_desc_c = oneapi::mkl::dft::descriptor<std::is_same_v<float,T>
145 ? oneapi::mkl::dft::precision::SINGLE
146 : oneapi::mkl::dft::precision::DOUBLE,
147 oneapi::mkl::dft::domain::COMPLEX>;
148 using VendorPlan = std::variant<mkl_desc_r*,mkl_desc_c*>;
149 using VendorComplex = std::complex<T>;
150#else
151 using VendorPlan = std::conditional_t<std::is_same_v<float,T>,
152 fftwf_plan, fftw_plan>;
153 using VendorComplex = std::conditional_t<std::is_same_v<float,T>,
154 fftwf_complex, fftw_complex>;
155#endif
156
157 int n = 0;
158 int howmany = 0;
161 bool defined = false;
162 bool defined2 = false;
165 void* pf = nullptr;
166 void* pb = nullptr;
167
168#ifdef AMREX_USE_GPU
169 void set_ptrs (void* p0, void* p1) {
170 pf = p0;
171 pb = p1;
172 }
173#endif
174
175 void destroy ()
176 {
177 if (defined) {
179 defined = false;
180 }
181#if !defined(AMREX_USE_GPU)
182 if (defined2) {
184 defined2 = false;
185 }
186#endif
187 }
188
189 template <Direction D>
190 void init_r2c (Box const& box, T* pr, VendorComplex* pc, bool is_2d_transform = false, int ncomp = 1)
191 {
192 static_assert(D == Direction::forward || D == Direction::backward);
193
194 int rank = is_2d_transform ? 2 : 1;
195
197 defined = true;
198 pf = (void*)pr;
199 pb = (void*)pc;
200
201 int len[2] = {};
202 if (rank == 1) {
203 len[0] = box.length(0);
204 len[1] = box.length(0); // Not used except for HIP. Yes it's `(0)`.
205 } else {
206 len[0] = box.length(1); // Most FFT libraries assume row-major ordering
207 len[1] = box.length(0); // except for rocfft
208 }
209 int nr = (rank == 1) ? len[0] : len[0]*len[1];
210 n = nr;
211 int nc = (rank == 1) ? (len[0]/2+1) : (len[1]/2+1)*len[0];
212#if (AMREX_SPACEDIM == 1)
213 howmany = 1;
214#else
215 howmany = (rank == 1) ? AMREX_D_TERM(1, *box.length(1), *box.length(2))
216 : AMREX_D_TERM(1, *1 , *box.length(2));
217#endif
218 howmany *= ncomp;
219
221
222#if defined(AMREX_USE_CUDA)
223
224 AMREX_CUFFT_SAFE_CALL(cufftCreate(&plan));
225 AMREX_CUFFT_SAFE_CALL(cufftSetAutoAllocation(plan, 0));
226 std::size_t work_size;
227 if constexpr (D == Direction::forward) {
228 cufftType fwd_type = std::is_same_v<float,T> ? CUFFT_R2C : CUFFT_D2Z;
230 (cufftMakePlanMany(plan, rank, len, nullptr, 1, nr, nullptr, 1, nc, fwd_type, howmany, &work_size));
231 } else {
232 cufftType bwd_type = std::is_same_v<float,T> ? CUFFT_C2R : CUFFT_Z2D;
234 (cufftMakePlanMany(plan, rank, len, nullptr, 1, nc, nullptr, 1, nr, bwd_type, howmany, &work_size));
235 }
236
237#elif defined(AMREX_USE_HIP)
238
239 auto prec = std::is_same_v<float,T> ? rocfft_precision_single : rocfft_precision_double;
240 // switch to column-major ordering
241 std::size_t length[2] = {std::size_t(len[1]), std::size_t(len[0])};
242 if constexpr (D == Direction::forward) {
243 AMREX_ROCFFT_SAFE_CALL
244 (rocfft_plan_create(&plan, rocfft_placement_notinplace,
245 rocfft_transform_type_real_forward, prec, rank,
246 length, howmany, nullptr));
247 } else {
248 AMREX_ROCFFT_SAFE_CALL
249 (rocfft_plan_create(&plan, rocfft_placement_notinplace,
250 rocfft_transform_type_real_inverse, prec, rank,
251 length, howmany, nullptr));
252 }
253
254#elif defined(AMREX_USE_SYCL)
255
256 mkl_desc_r* pp;
257 if (rank == 1) {
258 pp = new mkl_desc_r(len[0]);
259 } else {
260 pp = new mkl_desc_r({std::int64_t(len[0]), std::int64_t(len[1])});
261 }
262#ifndef AMREX_USE_MKL_DFTI_2024
263 pp->set_value(oneapi::mkl::dft::config_param::PLACEMENT,
264 oneapi::mkl::dft::config_value::NOT_INPLACE);
265#else
266 pp->set_value(oneapi::mkl::dft::config_param::PLACEMENT, DFTI_NOT_INPLACE);
267#endif
268 pp->set_value(oneapi::mkl::dft::config_param::NUMBER_OF_TRANSFORMS, howmany);
269 pp->set_value(oneapi::mkl::dft::config_param::FWD_DISTANCE, nr);
270 pp->set_value(oneapi::mkl::dft::config_param::BWD_DISTANCE, nc);
271 std::vector<std::int64_t> strides;
272 strides.push_back(0);
273 if (rank == 2) { strides.push_back(len[1]); }
274 strides.push_back(1);
275#ifndef AMREX_USE_MKL_DFTI_2024
276 pp->set_value(oneapi::mkl::dft::config_param::FWD_STRIDES, strides);
277 // Do not set BWD_STRIDES
278#else
279 pp->set_value(oneapi::mkl::dft::config_param::FWD_STRIDES, strides.data());
280 // Do not set BWD_STRIDES
281#endif
282 pp->set_value(oneapi::mkl::dft::config_param::WORKSPACE,
283 oneapi::mkl::dft::config_value::WORKSPACE_EXTERNAL);
284 pp->commit(amrex::Gpu::Device::streamQueue());
285 plan = pp;
286
287#else /* FFTW */
288
289 if constexpr (std::is_same_v<float,T>) {
290 if constexpr (D == Direction::forward) {
291 plan = fftwf_plan_many_dft_r2c
292 (rank, len, howmany, pr, nullptr, 1, nr, pc, nullptr, 1, nc,
293 FFTW_ESTIMATE | FFTW_DESTROY_INPUT);
294 } else {
295 plan = fftwf_plan_many_dft_c2r
296 (rank, len, howmany, pc, nullptr, 1, nc, pr, nullptr, 1, nr,
297 FFTW_ESTIMATE | FFTW_DESTROY_INPUT);
298 }
299 } else {
300 if constexpr (D == Direction::forward) {
301 plan = fftw_plan_many_dft_r2c
302 (rank, len, howmany, pr, nullptr, 1, nr, pc, nullptr, 1, nc,
303 FFTW_ESTIMATE | FFTW_DESTROY_INPUT);
304 } else {
305 plan = fftw_plan_many_dft_c2r
306 (rank, len, howmany, pc, nullptr, 1, nc, pr, nullptr, 1, nr,
307 FFTW_ESTIMATE | FFTW_DESTROY_INPUT);
308 }
309 }
310#endif
311 }
312
313 template <Direction D, int M>
314 void init_r2c (IntVectND<M> const& fft_size, void*, void*, bool cache, int ncomp = 1);
315
316 template <Direction D>
317 void init_c2c (Box const& box, VendorComplex* p, int ncomp = 1, int ndims = 1)
318 {
319 static_assert(D == Direction::forward || D == Direction::backward);
320
322 defined = true;
323 pf = (void*)p;
324 pb = (void*)p;
325
326 int len[3] = {};
327
328 if (ndims == 1) {
329 n = box.length(0);
330 howmany = AMREX_D_TERM(1, *box.length(1), *box.length(2));
331 howmany *= ncomp;
332 len[0] = box.length(0);
333 }
334#if (AMREX_SPACEDIM >= 2)
335 else if (ndims == 2) {
336 n = box.length(0) * box.length(1);
337#if (AMREX_SPACEDIM == 2)
338 howmany = ncomp;
339#else
340 howmany = box.length(2) * ncomp;
341#endif
342 len[0] = box.length(1);
343 len[1] = box.length(0);
344 }
345#if (AMREX_SPACEDIM == 3)
346 else if (ndims == 3) {
347 n = box.length(0) * box.length(1) * box.length(2);
348 howmany = ncomp;
349 len[0] = box.length(2);
350 len[1] = box.length(1);
351 len[2] = box.length(0);
352 }
353#endif
354#endif
355
356#if defined(AMREX_USE_CUDA)
357 AMREX_CUFFT_SAFE_CALL(cufftCreate(&plan));
358 AMREX_CUFFT_SAFE_CALL(cufftSetAutoAllocation(plan, 0));
359
360 cufftType t = std::is_same_v<float,T> ? CUFFT_C2C : CUFFT_Z2Z;
361 std::size_t work_size;
363 (cufftMakePlanMany(plan, ndims, len, nullptr, 1, n, nullptr, 1, n, t, howmany, &work_size));
364
365#elif defined(AMREX_USE_HIP)
366
367 auto prec = std::is_same_v<float,T> ? rocfft_precision_single
368 : rocfft_precision_double;
369 auto dir= (D == Direction::forward) ? rocfft_transform_type_complex_forward
370 : rocfft_transform_type_complex_inverse;
371 std::size_t length[3];
372 if (ndims == 1) {
373 length[0] = len[0];
374 } else if (ndims == 2) {
375 length[0] = len[1];
376 length[1] = len[0];
377 } else {
378 length[0] = len[2];
379 length[1] = len[1];
380 length[2] = len[0];
381 }
382 AMREX_ROCFFT_SAFE_CALL
383 (rocfft_plan_create(&plan, rocfft_placement_inplace, dir, prec, ndims,
384 length, howmany, nullptr));
385
386#elif defined(AMREX_USE_SYCL)
387
388 mkl_desc_c* pp;
389 if (ndims == 1) {
390 pp = new mkl_desc_c(n);
391 } else if (ndims == 2) {
392 pp = new mkl_desc_c({std::int64_t(len[0]), std::int64_t(len[1])});
393 } else {
394 pp = new mkl_desc_c({std::int64_t(len[0]), std::int64_t(len[1]), std::int64_t(len[2])});
395 }
396#ifndef AMREX_USE_MKL_DFTI_2024
397 pp->set_value(oneapi::mkl::dft::config_param::PLACEMENT,
398 oneapi::mkl::dft::config_value::INPLACE);
399#else
400 pp->set_value(oneapi::mkl::dft::config_param::PLACEMENT, DFTI_INPLACE);
401#endif
402 pp->set_value(oneapi::mkl::dft::config_param::NUMBER_OF_TRANSFORMS, howmany);
403 pp->set_value(oneapi::mkl::dft::config_param::FWD_DISTANCE, n);
404 pp->set_value(oneapi::mkl::dft::config_param::BWD_DISTANCE, n);
405 std::vector<std::int64_t> strides(ndims+1);
406 strides[0] = 0;
407 strides[ndims] = 1;
408 for (int i = ndims-1; i >= 1; --i) {
409 strides[i] = strides[i+1] * len[i];
410 }
411#ifndef AMREX_USE_MKL_DFTI_2024
412 pp->set_value(oneapi::mkl::dft::config_param::FWD_STRIDES, strides);
413 pp->set_value(oneapi::mkl::dft::config_param::BWD_STRIDES, strides);
414#else
415 pp->set_value(oneapi::mkl::dft::config_param::FWD_STRIDES, strides.data());
416 pp->set_value(oneapi::mkl::dft::config_param::BWD_STRIDES, strides.data());
417#endif
418 pp->set_value(oneapi::mkl::dft::config_param::WORKSPACE,
419 oneapi::mkl::dft::config_value::WORKSPACE_EXTERNAL);
420 pp->commit(amrex::Gpu::Device::streamQueue());
421 plan = pp;
422
423#else /* FFTW */
424
425 if constexpr (std::is_same_v<float,T>) {
426 if constexpr (D == Direction::forward) {
427 plan = fftwf_plan_many_dft
428 (ndims, len, howmany, p, nullptr, 1, n, p, nullptr, 1, n, -1,
429 FFTW_ESTIMATE);
430 } else {
431 plan = fftwf_plan_many_dft
432 (ndims, len, howmany, p, nullptr, 1, n, p, nullptr, 1, n, +1,
433 FFTW_ESTIMATE);
434 }
435 } else {
436 if constexpr (D == Direction::forward) {
437 plan = fftw_plan_many_dft
438 (ndims, len, howmany, p, nullptr, 1, n, p, nullptr, 1, n, -1,
439 FFTW_ESTIMATE);
440 } else {
441 plan = fftw_plan_many_dft
442 (ndims, len, howmany, p, nullptr, 1, n, p, nullptr, 1, n, +1,
443 FFTW_ESTIMATE);
444 }
445 }
446#endif
447 }
448
449#ifndef AMREX_USE_GPU
450 template <Direction D>
451 fftw_r2r_kind get_fftw_kind (std::pair<Boundary,Boundary> const& bc)
452 {
453 if (bc.first == Boundary::even && bc.second == Boundary::even)
454 {
455 return (D == Direction::forward) ? FFTW_REDFT10 : FFTW_REDFT01;
456 }
457 else if (bc.first == Boundary::even && bc.second == Boundary::odd)
458 {
459 return FFTW_REDFT11;
460 }
461 else if (bc.first == Boundary::odd && bc.second == Boundary::even)
462 {
463 return FFTW_RODFT11;
464 }
465 else if (bc.first == Boundary::odd && bc.second == Boundary::odd)
466 {
467 return (D == Direction::forward) ? FFTW_RODFT10 : FFTW_RODFT01;
468 }
469 else {
470 amrex::Abort("FFT: unsupported BC");
471 return fftw_r2r_kind{};
472 }
473
474 }
475#endif
476
477 template <Direction D>
478 Kind get_r2r_kind (std::pair<Boundary,Boundary> const& bc)
479 {
480 if (bc.first == Boundary::even && bc.second == Boundary::even)
481 {
483 }
484 else if (bc.first == Boundary::even && bc.second == Boundary::odd)
485 {
486 return Kind::r2r_eo;
487 }
488 else if (bc.first == Boundary::odd && bc.second == Boundary::even)
489 {
490 return Kind::r2r_oe;
491 }
492 else if (bc.first == Boundary::odd && bc.second == Boundary::odd)
493 {
495 }
496 else {
497 amrex::Abort("FFT: unsupported BC");
498 return Kind::none;
499 }
500
501 }
502
503 template <Direction D>
504 void init_r2r (Box const& box, T* p, std::pair<Boundary,Boundary> const& bc,
505 int howmany_initval = 1)
506 {
507 static_assert(D == Direction::forward || D == Direction::backward);
508
509 kind = get_r2r_kind<D>(bc);
510 defined = true;
511 pf = (void*)p;
512 pb = (void*)p;
513
514 n = box.length(0);
515 howmany = AMREX_D_TERM(howmany_initval, *box.length(1), *box.length(2));
516
517#if defined(AMREX_USE_GPU)
518 int nex=0;
519 if (bc.first == Boundary::odd && bc.second == Boundary::odd &&
520 Direction::forward == D) {
521 nex = 2*n;
522 } else if (bc.first == Boundary::odd && bc.second == Boundary::odd &&
523 Direction::backward == D) {
524 nex = 4*n;
525 } else if (bc.first == Boundary::even && bc.second == Boundary::even &&
526 Direction::forward == D) {
527 nex = 2*n;
528 } else if (bc.first == Boundary::even && bc.second == Boundary::even &&
529 Direction::backward == D) {
530 nex = 4*n;
531 } else if ((bc.first == Boundary::even && bc.second == Boundary::odd) ||
532 (bc.first == Boundary::odd && bc.second == Boundary::even)) {
533 nex = 4*n;
534 } else {
535 amrex::Abort("FFT: unsupported BC");
536 }
537 int nc = (nex/2) + 1;
538
539#if defined (AMREX_USE_CUDA)
540
541 AMREX_CUFFT_SAFE_CALL(cufftCreate(&plan));
542 AMREX_CUFFT_SAFE_CALL(cufftSetAutoAllocation(plan, 0));
543 cufftType fwd_type = std::is_same_v<float,T> ? CUFFT_R2C : CUFFT_D2Z;
544 std::size_t work_size;
546 (cufftMakePlanMany(plan, 1, &nex, nullptr, 1, nc*2, nullptr, 1, nc, fwd_type, howmany, &work_size));
547
548#elif defined(AMREX_USE_HIP)
549
551 auto prec = std::is_same_v<float,T> ? rocfft_precision_single : rocfft_precision_double;
552 const std::size_t length = nex;
553 AMREX_ROCFFT_SAFE_CALL
554 (rocfft_plan_create(&plan, rocfft_placement_inplace,
555 rocfft_transform_type_real_forward, prec, 1,
556 &length, howmany, nullptr));
557
558#elif defined(AMREX_USE_SYCL)
559
560 auto* pp = new mkl_desc_r(nex);
561#ifndef AMREX_USE_MKL_DFTI_2024
562 pp->set_value(oneapi::mkl::dft::config_param::PLACEMENT,
563 oneapi::mkl::dft::config_value::INPLACE);
564#else
565 pp->set_value(oneapi::mkl::dft::config_param::PLACEMENT, DFTI_INPLACE);
566#endif
567 pp->set_value(oneapi::mkl::dft::config_param::NUMBER_OF_TRANSFORMS, howmany);
568 pp->set_value(oneapi::mkl::dft::config_param::FWD_DISTANCE, nc*2);
569 pp->set_value(oneapi::mkl::dft::config_param::BWD_DISTANCE, nc);
570 std::vector<std::int64_t> strides = {0,1};
571#ifndef AMREX_USE_MKL_DFTI_2024
572 pp->set_value(oneapi::mkl::dft::config_param::FWD_STRIDES, strides);
573 pp->set_value(oneapi::mkl::dft::config_param::BWD_STRIDES, strides);
574#else
575 pp->set_value(oneapi::mkl::dft::config_param::FWD_STRIDES, strides.data());
576 pp->set_value(oneapi::mkl::dft::config_param::BWD_STRIDES, strides.data());
577#endif
578 pp->set_value(oneapi::mkl::dft::config_param::WORKSPACE,
579 oneapi::mkl::dft::config_value::WORKSPACE_EXTERNAL);
580 pp->commit(amrex::Gpu::Device::streamQueue());
581 plan = pp;
582
583#endif
584
585#else /* FFTW */
586 auto fftw_kind = get_fftw_kind<D>(bc);
587 if constexpr (std::is_same_v<float,T>) {
588 plan = fftwf_plan_many_r2r
589 (1, &n, howmany, p, nullptr, 1, n, p, nullptr, 1, n, &fftw_kind,
590 FFTW_ESTIMATE);
591 } else {
592 plan = fftw_plan_many_r2r
593 (1, &n, howmany, p, nullptr, 1, n, p, nullptr, 1, n, &fftw_kind,
594 FFTW_ESTIMATE);
595 }
596#endif
597 }
598
599 template <Direction D>
600 void init_r2r (Box const& box, VendorComplex* pc,
601 std::pair<Boundary,Boundary> const& bc)
602 {
603 static_assert(D == Direction::forward || D == Direction::backward);
604
605 auto* p = (T*)pc;
606
607#if defined(AMREX_USE_GPU)
608
609 init_r2r<D>(box, p, bc, 2);
610 r2r_data_is_complex = true;
611
612#else
613
614 kind = get_r2r_kind<D>(bc);
615 defined = true;
616 pf = (void*)p;
617 pb = (void*)p;
618
619 n = box.length(0);
620 howmany = AMREX_D_TERM(1, *box.length(1), *box.length(2));
621
622 defined2 = true;
623 auto fftw_kind = get_fftw_kind<D>(bc);
624 if constexpr (std::is_same_v<float,T>) {
625 plan = fftwf_plan_many_r2r
626 (1, &n, howmany, p, nullptr, 2, n*2, p, nullptr, 2, n*2, &fftw_kind,
627 FFTW_ESTIMATE);
628 plan2 = fftwf_plan_many_r2r
629 (1, &n, howmany, p+1, nullptr, 2, n*2, p+1, nullptr, 2, n*2, &fftw_kind,
630 FFTW_ESTIMATE);
631 } else {
632 plan = fftw_plan_many_r2r
633 (1, &n, howmany, p, nullptr, 2, n*2, p, nullptr, 2, n*2, &fftw_kind,
634 FFTW_ESTIMATE);
635 plan2 = fftw_plan_many_r2r
636 (1, &n, howmany, p+1, nullptr, 2, n*2, p+1, nullptr, 2, n*2, &fftw_kind,
637 FFTW_ESTIMATE);
638 }
639#endif
640 }
641
642 template <Direction D>
644 {
645 static_assert(D == Direction::forward || D == Direction::backward);
646 if (!defined) { return; }
647
648 using TI = std::conditional_t<(D == Direction::forward), T, VendorComplex>;
649 using TO = std::conditional_t<(D == Direction::backward), T, VendorComplex>;
650 auto* pi = (TI*)((D == Direction::forward) ? pf : pb);
651 auto* po = (TO*)((D == Direction::forward) ? pb : pf);
652
653#if defined(AMREX_USE_CUDA)
654 AMREX_CUFFT_SAFE_CALL(cufftSetStream(plan, Gpu::gpuStream()));
655
656 std::size_t work_size = 0;
657 AMREX_CUFFT_SAFE_CALL(cufftGetSize(plan, &work_size));
658
659 auto* work_area = The_Arena()->alloc(work_size);
660 AMREX_CUFFT_SAFE_CALL(cufftSetWorkArea(plan, work_area));
661
662 if constexpr (D == Direction::forward) {
663 if constexpr (std::is_same_v<float,T>) {
664 AMREX_CUFFT_SAFE_CALL(cufftExecR2C(plan, pi, po));
665 } else {
666 AMREX_CUFFT_SAFE_CALL(cufftExecD2Z(plan, pi, po));
667 }
668 } else {
669 if constexpr (std::is_same_v<float,T>) {
670 AMREX_CUFFT_SAFE_CALL(cufftExecC2R(plan, pi, po));
671 } else {
672 AMREX_CUFFT_SAFE_CALL(cufftExecZ2D(plan, pi, po));
673 }
674 }
676 The_Arena()->free(work_area);
677#elif defined(AMREX_USE_HIP)
678 detail::hip_execute(plan, (void**)&pi, (void**)&po);
679#elif defined(AMREX_USE_SYCL)
680 detail::sycl_execute<T,D>(std::get<0>(plan), pi, po);
681#else
683 if constexpr (std::is_same_v<float,T>) {
684 fftwf_execute(plan);
685 } else {
686 fftw_execute(plan);
687 }
688#endif
689 }
690
691 template <Direction D>
693 {
694 static_assert(D == Direction::forward || D == Direction::backward);
695 if (!defined) { return; }
696
697 auto* p = (VendorComplex*)pf;
698
699#if defined(AMREX_USE_CUDA)
700 AMREX_CUFFT_SAFE_CALL(cufftSetStream(plan, Gpu::gpuStream()));
701
702 std::size_t work_size = 0;
703 AMREX_CUFFT_SAFE_CALL(cufftGetSize(plan, &work_size));
704
705 auto* work_area = The_Arena()->alloc(work_size);
706 AMREX_CUFFT_SAFE_CALL(cufftSetWorkArea(plan, work_area));
707
708 auto dir = (D == Direction::forward) ? CUFFT_FORWARD : CUFFT_INVERSE;
709 if constexpr (std::is_same_v<float,T>) {
710 AMREX_CUFFT_SAFE_CALL(cufftExecC2C(plan, p, p, dir));
711 } else {
712 AMREX_CUFFT_SAFE_CALL(cufftExecZ2Z(plan, p, p, dir));
713 }
715 The_Arena()->free(work_area);
716#elif defined(AMREX_USE_HIP)
717 detail::hip_execute(plan, (void**)&p, (void**)&p);
718#elif defined(AMREX_USE_SYCL)
719 detail::sycl_execute<T,D>(std::get<1>(plan), p, p);
720#else
722 if constexpr (std::is_same_v<float,T>) {
723 fftwf_execute(plan);
724 } else {
725 fftw_execute(plan);
726 }
727#endif
728 }
729
730#ifdef AMREX_USE_GPU
731 [[nodiscard]] void* alloc_scratch_space () const
732 {
733 int nc = 0;
734 if (kind == Kind::r2r_oo_f || kind == Kind::r2r_ee_f) {
735 nc = n + 1;
736 } else if (kind == Kind::r2r_oo_b || kind == Kind::r2r_ee_b ||
738 nc = 2*n+1;
739 } else {
740 amrex::Abort("FFT: alloc_scratch_space: unsupported kind");
741 }
742 return The_Arena()->alloc(sizeof(GpuComplex<T>)*nc*howmany);
743 }
744
745 static void free_scratch_space (void* p) { The_Arena()->free(p); }
746
747 void pack_r2r_buffer (void* pbuf, T const* psrc) const
748 {
749 auto* pdst = (T*) pbuf;
750 if (kind == Kind::r2r_oo_f || kind == Kind::r2r_ee_f) {
751 T sign = (kind == Kind::r2r_oo_f) ? T(-1) : T(1);
752 int ostride = (n+1)*2;
753 int istride = n;
754 int nex = 2*n;
755 int norig = n;
756 Long nelems = Long(nex)*howmany;
758 ParallelForOMP(nelems/2, [=] AMREX_GPU_DEVICE (Long ielem)
759 {
760 auto batch = ielem / Long(nex);
761 auto i = int(ielem - batch*nex);
762 for (int ir = 0; ir < 2; ++ir) {
763 auto* po = pdst + (2*batch+ir)*ostride + i;
764 auto const* pi = psrc + 2*batch*istride + ir;
765 if (i < norig) {
766 *po = pi[i*2];
767 } else {
768 *po = sign * pi[(2*norig-1-i)*2];
769 }
770 }
771 });
772 } else {
773 ParallelForOMP(nelems, [=] AMREX_GPU_DEVICE (Long ielem)
774 {
775 auto batch = ielem / Long(nex);
776 auto i = int(ielem - batch*nex);
777 auto* po = pdst + batch*ostride + i;
778 auto const* pi = psrc + batch*istride;
779 if (i < norig) {
780 *po = pi[i];
781 } else {
782 *po = sign * pi[2*norig-1-i];
783 }
784 });
785 }
786 } else if (kind == Kind::r2r_oo_b) {
787 int ostride = (2*n+1)*2;
788 int istride = n;
789 int nex = 4*n;
790 int norig = n;
791 Long nelems = Long(nex)*howmany;
793 ParallelForOMP(nelems/2, [=] AMREX_GPU_DEVICE (Long ielem)
794 {
795 auto batch = ielem / Long(nex);
796 auto i = int(ielem - batch*nex);
797 for (int ir = 0; ir < 2; ++ir) {
798 auto* po = pdst + (2*batch+ir)*ostride + i;
799 auto const* pi = psrc + 2*batch*istride + ir;
800 if (i < norig) {
801 *po = pi[i*2];
802 } else if (i < (2*norig-1)) {
803 *po = pi[(2*norig-2-i)*2];
804 } else if (i == (2*norig-1)) {
805 *po = T(0);
806 } else if (i < (3*norig)) {
807 *po = -pi[(i-2*norig)*2];
808 } else if (i < (4*norig-1)) {
809 *po = -pi[(4*norig-2-i)*2];
810 } else {
811 *po = T(0);
812 }
813 }
814 });
815 } else {
816 ParallelForOMP(nelems, [=] AMREX_GPU_DEVICE (Long ielem)
817 {
818 auto batch = ielem / Long(nex);
819 auto i = int(ielem - batch*nex);
820 auto* po = pdst + batch*ostride + i;
821 auto const* pi = psrc + batch*istride;
822 if (i < norig) {
823 *po = pi[i];
824 } else if (i < (2*norig-1)) {
825 *po = pi[2*norig-2-i];
826 } else if (i == (2*norig-1)) {
827 *po = T(0);
828 } else if (i < (3*norig)) {
829 *po = -pi[i-2*norig];
830 } else if (i < (4*norig-1)) {
831 *po = -pi[4*norig-2-i];
832 } else {
833 *po = T(0);
834 }
835 });
836 }
837 } else if (kind == Kind::r2r_ee_b) {
838 int ostride = (2*n+1)*2;
839 int istride = n;
840 int nex = 4*n;
841 int norig = n;
842 Long nelems = Long(nex)*howmany;
844 ParallelForOMP(nelems/2, [=] AMREX_GPU_DEVICE (Long ielem)
845 {
846 auto batch = ielem / Long(nex);
847 auto i = int(ielem - batch*nex);
848 for (int ir = 0; ir < 2; ++ir) {
849 auto* po = pdst + (2*batch+ir)*ostride + i;
850 auto const* pi = psrc + 2*batch*istride + ir;
851 if (i < norig) {
852 *po = pi[i*2];
853 } else if (i == norig) {
854 *po = T(0);
855 } else if (i < (2*norig+1)) {
856 *po = -pi[(2*norig-i)*2];
857 } else if (i < (3*norig)) {
858 *po = -pi[(i-2*norig)*2];
859 } else if (i == 3*norig) {
860 *po = T(0);
861 } else {
862 *po = pi[(4*norig-i)*2];
863 }
864 }
865 });
866 } else {
867 ParallelForOMP(nelems, [=] AMREX_GPU_DEVICE (Long ielem)
868 {
869 auto batch = ielem / Long(nex);
870 auto i = int(ielem - batch*nex);
871 auto* po = pdst + batch*ostride + i;
872 auto const* pi = psrc + batch*istride;
873 if (i < norig) {
874 *po = pi[i];
875 } else if (i == norig) {
876 *po = T(0);
877 } else if (i < (2*norig+1)) {
878 *po = -pi[2*norig-i];
879 } else if (i < (3*norig)) {
880 *po = -pi[i-2*norig];
881 } else if (i == 3*norig) {
882 *po = T(0);
883 } else {
884 *po = pi[4*norig-i];
885 }
886 });
887 }
888 } else if (kind == Kind::r2r_eo) {
889 int ostride = (2*n+1)*2;
890 int istride = n;
891 int nex = 4*n;
892 int norig = n;
893 Long nelems = Long(nex)*howmany;
895 ParallelForOMP(nelems/2, [=] AMREX_GPU_DEVICE (Long ielem)
896 {
897 auto batch = ielem / Long(nex);
898 auto i = int(ielem - batch*nex);
899 for (int ir = 0; ir < 2; ++ir) {
900 auto* po = pdst + (2*batch+ir)*ostride + i;
901 auto const* pi = psrc + 2*batch*istride + ir;
902 if (i < norig) {
903 *po = pi[i*2];
904 } else if (i < (2*norig)) {
905 *po = -pi[(2*norig-1-i)*2];
906 } else if (i < (3*norig)) {
907 *po = -pi[(i-2*norig)*2];
908 } else {
909 *po = pi[(4*norig-1-i)*2];
910 }
911 }
912 });
913 } else {
914 ParallelForOMP(nelems, [=] AMREX_GPU_DEVICE (Long ielem)
915 {
916 auto batch = ielem / Long(nex);
917 auto i = int(ielem - batch*nex);
918 auto* po = pdst + batch*ostride + i;
919 auto const* pi = psrc + batch*istride;
920 if (i < norig) {
921 *po = pi[i];
922 } else if (i < (2*norig)) {
923 *po = -pi[2*norig-1-i];
924 } else if (i < (3*norig)) {
925 *po = -pi[i-2*norig];
926 } else {
927 *po = pi[4*norig-1-i];
928 }
929 });
930 }
931 } else if (kind == Kind::r2r_oe) {
932 int ostride = (2*n+1)*2;
933 int istride = n;
934 int nex = 4*n;
935 int norig = n;
936 Long nelems = Long(nex)*howmany;
938 ParallelForOMP(nelems/2, [=] AMREX_GPU_DEVICE (Long ielem)
939 {
940 auto batch = ielem / Long(nex);
941 auto i = int(ielem - batch*nex);
942 for (int ir = 0; ir < 2; ++ir) {
943 auto* po = pdst + (2*batch+ir)*ostride + i;
944 auto const* pi = psrc + 2*batch*istride + ir;
945 if (i < norig) {
946 *po = pi[i*2];
947 } else if (i < (2*norig)) {
948 *po = pi[(2*norig-1-i)*2];
949 } else if (i < (3*norig)) {
950 *po = -pi[(i-2*norig)*2];
951 } else {
952 *po = -pi[(4*norig-1-i)*2];
953 }
954 }
955 });
956 } else {
957 ParallelForOMP(nelems, [=] AMREX_GPU_DEVICE (Long ielem)
958 {
959 auto batch = ielem / Long(nex);
960 auto i = int(ielem - batch*nex);
961 auto* po = pdst + batch*ostride + i;
962 auto const* pi = psrc + batch*istride;
963 if (i < norig) {
964 *po = pi[i];
965 } else if (i < (2*norig)) {
966 *po = pi[2*norig-1-i];
967 } else if (i < (3*norig)) {
968 *po = -pi[i-2*norig];
969 } else {
970 *po = -pi[4*norig-1-i];
971 }
972 });
973 }
974 } else {
975 amrex::Abort("FFT: pack_r2r_buffer: unsupported kind");
976 }
977 }
978
979 void unpack_r2r_buffer (T* pdst, void const* pbuf) const
980 {
981 auto const* psrc = (GpuComplex<T> const*) pbuf;
982 int norig = n;
983 Long nelems = Long(norig)*howmany;
984 int ostride = n;
985
986 if (kind == Kind::r2r_oo_f) {
987 int istride = n+1;
989 ParallelForOMP(nelems/2, [=] AMREX_GPU_DEVICE (Long ielem)
990 {
991 auto batch = ielem / Long(norig);
992 auto k = int(ielem - batch*norig);
993 auto [s, c] = Math::sincospi(T(k+1)/T(2*norig));
994 for (int ir = 0; ir < 2; ++ir) {
995 auto const& yk = psrc[(2*batch+ir)*istride+k+1];
996 pdst[2*batch*ostride+ir+k*2] = s * yk.real() - c * yk.imag();
997 }
998 });
999 } else {
1000 ParallelForOMP(nelems, [=] AMREX_GPU_DEVICE (Long ielem)
1001 {
1002 auto batch = ielem / Long(norig);
1003 auto k = int(ielem - batch*norig);
1004 auto [s, c] = Math::sincospi(T(k+1)/T(2*norig));
1005 auto const& yk = psrc[batch*istride+k+1];
1006 pdst[batch*ostride+k] = s * yk.real() - c * yk.imag();
1007 });
1008 }
1009 } else if (kind == Kind::r2r_oo_b) {
1010 int istride = 2*n+1;
1011 if (r2r_data_is_complex) {
1012 ParallelForOMP(nelems/2, [=] AMREX_GPU_DEVICE (Long ielem)
1013 {
1014 auto batch = ielem / Long(norig);
1015 auto k = int(ielem - batch*norig);
1016 auto [s, c] = Math::sincospi(T(2*k+1)/T(2*norig));
1017 for (int ir = 0; ir < 2; ++ir) {
1018 auto const& yk = psrc[(2*batch+ir)*istride+2*k+1];
1019 pdst[2*batch*ostride+ir+k*2] = T(0.5)*(s * yk.real() - c * yk.imag());
1020 }
1021 });
1022 } else {
1023 ParallelForOMP(nelems, [=] AMREX_GPU_DEVICE (Long ielem)
1024 {
1025 auto batch = ielem / Long(norig);
1026 auto k = int(ielem - batch*norig);
1027 auto [s, c] = Math::sincospi(T(2*k+1)/T(2*norig));
1028 auto const& yk = psrc[batch*istride+2*k+1];
1029 pdst[batch*ostride+k] = T(0.5)*(s * yk.real() - c * yk.imag());
1030 });
1031 }
1032 } else if (kind == Kind::r2r_ee_f) {
1033 int istride = n+1;
1034 if (r2r_data_is_complex) {
1035 ParallelForOMP(nelems/2, [=] AMREX_GPU_DEVICE (Long ielem)
1036 {
1037 auto batch = ielem / Long(norig);
1038 auto k = int(ielem - batch*norig);
1039 auto [s, c] = Math::sincospi(T(k)/T(2*norig));
1040 for (int ir = 0; ir < 2; ++ir) {
1041 auto const& yk = psrc[(2*batch+ir)*istride+k];
1042 pdst[2*batch*ostride+ir+k*2] = c * yk.real() + s * yk.imag();
1043 }
1044 });
1045 } else {
1046 ParallelForOMP(nelems, [=] AMREX_GPU_DEVICE (Long ielem)
1047 {
1048 auto batch = ielem / Long(norig);
1049 auto k = int(ielem - batch*norig);
1050 auto [s, c] = Math::sincospi(T(k)/T(2*norig));
1051 auto const& yk = psrc[batch*istride+k];
1052 pdst[batch*ostride+k] = c * yk.real() + s * yk.imag();
1053 });
1054 }
1055 } else if (kind == Kind::r2r_ee_b) {
1056 int istride = 2*n+1;
1057 if (r2r_data_is_complex) {
1058 ParallelForOMP(nelems/2, [=] AMREX_GPU_DEVICE (Long ielem)
1059 {
1060 auto batch = ielem / Long(norig);
1061 auto k = int(ielem - batch*norig);
1062 for (int ir = 0; ir < 2; ++ir) {
1063 auto const& yk = psrc[(2*batch+ir)*istride+2*k+1];
1064 pdst[2*batch*ostride+ir+k*2] = T(0.5) * yk.real();
1065 }
1066 });
1067 } else {
1068 ParallelForOMP(nelems, [=] AMREX_GPU_DEVICE (Long ielem)
1069 {
1070 auto batch = ielem / Long(norig);
1071 auto k = int(ielem - batch*norig);
1072 auto const& yk = psrc[batch*istride+2*k+1];
1073 pdst[batch*ostride+k] = T(0.5) * yk.real();
1074 });
1075 }
1076 } else if (kind == Kind::r2r_eo) {
1077 int istride = 2*n+1;
1078 if (r2r_data_is_complex) {
1079 ParallelForOMP(nelems/2, [=] AMREX_GPU_DEVICE (Long ielem)
1080 {
1081 auto batch = ielem / Long(norig);
1082 auto k = int(ielem - batch*norig);
1083 auto [s, c] = Math::sincospi((k+T(0.5))/T(2*norig));
1084 for (int ir = 0; ir < 2; ++ir) {
1085 auto const& yk = psrc[(2*batch+ir)*istride+2*k+1];
1086 pdst[2*batch*ostride+ir+k*2] = T(0.5) * (c * yk.real() + s * yk.imag());
1087 }
1088 });
1089 } else {
1090 ParallelForOMP(nelems, [=] AMREX_GPU_DEVICE (Long ielem)
1091 {
1092 auto batch = ielem / Long(norig);
1093 auto k = int(ielem - batch*norig);
1094 auto [s, c] = Math::sincospi((k+T(0.5))/T(2*norig));
1095 auto const& yk = psrc[batch*istride+2*k+1];
1096 pdst[batch*ostride+k] = T(0.5) * (c * yk.real() + s * yk.imag());
1097 });
1098 }
1099 } else if (kind == Kind::r2r_oe) {
1100 int istride = 2*n+1;
1101 if (r2r_data_is_complex) {
1102 ParallelForOMP(nelems/2, [=] AMREX_GPU_DEVICE (Long ielem)
1103 {
1104 auto batch = ielem / Long(norig);
1105 auto k = int(ielem - batch*norig);
1106 auto [s, c] = Math::sincospi((k+T(0.5))/T(2*norig));
1107 for (int ir = 0; ir < 2; ++ir) {
1108 auto const& yk = psrc[(2*batch+ir)*istride+2*k+1];
1109 pdst[2*batch*ostride+ir+k*2] = T(0.5) * (s * yk.real() - c * yk.imag());
1110 }
1111 });
1112 } else {
1113 ParallelForOMP(nelems, [=] AMREX_GPU_DEVICE (Long ielem)
1114 {
1115 auto batch = ielem / Long(norig);
1116 auto k = int(ielem - batch*norig);
1117 auto [s, c] = Math::sincospi((k+T(0.5))/T(2*norig));
1118 auto const& yk = psrc[batch*istride+2*k+1];
1119 pdst[batch*ostride+k] = T(0.5) * (s * yk.real() - c * yk.imag());
1120 });
1121 }
1122 } else {
1123 amrex::Abort("FFT: unpack_r2r_buffer: unsupported kind");
1124 }
1125 }
1126#endif
1127
1128 template <Direction D>
1130 {
1131 static_assert(D == Direction::forward || D == Direction::backward);
1132 if (!defined) { return; }
1133
1134#if defined(AMREX_USE_GPU)
1135
1136 auto* pscratch = alloc_scratch_space();
1137
1138 pack_r2r_buffer(pscratch, (T*)((D == Direction::forward) ? pf : pb));
1139
1140#if defined(AMREX_USE_CUDA)
1141
1142 AMREX_CUFFT_SAFE_CALL(cufftSetStream(plan, Gpu::gpuStream()));
1143
1144 std::size_t work_size = 0;
1145 AMREX_CUFFT_SAFE_CALL(cufftGetSize(plan, &work_size));
1146
1147 auto* work_area = The_Arena()->alloc(work_size);
1148 AMREX_CUFFT_SAFE_CALL(cufftSetWorkArea(plan, work_area));
1149
1150 if constexpr (std::is_same_v<float,T>) {
1151 AMREX_CUFFT_SAFE_CALL(cufftExecR2C(plan, (T*)pscratch, (VendorComplex*)pscratch));
1152 } else {
1153 AMREX_CUFFT_SAFE_CALL(cufftExecD2Z(plan, (T*)pscratch, (VendorComplex*)pscratch));
1154 }
1155
1156#elif defined(AMREX_USE_HIP)
1157 detail::hip_execute(plan, (void**)&pscratch, (void**)&pscratch);
1158#elif defined(AMREX_USE_SYCL)
1159 detail::sycl_execute<T,Direction::forward>(std::get<0>(plan), (T*)pscratch, (VendorComplex*)pscratch);
1160#endif
1161
1162 unpack_r2r_buffer((T*)((D == Direction::forward) ? pb : pf), pscratch);
1163
1165 free_scratch_space(pscratch);
1166#if defined(AMREX_USE_CUDA)
1167 The_Arena()->free(work_area);
1168#endif
1169
1170#else /* FFTW */
1171
1172 if constexpr (std::is_same_v<float,T>) {
1173 fftwf_execute(plan);
1174 if (defined2) { fftwf_execute(plan2); }
1175 } else {
1176 fftw_execute(plan);
1177 if (defined2) { fftw_execute(plan2); }
1178 }
1179
1180#endif
1181 }
1182
1184 {
1185#if defined(AMREX_USE_CUDA)
1186 AMREX_CUFFT_SAFE_CALL(cufftDestroy(plan));
1187#elif defined(AMREX_USE_HIP)
1188 AMREX_ROCFFT_SAFE_CALL(rocfft_plan_destroy(plan));
1189#elif defined(AMREX_USE_SYCL)
1190 std::visit([](auto&& p) { delete p; }, plan);
1191#else
1192 if constexpr (std::is_same_v<float,T>) {
1193 fftwf_destroy_plan(plan);
1194 } else {
1195 fftw_destroy_plan(plan);
1196 }
1197#endif
1198 }
1199};
1200
1201using Key = std::tuple<IntVectND<3>,int,Direction,Kind>;
1204
1205PlanD* get_vendor_plan_d (Key const& key);
1206PlanF* get_vendor_plan_f (Key const& key);
1207
1208void add_vendor_plan_d (Key const& key, PlanD plan);
1209void add_vendor_plan_f (Key const& key, PlanF plan);
1210
1211template <typename T>
1212template <Direction D, int M>
1213void Plan<T>::init_r2c (IntVectND<M> const& fft_size, void* pbf, void* pbb, bool cache, int ncomp)
1214{
1215 static_assert(D == Direction::forward || D == Direction::backward);
1216
1217 kind = (D == Direction::forward) ? Kind::r2c_f : Kind::r2c_b;
1218 defined = true;
1219 pf = pbf;
1220 pb = pbb;
1221
1222 n = 1;
1223 for (auto s : fft_size) { n *= s; }
1224 howmany = ncomp;
1225
1226#if defined(AMREX_USE_GPU)
1227 Key key = {fft_size.template expand<3>(), ncomp, D, kind};
1228 if (cache) {
1229 VendorPlan* cached_plan = nullptr;
1230 if constexpr (std::is_same_v<float,T>) {
1231 cached_plan = get_vendor_plan_f(key);
1232 } else {
1233 cached_plan = get_vendor_plan_d(key);
1234 }
1235 if (cached_plan) {
1236 plan = *cached_plan;
1237 return;
1238 }
1239 }
1240#else
1241 amrex::ignore_unused(cache);
1242#endif
1243
1244 int len[M];
1245 for (int i = 0; i < M; ++i) {
1246 len[i] = fft_size[M-1-i];
1247 }
1248
1249 int nc = fft_size[0]/2+1;
1250 for (int i = 1; i < M; ++i) {
1251 nc *= fft_size[i];
1252 }
1253
1254#if defined(AMREX_USE_CUDA)
1255
1256 AMREX_CUFFT_SAFE_CALL(cufftCreate(&plan));
1257 AMREX_CUFFT_SAFE_CALL(cufftSetAutoAllocation(plan, 0));
1258 cufftType type;
1259 int n_in, n_out;
1260 if constexpr (D == Direction::forward) {
1261 type = std::is_same_v<float,T> ? CUFFT_R2C : CUFFT_D2Z;
1262 n_in = n;
1263 n_out = nc;
1264 } else {
1265 type = std::is_same_v<float,T> ? CUFFT_C2R : CUFFT_Z2D;
1266 n_in = nc;
1267 n_out = n;
1268 }
1269 std::size_t work_size;
1271 (cufftMakePlanMany(plan, M, len, nullptr, 1, n_in, nullptr, 1, n_out, type, howmany, &work_size));
1272
1273#elif defined(AMREX_USE_HIP)
1274
1275 auto prec = std::is_same_v<float,T> ? rocfft_precision_single : rocfft_precision_double;
1276 std::size_t length[M];
1277 for (int idim = 0; idim < M; ++idim) { length[idim] = fft_size[idim]; }
1278 if constexpr (D == Direction::forward) {
1279 AMREX_ROCFFT_SAFE_CALL
1280 (rocfft_plan_create(&plan, rocfft_placement_notinplace,
1281 rocfft_transform_type_real_forward, prec, M,
1282 length, howmany, nullptr));
1283 } else {
1284 AMREX_ROCFFT_SAFE_CALL
1285 (rocfft_plan_create(&plan, rocfft_placement_notinplace,
1286 rocfft_transform_type_real_inverse, prec, M,
1287 length, howmany, nullptr));
1288 }
1289
1290#elif defined(AMREX_USE_SYCL)
1291
1292 mkl_desc_r* pp;
1293 if (M == 1) {
1294 pp = new mkl_desc_r(fft_size[0]);
1295 } else {
1296 std::vector<std::int64_t> len64(M);
1297 for (int idim = 0; idim < M; ++idim) {
1298 len64[idim] = len[idim];
1299 }
1300 pp = new mkl_desc_r(len64);
1301 }
1302#ifndef AMREX_USE_MKL_DFTI_2024
1303 pp->set_value(oneapi::mkl::dft::config_param::PLACEMENT,
1304 oneapi::mkl::dft::config_value::NOT_INPLACE);
1305#else
1306 pp->set_value(oneapi::mkl::dft::config_param::PLACEMENT, DFTI_NOT_INPLACE);
1307#endif
1308 pp->set_value(oneapi::mkl::dft::config_param::NUMBER_OF_TRANSFORMS, howmany);
1309 pp->set_value(oneapi::mkl::dft::config_param::FWD_DISTANCE, n);
1310 pp->set_value(oneapi::mkl::dft::config_param::BWD_DISTANCE, nc);
1311 std::vector<std::int64_t> strides(M+1);
1312 strides[0] = 0;
1313 strides[M] = 1;
1314 for (int i = M-1; i >= 1; --i) {
1315 strides[i] = strides[i+1] * fft_size[M-1-i];
1316 }
1317
1318#ifndef AMREX_USE_MKL_DFTI_2024
1319 pp->set_value(oneapi::mkl::dft::config_param::FWD_STRIDES, strides);
1320 // Do not set BWD_STRIDES
1321#else
1322 pp->set_value(oneapi::mkl::dft::config_param::FWD_STRIDES, strides.data());
1323 // Do not set BWD_STRIDES
1324#endif
1325 pp->set_value(oneapi::mkl::dft::config_param::WORKSPACE,
1326 oneapi::mkl::dft::config_value::WORKSPACE_EXTERNAL);
1327 pp->commit(amrex::Gpu::Device::streamQueue());
1328 plan = pp;
1329
1330#else /* FFTW */
1331
1332 if (pf == nullptr || pb == nullptr) {
1333 defined = false;
1334 return;
1335 }
1336
1337 if constexpr (std::is_same_v<float,T>) {
1338 if constexpr (D == Direction::forward) {
1339 plan = fftwf_plan_many_dft_r2c
1340 (M, len, howmany, (float*)pf, nullptr, 1, n, (fftwf_complex*)pb, nullptr, 1, nc,
1341 FFTW_ESTIMATE);
1342 } else {
1343 plan = fftwf_plan_many_dft_c2r
1344 (M, len, howmany, (fftwf_complex*)pb, nullptr, 1, nc, (float*)pf, nullptr, 1, n,
1345 FFTW_ESTIMATE);
1346 }
1347 } else {
1348 if constexpr (D == Direction::forward) {
1349 plan = fftw_plan_many_dft_r2c
1350 (M, len, howmany, (double*)pf, nullptr, 1, n, (fftw_complex*)pb, nullptr, 1, nc,
1351 FFTW_ESTIMATE);
1352 } else {
1353 plan = fftw_plan_many_dft_c2r
1354 (M, len, howmany, (fftw_complex*)pb, nullptr, 1, nc, (double*)pf, nullptr, 1, n,
1355 FFTW_ESTIMATE);
1356 }
1357 }
1358#endif
1359
1360#if defined(AMREX_USE_GPU)
1361 if (cache) {
1362 if constexpr (std::is_same_v<float,T>) {
1363 add_vendor_plan_f(key, plan);
1364 } else {
1365 add_vendor_plan_d(key, plan);
1366 }
1367 }
1368#endif
1369}
1370
1371namespace detail
1372{
1374
1375 template <typename FA>
1376 typename FA::FABType::value_type * get_fab (FA& fa)
1377 {
1378 auto myproc = ParallelContext::MyProcSub();
1379 if (myproc < fa.size()) {
1380 return fa.fabPtr(myproc);
1381 } else {
1382 return nullptr;
1383 }
1384 }
1385
1386 template <typename FA1, typename FA2>
1387 std::unique_ptr<char,DataDeleter> make_mfs_share (FA1& fa1, FA2& fa2)
1388 {
1389 bool not_same_fa = true;
1390 if constexpr (std::is_same_v<FA1,FA2>) {
1391 not_same_fa = (&fa1 != &fa2);
1392 }
1393 using FAB1 = typename FA1::FABType::value_type;
1394 using FAB2 = typename FA2::FABType::value_type;
1395 using T1 = typename FAB1::value_type;
1396 using T2 = typename FAB2::value_type;
1397 auto myproc = ParallelContext::MyProcSub();
1398 bool alloc_1 = (myproc < fa1.size());
1399 bool alloc_2 = (myproc < fa2.size()) && not_same_fa;
1400 void* p = nullptr;
1401 if (alloc_1 && alloc_2) {
1402 Box const& box1 = fa1.fabbox(myproc);
1403 Box const& box2 = fa2.fabbox(myproc);
1404 int ncomp1 = fa1.nComp();
1405 int ncomp2 = fa2.nComp();
1406 p = The_Arena()->alloc(std::max(sizeof(T1)*box1.numPts()*ncomp1,
1407 sizeof(T2)*box2.numPts()*ncomp2));
1408 fa1.setFab(myproc, FAB1(box1, ncomp1, (T1*)p));
1409 fa2.setFab(myproc, FAB2(box2, ncomp2, (T2*)p));
1410 } else if (alloc_1) {
1411 Box const& box1 = fa1.fabbox(myproc);
1412 int ncomp1 = fa1.nComp();
1413 p = The_Arena()->alloc(sizeof(T1)*box1.numPts()*ncomp1);
1414 fa1.setFab(myproc, FAB1(box1, ncomp1, (T1*)p));
1415 } else if (alloc_2) {
1416 Box const& box2 = fa2.fabbox(myproc);
1417 int ncomp2 = fa2.nComp();
1418 p = The_Arena()->alloc(sizeof(T2)*box2.numPts()*ncomp2);
1419 fa2.setFab(myproc, FAB2(box2, ncomp2, (T2*)p));
1420 } else {
1421 return nullptr;
1422 }
1423 return std::unique_ptr<char,DataDeleter>((char*)p, DataDeleter{The_Arena()});
1424 }
1425}
1426
1428{
1429 [[nodiscard]] constexpr Dim3 operator() (Dim3 i) const noexcept
1430 {
1431 return {i.y, i.x, i.z};
1432 }
1433
1434 static constexpr Dim3 Inverse (Dim3 i)
1435 {
1436 return {i.y, i.x, i.z};
1437 }
1438
1439 [[nodiscard]] constexpr IndexType operator() (IndexType it) const noexcept
1440 {
1441 return it;
1442 }
1443
1444 static constexpr IndexType Inverse (IndexType it)
1445 {
1446 return it;
1447 }
1448};
1449
1451{
1452 [[nodiscard]] constexpr Dim3 operator() (Dim3 i) const noexcept
1453 {
1454 return {i.z, i.y, i.x};
1455 }
1456
1457 static constexpr Dim3 Inverse (Dim3 i)
1458 {
1459 return {i.z, i.y, i.x};
1460 }
1461
1462 [[nodiscard]] constexpr IndexType operator() (IndexType it) const noexcept
1463 {
1464 return it;
1465 }
1466
1467 static constexpr IndexType Inverse (IndexType it)
1468 {
1469 return it;
1470 }
1471};
1472
1474{
1475 // dest -> src: (x,y,z) -> (y,z,x)
1476 [[nodiscard]] constexpr Dim3 operator() (Dim3 i) const noexcept
1477 {
1478 return {i.y, i.z, i.x};
1479 }
1480
1481 // src -> dest: (x,y,z) -> (z,x,y)
1482 static constexpr Dim3 Inverse (Dim3 i)
1483 {
1484 return {i.z, i.x, i.y};
1485 }
1486
1487 [[nodiscard]] constexpr IndexType operator() (IndexType it) const noexcept
1488 {
1489 return it;
1490 }
1491
1492 static constexpr IndexType Inverse (IndexType it)
1493 {
1494 return it;
1495 }
1496};
1497
1499{
1500 // dest -> src: (x,y,z) -> (z,x,y)
1501 [[nodiscard]] constexpr Dim3 operator() (Dim3 i) const noexcept
1502 {
1503 return {i.z, i.x, i.y};
1504 }
1505
1506 // src -> dest: (x,y,z) -> (y,z,x)
1507 static constexpr Dim3 Inverse (Dim3 i)
1508 {
1509 return {i.y, i.z, i.x};
1510 }
1511
1512 [[nodiscard]] constexpr IndexType operator() (IndexType it) const noexcept
1513 {
1514 return it;
1515 }
1516
1517 static constexpr IndexType Inverse (IndexType it)
1518 {
1519 return it;
1520 }
1521};
1522
1523namespace detail
1524{
1526 {
1527 explicit SubHelper (Box const& domain);
1528
1529 [[nodiscard]] Box make_box (Box const& box) const;
1530
1531 [[nodiscard]] Periodicity make_periodicity (Periodicity const& period) const;
1532
1533 [[nodiscard]] bool ghost_safe (IntVect const& ng) const;
1534
1535 // This rearranges the order.
1536 [[nodiscard]] IntVect make_iv (IntVect const& iv) const;
1537
1538 // This keeps the order, but zero out the values in the hidden dimension.
1539 [[nodiscard]] IntVect make_safe_ghost (IntVect const& ng) const;
1540
1541 [[nodiscard]] BoxArray inverse_boxarray (BoxArray const& ba) const;
1542
1543 [[nodiscard]] IntVect inverse_order (IntVect const& order) const;
1544
1545 template <typename T>
1546 [[nodiscard]] T make_array (T const& a) const
1547 {
1548#if (AMREX_SPACEDIM == 1)
1550 return a;
1551#elif (AMREX_SPACEDIM == 2)
1552 if (m_case == case_1n) {
1553 return T{a[1],a[0]};
1554 } else {
1555 return a;
1556 }
1557#else
1558 if (m_case == case_11n) {
1559 return T{a[2],a[0],a[1]};
1560 } else if (m_case == case_1n1) {
1561 return T{a[1],a[0],a[2]};
1562 } else if (m_case == case_1nn) {
1563 return T{a[1],a[2],a[0]};
1564 } else if (m_case == case_n1n) {
1565 return T{a[0],a[2],a[1]};
1566 } else {
1567 return a;
1568 }
1569#endif
1570 }
1571
1572 [[nodiscard]] GpuArray<int,3> xyz_order () const;
1573
1574 template <typename FA>
1575 FA make_alias_mf (FA const& mf)
1576 {
1577 BoxList bl = mf.boxArray().boxList();
1578 for (auto& b : bl) {
1579 b = make_box(b);
1580 }
1581 auto const& ng = make_iv(mf.nGrowVect());
1582 FA submf(BoxArray(std::move(bl)), mf.DistributionMap(), mf.nComp(), ng, MFInfo{}.SetAlloc(false));
1583 using FAB = typename FA::fab_type;
1584 for (MFIter mfi(submf, MFItInfo().DisableDeviceSync()); mfi.isValid(); ++mfi) {
1585 submf.setFab(mfi, FAB(mfi.fabbox(), mf.nComp(), mf[mfi].dataPtr()));
1586 }
1587 return submf;
1588 }
1589
1590#if (AMREX_SPACEDIM == 2)
1591 enum Case { case_1n, case_other };
1592 int m_case = case_other;
1593#elif (AMREX_SPACEDIM == 3)
1596#endif
1597 };
1598}
1599
1600}
1601
1602#endif
#define AMREX_ENUM(CLASS,...)
Definition AMReX_Enum.H:206
#define AMREX_CUFFT_SAFE_CALL(call)
Definition AMReX_GpuError.H:92
#define AMREX_GPU_DEVICE
Definition AMReX_GpuQualifiers.H:18
amrex::ParmParse pp
Input file parser instance for the given namespace.
Definition AMReX_HypreIJIface.cpp:15
Real * pdst
Definition AMReX_HypreMLABecLap.cpp:1090
#define AMREX_D_TERM(a, b, c)
Definition AMReX_SPACE.H:172
virtual void free(void *pt)=0
A pure virtual function for deleting the arena pointed to by pt.
virtual void * alloc(std::size_t sz)=0
A collection of Boxes stored in an Array.
Definition AMReX_BoxArray.H:551
A class for managing a List of Boxes that share a common IndexType. This class implements operations ...
Definition AMReX_BoxList.H:52
__host__ __device__ Long numPts() const noexcept
Returns the number of points contained in the BoxND.
Definition AMReX_Box.H:349
__host__ __device__ IntVectND< dim > length() const noexcept
Return the length of the BoxND.
Definition AMReX_Box.H:149
Calculates the distribution of FABs to MPI processes.
Definition AMReX_DistributionMapping.H:41
Definition AMReX_IntVect.H:55
Definition AMReX_MFIter.H:57
bool isValid() const noexcept
Is the iterator valid i.e. is it associated with a FAB?
Definition AMReX_MFIter.H:141
This provides length of period for periodic domains. 0 means it is not periodic in that direction....
Definition AMReX_Periodicity.H:17
std::unique_ptr< char, DataDeleter > make_mfs_share(FA1 &fa1, FA2 &fa2)
Definition AMReX_FFT_Helper.H:1387
FA::FABType::value_type * get_fab(FA &fa)
Definition AMReX_FFT_Helper.H:1376
DistributionMapping make_iota_distromap(Long n)
Definition AMReX_FFT.cpp:88
Definition AMReX_FFT.cpp:7
Direction
Definition AMReX_FFT_Helper.H:48
Boundary
Definition AMReX_FFT_Helper.H:52
void add_vendor_plan_f(Key const &key, PlanF plan)
Definition AMReX_FFT.cpp:78
DomainStrategy
Definition AMReX_FFT_Helper.H:50
typename Plan< float >::VendorPlan PlanF
Definition AMReX_FFT_Helper.H:1203
typename Plan< double >::VendorPlan PlanD
Definition AMReX_FFT_Helper.H:1202
void add_vendor_plan_d(Key const &key, PlanD plan)
Definition AMReX_FFT.cpp:73
Kind
Definition AMReX_FFT_Helper.H:54
PlanF * get_vendor_plan_f(Key const &key)
Definition AMReX_FFT.cpp:64
PlanD * get_vendor_plan_d(Key const &key)
Definition AMReX_FFT.cpp:55
std::tuple< IntVectND< 3 >, int, Direction, Kind > Key
Definition AMReX_FFT_Helper.H:1201
void streamSynchronize() noexcept
Definition AMReX_GpuDevice.H:260
gpuStream_t gpuStream() noexcept
Definition AMReX_GpuDevice.H:241
__host__ __device__ std::pair< double, double > sincospi(double x)
Return sin(pi*x) and cos(pi*x) given x.
Definition AMReX_Math.H:198
int MyProcSub() noexcept
my sub-rank in current frame
Definition AMReX_ParallelContext.H:76
void ParallelForOMP(T n, L const &f) noexcept
Definition AMReX_GpuLaunch.H:249
__host__ __device__ void ignore_unused(const Ts &...)
This shuts up the compiler about unused variables.
Definition AMReX.H:138
__host__ __device__ Dim3 length(Array4< T > const &a) noexcept
Definition AMReX_Array4.H:326
void Abort(const std::string &msg)
Print out message to cerr and exit via abort().
Definition AMReX.cpp:230
Arena * The_Arena()
Definition AMReX_Arena.cpp:705
Definition AMReX_FabArrayCommI.H:1000
Definition AMReX_DataAllocator.H:29
Definition AMReX_Dim3.H:12
int x
Definition AMReX_Dim3.H:12
int z
Definition AMReX_Dim3.H:12
int y
Definition AMReX_Dim3.H:12
Definition AMReX_FFT_Helper.H:58
bool twod_mode
Definition AMReX_FFT_Helper.H:69
Info & setNumProcs(int n)
Definition AMReX_FFT_Helper.H:85
bool oned_mode
We might have a special twod_mode: nx or ny == 1 && nz > 1.
Definition AMReX_FFT_Helper.H:72
int batch_size
Batched FFT size. Only support in R2C, not R2X.
Definition AMReX_FFT_Helper.H:75
Info & setDomainStrategy(DomainStrategy s)
Definition AMReX_FFT_Helper.H:80
DomainStrategy domain_strategy
Domain composition strategy.
Definition AMReX_FFT_Helper.H:60
int nprocs
Max number of processes to use.
Definition AMReX_FFT_Helper.H:78
int pencil_threshold
Definition AMReX_FFT_Helper.H:64
Info & setOneDMode(bool x)
Definition AMReX_FFT_Helper.H:83
Info & setBatchSize(int bsize)
Definition AMReX_FFT_Helper.H:84
Info & setPencilThreshold(int t)
Definition AMReX_FFT_Helper.H:81
Info & setTwoDMode(bool x)
Definition AMReX_FFT_Helper.H:82
Definition AMReX_FFT_Helper.H:130
void * pf
Definition AMReX_FFT_Helper.H:165
void unpack_r2r_buffer(T *pdst, void const *pbuf) const
Definition AMReX_FFT_Helper.H:979
std::conditional_t< std::is_same_v< float, T >, cuComplex, cuDoubleComplex > VendorComplex
Definition AMReX_FFT_Helper.H:134
VendorPlan plan2
Definition AMReX_FFT_Helper.H:164
int n
Definition AMReX_FFT_Helper.H:157
void destroy()
Definition AMReX_FFT_Helper.H:175
bool defined2
Definition AMReX_FFT_Helper.H:162
void init_r2r(Box const &box, VendorComplex *pc, std::pair< Boundary, Boundary > const &bc)
Definition AMReX_FFT_Helper.H:600
void pack_r2r_buffer(void *pbuf, T const *psrc) const
Definition AMReX_FFT_Helper.H:747
static void free_scratch_space(void *p)
Definition AMReX_FFT_Helper.H:745
static void destroy_vendor_plan(VendorPlan plan)
Definition AMReX_FFT_Helper.H:1183
Kind get_r2r_kind(std::pair< Boundary, Boundary > const &bc)
Definition AMReX_FFT_Helper.H:478
cufftHandle VendorPlan
Definition AMReX_FFT_Helper.H:132
Kind kind
Definition AMReX_FFT_Helper.H:159
void init_c2c(Box const &box, VendorComplex *p, int ncomp=1, int ndims=1)
Definition AMReX_FFT_Helper.H:317
int howmany
Definition AMReX_FFT_Helper.H:158
void * pb
Definition AMReX_FFT_Helper.H:166
void init_r2c(Box const &box, T *pr, VendorComplex *pc, bool is_2d_transform=false, int ncomp=1)
Definition AMReX_FFT_Helper.H:190
void compute_r2r()
Definition AMReX_FFT_Helper.H:1129
void compute_c2c()
Definition AMReX_FFT_Helper.H:692
bool r2r_data_is_complex
Definition AMReX_FFT_Helper.H:160
void * alloc_scratch_space() const
Definition AMReX_FFT_Helper.H:731
VendorPlan plan
Definition AMReX_FFT_Helper.H:163
void compute_r2c()
Definition AMReX_FFT_Helper.H:643
bool defined
Definition AMReX_FFT_Helper.H:161
void set_ptrs(void *p0, void *p1)
Definition AMReX_FFT_Helper.H:169
void init_r2r(Box const &box, T *p, std::pair< Boundary, Boundary > const &bc, int howmany_initval=1)
Definition AMReX_FFT_Helper.H:504
void init_r2c(IntVectND< M > const &fft_size, void *, void *, bool cache, int ncomp=1)
Definition AMReX_FFT_Helper.H:1213
Definition AMReX_FFT_Helper.H:1499
constexpr Dim3 operator()(Dim3 i) const noexcept
Definition AMReX_FFT_Helper.H:1501
static constexpr IndexType Inverse(IndexType it)
Definition AMReX_FFT_Helper.H:1517
static constexpr Dim3 Inverse(Dim3 i)
Definition AMReX_FFT_Helper.H:1507
Definition AMReX_FFT_Helper.H:1474
static constexpr IndexType Inverse(IndexType it)
Definition AMReX_FFT_Helper.H:1492
constexpr Dim3 operator()(Dim3 i) const noexcept
Definition AMReX_FFT_Helper.H:1476
static constexpr Dim3 Inverse(Dim3 i)
Definition AMReX_FFT_Helper.H:1482
Definition AMReX_FFT_Helper.H:1428
constexpr Dim3 operator()(Dim3 i) const noexcept
Definition AMReX_FFT_Helper.H:1429
static constexpr IndexType Inverse(IndexType it)
Definition AMReX_FFT_Helper.H:1444
static constexpr Dim3 Inverse(Dim3 i)
Definition AMReX_FFT_Helper.H:1434
Definition AMReX_FFT_Helper.H:1451
static constexpr Dim3 Inverse(Dim3 i)
Definition AMReX_FFT_Helper.H:1457
static constexpr IndexType Inverse(IndexType it)
Definition AMReX_FFT_Helper.H:1467
constexpr Dim3 operator()(Dim3 i) const noexcept
Definition AMReX_FFT_Helper.H:1452
Definition AMReX_FFT_Helper.H:1526
T make_array(T const &a) const
Definition AMReX_FFT_Helper.H:1546
Box make_box(Box const &box) const
Definition AMReX_FFT.cpp:142
BoxArray inverse_boxarray(BoxArray const &ba) const
Definition AMReX_FFT.cpp:209
bool ghost_safe(IntVect const &ng) const
Definition AMReX_FFT.cpp:152
GpuArray< int, 3 > xyz_order() const
Definition AMReX_FFT.cpp:326
IntVect inverse_order(IntVect const &order) const
Definition AMReX_FFT.cpp:266
int m_case
Definition AMReX_FFT_Helper.H:1595
IntVect make_iv(IntVect const &iv) const
Definition AMReX_FFT.cpp:178
FA make_alias_mf(FA const &mf)
Definition AMReX_FFT_Helper.H:1575
Periodicity make_periodicity(Periodicity const &period) const
Definition AMReX_FFT.cpp:147
IntVect make_safe_ghost(IntVect const &ng) const
Definition AMReX_FFT.cpp:183
Case
Definition AMReX_FFT_Helper.H:1594
@ case_n1n
Definition AMReX_FFT_Helper.H:1594
@ case_11n
Definition AMReX_FFT_Helper.H:1594
@ case_1nn
Definition AMReX_FFT_Helper.H:1594
@ case_1n1
Definition AMReX_FFT_Helper.H:1594
@ case_other
Definition AMReX_FFT_Helper.H:1594
Definition AMReX_Array.H:34
A host / device complex number type, because std::complex doesn't work in device code with Cuda yet.
Definition AMReX_GpuComplex.H:29
FabArray memory allocation information.
Definition AMReX_FabArray.H:66
MFInfo & SetAlloc(bool a) noexcept
Definition AMReX_FabArray.H:73
Definition AMReX_MFIter.H:20