Block-Structured AMR Software Framework
Loading...
Searching...
No Matches
AMReX_FFT_Helper.H
Go to the documentation of this file.
1#ifndef AMREX_FFT_HELPER_H_
2#define AMREX_FFT_HELPER_H_
3#include <AMReX_Config.H>
4
5#include <AMReX.H>
6#include <AMReX_BLProfiler.H>
9#include <AMReX_Enum.H>
10#include <AMReX_FabArray.H>
11#include <AMReX_Gpu.H>
12#include <AMReX_GpuComplex.H>
13#include <AMReX_Math.H>
14#include <AMReX_Periodicity.H>
15
22#if defined(AMREX_USE_CUDA)
23# include <cufft.h>
24# include <cuComplex.h>
25#elif defined(AMREX_USE_HIP)
26# if __has_include(<rocfft/rocfft.h>) // ROCm 5.3+
27# include <rocfft/rocfft.h>
28# else
29# include <rocfft.h>
30# endif
31# include <hip/hip_complex.h>
32#elif defined(AMREX_USE_SYCL)
33# if __has_include(<oneapi/mkl/dft.hpp>) // oneAPI 2025.0
34# include <oneapi/mkl/dft.hpp>
35#else
36# define AMREX_USE_MKL_DFTI_2024 1
37# include <oneapi/mkl/dfti.hpp>
38# endif
39#else
40# include <fftw3.h>
41#endif
42
43#include <algorithm>
44#include <complex>
45#include <limits>
46#include <memory>
47#include <tuple>
48#include <utility>
49#include <variant>
50
51namespace amrex::FFT
52{
53
54enum struct Direction { forward, backward, both, none };
55
57
59
62
63struct Info
64{
67
71
75 bool twod_mode = false;
76
78 bool oned_mode = false;
79
81 int batch_size = 1;
82
84 int nprocs = std::numeric_limits<int>::max();
85
99 Info& setPencilThreshold (int t) { pencil_threshold = t; return *this; }
106 Info& setTwoDMode (bool x) { twod_mode = x; return *this; }
117 Info& setOneDMode (bool x) { oned_mode = x; return *this; }
124 Info& setBatchSize (int bsize) { batch_size = bsize; return *this; }
131 Info& setNumProcs (int n) { nprocs = n; return *this; }
132};
133
134#ifdef AMREX_USE_HIP
136namespace detail { void hip_execute (rocfft_plan plan, void **in, void **out); }
138#endif
139
140#ifdef AMREX_USE_SYCL
142namespace detail
143{
144inline void assert_no_external_stream ()
145{
148 "SYCL FFT does not support external GPU streams.");
149}
150
151template <typename T, Direction direction, typename P, typename TI, typename TO>
152void sycl_execute (P* plan, TI* in, TO* out)
153{
154 assert_no_external_stream();
155#ifndef AMREX_USE_MKL_DFTI_2024
156 std::int64_t workspaceSize = 0;
157#else
158 std::size_t workspaceSize = 0;
159#endif
160 plan->get_value(oneapi::mkl::dft::config_param::WORKSPACE_BYTES,
161 &workspaceSize);
162 auto* buffer = (T*)amrex::The_Arena()->alloc(workspaceSize);
163 plan->set_workspace(buffer);
164 sycl::event r;
165 if (std::is_same_v<TI,TO>) {
167 if constexpr (direction == Direction::forward) {
168 r = oneapi::mkl::dft::compute_forward(*plan, out);
169 } else {
170 r = oneapi::mkl::dft::compute_backward(*plan, out);
171 }
172 } else {
173 if constexpr (direction == Direction::forward) {
174 r = oneapi::mkl::dft::compute_forward(*plan, in, out);
175 } else {
176 r = oneapi::mkl::dft::compute_backward(*plan, in, out);
177 }
178 }
179 r.wait();
180 amrex::The_Arena()->free(buffer);
181}
182}
184#endif
185
186template <typename T>
187struct Plan
188{
189#if defined(AMREX_USE_CUDA)
190 using VendorPlan = cufftHandle;
191 using VendorComplex = std::conditional_t<std::is_same_v<float,T>,
192 cuComplex, cuDoubleComplex>;
193#elif defined(AMREX_USE_HIP)
194 using VendorPlan = rocfft_plan;
195 using VendorComplex = std::conditional_t<std::is_same_v<float,T>,
196 float2, double2>;
197#elif defined(AMREX_USE_SYCL)
198 using mkl_desc_r = oneapi::mkl::dft::descriptor<std::is_same_v<float,T>
199 ? oneapi::mkl::dft::precision::SINGLE
200 : oneapi::mkl::dft::precision::DOUBLE,
201 oneapi::mkl::dft::domain::REAL>;
202 using mkl_desc_c = oneapi::mkl::dft::descriptor<std::is_same_v<float,T>
203 ? oneapi::mkl::dft::precision::SINGLE
204 : oneapi::mkl::dft::precision::DOUBLE,
205 oneapi::mkl::dft::domain::COMPLEX>;
206 using VendorPlan = std::variant<mkl_desc_r*,mkl_desc_c*>;
207 using VendorComplex = std::complex<T>;
208#else
209 using VendorPlan = std::conditional_t<std::is_same_v<float,T>,
210 fftwf_plan, fftw_plan>;
211 using VendorComplex = std::conditional_t<std::is_same_v<float,T>,
212 fftwf_complex, fftw_complex>;
213#endif
214
215 int n = 0;
216 int howmany = 0;
219 bool defined = false;
220 bool defined2 = false;
223 void* pf = nullptr;
224 void* pb = nullptr;
225
226#ifdef AMREX_USE_GPU
233 void set_ptrs (void* p0, void* p1) {
234 pf = p0;
235 pb = p1;
236 }
237#endif
238
242 void destroy ()
243 {
244 if (defined) {
246 defined = false;
247 }
248#if !defined(AMREX_USE_GPU)
249 if (defined2) {
251 defined2 = false;
252 }
253#endif
254 }
255
265 template <Direction D>
266 void init_r2c (Box const& box, T* pr, VendorComplex* pc, bool is_2d_transform = false, int ncomp = 1)
267 {
268 static_assert(D == Direction::forward || D == Direction::backward);
269
270 int rank = is_2d_transform ? 2 : 1;
271
273 defined = true;
274 pf = (void*)pr;
275 pb = (void*)pc;
276
277 int len[2] = {};
278 if (rank == 1) {
279 len[0] = box.length(0);
280 len[1] = box.length(0); // Not used except for HIP. Yes it's `(0)`.
281 } else {
282 len[0] = box.length(1); // Most FFT libraries assume row-major ordering
283 len[1] = box.length(0); // except for rocfft
284 }
285 int nr = (rank == 1) ? len[0] : len[0]*len[1];
286 n = nr;
287 int nc = (rank == 1) ? (len[0]/2+1) : (len[1]/2+1)*len[0];
288#if (AMREX_SPACEDIM == 1)
289 howmany = 1;
290#else
291 howmany = (rank == 1) ? AMREX_D_TERM(1, *box.length(1), *box.length(2))
292 : AMREX_D_TERM(1, *1 , *box.length(2));
293#endif
294 howmany *= ncomp;
295
297
298#if defined(AMREX_USE_CUDA)
299
300 AMREX_CUFFT_SAFE_CALL(cufftCreate(&plan));
301 AMREX_CUFFT_SAFE_CALL(cufftSetAutoAllocation(plan, 0));
302 std::size_t work_size;
303 if constexpr (D == Direction::forward) {
304 cufftType fwd_type = std::is_same_v<float,T> ? CUFFT_R2C : CUFFT_D2Z;
306 (cufftMakePlanMany(plan, rank, len, nullptr, 1, nr, nullptr, 1, nc, fwd_type, howmany, &work_size));
307 } else {
308 cufftType bwd_type = std::is_same_v<float,T> ? CUFFT_C2R : CUFFT_Z2D;
310 (cufftMakePlanMany(plan, rank, len, nullptr, 1, nc, nullptr, 1, nr, bwd_type, howmany, &work_size));
311 }
312
313#elif defined(AMREX_USE_HIP)
314
315 auto prec = std::is_same_v<float,T> ? rocfft_precision_single : rocfft_precision_double;
316 // switch to column-major ordering
317 std::size_t length[2] = {std::size_t(len[1]), std::size_t(len[0])};
318 if constexpr (D == Direction::forward) {
319 AMREX_ROCFFT_SAFE_CALL
320 (rocfft_plan_create(&plan, rocfft_placement_notinplace,
321 rocfft_transform_type_real_forward, prec, rank,
322 length, howmany, nullptr));
323 } else {
324 AMREX_ROCFFT_SAFE_CALL
325 (rocfft_plan_create(&plan, rocfft_placement_notinplace,
326 rocfft_transform_type_real_inverse, prec, rank,
327 length, howmany, nullptr));
328 }
329
330#elif defined(AMREX_USE_SYCL)
331
332 mkl_desc_r* pp;
333 if (rank == 1) {
334 pp = new mkl_desc_r(len[0]);
335 } else {
336 pp = new mkl_desc_r({std::int64_t(len[0]), std::int64_t(len[1])});
337 }
338#ifndef AMREX_USE_MKL_DFTI_2024
339 pp->set_value(oneapi::mkl::dft::config_param::PLACEMENT,
340 oneapi::mkl::dft::config_value::NOT_INPLACE);
341#else
342 pp->set_value(oneapi::mkl::dft::config_param::PLACEMENT, DFTI_NOT_INPLACE);
343#endif
344 pp->set_value(oneapi::mkl::dft::config_param::NUMBER_OF_TRANSFORMS, howmany);
345 pp->set_value(oneapi::mkl::dft::config_param::FWD_DISTANCE, nr);
346 pp->set_value(oneapi::mkl::dft::config_param::BWD_DISTANCE, nc);
347 std::vector<std::int64_t> strides;
348 strides.push_back(0);
349 if (rank == 2) { strides.push_back(len[1]); }
350 strides.push_back(1);
351#ifndef AMREX_USE_MKL_DFTI_2024
352 pp->set_value(oneapi::mkl::dft::config_param::FWD_STRIDES, strides);
353 // Do not set BWD_STRIDES
354#else
355 pp->set_value(oneapi::mkl::dft::config_param::FWD_STRIDES, strides.data());
356 // Do not set BWD_STRIDES
357#endif
358 pp->set_value(oneapi::mkl::dft::config_param::WORKSPACE,
359 oneapi::mkl::dft::config_value::WORKSPACE_EXTERNAL);
360 detail::assert_no_external_stream();
361 pp->commit(amrex::Gpu::Device::streamQueue());
362 plan = pp;
363
364#else /* FFTW */
365
366 if constexpr (std::is_same_v<float,T>) {
367 if constexpr (D == Direction::forward) {
368 plan = fftwf_plan_many_dft_r2c
369 (rank, len, howmany, pr, nullptr, 1, nr, pc, nullptr, 1, nc,
370 FFTW_ESTIMATE | FFTW_DESTROY_INPUT);
371 } else {
372 plan = fftwf_plan_many_dft_c2r
373 (rank, len, howmany, pc, nullptr, 1, nc, pr, nullptr, 1, nr,
374 FFTW_ESTIMATE | FFTW_DESTROY_INPUT);
375 }
376 } else {
377 if constexpr (D == Direction::forward) {
378 plan = fftw_plan_many_dft_r2c
379 (rank, len, howmany, pr, nullptr, 1, nr, pc, nullptr, 1, nc,
380 FFTW_ESTIMATE | FFTW_DESTROY_INPUT);
381 } else {
382 plan = fftw_plan_many_dft_c2r
383 (rank, len, howmany, pc, nullptr, 1, nc, pr, nullptr, 1, nr,
384 FFTW_ESTIMATE | FFTW_DESTROY_INPUT);
385 }
386 }
387#endif
388 }
389
399 template <Direction D, int M>
400 void init_r2c (IntVectND<M> const& fft_size, void* pbf, void* pbb, bool cache, int ncomp = 1);
401
417 template <Direction D>
418 void init_c2c (Box const& box, VendorComplex* p, int ncomp = 1, int ndims = 1)
419 {
420 static_assert(D == Direction::forward || D == Direction::backward);
421
423 defined = true;
424 pf = (void*)p;
425 pb = (void*)p;
426
427 int len[3] = {};
428
429 if (ndims == 1) {
430 n = box.length(0);
431 howmany = AMREX_D_TERM(1, *box.length(1), *box.length(2));
432 howmany *= ncomp;
433 len[0] = box.length(0);
434 }
435#if (AMREX_SPACEDIM >= 2)
436 else if (ndims == 2) {
437 n = box.length(0) * box.length(1);
438#if (AMREX_SPACEDIM == 2)
439 howmany = ncomp;
440#else
441 howmany = box.length(2) * ncomp;
442#endif
443 len[0] = box.length(1);
444 len[1] = box.length(0);
445 }
446#if (AMREX_SPACEDIM == 3)
447 else if (ndims == 3) {
448 n = box.length(0) * box.length(1) * box.length(2);
449 howmany = ncomp;
450 len[0] = box.length(2);
451 len[1] = box.length(1);
452 len[2] = box.length(0);
453 }
454#endif
455#endif
456
457#if defined(AMREX_USE_CUDA)
458 AMREX_CUFFT_SAFE_CALL(cufftCreate(&plan));
459 AMREX_CUFFT_SAFE_CALL(cufftSetAutoAllocation(plan, 0));
460
461 cufftType t = std::is_same_v<float,T> ? CUFFT_C2C : CUFFT_Z2Z;
462 std::size_t work_size;
464 (cufftMakePlanMany(plan, ndims, len, nullptr, 1, n, nullptr, 1, n, t, howmany, &work_size));
465
466#elif defined(AMREX_USE_HIP)
467
468 auto prec = std::is_same_v<float,T> ? rocfft_precision_single
469 : rocfft_precision_double;
470 auto dir= (D == Direction::forward) ? rocfft_transform_type_complex_forward
471 : rocfft_transform_type_complex_inverse;
472 std::size_t length[3];
473 if (ndims == 1) {
474 length[0] = len[0];
475 } else if (ndims == 2) {
476 length[0] = len[1];
477 length[1] = len[0];
478 } else {
479 length[0] = len[2];
480 length[1] = len[1];
481 length[2] = len[0];
482 }
483 AMREX_ROCFFT_SAFE_CALL
484 (rocfft_plan_create(&plan, rocfft_placement_inplace, dir, prec, ndims,
485 length, howmany, nullptr));
486
487#elif defined(AMREX_USE_SYCL)
488
489 mkl_desc_c* pp;
490 if (ndims == 1) {
491 pp = new mkl_desc_c(n);
492 } else if (ndims == 2) {
493 pp = new mkl_desc_c({std::int64_t(len[0]), std::int64_t(len[1])});
494 } else {
495 pp = new mkl_desc_c({std::int64_t(len[0]), std::int64_t(len[1]), std::int64_t(len[2])});
496 }
497#ifndef AMREX_USE_MKL_DFTI_2024
498 pp->set_value(oneapi::mkl::dft::config_param::PLACEMENT,
499 oneapi::mkl::dft::config_value::INPLACE);
500#else
501 pp->set_value(oneapi::mkl::dft::config_param::PLACEMENT, DFTI_INPLACE);
502#endif
503 pp->set_value(oneapi::mkl::dft::config_param::NUMBER_OF_TRANSFORMS, howmany);
504 pp->set_value(oneapi::mkl::dft::config_param::FWD_DISTANCE, n);
505 pp->set_value(oneapi::mkl::dft::config_param::BWD_DISTANCE, n);
506 std::vector<std::int64_t> strides(ndims+1);
507 strides[0] = 0;
508 strides[ndims] = 1;
509 for (int i = ndims-1; i >= 1; --i) {
510 strides[i] = strides[i+1] * len[i];
511 }
512#ifndef AMREX_USE_MKL_DFTI_2024
513 pp->set_value(oneapi::mkl::dft::config_param::FWD_STRIDES, strides);
514 pp->set_value(oneapi::mkl::dft::config_param::BWD_STRIDES, strides);
515#else
516 pp->set_value(oneapi::mkl::dft::config_param::FWD_STRIDES, strides.data());
517 pp->set_value(oneapi::mkl::dft::config_param::BWD_STRIDES, strides.data());
518#endif
519 pp->set_value(oneapi::mkl::dft::config_param::WORKSPACE,
520 oneapi::mkl::dft::config_value::WORKSPACE_EXTERNAL);
521 detail::assert_no_external_stream();
522 pp->commit(amrex::Gpu::Device::streamQueue());
523 plan = pp;
524
525#else /* FFTW */
526
527 if constexpr (std::is_same_v<float,T>) {
528 if constexpr (D == Direction::forward) {
529 plan = fftwf_plan_many_dft
530 (ndims, len, howmany, p, nullptr, 1, n, p, nullptr, 1, n, -1,
531 FFTW_ESTIMATE);
532 } else {
533 plan = fftwf_plan_many_dft
534 (ndims, len, howmany, p, nullptr, 1, n, p, nullptr, 1, n, +1,
535 FFTW_ESTIMATE);
536 }
537 } else {
538 if constexpr (D == Direction::forward) {
539 plan = fftw_plan_many_dft
540 (ndims, len, howmany, p, nullptr, 1, n, p, nullptr, 1, n, -1,
541 FFTW_ESTIMATE);
542 } else {
543 plan = fftw_plan_many_dft
544 (ndims, len, howmany, p, nullptr, 1, n, p, nullptr, 1, n, +1,
545 FFTW_ESTIMATE);
546 }
547 }
548#endif
549 }
550
551#ifndef AMREX_USE_GPU
552 template <Direction D>
553 fftw_r2r_kind get_fftw_kind (std::pair<Boundary,Boundary> const& bc)
554 {
555 if (bc.first == Boundary::even && bc.second == Boundary::even)
556 {
557 return (D == Direction::forward) ? FFTW_REDFT10 : FFTW_REDFT01;
558 }
559 else if (bc.first == Boundary::even && bc.second == Boundary::odd)
560 {
561 return FFTW_REDFT11;
562 }
563 else if (bc.first == Boundary::odd && bc.second == Boundary::even)
564 {
565 return FFTW_RODFT11;
566 }
567 else if (bc.first == Boundary::odd && bc.second == Boundary::odd)
568 {
569 return (D == Direction::forward) ? FFTW_RODFT10 : FFTW_RODFT01;
570 }
571 else {
572 amrex::Abort("FFT: unsupported BC");
573 return fftw_r2r_kind{};
574 }
575
576 }
577#endif
578
585 template <Direction D>
586 Kind get_r2r_kind (std::pair<Boundary,Boundary> const& bc)
587 {
588 if (bc.first == Boundary::even && bc.second == Boundary::even)
589 {
591 }
592 else if (bc.first == Boundary::even && bc.second == Boundary::odd)
593 {
594 return Kind::r2r_eo;
595 }
596 else if (bc.first == Boundary::odd && bc.second == Boundary::even)
597 {
598 return Kind::r2r_oe;
599 }
600 else if (bc.first == Boundary::odd && bc.second == Boundary::odd)
601 {
603 }
604 else {
605 amrex::Abort("FFT: unsupported BC");
606 return Kind::none;
607 }
608
609 }
610
619 template <Direction D>
620 void init_r2r (Box const& box, T* p, std::pair<Boundary,Boundary> const& bc,
621 int howmany_initval = 1)
622 {
623 static_assert(D == Direction::forward || D == Direction::backward);
624
625 kind = get_r2r_kind<D>(bc);
626 defined = true;
627 pf = (void*)p;
628 pb = (void*)p;
629
630 n = box.length(0);
631 howmany = AMREX_D_TERM(howmany_initval, *box.length(1), *box.length(2));
632
633#if defined(AMREX_USE_GPU)
634 int nex=0;
635 if (bc.first == Boundary::odd && bc.second == Boundary::odd &&
636 Direction::forward == D) {
637 nex = 2*n;
638 } else if (bc.first == Boundary::odd && bc.second == Boundary::odd &&
639 Direction::backward == D) {
640 nex = 4*n;
641 } else if (bc.first == Boundary::even && bc.second == Boundary::even &&
642 Direction::forward == D) {
643 nex = 2*n;
644 } else if (bc.first == Boundary::even && bc.second == Boundary::even &&
645 Direction::backward == D) {
646 nex = 4*n;
647 } else if ((bc.first == Boundary::even && bc.second == Boundary::odd) ||
648 (bc.first == Boundary::odd && bc.second == Boundary::even)) {
649 nex = 4*n;
650 } else {
651 amrex::Abort("FFT: unsupported BC");
652 }
653 int nc = (nex/2) + 1;
654
655#if defined (AMREX_USE_CUDA)
656
657 AMREX_CUFFT_SAFE_CALL(cufftCreate(&plan));
658 AMREX_CUFFT_SAFE_CALL(cufftSetAutoAllocation(plan, 0));
659 cufftType fwd_type = std::is_same_v<float,T> ? CUFFT_R2C : CUFFT_D2Z;
660 std::size_t work_size;
662 (cufftMakePlanMany(plan, 1, &nex, nullptr, 1, nc*2, nullptr, 1, nc, fwd_type, howmany, &work_size));
663
664#elif defined(AMREX_USE_HIP)
665
667 auto prec = std::is_same_v<float,T> ? rocfft_precision_single : rocfft_precision_double;
668 const std::size_t length = nex;
669 AMREX_ROCFFT_SAFE_CALL
670 (rocfft_plan_create(&plan, rocfft_placement_inplace,
671 rocfft_transform_type_real_forward, prec, 1,
672 &length, howmany, nullptr));
673
674#elif defined(AMREX_USE_SYCL)
675
676 auto* pp = new mkl_desc_r(nex);
677#ifndef AMREX_USE_MKL_DFTI_2024
678 pp->set_value(oneapi::mkl::dft::config_param::PLACEMENT,
679 oneapi::mkl::dft::config_value::INPLACE);
680#else
681 pp->set_value(oneapi::mkl::dft::config_param::PLACEMENT, DFTI_INPLACE);
682#endif
683 pp->set_value(oneapi::mkl::dft::config_param::NUMBER_OF_TRANSFORMS, howmany);
684 pp->set_value(oneapi::mkl::dft::config_param::FWD_DISTANCE, nc*2);
685 pp->set_value(oneapi::mkl::dft::config_param::BWD_DISTANCE, nc);
686 std::vector<std::int64_t> strides = {0,1};
687#ifndef AMREX_USE_MKL_DFTI_2024
688 pp->set_value(oneapi::mkl::dft::config_param::FWD_STRIDES, strides);
689 pp->set_value(oneapi::mkl::dft::config_param::BWD_STRIDES, strides);
690#else
691 pp->set_value(oneapi::mkl::dft::config_param::FWD_STRIDES, strides.data());
692 pp->set_value(oneapi::mkl::dft::config_param::BWD_STRIDES, strides.data());
693#endif
694 pp->set_value(oneapi::mkl::dft::config_param::WORKSPACE,
695 oneapi::mkl::dft::config_value::WORKSPACE_EXTERNAL);
696 detail::assert_no_external_stream();
697 pp->commit(amrex::Gpu::Device::streamQueue());
698 plan = pp;
699
700#endif
701
702#else /* FFTW */
703 auto fftw_kind = get_fftw_kind<D>(bc);
704 if constexpr (std::is_same_v<float,T>) {
705 plan = fftwf_plan_many_r2r
706 (1, &n, howmany, p, nullptr, 1, n, p, nullptr, 1, n, &fftw_kind,
707 FFTW_ESTIMATE);
708 } else {
709 plan = fftw_plan_many_r2r
710 (1, &n, howmany, p, nullptr, 1, n, p, nullptr, 1, n, &fftw_kind,
711 FFTW_ESTIMATE);
712 }
713#endif
714 }
715
723 template <Direction D>
724 void init_r2r (Box const& box, VendorComplex* pc,
725 std::pair<Boundary,Boundary> const& bc)
726 {
727 static_assert(D == Direction::forward || D == Direction::backward);
728
729 auto* p = (T*)pc;
730
731#if defined(AMREX_USE_GPU)
732
733 init_r2r<D>(box, p, bc, 2);
734 r2r_data_is_complex = true;
735
736#else
737
738 kind = get_r2r_kind<D>(bc);
739 defined = true;
740 pf = (void*)p;
741 pb = (void*)p;
742
743 n = box.length(0);
744 howmany = AMREX_D_TERM(1, *box.length(1), *box.length(2));
745
746 defined2 = true;
747 auto fftw_kind = get_fftw_kind<D>(bc);
748 if constexpr (std::is_same_v<float,T>) {
749 plan = fftwf_plan_many_r2r
750 (1, &n, howmany, p, nullptr, 2, n*2, p, nullptr, 2, n*2, &fftw_kind,
751 FFTW_ESTIMATE);
752 plan2 = fftwf_plan_many_r2r
753 (1, &n, howmany, p+1, nullptr, 2, n*2, p+1, nullptr, 2, n*2, &fftw_kind,
754 FFTW_ESTIMATE);
755 } else {
756 plan = fftw_plan_many_r2r
757 (1, &n, howmany, p, nullptr, 2, n*2, p, nullptr, 2, n*2, &fftw_kind,
758 FFTW_ESTIMATE);
759 plan2 = fftw_plan_many_r2r
760 (1, &n, howmany, p+1, nullptr, 2, n*2, p+1, nullptr, 2, n*2, &fftw_kind,
761 FFTW_ESTIMATE);
762 }
763#endif
764 }
765
771 template <Direction D>
773 {
774 static_assert(D == Direction::forward || D == Direction::backward);
775 if (!defined) { return; }
776
777 using TI = std::conditional_t<(D == Direction::forward), T, VendorComplex>;
778 using TO = std::conditional_t<(D == Direction::backward), T, VendorComplex>;
779 auto* pi = (TI*)((D == Direction::forward) ? pf : pb);
780 auto* po = (TO*)((D == Direction::forward) ? pb : pf);
781
782#if defined(AMREX_USE_CUDA)
783 AMREX_CUFFT_SAFE_CALL(cufftSetStream(plan, Gpu::gpuStream()));
784
785 std::size_t work_size = 0;
786 AMREX_CUFFT_SAFE_CALL(cufftGetSize(plan, &work_size));
787
788 auto* work_area = The_Arena()->alloc(work_size);
789 AMREX_CUFFT_SAFE_CALL(cufftSetWorkArea(plan, work_area));
790
791 if constexpr (D == Direction::forward) {
792 if constexpr (std::is_same_v<float,T>) {
793 AMREX_CUFFT_SAFE_CALL(cufftExecR2C(plan, pi, po));
794 } else {
795 AMREX_CUFFT_SAFE_CALL(cufftExecD2Z(plan, pi, po));
796 }
797 } else {
798 if constexpr (std::is_same_v<float,T>) {
799 AMREX_CUFFT_SAFE_CALL(cufftExecC2R(plan, pi, po));
800 } else {
801 AMREX_CUFFT_SAFE_CALL(cufftExecZ2D(plan, pi, po));
802 }
803 }
805 The_Arena()->free(work_area);
806#elif defined(AMREX_USE_HIP)
807 detail::hip_execute(plan, (void**)&pi, (void**)&po);
808#elif defined(AMREX_USE_SYCL)
809 detail::sycl_execute<T,D>(std::get<0>(plan), pi, po);
810#else
812 if constexpr (std::is_same_v<float,T>) {
813 fftwf_execute(plan);
814 } else {
815 fftw_execute(plan);
816 }
817#endif
818 }
819
823 template <Direction D>
825 {
826 static_assert(D == Direction::forward || D == Direction::backward);
827 if (!defined) { return; }
828
829 auto* p = (VendorComplex*)pf;
830
831#if defined(AMREX_USE_CUDA)
832 AMREX_CUFFT_SAFE_CALL(cufftSetStream(plan, Gpu::gpuStream()));
833
834 std::size_t work_size = 0;
835 AMREX_CUFFT_SAFE_CALL(cufftGetSize(plan, &work_size));
836
837 auto* work_area = The_Arena()->alloc(work_size);
838 AMREX_CUFFT_SAFE_CALL(cufftSetWorkArea(plan, work_area));
839
840 auto dir = (D == Direction::forward) ? CUFFT_FORWARD : CUFFT_INVERSE;
841 if constexpr (std::is_same_v<float,T>) {
842 AMREX_CUFFT_SAFE_CALL(cufftExecC2C(plan, p, p, dir));
843 } else {
844 AMREX_CUFFT_SAFE_CALL(cufftExecZ2Z(plan, p, p, dir));
845 }
847 The_Arena()->free(work_area);
848#elif defined(AMREX_USE_HIP)
849 detail::hip_execute(plan, (void**)&p, (void**)&p);
850#elif defined(AMREX_USE_SYCL)
851 detail::sycl_execute<T,D>(std::get<1>(plan), p, p);
852#else
854 if constexpr (std::is_same_v<float,T>) {
855 fftwf_execute(plan);
856 } else {
857 fftw_execute(plan);
858 }
859#endif
860 }
861
862#ifdef AMREX_USE_GPU
868 [[nodiscard]] void* alloc_scratch_space () const
869 {
870 int nc = 0;
871 if (kind == Kind::r2r_oo_f || kind == Kind::r2r_ee_f) {
872 nc = n + 1;
873 } else if (kind == Kind::r2r_oo_b || kind == Kind::r2r_ee_b ||
875 nc = 2*n+1;
876 } else {
877 amrex::Abort("FFT: alloc_scratch_space: unsupported kind");
878 }
879 return The_Arena()->alloc(sizeof(GpuComplex<T>)*nc*howmany);
880 }
881
885 static void free_scratch_space (void* p) { The_Arena()->free(p); }
886
893 void pack_r2r_buffer (void* pbuf, T const* psrc) const
894 {
895 auto* pdst = (T*) pbuf;
896 if (kind == Kind::r2r_oo_f || kind == Kind::r2r_ee_f) {
897 T sign = (kind == Kind::r2r_oo_f) ? T(-1) : T(1);
898 int ostride = (n+1)*2;
899 int istride = n;
900 int nex = 2*n;
901 int norig = n;
902 Long nelems = Long(nex)*howmany;
904 ParallelForOMP(nelems/2, [=] AMREX_GPU_DEVICE (Long ielem)
905 {
906 auto batch = ielem / Long(nex);
907 auto i = int(ielem - batch*nex);
908 for (int ir = 0; ir < 2; ++ir) {
909 auto* po = pdst + (2*batch+ir)*ostride + i;
910 auto const* pi = psrc + 2*batch*istride + ir;
911 if (i < norig) {
912 *po = pi[i*2];
913 } else {
914 *po = sign * pi[(2*norig-1-i)*2];
915 }
916 }
917 });
918 } else {
919 ParallelForOMP(nelems, [=] AMREX_GPU_DEVICE (Long ielem)
920 {
921 auto batch = ielem / Long(nex);
922 auto i = int(ielem - batch*nex);
923 auto* po = pdst + batch*ostride + i;
924 auto const* pi = psrc + batch*istride;
925 if (i < norig) {
926 *po = pi[i];
927 } else {
928 *po = sign * pi[2*norig-1-i];
929 }
930 });
931 }
932 } else if (kind == Kind::r2r_oo_b) {
933 int ostride = (2*n+1)*2;
934 int istride = n;
935 int nex = 4*n;
936 int norig = n;
937 Long nelems = Long(nex)*howmany;
939 ParallelForOMP(nelems/2, [=] AMREX_GPU_DEVICE (Long ielem)
940 {
941 auto batch = ielem / Long(nex);
942 auto i = int(ielem - batch*nex);
943 for (int ir = 0; ir < 2; ++ir) {
944 auto* po = pdst + (2*batch+ir)*ostride + i;
945 auto const* pi = psrc + 2*batch*istride + ir;
946 if (i < norig) {
947 *po = pi[i*2];
948 } else if (i < (2*norig-1)) {
949 *po = pi[(2*norig-2-i)*2];
950 } else if (i == (2*norig-1)) {
951 *po = T(0);
952 } else if (i < (3*norig)) {
953 *po = -pi[(i-2*norig)*2];
954 } else if (i < (4*norig-1)) {
955 *po = -pi[(4*norig-2-i)*2];
956 } else {
957 *po = T(0);
958 }
959 }
960 });
961 } else {
962 ParallelForOMP(nelems, [=] AMREX_GPU_DEVICE (Long ielem)
963 {
964 auto batch = ielem / Long(nex);
965 auto i = int(ielem - batch*nex);
966 auto* po = pdst + batch*ostride + i;
967 auto const* pi = psrc + batch*istride;
968 if (i < norig) {
969 *po = pi[i];
970 } else if (i < (2*norig-1)) {
971 *po = pi[2*norig-2-i];
972 } else if (i == (2*norig-1)) {
973 *po = T(0);
974 } else if (i < (3*norig)) {
975 *po = -pi[i-2*norig];
976 } else if (i < (4*norig-1)) {
977 *po = -pi[4*norig-2-i];
978 } else {
979 *po = T(0);
980 }
981 });
982 }
983 } else if (kind == Kind::r2r_ee_b) {
984 int ostride = (2*n+1)*2;
985 int istride = n;
986 int nex = 4*n;
987 int norig = n;
988 Long nelems = Long(nex)*howmany;
990 ParallelForOMP(nelems/2, [=] AMREX_GPU_DEVICE (Long ielem)
991 {
992 auto batch = ielem / Long(nex);
993 auto i = int(ielem - batch*nex);
994 for (int ir = 0; ir < 2; ++ir) {
995 auto* po = pdst + (2*batch+ir)*ostride + i;
996 auto const* pi = psrc + 2*batch*istride + ir;
997 if (i < norig) {
998 *po = pi[i*2];
999 } else if (i == norig) {
1000 *po = T(0);
1001 } else if (i < (2*norig+1)) {
1002 *po = -pi[(2*norig-i)*2];
1003 } else if (i < (3*norig)) {
1004 *po = -pi[(i-2*norig)*2];
1005 } else if (i == 3*norig) {
1006 *po = T(0);
1007 } else {
1008 *po = pi[(4*norig-i)*2];
1009 }
1010 }
1011 });
1012 } else {
1013 ParallelForOMP(nelems, [=] AMREX_GPU_DEVICE (Long ielem)
1014 {
1015 auto batch = ielem / Long(nex);
1016 auto i = int(ielem - batch*nex);
1017 auto* po = pdst + batch*ostride + i;
1018 auto const* pi = psrc + batch*istride;
1019 if (i < norig) {
1020 *po = pi[i];
1021 } else if (i == norig) {
1022 *po = T(0);
1023 } else if (i < (2*norig+1)) {
1024 *po = -pi[2*norig-i];
1025 } else if (i < (3*norig)) {
1026 *po = -pi[i-2*norig];
1027 } else if (i == 3*norig) {
1028 *po = T(0);
1029 } else {
1030 *po = pi[4*norig-i];
1031 }
1032 });
1033 }
1034 } else if (kind == Kind::r2r_eo) {
1035 int ostride = (2*n+1)*2;
1036 int istride = n;
1037 int nex = 4*n;
1038 int norig = n;
1039 Long nelems = Long(nex)*howmany;
1040 if (r2r_data_is_complex) {
1041 ParallelForOMP(nelems/2, [=] AMREX_GPU_DEVICE (Long ielem)
1042 {
1043 auto batch = ielem / Long(nex);
1044 auto i = int(ielem - batch*nex);
1045 for (int ir = 0; ir < 2; ++ir) {
1046 auto* po = pdst + (2*batch+ir)*ostride + i;
1047 auto const* pi = psrc + 2*batch*istride + ir;
1048 if (i < norig) {
1049 *po = pi[i*2];
1050 } else if (i < (2*norig)) {
1051 *po = -pi[(2*norig-1-i)*2];
1052 } else if (i < (3*norig)) {
1053 *po = -pi[(i-2*norig)*2];
1054 } else {
1055 *po = pi[(4*norig-1-i)*2];
1056 }
1057 }
1058 });
1059 } else {
1060 ParallelForOMP(nelems, [=] AMREX_GPU_DEVICE (Long ielem)
1061 {
1062 auto batch = ielem / Long(nex);
1063 auto i = int(ielem - batch*nex);
1064 auto* po = pdst + batch*ostride + i;
1065 auto const* pi = psrc + batch*istride;
1066 if (i < norig) {
1067 *po = pi[i];
1068 } else if (i < (2*norig)) {
1069 *po = -pi[2*norig-1-i];
1070 } else if (i < (3*norig)) {
1071 *po = -pi[i-2*norig];
1072 } else {
1073 *po = pi[4*norig-1-i];
1074 }
1075 });
1076 }
1077 } else if (kind == Kind::r2r_oe) {
1078 int ostride = (2*n+1)*2;
1079 int istride = n;
1080 int nex = 4*n;
1081 int norig = n;
1082 Long nelems = Long(nex)*howmany;
1083 if (r2r_data_is_complex) {
1084 ParallelForOMP(nelems/2, [=] AMREX_GPU_DEVICE (Long ielem)
1085 {
1086 auto batch = ielem / Long(nex);
1087 auto i = int(ielem - batch*nex);
1088 for (int ir = 0; ir < 2; ++ir) {
1089 auto* po = pdst + (2*batch+ir)*ostride + i;
1090 auto const* pi = psrc + 2*batch*istride + ir;
1091 if (i < norig) {
1092 *po = pi[i*2];
1093 } else if (i < (2*norig)) {
1094 *po = pi[(2*norig-1-i)*2];
1095 } else if (i < (3*norig)) {
1096 *po = -pi[(i-2*norig)*2];
1097 } else {
1098 *po = -pi[(4*norig-1-i)*2];
1099 }
1100 }
1101 });
1102 } else {
1103 ParallelForOMP(nelems, [=] AMREX_GPU_DEVICE (Long ielem)
1104 {
1105 auto batch = ielem / Long(nex);
1106 auto i = int(ielem - batch*nex);
1107 auto* po = pdst + batch*ostride + i;
1108 auto const* pi = psrc + batch*istride;
1109 if (i < norig) {
1110 *po = pi[i];
1111 } else if (i < (2*norig)) {
1112 *po = pi[2*norig-1-i];
1113 } else if (i < (3*norig)) {
1114 *po = -pi[i-2*norig];
1115 } else {
1116 *po = -pi[4*norig-1-i];
1117 }
1118 });
1119 }
1120 } else {
1121 amrex::Abort("FFT: pack_r2r_buffer: unsupported kind");
1122 }
1123 }
1124
1131 void unpack_r2r_buffer (T* pdst, void const* pbuf) const
1132 {
1133 auto const* psrc = (GpuComplex<T> const*) pbuf;
1134 int norig = n;
1135 Long nelems = Long(norig)*howmany;
1136 int ostride = n;
1137
1138 if (kind == Kind::r2r_oo_f) {
1139 int istride = n+1;
1140 if (r2r_data_is_complex) {
1141 ParallelForOMP(nelems/2, [=] AMREX_GPU_DEVICE (Long ielem)
1142 {
1143 auto batch = ielem / Long(norig);
1144 auto k = int(ielem - batch*norig);
1145 auto [s, c] = Math::sincospi(T(k+1)/T(2*norig));
1146 for (int ir = 0; ir < 2; ++ir) {
1147 auto const& yk = psrc[(2*batch+ir)*istride+k+1];
1148 pdst[2*batch*ostride+ir+k*2] = s * yk.real() - c * yk.imag();
1149 }
1150 });
1151 } else {
1152 ParallelForOMP(nelems, [=] AMREX_GPU_DEVICE (Long ielem)
1153 {
1154 auto batch = ielem / Long(norig);
1155 auto k = int(ielem - batch*norig);
1156 auto [s, c] = Math::sincospi(T(k+1)/T(2*norig));
1157 auto const& yk = psrc[batch*istride+k+1];
1158 pdst[batch*ostride+k] = s * yk.real() - c * yk.imag();
1159 });
1160 }
1161 } else if (kind == Kind::r2r_oo_b) {
1162 int istride = 2*n+1;
1163 if (r2r_data_is_complex) {
1164 ParallelForOMP(nelems/2, [=] AMREX_GPU_DEVICE (Long ielem)
1165 {
1166 auto batch = ielem / Long(norig);
1167 auto k = int(ielem - batch*norig);
1168 auto [s, c] = Math::sincospi(T(2*k+1)/T(2*norig));
1169 for (int ir = 0; ir < 2; ++ir) {
1170 auto const& yk = psrc[(2*batch+ir)*istride+2*k+1];
1171 pdst[2*batch*ostride+ir+k*2] = T(0.5)*(s * yk.real() - c * yk.imag());
1172 }
1173 });
1174 } else {
1175 ParallelForOMP(nelems, [=] AMREX_GPU_DEVICE (Long ielem)
1176 {
1177 auto batch = ielem / Long(norig);
1178 auto k = int(ielem - batch*norig);
1179 auto [s, c] = Math::sincospi(T(2*k+1)/T(2*norig));
1180 auto const& yk = psrc[batch*istride+2*k+1];
1181 pdst[batch*ostride+k] = T(0.5)*(s * yk.real() - c * yk.imag());
1182 });
1183 }
1184 } else if (kind == Kind::r2r_ee_f) {
1185 int istride = n+1;
1186 if (r2r_data_is_complex) {
1187 ParallelForOMP(nelems/2, [=] AMREX_GPU_DEVICE (Long ielem)
1188 {
1189 auto batch = ielem / Long(norig);
1190 auto k = int(ielem - batch*norig);
1191 auto [s, c] = Math::sincospi(T(k)/T(2*norig));
1192 for (int ir = 0; ir < 2; ++ir) {
1193 auto const& yk = psrc[(2*batch+ir)*istride+k];
1194 pdst[2*batch*ostride+ir+k*2] = c * yk.real() + s * yk.imag();
1195 }
1196 });
1197 } else {
1198 ParallelForOMP(nelems, [=] AMREX_GPU_DEVICE (Long ielem)
1199 {
1200 auto batch = ielem / Long(norig);
1201 auto k = int(ielem - batch*norig);
1202 auto [s, c] = Math::sincospi(T(k)/T(2*norig));
1203 auto const& yk = psrc[batch*istride+k];
1204 pdst[batch*ostride+k] = c * yk.real() + s * yk.imag();
1205 });
1206 }
1207 } else if (kind == Kind::r2r_ee_b) {
1208 int istride = 2*n+1;
1209 if (r2r_data_is_complex) {
1210 ParallelForOMP(nelems/2, [=] AMREX_GPU_DEVICE (Long ielem)
1211 {
1212 auto batch = ielem / Long(norig);
1213 auto k = int(ielem - batch*norig);
1214 for (int ir = 0; ir < 2; ++ir) {
1215 auto const& yk = psrc[(2*batch+ir)*istride+2*k+1];
1216 pdst[2*batch*ostride+ir+k*2] = T(0.5) * yk.real();
1217 }
1218 });
1219 } else {
1220 ParallelForOMP(nelems, [=] AMREX_GPU_DEVICE (Long ielem)
1221 {
1222 auto batch = ielem / Long(norig);
1223 auto k = int(ielem - batch*norig);
1224 auto const& yk = psrc[batch*istride+2*k+1];
1225 pdst[batch*ostride+k] = T(0.5) * yk.real();
1226 });
1227 }
1228 } else if (kind == Kind::r2r_eo) {
1229 int istride = 2*n+1;
1230 if (r2r_data_is_complex) {
1231 ParallelForOMP(nelems/2, [=] AMREX_GPU_DEVICE (Long ielem)
1232 {
1233 auto batch = ielem / Long(norig);
1234 auto k = int(ielem - batch*norig);
1235 auto [s, c] = Math::sincospi((k+T(0.5))/T(2*norig));
1236 for (int ir = 0; ir < 2; ++ir) {
1237 auto const& yk = psrc[(2*batch+ir)*istride+2*k+1];
1238 pdst[2*batch*ostride+ir+k*2] = T(0.5) * (c * yk.real() + s * yk.imag());
1239 }
1240 });
1241 } else {
1242 ParallelForOMP(nelems, [=] AMREX_GPU_DEVICE (Long ielem)
1243 {
1244 auto batch = ielem / Long(norig);
1245 auto k = int(ielem - batch*norig);
1246 auto [s, c] = Math::sincospi((k+T(0.5))/T(2*norig));
1247 auto const& yk = psrc[batch*istride+2*k+1];
1248 pdst[batch*ostride+k] = T(0.5) * (c * yk.real() + s * yk.imag());
1249 });
1250 }
1251 } else if (kind == Kind::r2r_oe) {
1252 int istride = 2*n+1;
1253 if (r2r_data_is_complex) {
1254 ParallelForOMP(nelems/2, [=] AMREX_GPU_DEVICE (Long ielem)
1255 {
1256 auto batch = ielem / Long(norig);
1257 auto k = int(ielem - batch*norig);
1258 auto [s, c] = Math::sincospi((k+T(0.5))/T(2*norig));
1259 for (int ir = 0; ir < 2; ++ir) {
1260 auto const& yk = psrc[(2*batch+ir)*istride+2*k+1];
1261 pdst[2*batch*ostride+ir+k*2] = T(0.5) * (s * yk.real() - c * yk.imag());
1262 }
1263 });
1264 } else {
1265 ParallelForOMP(nelems, [=] AMREX_GPU_DEVICE (Long ielem)
1266 {
1267 auto batch = ielem / Long(norig);
1268 auto k = int(ielem - batch*norig);
1269 auto [s, c] = Math::sincospi((k+T(0.5))/T(2*norig));
1270 auto const& yk = psrc[batch*istride+2*k+1];
1271 pdst[batch*ostride+k] = T(0.5) * (s * yk.real() - c * yk.imag());
1272 });
1273 }
1274 } else {
1275 amrex::Abort("FFT: unpack_r2r_buffer: unsupported kind");
1276 }
1277 }
1278#endif
1279
1283 template <Direction D>
1285 {
1286 static_assert(D == Direction::forward || D == Direction::backward);
1287 if (!defined) { return; }
1288
1289#if defined(AMREX_USE_GPU)
1290
1291 auto* pscratch = alloc_scratch_space();
1292
1293 pack_r2r_buffer(pscratch, (T*)((D == Direction::forward) ? pf : pb));
1294
1295#if defined(AMREX_USE_CUDA)
1296
1297 AMREX_CUFFT_SAFE_CALL(cufftSetStream(plan, Gpu::gpuStream()));
1298
1299 std::size_t work_size = 0;
1300 AMREX_CUFFT_SAFE_CALL(cufftGetSize(plan, &work_size));
1301
1302 auto* work_area = The_Arena()->alloc(work_size);
1303 AMREX_CUFFT_SAFE_CALL(cufftSetWorkArea(plan, work_area));
1304
1305 if constexpr (std::is_same_v<float,T>) {
1306 AMREX_CUFFT_SAFE_CALL(cufftExecR2C(plan, (T*)pscratch, (VendorComplex*)pscratch));
1307 } else {
1308 AMREX_CUFFT_SAFE_CALL(cufftExecD2Z(plan, (T*)pscratch, (VendorComplex*)pscratch));
1309 }
1310
1311#elif defined(AMREX_USE_HIP)
1312 detail::hip_execute(plan, (void**)&pscratch, (void**)&pscratch);
1313#elif defined(AMREX_USE_SYCL)
1314 detail::sycl_execute<T,Direction::forward>(std::get<0>(plan), (T*)pscratch, (VendorComplex*)pscratch);
1315#endif
1316
1317 unpack_r2r_buffer((T*)((D == Direction::forward) ? pb : pf), pscratch);
1318
1320 free_scratch_space(pscratch);
1321#if defined(AMREX_USE_CUDA)
1322 The_Arena()->free(work_area);
1323#endif
1324
1325#else /* FFTW */
1326
1327 if constexpr (std::is_same_v<float,T>) {
1328 fftwf_execute(plan);
1329 if (defined2) { fftwf_execute(plan2); }
1330 } else {
1331 fftw_execute(plan);
1332 if (defined2) { fftw_execute(plan2); }
1333 }
1334
1335#endif
1336 }
1337
1342 {
1343#if defined(AMREX_USE_CUDA)
1344 AMREX_CUFFT_SAFE_CALL(cufftDestroy(plan));
1345#elif defined(AMREX_USE_HIP)
1346 AMREX_ROCFFT_SAFE_CALL(rocfft_plan_destroy(plan));
1347#elif defined(AMREX_USE_SYCL)
1348 std::visit([](auto&& p) { delete p; }, plan);
1349#else
1350 if constexpr (std::is_same_v<float,T>) {
1351 fftwf_destroy_plan(plan);
1352 } else {
1353 fftw_destroy_plan(plan);
1354 }
1355#endif
1356 }
1357};
1358
1360
1361using Key = std::tuple<IntVectND<3>,int,Direction,Kind>;
1362using PlanD = typename Plan<double>::VendorPlan;
1363using PlanF = typename Plan<float>::VendorPlan;
1364
1365namespace detail {
1366 PlanD* get_vendor_plan_d (Key const& key);
1367 PlanF* get_vendor_plan_f (Key const& key);
1368
1369 void add_vendor_plan_d (Key const& key, PlanD plan);
1370 void add_vendor_plan_f (Key const& key, PlanF plan);
1371}
1373
1374template <typename T>
1375template <Direction D, int M>
1376void Plan<T>::init_r2c (IntVectND<M> const& fft_size, void* pbf, void* pbb, bool cache, int ncomp)
1377{
1378 static_assert(D == Direction::forward || D == Direction::backward);
1379
1380 kind = (D == Direction::forward) ? Kind::r2c_f : Kind::r2c_b;
1381 defined = true;
1382 pf = pbf;
1383 pb = pbb;
1384
1385 n = 1;
1386 for (auto s : fft_size) { n *= s; }
1387 howmany = ncomp;
1388
1389#if defined(AMREX_USE_GPU)
1390 Key key = {fft_size.template expand<3>(), ncomp, D, kind};
1391 if (cache) {
1392 VendorPlan* cached_plan = nullptr;
1393 if constexpr (std::is_same_v<float,T>) {
1394 cached_plan = detail::get_vendor_plan_f(key);
1395 } else {
1396 cached_plan = detail::get_vendor_plan_d(key);
1397 }
1398 if (cached_plan) {
1399 plan = *cached_plan;
1400 return;
1401 }
1402 }
1403#else
1404 amrex::ignore_unused(cache);
1405#endif
1406
1407 int len[M];
1408 for (int i = 0; i < M; ++i) {
1409 len[i] = fft_size[M-1-i];
1410 }
1411
1412 int nc = fft_size[0]/2+1;
1413 for (int i = 1; i < M; ++i) {
1414 nc *= fft_size[i];
1415 }
1416
1417#if defined(AMREX_USE_CUDA)
1418
1419 AMREX_CUFFT_SAFE_CALL(cufftCreate(&plan));
1420 AMREX_CUFFT_SAFE_CALL(cufftSetAutoAllocation(plan, 0));
1421 cufftType type;
1422 int n_in, n_out;
1423 if constexpr (D == Direction::forward) {
1424 type = std::is_same_v<float,T> ? CUFFT_R2C : CUFFT_D2Z;
1425 n_in = n;
1426 n_out = nc;
1427 } else {
1428 type = std::is_same_v<float,T> ? CUFFT_C2R : CUFFT_Z2D;
1429 n_in = nc;
1430 n_out = n;
1431 }
1432 std::size_t work_size;
1434 (cufftMakePlanMany(plan, M, len, nullptr, 1, n_in, nullptr, 1, n_out, type, howmany, &work_size));
1435
1436#elif defined(AMREX_USE_HIP)
1437
1438 auto prec = std::is_same_v<float,T> ? rocfft_precision_single : rocfft_precision_double;
1439 std::size_t length[M];
1440 for (int idim = 0; idim < M; ++idim) { length[idim] = fft_size[idim]; }
1441 if constexpr (D == Direction::forward) {
1442 AMREX_ROCFFT_SAFE_CALL
1443 (rocfft_plan_create(&plan, rocfft_placement_notinplace,
1444 rocfft_transform_type_real_forward, prec, M,
1445 length, howmany, nullptr));
1446 } else {
1447 AMREX_ROCFFT_SAFE_CALL
1448 (rocfft_plan_create(&plan, rocfft_placement_notinplace,
1449 rocfft_transform_type_real_inverse, prec, M,
1450 length, howmany, nullptr));
1451 }
1452
1453#elif defined(AMREX_USE_SYCL)
1454
1455 mkl_desc_r* pp;
1456 if (M == 1) {
1457 pp = new mkl_desc_r(fft_size[0]);
1458 } else {
1459 std::vector<std::int64_t> len64(M);
1460 for (int idim = 0; idim < M; ++idim) {
1461 len64[idim] = len[idim];
1462 }
1463 pp = new mkl_desc_r(len64);
1464 }
1465#ifndef AMREX_USE_MKL_DFTI_2024
1466 pp->set_value(oneapi::mkl::dft::config_param::PLACEMENT,
1467 oneapi::mkl::dft::config_value::NOT_INPLACE);
1468#else
1469 pp->set_value(oneapi::mkl::dft::config_param::PLACEMENT, DFTI_NOT_INPLACE);
1470#endif
1471 pp->set_value(oneapi::mkl::dft::config_param::NUMBER_OF_TRANSFORMS, howmany);
1472 pp->set_value(oneapi::mkl::dft::config_param::FWD_DISTANCE, n);
1473 pp->set_value(oneapi::mkl::dft::config_param::BWD_DISTANCE, nc);
1474 std::vector<std::int64_t> strides(M+1);
1475 strides[0] = 0;
1476 strides[M] = 1;
1477 for (int i = M-1; i >= 1; --i) {
1478 strides[i] = strides[i+1] * fft_size[M-1-i];
1479 }
1480
1481#ifndef AMREX_USE_MKL_DFTI_2024
1482 pp->set_value(oneapi::mkl::dft::config_param::FWD_STRIDES, strides);
1483 // Do not set BWD_STRIDES
1484#else
1485 pp->set_value(oneapi::mkl::dft::config_param::FWD_STRIDES, strides.data());
1486 // Do not set BWD_STRIDES
1487#endif
1488 pp->set_value(oneapi::mkl::dft::config_param::WORKSPACE,
1489 oneapi::mkl::dft::config_value::WORKSPACE_EXTERNAL);
1490 detail::assert_no_external_stream();
1491 pp->commit(amrex::Gpu::Device::streamQueue());
1492 plan = pp;
1493
1494#else /* FFTW */
1495
1496 if (pf == nullptr || pb == nullptr) {
1497 defined = false;
1498 return;
1499 }
1500
1501 if constexpr (std::is_same_v<float,T>) {
1502 if constexpr (D == Direction::forward) {
1503 plan = fftwf_plan_many_dft_r2c
1504 (M, len, howmany, (float*)pf, nullptr, 1, n, (fftwf_complex*)pb, nullptr, 1, nc,
1505 FFTW_ESTIMATE);
1506 } else {
1507 plan = fftwf_plan_many_dft_c2r
1508 (M, len, howmany, (fftwf_complex*)pb, nullptr, 1, nc, (float*)pf, nullptr, 1, n,
1509 FFTW_ESTIMATE);
1510 }
1511 } else {
1512 if constexpr (D == Direction::forward) {
1513 plan = fftw_plan_many_dft_r2c
1514 (M, len, howmany, (double*)pf, nullptr, 1, n, (fftw_complex*)pb, nullptr, 1, nc,
1515 FFTW_ESTIMATE);
1516 } else {
1517 plan = fftw_plan_many_dft_c2r
1518 (M, len, howmany, (fftw_complex*)pb, nullptr, 1, nc, (double*)pf, nullptr, 1, n,
1519 FFTW_ESTIMATE);
1520 }
1521 }
1522#endif
1523
1524#if defined(AMREX_USE_GPU)
1525 if (cache) {
1526 if constexpr (std::is_same_v<float,T>) {
1527 detail::add_vendor_plan_f(key, plan);
1528 } else {
1529 detail::add_vendor_plan_d(key, plan);
1530 }
1531 }
1532#endif
1533}
1534
1536namespace detail
1537{
1538 DistributionMapping make_iota_distromap (Long n);
1539
1540 template <typename FA>
1541 typename FA::FABType::value_type * get_fab (FA& fa)
1542 {
1543 auto myproc = ParallelContext::MyProcSub();
1544 if (myproc < fa.size()) {
1545 return fa.fabPtr(myproc);
1546 } else {
1547 return nullptr;
1548 }
1549 }
1550
1551 template <typename FA1, typename FA2>
1552 std::unique_ptr<char,DataDeleter> make_mfs_share (FA1& fa1, FA2& fa2)
1553 {
1554 bool not_same_fa = true;
1555 if constexpr (std::is_same_v<FA1,FA2>) {
1556 not_same_fa = (&fa1 != &fa2);
1557 }
1558 using FAB1 = typename FA1::FABType::value_type;
1559 using FAB2 = typename FA2::FABType::value_type;
1560 using T1 = typename FAB1::value_type;
1561 using T2 = typename FAB2::value_type;
1562 auto myproc = ParallelContext::MyProcSub();
1563 bool alloc_1 = (myproc < fa1.size());
1564 bool alloc_2 = (myproc < fa2.size()) && not_same_fa;
1565 void* p = nullptr;
1566 if (alloc_1 && alloc_2) {
1567 Box const& box1 = fa1.fabbox(myproc);
1568 Box const& box2 = fa2.fabbox(myproc);
1569 int ncomp1 = fa1.nComp();
1570 int ncomp2 = fa2.nComp();
1571 p = The_Arena()->alloc(std::max(sizeof(T1)*box1.numPts()*ncomp1,
1572 sizeof(T2)*box2.numPts()*ncomp2));
1573 fa1.setFab(myproc, FAB1(box1, ncomp1, (T1*)p));
1574 fa2.setFab(myproc, FAB2(box2, ncomp2, (T2*)p));
1575 } else if (alloc_1) {
1576 Box const& box1 = fa1.fabbox(myproc);
1577 int ncomp1 = fa1.nComp();
1578 p = The_Arena()->alloc(sizeof(T1)*box1.numPts()*ncomp1);
1579 fa1.setFab(myproc, FAB1(box1, ncomp1, (T1*)p));
1580 } else if (alloc_2) {
1581 Box const& box2 = fa2.fabbox(myproc);
1582 int ncomp2 = fa2.nComp();
1583 p = The_Arena()->alloc(sizeof(T2)*box2.numPts()*ncomp2);
1584 fa2.setFab(myproc, FAB2(box2, ncomp2, (T2*)p));
1585 } else {
1586 return nullptr;
1587 }
1588 return std::unique_ptr<char,DataDeleter>((char*)p, DataDeleter{The_Arena()});
1589 }
1590}
1592
1594
1595struct Swap01
1596{
1597 [[nodiscard]] constexpr Dim3 operator() (Dim3 i) const noexcept
1598 {
1599 return {i.y, i.x, i.z};
1600 }
1601
1602 static constexpr Dim3 Inverse (Dim3 i)
1603 {
1604 return {i.y, i.x, i.z};
1605 }
1606
1607 [[nodiscard]] constexpr IndexType operator() (IndexType it) const noexcept
1608 {
1609 return it;
1610 }
1611
1612 static constexpr IndexType Inverse (IndexType it)
1613 {
1614 return it;
1615 }
1616};
1617
1618struct Swap02
1619{
1620 [[nodiscard]] constexpr Dim3 operator() (Dim3 i) const noexcept
1621 {
1622 return {i.z, i.y, i.x};
1623 }
1624
1625 static constexpr Dim3 Inverse (Dim3 i)
1626 {
1627 return {i.z, i.y, i.x};
1628 }
1629
1630 [[nodiscard]] constexpr IndexType operator() (IndexType it) const noexcept
1631 {
1632 return it;
1633 }
1634
1635 static constexpr IndexType Inverse (IndexType it)
1636 {
1637 return it;
1638 }
1639};
1640
1641struct RotateFwd
1642{
1643 // dest -> src: (x,y,z) -> (y,z,x)
1644 [[nodiscard]] constexpr Dim3 operator() (Dim3 i) const noexcept
1645 {
1646 return {i.y, i.z, i.x};
1647 }
1648
1649 // src -> dest: (x,y,z) -> (z,x,y)
1650 static constexpr Dim3 Inverse (Dim3 i)
1651 {
1652 return {i.z, i.x, i.y};
1653 }
1654
1655 [[nodiscard]] constexpr IndexType operator() (IndexType it) const noexcept
1656 {
1657 return it;
1658 }
1659
1660 static constexpr IndexType Inverse (IndexType it)
1661 {
1662 return it;
1663 }
1664};
1665
1666struct RotateBwd
1667{
1668 // dest -> src: (x,y,z) -> (z,x,y)
1669 [[nodiscard]] constexpr Dim3 operator() (Dim3 i) const noexcept
1670 {
1671 return {i.z, i.x, i.y};
1672 }
1673
1674 // src -> dest: (x,y,z) -> (y,z,x)
1675 static constexpr Dim3 Inverse (Dim3 i)
1676 {
1677 return {i.y, i.z, i.x};
1678 }
1679
1680 [[nodiscard]] constexpr IndexType operator() (IndexType it) const noexcept
1681 {
1682 return it;
1683 }
1684
1685 static constexpr IndexType Inverse (IndexType it)
1686 {
1687 return it;
1688 }
1689};
1690
1692
1694namespace detail
1695{
1696 struct SubHelper
1697 {
1698 explicit SubHelper (Box const& domain);
1699
1700 [[nodiscard]] Box make_box (Box const& box) const;
1701
1702 [[nodiscard]] Periodicity make_periodicity (Periodicity const& period) const;
1703
1704 [[nodiscard]] bool ghost_safe (IntVect const& ng) const;
1705
1706 // This rearranges the order.
1707 [[nodiscard]] IntVect make_iv (IntVect const& iv) const;
1708
1709 // This keeps the order, but zero out the values in the hidden dimension.
1710 [[nodiscard]] IntVect make_safe_ghost (IntVect const& ng) const;
1711
1712 [[nodiscard]] BoxArray inverse_boxarray (BoxArray const& ba) const;
1713
1714 [[nodiscard]] IntVect inverse_order (IntVect const& order) const;
1715
1716 template <typename T>
1717 [[nodiscard]] T make_array (T const& a) const
1718 {
1719#if (AMREX_SPACEDIM == 1)
1721 return a;
1722#elif (AMREX_SPACEDIM == 2)
1723 if (m_case == case_1n) {
1724 return T{a[1],a[0]};
1725 } else {
1726 return a;
1727 }
1728#else
1729 if (m_case == case_11n) {
1730 return T{a[2],a[0],a[1]};
1731 } else if (m_case == case_1n1) {
1732 return T{a[1],a[0],a[2]};
1733 } else if (m_case == case_1nn) {
1734 return T{a[1],a[2],a[0]};
1735 } else if (m_case == case_n1n) {
1736 return T{a[0],a[2],a[1]};
1737 } else {
1738 return a;
1739 }
1740#endif
1741 }
1742
1743 [[nodiscard]] GpuArray<int,3> xyz_order () const;
1744
1745 template <typename FA>
1746 FA make_alias_mf (FA const& mf)
1747 {
1748 BoxList bl = mf.boxArray().boxList();
1749 for (auto& b : bl) {
1750 b = make_box(b);
1751 }
1752 auto const& ng = make_iv(mf.nGrowVect());
1753 FA submf(BoxArray(std::move(bl)), mf.DistributionMap(), mf.nComp(), ng, MFInfo{}.SetAlloc(false));
1754 using FAB = typename FA::fab_type;
1755 for (MFIter mfi(submf, MFItInfo().DisableDeviceSync()); mfi.isValid(); ++mfi) {
1756 submf.setFab(mfi, FAB(mfi.fabbox(), mf.nComp(), mf[mfi].dataPtr()));
1757 }
1758 return submf;
1759 }
1760
1761#if (AMREX_SPACEDIM == 2)
1762 enum Case { case_1n, case_other };
1763 int m_case = case_other;
1764#elif (AMREX_SPACEDIM == 3)
1765 enum Case { case_11n, case_1n1, case_1nn, case_n1n, case_other };
1766 int m_case = case_other;
1767#endif
1768 };
1769}
1771
1772}
1773
1774#endif
#define AMREX_ALWAYS_ASSERT_WITH_MESSAGE(EX, MSG)
Definition AMReX_BLassert.H:49
#define AMREX_ENUM(CLASS,...)
Definition AMReX_Enum.H:208
#define AMREX_CUFFT_SAFE_CALL(call)
Definition AMReX_GpuError.H:92
#define AMREX_GPU_DEVICE
Definition AMReX_GpuQualifiers.H:18
amrex::ParmParse pp
Input file parser instance for the given namespace.
Definition AMReX_HypreIJIface.cpp:15
Real * pdst
Definition AMReX_HypreMLABecLap.cpp:1140
#define AMREX_D_TERM(a, b, c)
Definition AMReX_SPACE.H:172
virtual void free(void *pt)=0
A pure virtual function for deleting the arena pointed to by pt.
virtual void * alloc(std::size_t sz)=0
__host__ __device__ IntVectND< dim > length() const noexcept
Return the length of the BoxND.
Definition AMReX_Box.H:154
Calculates the distribution of FABs to MPI processes.
Definition AMReX_DistributionMapping.H:43
static bool usingExternalStream() noexcept
Definition AMReX_GpuDevice.cpp:836
An Integer Vector in dim-Dimensional Space.
Definition AMReX_IntVect.H:57
amrex_long Long
Definition AMReX_INT.H:30
void ParallelForOMP(T n, L const &f) noexcept
Performance-portable kernel launch function with optional OpenMP threading.
Definition AMReX_GpuLaunch.H:319
__host__ __device__ Dim3 length(Array4< T > const &a) noexcept
Return the spatial extents of an Array4 in Dim3 form.
Definition AMReX_Array4.H:1345
Arena * The_Arena()
Definition AMReX_Arena.cpp:805
Definition AMReX_FFT_Helper.H:52
Direction
Definition AMReX_FFT_Helper.H:54
Boundary
Definition AMReX_FFT_Helper.H:58
DomainStrategy
Definition AMReX_FFT_Helper.H:56
Kind
Definition AMReX_FFT_Helper.H:60
void streamSynchronize() noexcept
Definition AMReX_GpuDevice.H:310
gpuStream_t gpuStream() noexcept
Definition AMReX_GpuDevice.H:291
__host__ __device__ std::pair< double, double > sincospi(double x)
Return sin(pi*x) and cos(pi*x) given x.
Definition AMReX_Math.H:204
__host__ __device__ void ignore_unused(const Ts &...)
This shuts up the compiler about unused variables.
Definition AMReX.H:139
BoxND< 3 > Box
Box is an alias for amrex::BoxND instantiated with AMREX_SPACEDIM.
Definition AMReX_BaseFwd.H:30
IndexTypeND< 3 > IndexType
IndexType is an alias for amrex::IndexTypeND instantiated with AMREX_SPACEDIM.
Definition AMReX_BaseFwd.H:36
IntVectND< 3 > IntVect
IntVect is an alias for amrex::IntVectND instantiated with AMREX_SPACEDIM.
Definition AMReX_BaseFwd.H:33
void Abort(const std::string &msg)
Print out message to cerr and exit via abort().
Definition AMReX.cpp:240
const int[]
Definition AMReX_BLProfiler.cpp:1664
Definition AMReX_FFT_Helper.H:64
bool twod_mode
Definition AMReX_FFT_Helper.H:75
Info & setNumProcs(int n)
Cap the number of MPI ranks used by FFT.
Definition AMReX_FFT_Helper.H:131
bool oned_mode
We might have a special twod_mode: nx or ny == 1 && nz > 1.
Definition AMReX_FFT_Helper.H:78
int batch_size
Batched FFT size. Only support in R2C, not R2X.
Definition AMReX_FFT_Helper.H:81
Info & setDomainStrategy(DomainStrategy s)
Select how the domain is decomposed across MPI ranks.
Definition AMReX_FFT_Helper.H:92
DomainStrategy domain_strategy
Domain composition strategy.
Definition AMReX_FFT_Helper.H:66
int nprocs
Max number of processes to use.
Definition AMReX_FFT_Helper.H:84
int pencil_threshold
Definition AMReX_FFT_Helper.H:70
Info & setOneDMode(bool x)
Flag the degenerate 2-D mode (nx==1 or ny==1) that still batches along z.
Definition AMReX_FFT_Helper.H:117
Info & setBatchSize(int bsize)
Specify the batch size for FFT.
Definition AMReX_FFT_Helper.H:124
Info & setPencilThreshold(int t)
Override the slab→pencil break-even threshold for the automatic strategy.
Definition AMReX_FFT_Helper.H:99
Info & setTwoDMode(bool x)
Restrict transforms to the first two dimensions (3-D problems only).
Definition AMReX_FFT_Helper.H:106
Definition AMReX_FFT_Helper.H:188
void * pf
Definition AMReX_FFT_Helper.H:223
void unpack_r2r_buffer(T *pdst, void const *pbuf) const
Collapse the spectral R2R result back into the original real layout.
Definition AMReX_FFT_Helper.H:1131
std::conditional_t< std::is_same_v< float, T >, cuComplex, cuDoubleComplex > VendorComplex
Definition AMReX_FFT_Helper.H:192
VendorPlan plan2
Definition AMReX_FFT_Helper.H:222
void init_r2c(IntVectND< M > const &fft_size, void *pbf, void *pbb, bool cache, int ncomp=1)
Initialize an M-dimensional batched real-to-complex plan.
Definition AMReX_FFT_Helper.H:1376
int n
Definition AMReX_FFT_Helper.H:215
void destroy()
Release any vendor FFT plan objects owned by this Plan.
Definition AMReX_FFT_Helper.H:242
bool defined2
Definition AMReX_FFT_Helper.H:220
void init_r2r(Box const &box, VendorComplex *pc, std::pair< Boundary, Boundary > const &bc)
Initialize a real-to-real plan that reads/writes complex storage.
Definition AMReX_FFT_Helper.H:724
void pack_r2r_buffer(void *pbuf, T const *psrc) const
Expand the real R2R input into the symmetry-extended buffer expected by CUFFT/rocFFT.
Definition AMReX_FFT_Helper.H:893
static void free_scratch_space(void *p)
Release GPU scratch allocated via alloc_scratch_space().
Definition AMReX_FFT_Helper.H:885
static void destroy_vendor_plan(VendorPlan plan)
Helper that destroys a vendor plan of the appropriate backend type.
Definition AMReX_FFT_Helper.H:1341
Kind get_r2r_kind(std::pair< Boundary, Boundary > const &bc)
Map boundary conditions to the Plan Kind for real-to-real transforms.
Definition AMReX_FFT_Helper.H:586
cufftHandle VendorPlan
Definition AMReX_FFT_Helper.H:190
Kind kind
Definition AMReX_FFT_Helper.H:217
void init_c2c(Box const &box, VendorComplex *p, int ncomp=1, int ndims=1)
Initialize a complex-to-complex plan across 1/2/3 dimensions.
Definition AMReX_FFT_Helper.H:418
int howmany
Definition AMReX_FFT_Helper.H:216
void * pb
Definition AMReX_FFT_Helper.H:224
void init_r2c(Box const &box, T *pr, VendorComplex *pc, bool is_2d_transform=false, int ncomp=1)
Initialize a 1-D or 2-D real-to-complex plan over the supplied Box.
Definition AMReX_FFT_Helper.H:266
void compute_r2r()
Execute the real-to-real plan, including GPU packing/unpacking.
Definition AMReX_FFT_Helper.H:1284
void compute_c2c()
Execute the complex-to-complex plan in place.
Definition AMReX_FFT_Helper.H:824
bool r2r_data_is_complex
Definition AMReX_FFT_Helper.H:218
void * alloc_scratch_space() const
Allocate GPU scratch space large enough to hold packed R2R data.
Definition AMReX_FFT_Helper.H:868
VendorPlan plan
Definition AMReX_FFT_Helper.H:221
void compute_r2c()
Execute the previously initialized real-to-complex plan.
Definition AMReX_FFT_Helper.H:772
bool defined
Definition AMReX_FFT_Helper.H:219
void set_ptrs(void *p0, void *p1)
Register device pointers used by the forward/backward executions.
Definition AMReX_FFT_Helper.H:233
void init_r2r(Box const &box, T *p, std::pair< Boundary, Boundary > const &bc, int howmany_initval=1)
Initialize a real-to-real (cosine/sine) plan that operates on real buffers.
Definition AMReX_FFT_Helper.H:620
A host / device complex number type, because std::complex doesn't work in device code with Cuda yet.
Definition AMReX_GpuComplex.H:30