Block-Structured AMR Software Framework
AMReX_FFT_Helper.H
Go to the documentation of this file.
1 #ifndef AMREX_FFT_HELPER_H_
2 #define AMREX_FFT_HELPER_H_
3 #include <AMReX_Config.H>
4 
5 #include <AMReX.H>
6 #include <AMReX_BLProfiler.H>
7 #include <AMReX_DataAllocator.H>
9 #include <AMReX_Enum.H>
10 #include <AMReX_FabArray.H>
11 #include <AMReX_Gpu.H>
12 #include <AMReX_GpuComplex.H>
13 #include <AMReX_Math.H>
14 #include <AMReX_Periodicity.H>
15 
16 #if defined(AMREX_USE_CUDA)
17 # include <cufft.h>
18 # include <cuComplex.h>
19 #elif defined(AMREX_USE_HIP)
20 # if __has_include(<rocfft/rocfft.h>) // ROCm 5.3+
21 # include <rocfft/rocfft.h>
22 # else
23 # include <rocfft.h>
24 # endif
25 # include <hip/hip_complex.h>
26 #elif defined(AMREX_USE_SYCL)
27 # if __has_include(<oneapi/mkl/dft.hpp>) // oneAPI 2025.0
28 # include <oneapi/mkl/dft.hpp>
29 #else
30 # define AMREX_USE_MKL_DFTI_2024 1
31 # include <oneapi/mkl/dfti.hpp>
32 # endif
33 #else
34 # include <fftw3.h>
35 #endif
36 
37 #include <algorithm>
38 #include <complex>
39 #include <limits>
40 #include <memory>
41 #include <tuple>
42 #include <utility>
43 #include <variant>
44 
45 namespace amrex::FFT
46 {
47 
48 enum struct Direction { forward, backward, both, none };
49 
50 enum struct DomainStrategy { slab, pencil };
51 
52 AMREX_ENUM( Boundary, periodic, even, odd );
53 
56 
57 struct Info
58 {
62  bool batch_mode = false;
63 
66 
67  Info& setBatchMode (bool x) { batch_mode = x; return *this; }
68  Info& setNumProcs (int n) { nprocs = n; return *this; }
69 };
70 
71 #ifdef AMREX_USE_HIP
72 namespace detail { void hip_execute (rocfft_plan plan, void **in, void **out); }
73 #endif
74 
75 #ifdef AMREX_USE_SYCL
76 namespace detail
77 {
78 template <typename T, Direction direction, typename P, typename TI, typename TO>
79 void sycl_execute (P* plan, TI* in, TO* out)
80 {
81 #ifndef AMREX_USE_MKL_DFTI_2024
82  std::int64_t workspaceSize = 0;
83 #else
84  std::size_t workspaceSize = 0;
85 #endif
86  plan->get_value(oneapi::mkl::dft::config_param::WORKSPACE_BYTES,
87  &workspaceSize);
88  auto* buffer = (T*)amrex::The_Arena()->alloc(workspaceSize);
89  plan->set_workspace(buffer);
90  sycl::event r;
91  if (std::is_same_v<TI,TO>) {
93  if constexpr (direction == Direction::forward) {
94  r = oneapi::mkl::dft::compute_forward(*plan, out);
95  } else {
96  r = oneapi::mkl::dft::compute_backward(*plan, out);
97  }
98  } else {
99  if constexpr (direction == Direction::forward) {
100  r = oneapi::mkl::dft::compute_forward(*plan, in, out);
101  } else {
102  r = oneapi::mkl::dft::compute_backward(*plan, in, out);
103  }
104  }
105  r.wait();
106  amrex::The_Arena()->free(buffer);
107 }
108 }
109 #endif
110 
111 template <typename T>
112 struct Plan
113 {
114 #if defined(AMREX_USE_CUDA)
115  using VendorPlan = cufftHandle;
116  using VendorComplex = std::conditional_t<std::is_same_v<float,T>,
117  cuComplex, cuDoubleComplex>;
118 #elif defined(AMREX_USE_HIP)
119  using VendorPlan = rocfft_plan;
120  using VendorComplex = std::conditional_t<std::is_same_v<float,T>,
121  float2, double2>;
122 #elif defined(AMREX_USE_SYCL)
123  using mkl_desc_r = oneapi::mkl::dft::descriptor<std::is_same_v<float,T>
124  ? oneapi::mkl::dft::precision::SINGLE
125  : oneapi::mkl::dft::precision::DOUBLE,
126  oneapi::mkl::dft::domain::REAL>;
127  using mkl_desc_c = oneapi::mkl::dft::descriptor<std::is_same_v<float,T>
128  ? oneapi::mkl::dft::precision::SINGLE
129  : oneapi::mkl::dft::precision::DOUBLE,
130  oneapi::mkl::dft::domain::COMPLEX>;
131  using VendorPlan = std::variant<mkl_desc_r*,mkl_desc_c*>;
132  using VendorComplex = std::complex<T>;
133 #else
134  using VendorPlan = std::conditional_t<std::is_same_v<float,T>,
135  fftwf_plan, fftw_plan>;
136  using VendorComplex = std::conditional_t<std::is_same_v<float,T>,
137  fftwf_complex, fftw_complex>;
138 #endif
139 
140  int n = 0;
141  int howmany = 0;
143  bool r2r_data_is_complex = false;
144  bool defined = false;
145  bool defined2 = false;
148  void* pf = nullptr;
149  void* pb = nullptr;
150 
151 #ifdef AMREX_USE_GPU
152  void set_ptrs (void* p0, void* p1) {
153  pf = p0;
154  pb = p1;
155  }
156 #endif
157 
158  void destroy ()
159  {
160  if (defined) {
162  defined = false;
163  }
164 #if !defined(AMREX_USE_GPU)
165  if (defined2) {
167  defined2 = false;
168  }
169 #endif
170  }
171 
172  template <Direction D>
173  void init_r2c (Box const& box, T* pr, VendorComplex* pc, bool is_2d_transform = false)
174  {
175  static_assert(D == Direction::forward || D == Direction::backward);
176 
177  int rank = is_2d_transform ? 2 : 1;
178 
180  defined = true;
181  pf = (void*)pr;
182  pb = (void*)pc;
183 
184  int len[2] = {};
185  if (rank == 1) {
186  len[0] = box.length(0);
187  len[1] = box.length(0); // Not used except for HIP. Yes it's `(0)`.
188  } else {
189  len[0] = box.length(1); // Most FFT libraries assume row-major ordering
190  len[1] = box.length(0); // except for rocfft
191  }
192  int nr = (rank == 1) ? len[0] : len[0]*len[1];
193  n = nr;
194  int nc = (rank == 1) ? (len[0]/2+1) : (len[1]/2+1)*len[0];
195 #if (AMREX_SPACEDIM == 1)
196  howmany = 1;
197 #else
198  howmany = (rank == 1) ? AMREX_D_TERM(1, *box.length(1), *box.length(2))
199  : AMREX_D_TERM(1, *1 , *box.length(2));
200 #endif
201 
203 
204 #if defined(AMREX_USE_CUDA)
205 
206  AMREX_CUFFT_SAFE_CALL(cufftCreate(&plan));
207  AMREX_CUFFT_SAFE_CALL(cufftSetAutoAllocation(plan, 0));
208  std::size_t work_size;
209  if constexpr (D == Direction::forward) {
210  cufftType fwd_type = std::is_same_v<float,T> ? CUFFT_R2C : CUFFT_D2Z;
212  (cufftMakePlanMany(plan, rank, len, nullptr, 1, nr, nullptr, 1, nc, fwd_type, howmany, &work_size));
213  } else {
214  cufftType bwd_type = std::is_same_v<float,T> ? CUFFT_C2R : CUFFT_Z2D;
216  (cufftMakePlanMany(plan, rank, len, nullptr, 1, nc, nullptr, 1, nr, bwd_type, howmany, &work_size));
217  }
218 
219 #elif defined(AMREX_USE_HIP)
220 
221  auto prec = std::is_same_v<float,T> ? rocfft_precision_single : rocfft_precision_double;
222  // switch to column-major ordering
223  std::size_t length[2] = {std::size_t(len[1]), std::size_t(len[0])};
224  if constexpr (D == Direction::forward) {
225  AMREX_ROCFFT_SAFE_CALL
226  (rocfft_plan_create(&plan, rocfft_placement_notinplace,
227  rocfft_transform_type_real_forward, prec, rank,
228  length, howmany, nullptr));
229  } else {
230  AMREX_ROCFFT_SAFE_CALL
231  (rocfft_plan_create(&plan, rocfft_placement_notinplace,
232  rocfft_transform_type_real_inverse, prec, rank,
233  length, howmany, nullptr));
234  }
235 
236 #elif defined(AMREX_USE_SYCL)
237 
238  mkl_desc_r* pp;
239  if (rank == 1) {
240  pp = new mkl_desc_r(len[0]);
241  } else {
242  pp = new mkl_desc_r({std::int64_t(len[0]), std::int64_t(len[1])});
243  }
244 #ifndef AMREX_USE_MKL_DFTI_2024
245  pp->set_value(oneapi::mkl::dft::config_param::PLACEMENT,
246  oneapi::mkl::dft::config_value::NOT_INPLACE);
247 #else
248  pp->set_value(oneapi::mkl::dft::config_param::PLACEMENT, DFTI_NOT_INPLACE);
249 #endif
250  pp->set_value(oneapi::mkl::dft::config_param::NUMBER_OF_TRANSFORMS, howmany);
251  pp->set_value(oneapi::mkl::dft::config_param::FWD_DISTANCE, nr);
252  pp->set_value(oneapi::mkl::dft::config_param::BWD_DISTANCE, nc);
253  std::vector<std::int64_t> strides;
254  strides.push_back(0);
255  if (rank == 2) { strides.push_back(len[1]); }
256  strides.push_back(1);
257 #ifndef AMREX_USE_MKL_DFTI_2024
258  pp->set_value(oneapi::mkl::dft::config_param::FWD_STRIDES, strides);
259  // Do not set BWD_STRIDES
260 #else
261  pp->set_value(oneapi::mkl::dft::config_param::FWD_STRIDES, strides.data());
262  // Do not set BWD_STRIDES
263 #endif
264  pp->set_value(oneapi::mkl::dft::config_param::WORKSPACE,
265  oneapi::mkl::dft::config_value::WORKSPACE_EXTERNAL);
266  pp->commit(amrex::Gpu::Device::streamQueue());
267  plan = pp;
268 
269 #else /* FFTW */
270 
271  if constexpr (std::is_same_v<float,T>) {
272  if constexpr (D == Direction::forward) {
273  plan = fftwf_plan_many_dft_r2c
274  (rank, len, howmany, pr, nullptr, 1, nr, pc, nullptr, 1, nc,
275  FFTW_ESTIMATE | FFTW_DESTROY_INPUT);
276  } else {
277  plan = fftwf_plan_many_dft_c2r
278  (rank, len, howmany, pc, nullptr, 1, nc, pr, nullptr, 1, nr,
279  FFTW_ESTIMATE | FFTW_DESTROY_INPUT);
280  }
281  } else {
282  if constexpr (D == Direction::forward) {
283  plan = fftw_plan_many_dft_r2c
284  (rank, len, howmany, pr, nullptr, 1, nr, pc, nullptr, 1, nc,
285  FFTW_ESTIMATE | FFTW_DESTROY_INPUT);
286  } else {
287  plan = fftw_plan_many_dft_c2r
288  (rank, len, howmany, pc, nullptr, 1, nc, pr, nullptr, 1, nr,
289  FFTW_ESTIMATE | FFTW_DESTROY_INPUT);
290  }
291  }
292 #endif
293  }
294 
295  template <Direction D, int M>
296  void init_r2c (IntVectND<M> const& fft_size, void*, void*, bool cache);
297 
298  template <Direction D>
299  void init_c2c (Box const& box, VendorComplex* p)
300  {
301  static_assert(D == Direction::forward || D == Direction::backward);
302 
304  defined = true;
305  pf = (void*)p;
306  pb = (void*)p;
307 
308  n = box.length(0);
309  howmany = AMREX_D_TERM(1, *box.length(1), *box.length(2));
310 
311 #if defined(AMREX_USE_CUDA)
312  AMREX_CUFFT_SAFE_CALL(cufftCreate(&plan));
313  AMREX_CUFFT_SAFE_CALL(cufftSetAutoAllocation(plan, 0));
314 
315  cufftType t = std::is_same_v<float,T> ? CUFFT_C2C : CUFFT_Z2Z;
316  std::size_t work_size;
318  (cufftMakePlanMany(plan, 1, &n, nullptr, 1, n, nullptr, 1, n, t, howmany, &work_size));
319 
320 #elif defined(AMREX_USE_HIP)
321 
322  auto prec = std::is_same_v<float,T> ? rocfft_precision_single
323  : rocfft_precision_double;
324  auto dir= (D == Direction::forward) ? rocfft_transform_type_complex_forward
325  : rocfft_transform_type_complex_inverse;
326  const std::size_t length = n;
327  AMREX_ROCFFT_SAFE_CALL
328  (rocfft_plan_create(&plan, rocfft_placement_inplace, dir, prec, 1,
329  &length, howmany, nullptr));
330 
331 #elif defined(AMREX_USE_SYCL)
332 
333  auto* pp = new mkl_desc_c(n);
334 #ifndef AMREX_USE_MKL_DFTI_2024
335  pp->set_value(oneapi::mkl::dft::config_param::PLACEMENT,
336  oneapi::mkl::dft::config_value::INPLACE);
337 #else
338  pp->set_value(oneapi::mkl::dft::config_param::PLACEMENT, DFTI_INPLACE);
339 #endif
340  pp->set_value(oneapi::mkl::dft::config_param::NUMBER_OF_TRANSFORMS, howmany);
341  pp->set_value(oneapi::mkl::dft::config_param::FWD_DISTANCE, n);
342  pp->set_value(oneapi::mkl::dft::config_param::BWD_DISTANCE, n);
343  std::vector<std::int64_t> strides = {0,1};
344 #ifndef AMREX_USE_MKL_DFTI_2024
345  pp->set_value(oneapi::mkl::dft::config_param::FWD_STRIDES, strides);
346  pp->set_value(oneapi::mkl::dft::config_param::BWD_STRIDES, strides);
347 #else
348  pp->set_value(oneapi::mkl::dft::config_param::FWD_STRIDES, strides.data());
349  pp->set_value(oneapi::mkl::dft::config_param::BWD_STRIDES, strides.data());
350 #endif
351  pp->set_value(oneapi::mkl::dft::config_param::WORKSPACE,
352  oneapi::mkl::dft::config_value::WORKSPACE_EXTERNAL);
353  pp->commit(amrex::Gpu::Device::streamQueue());
354  plan = pp;
355 
356 #else /* FFTW */
357 
358  if constexpr (std::is_same_v<float,T>) {
359  if constexpr (D == Direction::forward) {
360  plan = fftwf_plan_many_dft
361  (1, &n, howmany, p, nullptr, 1, n, p, nullptr, 1, n, -1,
362  FFTW_ESTIMATE);
363  } else {
364  plan = fftwf_plan_many_dft
365  (1, &n, howmany, p, nullptr, 1, n, p, nullptr, 1, n, +1,
366  FFTW_ESTIMATE);
367  }
368  } else {
369  if constexpr (D == Direction::forward) {
370  plan = fftw_plan_many_dft
371  (1, &n, howmany, p, nullptr, 1, n, p, nullptr, 1, n, -1,
372  FFTW_ESTIMATE);
373  } else {
374  plan = fftw_plan_many_dft
375  (1, &n, howmany, p, nullptr, 1, n, p, nullptr, 1, n, +1,
376  FFTW_ESTIMATE);
377  }
378  }
379 #endif
380  }
381 
382 #ifndef AMREX_USE_GPU
383  template <Direction D>
384  fftw_r2r_kind get_fftw_kind (std::pair<Boundary,Boundary> const& bc)
385  {
386  if (bc.first == Boundary::even && bc.second == Boundary::even)
387  {
388  return (D == Direction::forward) ? FFTW_REDFT10 : FFTW_REDFT01;
389  }
390  else if (bc.first == Boundary::even && bc.second == Boundary::odd)
391  {
392  return FFTW_REDFT11;
393  }
394  else if (bc.first == Boundary::odd && bc.second == Boundary::even)
395  {
396  return FFTW_RODFT11;
397  }
398  else if (bc.first == Boundary::odd && bc.second == Boundary::odd)
399  {
400  return (D == Direction::forward) ? FFTW_RODFT10 : FFTW_RODFT01;
401  }
402  else {
403  amrex::Abort("FFT: unsupported BC");
404  return fftw_r2r_kind{};
405  }
406 
407  }
408 #endif
409 
410  template <Direction D>
411  Kind get_r2r_kind (std::pair<Boundary,Boundary> const& bc)
412  {
413  if (bc.first == Boundary::even && bc.second == Boundary::even)
414  {
416  }
417  else if (bc.first == Boundary::even && bc.second == Boundary::odd)
418  {
419  return Kind::r2r_eo;
420  }
421  else if (bc.first == Boundary::odd && bc.second == Boundary::even)
422  {
423  return Kind::r2r_oe;
424  }
425  else if (bc.first == Boundary::odd && bc.second == Boundary::odd)
426  {
428  }
429  else {
430  amrex::Abort("FFT: unsupported BC");
431  return Kind::none;
432  }
433 
434  }
435 
436  template <Direction D>
437  void init_r2r (Box const& box, T* p, std::pair<Boundary,Boundary> const& bc,
438  int howmany_initval = 1)
439  {
440  static_assert(D == Direction::forward || D == Direction::backward);
441 
442  kind = get_r2r_kind<D>(bc);
443  defined = true;
444  pf = (void*)p;
445  pb = (void*)p;
446 
447  n = box.length(0);
448  howmany = AMREX_D_TERM(howmany_initval, *box.length(1), *box.length(2));
449 
450 #if defined(AMREX_USE_GPU)
451  int nex=0;
452  if (bc.first == Boundary::odd && bc.second == Boundary::odd &&
453  Direction::forward == D) {
454  nex = 2*n;
455  } else if (bc.first == Boundary::odd && bc.second == Boundary::odd &&
456  Direction::backward == D) {
457  nex = 4*n;
458  } else if (bc.first == Boundary::even && bc.second == Boundary::even &&
459  Direction::forward == D) {
460  nex = 2*n;
461  } else if (bc.first == Boundary::even && bc.second == Boundary::even &&
462  Direction::backward == D) {
463  nex = 4*n;
464  } else if ((bc.first == Boundary::even && bc.second == Boundary::odd) ||
465  (bc.first == Boundary::odd && bc.second == Boundary::even)) {
466  nex = 4*n;
467  } else {
468  amrex::Abort("FFT: unsupported BC");
469  }
470  int nc = (nex/2) + 1;
471 
472 #if defined (AMREX_USE_CUDA)
473 
474  AMREX_CUFFT_SAFE_CALL(cufftCreate(&plan));
475  AMREX_CUFFT_SAFE_CALL(cufftSetAutoAllocation(plan, 0));
476  cufftType fwd_type = std::is_same_v<float,T> ? CUFFT_R2C : CUFFT_D2Z;
477  std::size_t work_size;
479  (cufftMakePlanMany(plan, 1, &nex, nullptr, 1, nc*2, nullptr, 1, nc, fwd_type, howmany, &work_size));
480 
481 #elif defined(AMREX_USE_HIP)
482 
484  auto prec = std::is_same_v<float,T> ? rocfft_precision_single : rocfft_precision_double;
485  const std::size_t length = nex;
486  AMREX_ROCFFT_SAFE_CALL
487  (rocfft_plan_create(&plan, rocfft_placement_inplace,
488  rocfft_transform_type_real_forward, prec, 1,
489  &length, howmany, nullptr));
490 
491 #elif defined(AMREX_USE_SYCL)
492 
493  auto* pp = new mkl_desc_r(nex);
494 #ifndef AMREX_USE_MKL_DFTI_2024
495  pp->set_value(oneapi::mkl::dft::config_param::PLACEMENT,
496  oneapi::mkl::dft::config_value::INPLACE);
497 #else
498  pp->set_value(oneapi::mkl::dft::config_param::PLACEMENT, DFTI_INPLACE);
499 #endif
500  pp->set_value(oneapi::mkl::dft::config_param::NUMBER_OF_TRANSFORMS, howmany);
501  pp->set_value(oneapi::mkl::dft::config_param::FWD_DISTANCE, nc*2);
502  pp->set_value(oneapi::mkl::dft::config_param::BWD_DISTANCE, nc);
503  std::vector<std::int64_t> strides = {0,1};
504 #ifndef AMREX_USE_MKL_DFTI_2024
505  pp->set_value(oneapi::mkl::dft::config_param::FWD_STRIDES, strides);
506  pp->set_value(oneapi::mkl::dft::config_param::BWD_STRIDES, strides);
507 #else
508  pp->set_value(oneapi::mkl::dft::config_param::FWD_STRIDES, strides.data());
509  pp->set_value(oneapi::mkl::dft::config_param::BWD_STRIDES, strides.data());
510 #endif
511  pp->set_value(oneapi::mkl::dft::config_param::WORKSPACE,
512  oneapi::mkl::dft::config_value::WORKSPACE_EXTERNAL);
513  pp->commit(amrex::Gpu::Device::streamQueue());
514  plan = pp;
515 
516 #endif
517 
518 #else /* FFTW */
519  auto fftw_kind = get_fftw_kind<D>(bc);
520  if constexpr (std::is_same_v<float,T>) {
521  plan = fftwf_plan_many_r2r
522  (1, &n, howmany, p, nullptr, 1, n, p, nullptr, 1, n, &fftw_kind,
523  FFTW_ESTIMATE);
524  } else {
525  plan = fftw_plan_many_r2r
526  (1, &n, howmany, p, nullptr, 1, n, p, nullptr, 1, n, &fftw_kind,
527  FFTW_ESTIMATE);
528  }
529 #endif
530  }
531 
532  template <Direction D>
533  void init_r2r (Box const& box, VendorComplex* pc,
534  std::pair<Boundary,Boundary> const& bc)
535  {
536  static_assert(D == Direction::forward || D == Direction::backward);
537 
538  auto* p = (T*)pc;
539 
540 #if defined(AMREX_USE_GPU)
541 
542  init_r2r<D>(box, p, bc, 2);
543  r2r_data_is_complex = true;
544 
545 #else
546 
547  kind = get_r2r_kind<D>(bc);
548  defined = true;
549  pf = (void*)p;
550  pb = (void*)p;
551 
552  n = box.length(0);
553  howmany = AMREX_D_TERM(1, *box.length(1), *box.length(2));
554 
555  defined2 = true;
556  auto fftw_kind = get_fftw_kind<D>(bc);
557  if constexpr (std::is_same_v<float,T>) {
558  plan = fftwf_plan_many_r2r
559  (1, &n, howmany, p, nullptr, 2, n*2, p, nullptr, 2, n*2, &fftw_kind,
560  FFTW_ESTIMATE);
561  plan2 = fftwf_plan_many_r2r
562  (1, &n, howmany, p+1, nullptr, 2, n*2, p+1, nullptr, 2, n*2, &fftw_kind,
563  FFTW_ESTIMATE);
564  } else {
565  plan = fftw_plan_many_r2r
566  (1, &n, howmany, p, nullptr, 2, n*2, p, nullptr, 2, n*2, &fftw_kind,
567  FFTW_ESTIMATE);
568  plan2 = fftw_plan_many_r2r
569  (1, &n, howmany, p+1, nullptr, 2, n*2, p+1, nullptr, 2, n*2, &fftw_kind,
570  FFTW_ESTIMATE);
571  }
572 #endif
573  }
574 
575  template <Direction D>
576  void compute_r2c ()
577  {
578  static_assert(D == Direction::forward || D == Direction::backward);
579  if (!defined) { return; }
580 
581  using TI = std::conditional_t<(D == Direction::forward), T, VendorComplex>;
582  using TO = std::conditional_t<(D == Direction::backward), T, VendorComplex>;
583  auto* pi = (TI*)((D == Direction::forward) ? pf : pb);
584  auto* po = (TO*)((D == Direction::forward) ? pb : pf);
585 
586 #if defined(AMREX_USE_CUDA)
587  AMREX_CUFFT_SAFE_CALL(cufftSetStream(plan, Gpu::gpuStream()));
588 
589  std::size_t work_size = 0;
590  AMREX_CUFFT_SAFE_CALL(cufftGetSize(plan, &work_size));
591 
592  auto* work_area = The_Arena()->alloc(work_size);
593  AMREX_CUFFT_SAFE_CALL(cufftSetWorkArea(plan, work_area));
594 
595  if constexpr (D == Direction::forward) {
596  if constexpr (std::is_same_v<float,T>) {
597  AMREX_CUFFT_SAFE_CALL(cufftExecR2C(plan, pi, po));
598  } else {
599  AMREX_CUFFT_SAFE_CALL(cufftExecD2Z(plan, pi, po));
600  }
601  } else {
602  if constexpr (std::is_same_v<float,T>) {
603  AMREX_CUFFT_SAFE_CALL(cufftExecC2R(plan, pi, po));
604  } else {
605  AMREX_CUFFT_SAFE_CALL(cufftExecZ2D(plan, pi, po));
606  }
607  }
609  The_Arena()->free(work_area);
610 #elif defined(AMREX_USE_HIP)
611  detail::hip_execute(plan, (void**)&pi, (void**)&po);
612 #elif defined(AMREX_USE_SYCL)
613  detail::sycl_execute<T,D>(std::get<0>(plan), pi, po);
614 #else
616  if constexpr (std::is_same_v<float,T>) {
617  fftwf_execute(plan);
618  } else {
619  fftw_execute(plan);
620  }
621 #endif
622  }
623 
624  template <Direction D>
625  void compute_c2c ()
626  {
627  static_assert(D == Direction::forward || D == Direction::backward);
628  if (!defined) { return; }
629 
630  auto* p = (VendorComplex*)pf;
631 
632 #if defined(AMREX_USE_CUDA)
633  AMREX_CUFFT_SAFE_CALL(cufftSetStream(plan, Gpu::gpuStream()));
634 
635  std::size_t work_size = 0;
636  AMREX_CUFFT_SAFE_CALL(cufftGetSize(plan, &work_size));
637 
638  auto* work_area = The_Arena()->alloc(work_size);
639  AMREX_CUFFT_SAFE_CALL(cufftSetWorkArea(plan, work_area));
640 
641  auto dir = (D == Direction::forward) ? CUFFT_FORWARD : CUFFT_INVERSE;
642  if constexpr (std::is_same_v<float,T>) {
643  AMREX_CUFFT_SAFE_CALL(cufftExecC2C(plan, p, p, dir));
644  } else {
645  AMREX_CUFFT_SAFE_CALL(cufftExecZ2Z(plan, p, p, dir));
646  }
648  The_Arena()->free(work_area);
649 #elif defined(AMREX_USE_HIP)
650  detail::hip_execute(plan, (void**)&p, (void**)&p);
651 #elif defined(AMREX_USE_SYCL)
652  detail::sycl_execute<T,D>(std::get<1>(plan), p, p);
653 #else
655  if constexpr (std::is_same_v<float,T>) {
656  fftwf_execute(plan);
657  } else {
658  fftw_execute(plan);
659  }
660 #endif
661  }
662 
663 #ifdef AMREX_USE_GPU
664  [[nodiscard]] void* alloc_scratch_space () const
665  {
666  int nc = 0;
667  if (kind == Kind::r2r_oo_f || kind == Kind::r2r_ee_f) {
668  nc = n + 1;
669  } else if (kind == Kind::r2r_oo_b || kind == Kind::r2r_ee_b ||
670  kind == Kind::r2r_oe || kind == Kind::r2r_eo) {
671  nc = 2*n+1;
672  } else {
673  amrex::Abort("FFT: alloc_scratch_space: unsupported kind");
674  }
675  return The_Arena()->alloc(sizeof(GpuComplex<T>)*nc*howmany);
676  }
677 
678  static void free_scratch_space (void* p) { The_Arena()->free(p); }
679 
680  void pack_r2r_buffer (void* pbuf, T const* psrc) const
681  {
682  auto* pdst = (T*) pbuf;
683  if (kind == Kind::r2r_oo_f || kind == Kind::r2r_ee_f) {
684  T sign = (kind == Kind::r2r_oo_f) ? T(-1) : T(1);
685  int ostride = (n+1)*2;
686  int istride = n;
687  int nex = 2*n;
688  int norig = n;
689  Long nelems = Long(nex)*howmany;
690  if (r2r_data_is_complex) {
691  ParallelFor(nelems/2, [=] AMREX_GPU_DEVICE (Long ielem)
692  {
693  auto batch = ielem / Long(nex);
694  auto i = int(ielem - batch*nex);
695  for (int ir = 0; ir < 2; ++ir) {
696  auto* po = pdst + (2*batch+ir)*ostride + i;
697  auto const* pi = psrc + 2*batch*istride + ir;
698  if (i < norig) {
699  *po = pi[i*2];
700  } else {
701  *po = sign * pi[(2*norig-1-i)*2];
702  }
703  }
704  });
705  } else {
706  ParallelFor(nelems, [=] AMREX_GPU_DEVICE (Long ielem)
707  {
708  auto batch = ielem / Long(nex);
709  auto i = int(ielem - batch*nex);
710  auto* po = pdst + batch*ostride + i;
711  auto const* pi = psrc + batch*istride;
712  if (i < norig) {
713  *po = pi[i];
714  } else {
715  *po = sign * pi[2*norig-1-i];
716  }
717  });
718  }
719  } else if (kind == Kind::r2r_oo_b) {
720  int ostride = (2*n+1)*2;
721  int istride = n;
722  int nex = 4*n;
723  int norig = n;
724  Long nelems = Long(nex)*howmany;
725  if (r2r_data_is_complex) {
726  ParallelFor(nelems/2, [=] AMREX_GPU_DEVICE (Long ielem)
727  {
728  auto batch = ielem / Long(nex);
729  auto i = int(ielem - batch*nex);
730  for (int ir = 0; ir < 2; ++ir) {
731  auto* po = pdst + (2*batch+ir)*ostride + i;
732  auto const* pi = psrc + 2*batch*istride + ir;
733  if (i < norig) {
734  *po = pi[i*2];
735  } else if (i < (2*norig-1)) {
736  *po = pi[(2*norig-2-i)*2];
737  } else if (i == (2*norig-1)) {
738  *po = T(0);
739  } else if (i < (3*norig)) {
740  *po = -pi[(i-2*norig)*2];
741  } else if (i < (4*norig-1)) {
742  *po = -pi[(4*norig-2-i)*2];
743  } else {
744  *po = T(0);
745  }
746  }
747  });
748  } else {
749  ParallelFor(nelems, [=] AMREX_GPU_DEVICE (Long ielem)
750  {
751  auto batch = ielem / Long(nex);
752  auto i = int(ielem - batch*nex);
753  auto* po = pdst + batch*ostride + i;
754  auto const* pi = psrc + batch*istride;
755  if (i < norig) {
756  *po = pi[i];
757  } else if (i < (2*norig-1)) {
758  *po = pi[2*norig-2-i];
759  } else if (i == (2*norig-1)) {
760  *po = T(0);
761  } else if (i < (3*norig)) {
762  *po = -pi[i-2*norig];
763  } else if (i < (4*norig-1)) {
764  *po = -pi[4*norig-2-i];
765  } else {
766  *po = T(0);
767  }
768  });
769  }
770  } else if (kind == Kind::r2r_ee_b) {
771  int ostride = (2*n+1)*2;
772  int istride = n;
773  int nex = 4*n;
774  int norig = n;
775  Long nelems = Long(nex)*howmany;
776  if (r2r_data_is_complex) {
777  ParallelFor(nelems/2, [=] AMREX_GPU_DEVICE (Long ielem)
778  {
779  auto batch = ielem / Long(nex);
780  auto i = int(ielem - batch*nex);
781  for (int ir = 0; ir < 2; ++ir) {
782  auto* po = pdst + (2*batch+ir)*ostride + i;
783  auto const* pi = psrc + 2*batch*istride + ir;
784  if (i < norig) {
785  *po = pi[i*2];
786  } else if (i == norig) {
787  *po = T(0);
788  } else if (i < (2*norig+1)) {
789  *po = -pi[(2*norig-i)*2];
790  } else if (i < (3*norig)) {
791  *po = -pi[(i-2*norig)*2];
792  } else if (i == 3*norig) {
793  *po = T(0);
794  } else {
795  *po = pi[(4*norig-i)*2];
796  }
797  }
798  });
799  } else {
800  ParallelFor(nelems, [=] AMREX_GPU_DEVICE (Long ielem)
801  {
802  auto batch = ielem / Long(nex);
803  auto i = int(ielem - batch*nex);
804  auto* po = pdst + batch*ostride + i;
805  auto const* pi = psrc + batch*istride;
806  if (i < norig) {
807  *po = pi[i];
808  } else if (i == norig) {
809  *po = T(0);
810  } else if (i < (2*norig+1)) {
811  *po = -pi[2*norig-i];
812  } else if (i < (3*norig)) {
813  *po = -pi[i-2*norig];
814  } else if (i == 3*norig) {
815  *po = T(0);
816  } else {
817  *po = pi[4*norig-i];
818  }
819  });
820  }
821  } else if (kind == Kind::r2r_eo) {
822  int ostride = (2*n+1)*2;
823  int istride = n;
824  int nex = 4*n;
825  int norig = n;
826  Long nelems = Long(nex)*howmany;
827  if (r2r_data_is_complex) {
828  ParallelFor(nelems/2, [=] AMREX_GPU_DEVICE (Long ielem)
829  {
830  auto batch = ielem / Long(nex);
831  auto i = int(ielem - batch*nex);
832  for (int ir = 0; ir < 2; ++ir) {
833  auto* po = pdst + (2*batch+ir)*ostride + i;
834  auto const* pi = psrc + 2*batch*istride + ir;
835  if (i < norig) {
836  *po = pi[i*2];
837  } else if (i < (2*norig)) {
838  *po = -pi[(2*norig-1-i)*2];
839  } else if (i < (3*norig)) {
840  *po = -pi[(i-2*norig)*2];
841  } else {
842  *po = pi[(4*norig-1-i)*2];
843  }
844  }
845  });
846  } else {
847  ParallelFor(nelems, [=] AMREX_GPU_DEVICE (Long ielem)
848  {
849  auto batch = ielem / Long(nex);
850  auto i = int(ielem - batch*nex);
851  auto* po = pdst + batch*ostride + i;
852  auto const* pi = psrc + batch*istride;
853  if (i < norig) {
854  *po = pi[i];
855  } else if (i < (2*norig)) {
856  *po = -pi[2*norig-1-i];
857  } else if (i < (3*norig)) {
858  *po = -pi[i-2*norig];
859  } else {
860  *po = pi[4*norig-1-i];
861  }
862  });
863  }
864  } else if (kind == Kind::r2r_oe) {
865  int ostride = (2*n+1)*2;
866  int istride = n;
867  int nex = 4*n;
868  int norig = n;
869  Long nelems = Long(nex)*howmany;
870  if (r2r_data_is_complex) {
871  ParallelFor(nelems/2, [=] AMREX_GPU_DEVICE (Long ielem)
872  {
873  auto batch = ielem / Long(nex);
874  auto i = int(ielem - batch*nex);
875  for (int ir = 0; ir < 2; ++ir) {
876  auto* po = pdst + (2*batch+ir)*ostride + i;
877  auto const* pi = psrc + 2*batch*istride + ir;
878  if (i < norig) {
879  *po = pi[i*2];
880  } else if (i < (2*norig)) {
881  *po = pi[(2*norig-1-i)*2];
882  } else if (i < (3*norig)) {
883  *po = -pi[(i-2*norig)*2];
884  } else {
885  *po = -pi[(4*norig-1-i)*2];
886  }
887  }
888  });
889  } else {
890  ParallelFor(nelems, [=] AMREX_GPU_DEVICE (Long ielem)
891  {
892  auto batch = ielem / Long(nex);
893  auto i = int(ielem - batch*nex);
894  auto* po = pdst + batch*ostride + i;
895  auto const* pi = psrc + batch*istride;
896  if (i < norig) {
897  *po = pi[i];
898  } else if (i < (2*norig)) {
899  *po = pi[2*norig-1-i];
900  } else if (i < (3*norig)) {
901  *po = -pi[i-2*norig];
902  } else {
903  *po = -pi[4*norig-1-i];
904  }
905  });
906  }
907  } else {
908  amrex::Abort("FFT: pack_r2r_buffer: unsupported kind");
909  }
910  }
911 
912  void unpack_r2r_buffer (T* pdst, void const* pbuf) const
913  {
914  auto const* psrc = (GpuComplex<T> const*) pbuf;
915  int norig = n;
916  Long nelems = Long(norig)*howmany;
917  int ostride = n;
918 
919  if (kind == Kind::r2r_oo_f) {
920  int istride = n+1;
921  if (r2r_data_is_complex) {
922  ParallelFor(nelems/2, [=] AMREX_GPU_DEVICE (Long ielem)
923  {
924  auto batch = ielem / Long(norig);
925  auto k = int(ielem - batch*norig);
926  auto [s, c] = Math::sincospi(T(k+1)/T(2*norig));
927  for (int ir = 0; ir < 2; ++ir) {
928  auto const& yk = psrc[(2*batch+ir)*istride+k+1];
929  pdst[2*batch*ostride+ir+k*2] = s * yk.real() - c * yk.imag();
930  }
931  });
932  } else {
933  ParallelFor(nelems, [=] AMREX_GPU_DEVICE (Long ielem)
934  {
935  auto batch = ielem / Long(norig);
936  auto k = int(ielem - batch*norig);
937  auto [s, c] = Math::sincospi(T(k+1)/T(2*norig));
938  auto const& yk = psrc[batch*istride+k+1];
939  pdst[batch*ostride+k] = s * yk.real() - c * yk.imag();
940  });
941  }
942  } else if (kind == Kind::r2r_oo_b) {
943  int istride = 2*n+1;
944  if (r2r_data_is_complex) {
945  ParallelFor(nelems/2, [=] AMREX_GPU_DEVICE (Long ielem)
946  {
947  auto batch = ielem / Long(norig);
948  auto k = int(ielem - batch*norig);
949  auto [s, c] = Math::sincospi(T(2*k+1)/T(2*norig));
950  for (int ir = 0; ir < 2; ++ir) {
951  auto const& yk = psrc[(2*batch+ir)*istride+2*k+1];
952  pdst[2*batch*ostride+ir+k*2] = T(0.5)*(s * yk.real() - c * yk.imag());
953  }
954  });
955  } else {
956  ParallelFor(nelems, [=] AMREX_GPU_DEVICE (Long ielem)
957  {
958  auto batch = ielem / Long(norig);
959  auto k = int(ielem - batch*norig);
960  auto [s, c] = Math::sincospi(T(2*k+1)/T(2*norig));
961  auto const& yk = psrc[batch*istride+2*k+1];
962  pdst[batch*ostride+k] = T(0.5)*(s * yk.real() - c * yk.imag());
963  });
964  }
965  } else if (kind == Kind::r2r_ee_f) {
966  int istride = n+1;
967  if (r2r_data_is_complex) {
968  ParallelFor(nelems/2, [=] AMREX_GPU_DEVICE (Long ielem)
969  {
970  auto batch = ielem / Long(norig);
971  auto k = int(ielem - batch*norig);
972  auto [s, c] = Math::sincospi(T(k)/T(2*norig));
973  for (int ir = 0; ir < 2; ++ir) {
974  auto const& yk = psrc[(2*batch+ir)*istride+k];
975  pdst[2*batch*ostride+ir+k*2] = c * yk.real() + s * yk.imag();
976  }
977  });
978  } else {
979  ParallelFor(nelems, [=] AMREX_GPU_DEVICE (Long ielem)
980  {
981  auto batch = ielem / Long(norig);
982  auto k = int(ielem - batch*norig);
983  auto [s, c] = Math::sincospi(T(k)/T(2*norig));
984  auto const& yk = psrc[batch*istride+k];
985  pdst[batch*ostride+k] = c * yk.real() + s * yk.imag();
986  });
987  }
988  } else if (kind == Kind::r2r_ee_b) {
989  int istride = 2*n+1;
990  if (r2r_data_is_complex) {
991  ParallelFor(nelems/2, [=] AMREX_GPU_DEVICE (Long ielem)
992  {
993  auto batch = ielem / Long(norig);
994  auto k = int(ielem - batch*norig);
995  for (int ir = 0; ir < 2; ++ir) {
996  auto const& yk = psrc[(2*batch+ir)*istride+2*k+1];
997  pdst[2*batch*ostride+ir+k*2] = T(0.5) * yk.real();
998  }
999  });
1000  } else {
1001  ParallelFor(nelems, [=] AMREX_GPU_DEVICE (Long ielem)
1002  {
1003  auto batch = ielem / Long(norig);
1004  auto k = int(ielem - batch*norig);
1005  auto const& yk = psrc[batch*istride+2*k+1];
1006  pdst[batch*ostride+k] = T(0.5) * yk.real();
1007  });
1008  }
1009  } else if (kind == Kind::r2r_eo) {
1010  int istride = 2*n+1;
1011  if (r2r_data_is_complex) {
1012  ParallelFor(nelems/2, [=] AMREX_GPU_DEVICE (Long ielem)
1013  {
1014  auto batch = ielem / Long(norig);
1015  auto k = int(ielem - batch*norig);
1016  auto [s, c] = Math::sincospi((k+T(0.5))/T(2*norig));
1017  for (int ir = 0; ir < 2; ++ir) {
1018  auto const& yk = psrc[(2*batch+ir)*istride+2*k+1];
1019  pdst[2*batch*ostride+ir+k*2] = T(0.5) * (c * yk.real() + s * yk.imag());
1020  }
1021  });
1022  } else {
1023  ParallelFor(nelems, [=] AMREX_GPU_DEVICE (Long ielem)
1024  {
1025  auto batch = ielem / Long(norig);
1026  auto k = int(ielem - batch*norig);
1027  auto [s, c] = Math::sincospi((k+T(0.5))/T(2*norig));
1028  auto const& yk = psrc[batch*istride+2*k+1];
1029  pdst[batch*ostride+k] = T(0.5) * (c * yk.real() + s * yk.imag());
1030  });
1031  }
1032  } else if (kind == Kind::r2r_oe) {
1033  int istride = 2*n+1;
1034  if (r2r_data_is_complex) {
1035  ParallelFor(nelems/2, [=] AMREX_GPU_DEVICE (Long ielem)
1036  {
1037  auto batch = ielem / Long(norig);
1038  auto k = int(ielem - batch*norig);
1039  auto [s, c] = Math::sincospi((k+T(0.5))/T(2*norig));
1040  for (int ir = 0; ir < 2; ++ir) {
1041  auto const& yk = psrc[(2*batch+ir)*istride+2*k+1];
1042  pdst[2*batch*ostride+ir+k*2] = T(0.5) * (s * yk.real() - c * yk.imag());
1043  }
1044  });
1045  } else {
1046  ParallelFor(nelems, [=] AMREX_GPU_DEVICE (Long ielem)
1047  {
1048  auto batch = ielem / Long(norig);
1049  auto k = int(ielem - batch*norig);
1050  auto [s, c] = Math::sincospi((k+T(0.5))/T(2*norig));
1051  auto const& yk = psrc[batch*istride+2*k+1];
1052  pdst[batch*ostride+k] = T(0.5) * (s * yk.real() - c * yk.imag());
1053  });
1054  }
1055  } else {
1056  amrex::Abort("FFT: unpack_r2r_buffer: unsupported kind");
1057  }
1058  }
1059 #endif
1060 
1061  template <Direction D>
1062  void compute_r2r ()
1063  {
1064  static_assert(D == Direction::forward || D == Direction::backward);
1065  if (!defined) { return; }
1066 
1067 #if defined(AMREX_USE_GPU)
1068 
1069  auto* pscratch = alloc_scratch_space();
1070 
1071  pack_r2r_buffer(pscratch, (T*)((D == Direction::forward) ? pf : pb));
1072 
1073 #if defined(AMREX_USE_CUDA)
1074 
1075  AMREX_CUFFT_SAFE_CALL(cufftSetStream(plan, Gpu::gpuStream()));
1076 
1077  std::size_t work_size = 0;
1078  AMREX_CUFFT_SAFE_CALL(cufftGetSize(plan, &work_size));
1079 
1080  auto* work_area = The_Arena()->alloc(work_size);
1081  AMREX_CUFFT_SAFE_CALL(cufftSetWorkArea(plan, work_area));
1082 
1083  if constexpr (std::is_same_v<float,T>) {
1084  AMREX_CUFFT_SAFE_CALL(cufftExecR2C(plan, (T*)pscratch, (VendorComplex*)pscratch));
1085  } else {
1086  AMREX_CUFFT_SAFE_CALL(cufftExecD2Z(plan, (T*)pscratch, (VendorComplex*)pscratch));
1087  }
1088 
1089 #elif defined(AMREX_USE_HIP)
1090  detail::hip_execute(plan, (void**)&pscratch, (void**)&pscratch);
1091 #elif defined(AMREX_USE_SYCL)
1092  detail::sycl_execute<T,Direction::forward>(std::get<0>(plan), (T*)pscratch, (VendorComplex*)pscratch);
1093 #endif
1094 
1095  unpack_r2r_buffer((T*)((D == Direction::forward) ? pb : pf), pscratch);
1096 
1098  free_scratch_space(pscratch);
1099 #if defined(AMREX_USE_CUDA)
1100  The_Arena()->free(work_area);
1101 #endif
1102 
1103 #else /* FFTW */
1104 
1105  if constexpr (std::is_same_v<float,T>) {
1106  fftwf_execute(plan);
1107  if (defined2) { fftwf_execute(plan2); }
1108  } else {
1109  fftw_execute(plan);
1110  if (defined2) { fftw_execute(plan2); }
1111  }
1112 
1113 #endif
1114  }
1115 
1117  {
1118 #if defined(AMREX_USE_CUDA)
1119  AMREX_CUFFT_SAFE_CALL(cufftDestroy(plan));
1120 #elif defined(AMREX_USE_HIP)
1121  AMREX_ROCFFT_SAFE_CALL(rocfft_plan_destroy(plan));
1122 #elif defined(AMREX_USE_SYCL)
1123  std::visit([](auto&& p) { delete p; }, plan);
1124 #else
1125  if constexpr (std::is_same_v<float,T>) {
1126  fftwf_destroy_plan(plan);
1127  } else {
1128  fftw_destroy_plan(plan);
1129  }
1130 #endif
1131  }
1132 };
1133 
1134 using Key = std::tuple<IntVectND<3>,Direction,Kind>;
1137 
1138 PlanD* get_vendor_plan_d (Key const& key);
1139 PlanF* get_vendor_plan_f (Key const& key);
1140 
1141 void add_vendor_plan_d (Key const& key, PlanD plan);
1142 void add_vendor_plan_f (Key const& key, PlanF plan);
1143 
1144 template <typename T>
1145 template <Direction D, int M>
1146 void Plan<T>::init_r2c (IntVectND<M> const& fft_size, void* pbf, void* pbb, bool cache)
1147 {
1148  static_assert(D == Direction::forward || D == Direction::backward);
1149 
1150  kind = (D == Direction::forward) ? Kind::r2c_f : Kind::r2c_b;
1151  defined = true;
1152  pf = pbf;
1153  pb = pbb;
1154 
1155  n = 1;
1156  for (auto s : fft_size) { n *= s; }
1157  howmany = 1;
1158 
1159 #if defined(AMREX_USE_GPU)
1160  Key key = {fft_size.template expand<3>(), D, kind};
1161  if (cache) {
1162  VendorPlan* cached_plan = nullptr;
1163  if constexpr (std::is_same_v<float,T>) {
1164  cached_plan = get_vendor_plan_f(key);
1165  } else {
1166  cached_plan = get_vendor_plan_d(key);
1167  }
1168  if (cached_plan) {
1169  plan = *cached_plan;
1170  return;
1171  }
1172  }
1173 #else
1174  amrex::ignore_unused(cache);
1175 #endif
1176 
1177 #if defined(AMREX_USE_CUDA)
1178 
1179  AMREX_CUFFT_SAFE_CALL(cufftCreate(&plan));
1180  AMREX_CUFFT_SAFE_CALL(cufftSetAutoAllocation(plan, 0));
1181  cufftType type;
1182  if constexpr (D == Direction::forward) {
1183  type = std::is_same_v<float,T> ? CUFFT_R2C : CUFFT_D2Z;
1184  } else {
1185  type = std::is_same_v<float,T> ? CUFFT_C2R : CUFFT_Z2D;
1186  }
1187  std::size_t work_size;
1188  if constexpr (M == 1) {
1190  (cufftMakePlan1d(plan, fft_size[0], type, howmany, &work_size));
1191  } else if constexpr (M == 2) {
1193  (cufftMakePlan2d(plan, fft_size[1], fft_size[0], type, &work_size));
1194  } else if constexpr (M == 3) {
1196  (cufftMakePlan3d(plan, fft_size[2], fft_size[1], fft_size[0], type, &work_size));
1197  }
1198 
1199 #elif defined(AMREX_USE_HIP)
1200 
1201  auto prec = std::is_same_v<float,T> ? rocfft_precision_single : rocfft_precision_double;
1202  std::size_t length[M];
1203  for (int idim = 0; idim < M; ++idim) { length[idim] = fft_size[idim]; }
1204  if constexpr (D == Direction::forward) {
1205  AMREX_ROCFFT_SAFE_CALL
1206  (rocfft_plan_create(&plan, rocfft_placement_notinplace,
1207  rocfft_transform_type_real_forward, prec, M,
1208  length, howmany, nullptr));
1209  } else {
1210  AMREX_ROCFFT_SAFE_CALL
1211  (rocfft_plan_create(&plan, rocfft_placement_notinplace,
1212  rocfft_transform_type_real_inverse, prec, M,
1213  length, howmany, nullptr));
1214  }
1215 
1216 #elif defined(AMREX_USE_SYCL)
1217 
1218  mkl_desc_r* pp;
1219  if (M == 1) {
1220  pp = new mkl_desc_r(fft_size[0]);
1221  } else {
1222  std::vector<std::int64_t> len(M);
1223  for (int idim = 0; idim < M; ++idim) {
1224  len[idim] = fft_size[M-1-idim];
1225  }
1226  pp = new mkl_desc_r(len);
1227  }
1228 #ifndef AMREX_USE_MKL_DFTI_2024
1229  pp->set_value(oneapi::mkl::dft::config_param::PLACEMENT,
1230  oneapi::mkl::dft::config_value::NOT_INPLACE);
1231 #else
1232  pp->set_value(oneapi::mkl::dft::config_param::PLACEMENT, DFTI_NOT_INPLACE);
1233 #endif
1234 
1235  std::vector<std::int64_t> strides(M+1);
1236  strides[0] = 0;
1237  strides[M] = 1;
1238  for (int i = M-1; i >= 1; --i) {
1239  strides[i] = strides[i+1] * fft_size[M-1-i];
1240  }
1241 
1242 #ifndef AMREX_USE_MKL_DFTI_2024
1243  pp->set_value(oneapi::mkl::dft::config_param::FWD_STRIDES, strides);
1244  // Do not set BWD_STRIDES
1245 #else
1246  pp->set_value(oneapi::mkl::dft::config_param::FWD_STRIDES, strides.data());
1247  // Do not set BWD_STRIDES
1248 #endif
1249  pp->set_value(oneapi::mkl::dft::config_param::WORKSPACE,
1250  oneapi::mkl::dft::config_value::WORKSPACE_EXTERNAL);
1251  pp->commit(amrex::Gpu::Device::streamQueue());
1252  plan = pp;
1253 
1254 #else /* FFTW */
1255 
1256  if (pf == nullptr || pb == nullptr) {
1257  defined = false;
1258  return;
1259  }
1260 
1261  int size_for_row_major[M];
1262  for (int idim = 0; idim < M; ++idim) {
1263  size_for_row_major[idim] = fft_size[M-1-idim];
1264  }
1265 
1266  if constexpr (std::is_same_v<float,T>) {
1267  if constexpr (D == Direction::forward) {
1268  plan = fftwf_plan_dft_r2c
1269  (M, size_for_row_major, (float*)pf, (fftwf_complex*)pb,
1270  FFTW_ESTIMATE);
1271  } else {
1272  plan = fftwf_plan_dft_c2r
1273  (M, size_for_row_major, (fftwf_complex*)pb, (float*)pf,
1274  FFTW_ESTIMATE);
1275  }
1276  } else {
1277  if constexpr (D == Direction::forward) {
1278  plan = fftw_plan_dft_r2c
1279  (M, size_for_row_major, (double*)pf, (fftw_complex*)pb,
1280  FFTW_ESTIMATE);
1281  } else {
1282  plan = fftw_plan_dft_c2r
1283  (M, size_for_row_major, (fftw_complex*)pb, (double*)pf,
1284  FFTW_ESTIMATE);
1285  }
1286  }
1287 #endif
1288 
1289 #if defined(AMREX_USE_GPU)
1290  if (cache) {
1291  if constexpr (std::is_same_v<float,T>) {
1292  add_vendor_plan_f(key, plan);
1293  } else {
1294  add_vendor_plan_d(key, plan);
1295  }
1296  }
1297 #endif
1298 }
1299 
1300 namespace detail
1301 {
1303 
1304  template <typename FA>
1305  typename FA::FABType::value_type * get_fab (FA& fa)
1306  {
1307  auto myproc = ParallelContext::MyProcSub();
1308  if (myproc < fa.size()) {
1309  return fa.fabPtr(myproc);
1310  } else {
1311  return nullptr;
1312  }
1313  }
1314 
1315  template <typename FA1, typename FA2>
1316  std::unique_ptr<char,DataDeleter> make_mfs_share (FA1& fa1, FA2& fa2)
1317  {
1318  bool not_same_fa = true;
1319  if constexpr (std::is_same_v<FA1,FA2>) {
1320  not_same_fa = (&fa1 != &fa2);
1321  }
1322  using FAB1 = typename FA1::FABType::value_type;
1323  using FAB2 = typename FA2::FABType::value_type;
1324  using T1 = typename FAB1::value_type;
1325  using T2 = typename FAB2::value_type;
1326  auto myproc = ParallelContext::MyProcSub();
1327  bool alloc_1 = (myproc < fa1.size());
1328  bool alloc_2 = (myproc < fa2.size()) && not_same_fa;
1329  void* p = nullptr;
1330  if (alloc_1 && alloc_2) {
1331  Box const& box1 = fa1.fabbox(myproc);
1332  Box const& box2 = fa2.fabbox(myproc);
1333  int ncomp1 = fa1.nComp();
1334  int ncomp2 = fa2.nComp();
1335  p = The_Arena()->alloc(std::max(sizeof(T1)*box1.numPts()*ncomp1,
1336  sizeof(T2)*box2.numPts()*ncomp2));
1337  fa1.setFab(myproc, FAB1(box1, ncomp1, (T1*)p));
1338  fa2.setFab(myproc, FAB2(box2, ncomp2, (T2*)p));
1339  } else if (alloc_1) {
1340  Box const& box1 = fa1.fabbox(myproc);
1341  int ncomp1 = fa1.nComp();
1342  p = The_Arena()->alloc(sizeof(T1)*box1.numPts()*ncomp1);
1343  fa1.setFab(myproc, FAB1(box1, ncomp1, (T1*)p));
1344  } else if (alloc_2) {
1345  Box const& box2 = fa2.fabbox(myproc);
1346  int ncomp2 = fa2.nComp();
1347  p = The_Arena()->alloc(sizeof(T2)*box2.numPts()*ncomp2);
1348  fa2.setFab(myproc, FAB2(box2, ncomp2, (T2*)p));
1349  } else {
1350  return nullptr;
1351  }
1352  return std::unique_ptr<char,DataDeleter>((char*)p, DataDeleter{The_Arena()});
1353  }
1354 }
1355 
1356 struct Swap01
1357 {
1358  [[nodiscard]] constexpr Dim3 operator() (Dim3 i) const noexcept
1359  {
1360  return {i.y, i.x, i.z};
1361  }
1362 
1363  static constexpr Dim3 Inverse (Dim3 i)
1364  {
1365  return {i.y, i.x, i.z};
1366  }
1367 
1368  [[nodiscard]] constexpr IndexType operator() (IndexType it) const noexcept
1369  {
1370  return it;
1371  }
1372 
1373  static constexpr IndexType Inverse (IndexType it)
1374  {
1375  return it;
1376  }
1377 };
1378 
1379 struct Swap02
1380 {
1381  [[nodiscard]] constexpr Dim3 operator() (Dim3 i) const noexcept
1382  {
1383  return {i.z, i.y, i.x};
1384  }
1385 
1386  static constexpr Dim3 Inverse (Dim3 i)
1387  {
1388  return {i.z, i.y, i.x};
1389  }
1390 
1391  [[nodiscard]] constexpr IndexType operator() (IndexType it) const noexcept
1392  {
1393  return it;
1394  }
1395 
1396  static constexpr IndexType Inverse (IndexType it)
1397  {
1398  return it;
1399  }
1400 };
1401 
1403 {
1404  // dest -> src: (x,y,z) -> (y,z,x)
1405  [[nodiscard]] constexpr Dim3 operator() (Dim3 i) const noexcept
1406  {
1407  return {i.y, i.z, i.x};
1408  }
1409 
1410  // src -> dest: (x,y,z) -> (z,x,y)
1411  static constexpr Dim3 Inverse (Dim3 i)
1412  {
1413  return {i.z, i.x, i.y};
1414  }
1415 
1416  [[nodiscard]] constexpr IndexType operator() (IndexType it) const noexcept
1417  {
1418  return it;
1419  }
1420 
1421  static constexpr IndexType Inverse (IndexType it)
1422  {
1423  return it;
1424  }
1425 };
1426 
1428 {
1429  // dest -> src: (x,y,z) -> (z,x,y)
1430  [[nodiscard]] constexpr Dim3 operator() (Dim3 i) const noexcept
1431  {
1432  return {i.z, i.x, i.y};
1433  }
1434 
1435  // src -> dest: (x,y,z) -> (y,z,x)
1436  static constexpr Dim3 Inverse (Dim3 i)
1437  {
1438  return {i.y, i.z, i.x};
1439  }
1440 
1441  [[nodiscard]] constexpr IndexType operator() (IndexType it) const noexcept
1442  {
1443  return it;
1444  }
1445 
1446  static constexpr IndexType Inverse (IndexType it)
1447  {
1448  return it;
1449  }
1450 };
1451 
1452 namespace detail
1453 {
1454  struct SubHelper
1455  {
1456  explicit SubHelper (Box const& domain);
1457 
1458  [[nodiscard]] Box make_box (Box const& box) const;
1459 
1460  [[nodiscard]] Periodicity make_periodicity (Periodicity const& period) const;
1461 
1462  [[nodiscard]] bool ghost_safe (IntVect const& ng) const;
1463 
1464  // This rearranges the order.
1465  [[nodiscard]] IntVect make_iv (IntVect const& iv) const;
1466 
1467  // This keeps the order, but zero out the values in the hidden dimension.
1468  [[nodiscard]] IntVect make_safe_ghost (IntVect const& ng) const;
1469 
1470  [[nodiscard]] BoxArray inverse_boxarray (BoxArray const& ba) const;
1471 
1472  [[nodiscard]] IntVect inverse_order (IntVect const& order) const;
1473 
1474  template <typename T>
1475  [[nodiscard]] T make_array (T const& a) const
1476  {
1477 #if (AMREX_SPACEDIM == 1)
1478  amrex::ignore_unused(this);
1479  return a;
1480 #elif (AMREX_SPACEDIM == 2)
1481  if (m_case == case_1n) {
1482  return T{a[1],a[0]};
1483  } else {
1484  return a;
1485  }
1486 #else
1487  if (m_case == case_11n) {
1488  return T{a[2],a[0],a[1]};
1489  } else if (m_case == case_1n1) {
1490  return T{a[1],a[0],a[2]};
1491  } else if (m_case == case_1nn) {
1492  return T{a[1],a[2],a[0]};
1493  } else if (m_case == case_n1n) {
1494  return T{a[0],a[2],a[1]};
1495  } else {
1496  return a;
1497  }
1498 #endif
1499  }
1500 
1501  [[nodiscard]] GpuArray<int,3> xyz_order () const;
1502 
1503  template <typename FA>
1504  FA make_alias_mf (FA const& mf)
1505  {
1506  BoxList bl = mf.boxArray().boxList();
1507  for (auto& b : bl) {
1508  b = make_box(b);
1509  }
1510  auto const& ng = make_iv(mf.nGrowVect());
1511  FA submf(BoxArray(std::move(bl)), mf.DistributionMap(), 1, ng, MFInfo{}.SetAlloc(false));
1512  using FAB = typename FA::fab_type;
1513  for (MFIter mfi(submf, MFItInfo().DisableDeviceSync()); mfi.isValid(); ++mfi) {
1514  submf.setFab(mfi, FAB(mfi.fabbox(), 1, mf[mfi].dataPtr()));
1515  }
1516  return submf;
1517  }
1518 
1519 #if (AMREX_SPACEDIM == 2)
1520  enum Case { case_1n, case_other };
1521  int m_case = case_other;
1522 #elif (AMREX_SPACEDIM == 3)
1523  enum Case { case_11n, case_1n1, case_1nn, case_n1n, case_other };
1524  int m_case = case_other;
1525 #endif
1526  };
1527 }
1528 
1529 }
1530 
1531 #endif
#define AMREX_CUFFT_SAFE_CALL(call)
Definition: AMReX_GpuError.H:88
#define AMREX_GPU_DEVICE
Definition: AMReX_GpuQualifiers.H:18
amrex::ParmParse pp
Input file parser instance for the given namespace.
Definition: AMReX_HypreIJIface.cpp:15
Real * pdst
Definition: AMReX_HypreMLABecLap.cpp:1090
#define AMREX_D_TERM(a, b, c)
Definition: AMReX_SPACE.H:129
virtual void free(void *pt)=0
A pure virtual function for deleting the arena pointed to by pt.
virtual void * alloc(std::size_t sz)=0
A collection of Boxes stored in an Array.
Definition: AMReX_BoxArray.H:550
A class for managing a List of Boxes that share a common IndexType. This class implements operations ...
Definition: AMReX_BoxList.H:52
AMREX_GPU_HOST_DEVICE IntVectND< dim > length() const noexcept
Return the length of the BoxND.
Definition: AMReX_Box.H:146
AMREX_GPU_HOST_DEVICE Long numPts() const noexcept
Returns the number of points contained in the BoxND.
Definition: AMReX_Box.H:346
Calculates the distribution of FABs to MPI processes.
Definition: AMReX_DistributionMapping.H:41
Definition: AMReX_IntVect.H:48
Definition: AMReX_MFIter.H:57
bool isValid() const noexcept
Is the iterator valid i.e. is it associated with a FAB?
Definition: AMReX_MFIter.H:141
This provides length of period for periodic domains. 0 means it is not periodic in that direction....
Definition: AMReX_Periodicity.H:17
@ FAB
Definition: AMReX_AmrvisConstants.H:86
FA::FABType::value_type * get_fab(FA &fa)
Definition: AMReX_FFT_Helper.H:1305
std::unique_ptr< char, DataDeleter > make_mfs_share(FA1 &fa1, FA2 &fa2)
Definition: AMReX_FFT_Helper.H:1316
DistributionMapping make_iota_distromap(Long n)
Definition: AMReX_FFT.cpp:88
Definition: AMReX_FFT.cpp:7
std::tuple< IntVectND< 3 >, Direction, Kind > Key
Definition: AMReX_FFT_Helper.H:1134
Direction
Definition: AMReX_FFT_Helper.H:48
void add_vendor_plan_f(Key const &key, PlanF plan)
Definition: AMReX_FFT.cpp:78
DomainStrategy
Definition: AMReX_FFT_Helper.H:50
typename Plan< float >::VendorPlan PlanF
Definition: AMReX_FFT_Helper.H:1136
AMREX_ENUM(Boundary, periodic, even, odd)
typename Plan< double >::VendorPlan PlanD
Definition: AMReX_FFT_Helper.H:1135
void add_vendor_plan_d(Key const &key, PlanD plan)
Definition: AMReX_FFT.cpp:73
Kind
Definition: AMReX_FFT_Helper.H:54
PlanF * get_vendor_plan_f(Key const &key)
Definition: AMReX_FFT.cpp:64
PlanD * get_vendor_plan_d(Key const &key)
Definition: AMReX_FFT.cpp:55
void streamSynchronize() noexcept
Definition: AMReX_GpuDevice.H:237
gpuStream_t gpuStream() noexcept
Definition: AMReX_GpuDevice.H:218
constexpr std::enable_if_t< std::is_floating_point_v< T >, T > pi()
Definition: AMReX_Math.H:62
AMREX_GPU_HOST_DEVICE AMREX_FORCE_INLINE std::pair< double, double > sincospi(double x)
Return sin(pi*x) and cos(pi*x) given x.
Definition: AMReX_Math.H:165
int MyProcSub() noexcept
my sub-rank in current frame
Definition: AMReX_ParallelContext.H:76
@ max
Definition: AMReX_ParallelReduce.H:17
static constexpr int M
Definition: AMReX_OpenBC.H:13
static constexpr int P
Definition: AMReX_OpenBC.H:14
std::enable_if_t< std::is_integral_v< T > > ParallelFor(TypeList< CTOs... > ctos, std::array< int, sizeof...(CTOs)> const &runtime_options, T N, F &&f)
Definition: AMReX_CTOParallelForImpl.H:200
AMREX_GPU_HOST_DEVICE AMREX_FORCE_INLINE void ignore_unused(const Ts &...)
This shuts up the compiler about unused variables.
Definition: AMReX.H:111
AMREX_GPU_HOST_DEVICE AMREX_FORCE_INLINE Dim3 length(Array4< T > const &a) noexcept
Definition: AMReX_Array4.H:322
void Abort(const std::string &msg)
Print out message to cerr and exit via abort().
Definition: AMReX.cpp:225
const int[]
Definition: AMReX_BLProfiler.cpp:1664
Arena * The_Arena()
Definition: AMReX_Arena.cpp:609
Definition: AMReX_FabArrayCommI.H:896
Definition: AMReX_DataAllocator.H:29
Definition: AMReX_Dim3.H:12
int x
Definition: AMReX_Dim3.H:12
int z
Definition: AMReX_Dim3.H:12
int y
Definition: AMReX_Dim3.H:12
Definition: AMReX_FFT_Helper.H:58
bool batch_mode
Definition: AMReX_FFT_Helper.H:62
Info & setBatchMode(bool x)
Definition: AMReX_FFT_Helper.H:67
int nprocs
Max number of processes to use.
Definition: AMReX_FFT_Helper.H:65
Info & setNumProcs(int n)
Definition: AMReX_FFT_Helper.H:68
Definition: AMReX_FFT_Helper.H:113
void * pf
Definition: AMReX_FFT_Helper.H:148
void init_r2c(Box const &box, T *pr, VendorComplex *pc, bool is_2d_transform=false)
Definition: AMReX_FFT_Helper.H:173
void unpack_r2r_buffer(T *pdst, void const *pbuf) const
Definition: AMReX_FFT_Helper.H:912
std::conditional_t< std::is_same_v< float, T >, cuComplex, cuDoubleComplex > VendorComplex
Definition: AMReX_FFT_Helper.H:117
VendorPlan plan2
Definition: AMReX_FFT_Helper.H:147
int n
Definition: AMReX_FFT_Helper.H:140
void destroy()
Definition: AMReX_FFT_Helper.H:158
bool defined2
Definition: AMReX_FFT_Helper.H:145
void init_r2r(Box const &box, VendorComplex *pc, std::pair< Boundary, Boundary > const &bc)
Definition: AMReX_FFT_Helper.H:533
void pack_r2r_buffer(void *pbuf, T const *psrc) const
Definition: AMReX_FFT_Helper.H:680
static void free_scratch_space(void *p)
Definition: AMReX_FFT_Helper.H:678
static void destroy_vendor_plan(VendorPlan plan)
Definition: AMReX_FFT_Helper.H:1116
Kind get_r2r_kind(std::pair< Boundary, Boundary > const &bc)
Definition: AMReX_FFT_Helper.H:411
cufftHandle VendorPlan
Definition: AMReX_FFT_Helper.H:115
Kind kind
Definition: AMReX_FFT_Helper.H:142
int howmany
Definition: AMReX_FFT_Helper.H:141
void init_r2c(IntVectND< M > const &fft_size, void *, void *, bool cache)
Definition: AMReX_FFT_Helper.H:1146
void * pb
Definition: AMReX_FFT_Helper.H:149
void * alloc_scratch_space() const
Definition: AMReX_FFT_Helper.H:664
void compute_r2r()
Definition: AMReX_FFT_Helper.H:1062
void compute_c2c()
Definition: AMReX_FFT_Helper.H:625
bool r2r_data_is_complex
Definition: AMReX_FFT_Helper.H:143
VendorPlan plan
Definition: AMReX_FFT_Helper.H:146
void compute_r2c()
Definition: AMReX_FFT_Helper.H:576
bool defined
Definition: AMReX_FFT_Helper.H:144
void set_ptrs(void *p0, void *p1)
Definition: AMReX_FFT_Helper.H:152
void init_r2r(Box const &box, T *p, std::pair< Boundary, Boundary > const &bc, int howmany_initval=1)
Definition: AMReX_FFT_Helper.H:437
void init_c2c(Box const &box, VendorComplex *p)
Definition: AMReX_FFT_Helper.H:299
Definition: AMReX_FFT_Helper.H:1428
constexpr Dim3 operator()(Dim3 i) const noexcept
Definition: AMReX_FFT_Helper.H:1430
static constexpr IndexType Inverse(IndexType it)
Definition: AMReX_FFT_Helper.H:1446
static constexpr Dim3 Inverse(Dim3 i)
Definition: AMReX_FFT_Helper.H:1436
Definition: AMReX_FFT_Helper.H:1403
static constexpr IndexType Inverse(IndexType it)
Definition: AMReX_FFT_Helper.H:1421
constexpr Dim3 operator()(Dim3 i) const noexcept
Definition: AMReX_FFT_Helper.H:1405
static constexpr Dim3 Inverse(Dim3 i)
Definition: AMReX_FFT_Helper.H:1411
Definition: AMReX_FFT_Helper.H:1357
constexpr Dim3 operator()(Dim3 i) const noexcept
Definition: AMReX_FFT_Helper.H:1358
static constexpr IndexType Inverse(IndexType it)
Definition: AMReX_FFT_Helper.H:1373
static constexpr Dim3 Inverse(Dim3 i)
Definition: AMReX_FFT_Helper.H:1363
Definition: AMReX_FFT_Helper.H:1380
static constexpr Dim3 Inverse(Dim3 i)
Definition: AMReX_FFT_Helper.H:1386
static constexpr IndexType Inverse(IndexType it)
Definition: AMReX_FFT_Helper.H:1396
constexpr Dim3 operator()(Dim3 i) const noexcept
Definition: AMReX_FFT_Helper.H:1381
Definition: AMReX_FFT_Helper.H:1455
SubHelper(Box const &domain)
Definition: AMReX_FFT.cpp:121
T make_array(T const &a) const
Definition: AMReX_FFT_Helper.H:1475
Box make_box(Box const &box) const
Definition: AMReX_FFT.cpp:142
BoxArray inverse_boxarray(BoxArray const &ba) const
Definition: AMReX_FFT.cpp:209
bool ghost_safe(IntVect const &ng) const
Definition: AMReX_FFT.cpp:152
GpuArray< int, 3 > xyz_order() const
Definition: AMReX_FFT.cpp:326
IntVect inverse_order(IntVect const &order) const
Definition: AMReX_FFT.cpp:266
IntVect make_iv(IntVect const &iv) const
Definition: AMReX_FFT.cpp:178
FA make_alias_mf(FA const &mf)
Definition: AMReX_FFT_Helper.H:1504
Periodicity make_periodicity(Periodicity const &period) const
Definition: AMReX_FFT.cpp:147
IntVect make_safe_ghost(IntVect const &ng) const
Definition: AMReX_FFT.cpp:183
A host / device complex number type, because std::complex doesn't work in device code with Cuda yet.
Definition: AMReX_GpuComplex.H:29
FabArray memory allocation information.
Definition: AMReX_FabArray.H:66
MFInfo & SetAlloc(bool a) noexcept
Definition: AMReX_FabArray.H:73
Definition: AMReX_MFIter.H:20