Block-Structured AMR Software Framework
AMReX_FFT_Helper.H
Go to the documentation of this file.
1 #ifndef AMREX_FFT_HELPER_H_
2 #define AMREX_FFT_HELPER_H_
3 #include <AMReX_Config.H>
4 
5 #include <AMReX.H>
6 #include <AMReX_BLProfiler.H>
7 #include <AMReX_DataAllocator.H>
9 #include <AMReX_Enum.H>
10 #include <AMReX_Gpu.H>
11 #include <AMReX_GpuComplex.H>
12 #include <AMReX_Math.H>
13 
14 #if defined(AMREX_USE_CUDA)
15 # include <cufft.h>
16 # include <cuComplex.h>
17 #elif defined(AMREX_USE_HIP)
18 # if __has_include(<rocfft/rocfft.h>) // ROCm 5.3+
19 # include <rocfft/rocfft.h>
20 # else
21 # include <rocfft.h>
22 # endif
23 # include <hip/hip_complex.h>
24 #elif defined(AMREX_USE_SYCL)
25 # if __has_include(<oneapi/mkl/dft.hpp>) // oneAPI 2025.0
26 # include <oneapi/mkl/dft.hpp>
27 #else
28 # define AMREX_USE_MKL_DFTI_2024 1
29 # include <oneapi/mkl/dfti.hpp>
30 # endif
31 #else
32 # include <fftw3.h>
33 #endif
34 
35 #include <algorithm>
36 #include <complex>
37 #include <limits>
38 #include <memory>
39 #include <tuple>
40 #include <utility>
41 #include <variant>
42 
43 namespace amrex::FFT
44 {
45 
46 enum struct Direction { forward, backward, both, none };
47 
48 enum struct DomainStrategy { slab, pencil };
49 
50 AMREX_ENUM( Boundary, periodic, even, odd );
51 
54 
55 struct Info
56 {
60  bool batch_mode = false;
61 
64 
65  Info& setBatchMode (bool x) { batch_mode = x; return *this; }
66  Info& setNumProcs (int n) { nprocs = n; return *this; }
67 };
68 
69 #ifdef AMREX_USE_HIP
70 namespace detail { void hip_execute (rocfft_plan plan, void **in, void **out); }
71 #endif
72 
73 #ifdef AMREX_USE_SYCL
74 namespace detail
75 {
76 template <typename T, Direction direction, typename P, typename TI, typename TO>
77 void sycl_execute (P* plan, TI* in, TO* out)
78 {
79 #ifndef AMREX_USE_MKL_DFTI_2024
80  std::int64_t workspaceSize = 0;
81 #else
82  std::size_t workspaceSize = 0;
83 #endif
84  plan->get_value(oneapi::mkl::dft::config_param::WORKSPACE_BYTES,
85  &workspaceSize);
86  auto* buffer = (T*)amrex::The_Arena()->alloc(workspaceSize);
87  plan->set_workspace(buffer);
88  sycl::event r;
89  if (std::is_same_v<TI,TO>) {
91  if constexpr (direction == Direction::forward) {
92  r = oneapi::mkl::dft::compute_forward(*plan, out);
93  } else {
94  r = oneapi::mkl::dft::compute_backward(*plan, out);
95  }
96  } else {
97  if constexpr (direction == Direction::forward) {
98  r = oneapi::mkl::dft::compute_forward(*plan, in, out);
99  } else {
100  r = oneapi::mkl::dft::compute_backward(*plan, in, out);
101  }
102  }
103  r.wait();
104  amrex::The_Arena()->free(buffer);
105 }
106 }
107 #endif
108 
109 template <typename T>
110 struct Plan
111 {
112 #if defined(AMREX_USE_CUDA)
113  using VendorPlan = cufftHandle;
114  using VendorComplex = std::conditional_t<std::is_same_v<float,T>,
115  cuComplex, cuDoubleComplex>;
116 #elif defined(AMREX_USE_HIP)
117  using VendorPlan = rocfft_plan;
118  using VendorComplex = std::conditional_t<std::is_same_v<float,T>,
119  float2, double2>;
120 #elif defined(AMREX_USE_SYCL)
121  using mkl_desc_r = oneapi::mkl::dft::descriptor<std::is_same_v<float,T>
122  ? oneapi::mkl::dft::precision::SINGLE
123  : oneapi::mkl::dft::precision::DOUBLE,
124  oneapi::mkl::dft::domain::REAL>;
125  using mkl_desc_c = oneapi::mkl::dft::descriptor<std::is_same_v<float,T>
126  ? oneapi::mkl::dft::precision::SINGLE
127  : oneapi::mkl::dft::precision::DOUBLE,
128  oneapi::mkl::dft::domain::COMPLEX>;
129  using VendorPlan = std::variant<mkl_desc_r*,mkl_desc_c*>;
130  using VendorComplex = std::complex<T>;
131 #else
132  using VendorPlan = std::conditional_t<std::is_same_v<float,T>,
133  fftwf_plan, fftw_plan>;
134  using VendorComplex = std::conditional_t<std::is_same_v<float,T>,
135  fftwf_complex, fftw_complex>;
136 #endif
137 
138  int n = 0;
139  int howmany = 0;
141  bool r2r_data_is_complex = false;
142  bool defined = false;
143  bool defined2 = false;
146  void* pf = nullptr;
147  void* pb = nullptr;
148 
149 #ifdef AMREX_USE_GPU
150  void set_ptrs (void* p0, void* p1) {
151  pf = p0;
152  pb = p1;
153  }
154 #endif
155 
156  void destroy ()
157  {
158  if (defined) {
160  defined = false;
161  }
162 #if !defined(AMREX_USE_GPU)
163  if (defined2) {
165  defined2 = false;
166  }
167 #endif
168  }
169 
170  template <Direction D>
171  void init_r2c (Box const& box, T* pr, VendorComplex* pc, bool is_2d_transform = false)
172  {
173  static_assert(D == Direction::forward || D == Direction::backward);
174 
175  int rank = is_2d_transform ? 2 : 1;
176 
178  defined = true;
179  pf = (void*)pr;
180  pb = (void*)pc;
181 
182  int len[2] = {};
183  if (rank == 1) {
184  len[0] = box.length(0);
185  len[1] = box.length(0); // Not used except for HIP. Yes it's `(0)`.
186  } else {
187  len[0] = box.length(1); // Most FFT libraries assume row-major ordering
188  len[1] = box.length(0); // except for rocfft
189  }
190  int nr = (rank == 1) ? len[0] : len[0]*len[1];
191  n = nr;
192  int nc = (rank == 1) ? (len[0]/2+1) : (len[1]/2+1)*len[0];
193 #if (AMREX_SPACEDIM == 1)
194  howmany = 1;
195 #else
196  howmany = (rank == 1) ? AMREX_D_TERM(1, *box.length(1), *box.length(2))
197  : AMREX_D_TERM(1, *1 , *box.length(2));
198 #endif
199 
201 
202 #if defined(AMREX_USE_CUDA)
203 
204  AMREX_CUFFT_SAFE_CALL(cufftCreate(&plan));
205  AMREX_CUFFT_SAFE_CALL(cufftSetAutoAllocation(plan, 0));
206  std::size_t work_size;
207  if constexpr (D == Direction::forward) {
208  cufftType fwd_type = std::is_same_v<float,T> ? CUFFT_R2C : CUFFT_D2Z;
210  (cufftMakePlanMany(plan, rank, len, nullptr, 1, nr, nullptr, 1, nc, fwd_type, howmany, &work_size));
211  } else {
212  cufftType bwd_type = std::is_same_v<float,T> ? CUFFT_C2R : CUFFT_Z2D;
214  (cufftMakePlanMany(plan, rank, len, nullptr, 1, nc, nullptr, 1, nr, bwd_type, howmany, &work_size));
215  }
216 
217 #elif defined(AMREX_USE_HIP)
218 
219  auto prec = std::is_same_v<float,T> ? rocfft_precision_single : rocfft_precision_double;
220  // switch to column-major ordering
221  std::size_t length[2] = {std::size_t(len[1]), std::size_t(len[0])};
222  if constexpr (D == Direction::forward) {
223  AMREX_ROCFFT_SAFE_CALL
224  (rocfft_plan_create(&plan, rocfft_placement_notinplace,
225  rocfft_transform_type_real_forward, prec, rank,
226  length, howmany, nullptr));
227  } else {
228  AMREX_ROCFFT_SAFE_CALL
229  (rocfft_plan_create(&plan, rocfft_placement_notinplace,
230  rocfft_transform_type_real_inverse, prec, rank,
231  length, howmany, nullptr));
232  }
233 
234 #elif defined(AMREX_USE_SYCL)
235 
236  mkl_desc_r* pp;
237  if (rank == 1) {
238  pp = new mkl_desc_r(len[0]);
239  } else {
240  pp = new mkl_desc_r({std::int64_t(len[0]), std::int64_t(len[1])});
241  }
242 #ifndef AMREX_USE_MKL_DFTI_2024
243  pp->set_value(oneapi::mkl::dft::config_param::PLACEMENT,
244  oneapi::mkl::dft::config_value::NOT_INPLACE);
245 #else
246  pp->set_value(oneapi::mkl::dft::config_param::PLACEMENT, DFTI_NOT_INPLACE);
247 #endif
248  pp->set_value(oneapi::mkl::dft::config_param::NUMBER_OF_TRANSFORMS, howmany);
249  pp->set_value(oneapi::mkl::dft::config_param::FWD_DISTANCE, nr);
250  pp->set_value(oneapi::mkl::dft::config_param::BWD_DISTANCE, nc);
251  std::vector<std::int64_t> strides;
252  strides.push_back(0);
253  if (rank == 2) { strides.push_back(len[1]); }
254  strides.push_back(1);
255 #ifndef AMREX_USE_MKL_DFTI_2024
256  pp->set_value(oneapi::mkl::dft::config_param::FWD_STRIDES, strides);
257  // Do not set BWD_STRIDES
258 #else
259  pp->set_value(oneapi::mkl::dft::config_param::FWD_STRIDES, strides.data());
260  // Do not set BWD_STRIDES
261 #endif
262  pp->set_value(oneapi::mkl::dft::config_param::WORKSPACE,
263  oneapi::mkl::dft::config_value::WORKSPACE_EXTERNAL);
264  pp->commit(amrex::Gpu::Device::streamQueue());
265  plan = pp;
266 
267 #else /* FFTW */
268 
269  if constexpr (std::is_same_v<float,T>) {
270  if constexpr (D == Direction::forward) {
271  plan = fftwf_plan_many_dft_r2c
272  (rank, len, howmany, pr, nullptr, 1, nr, pc, nullptr, 1, nc,
273  FFTW_ESTIMATE | FFTW_DESTROY_INPUT);
274  } else {
275  plan = fftwf_plan_many_dft_c2r
276  (rank, len, howmany, pc, nullptr, 1, nc, pr, nullptr, 1, nr,
277  FFTW_ESTIMATE | FFTW_DESTROY_INPUT);
278  }
279  } else {
280  if constexpr (D == Direction::forward) {
281  plan = fftw_plan_many_dft_r2c
282  (rank, len, howmany, pr, nullptr, 1, nr, pc, nullptr, 1, nc,
283  FFTW_ESTIMATE | FFTW_DESTROY_INPUT);
284  } else {
285  plan = fftw_plan_many_dft_c2r
286  (rank, len, howmany, pc, nullptr, 1, nc, pr, nullptr, 1, nr,
287  FFTW_ESTIMATE | FFTW_DESTROY_INPUT);
288  }
289  }
290 #endif
291  }
292 
293  template <Direction D, int M>
294  void init_r2c (IntVectND<M> const& fft_size, void*, void*, bool cache);
295 
296  template <Direction D>
297  void init_c2c (Box const& box, VendorComplex* p)
298  {
299  static_assert(D == Direction::forward || D == Direction::backward);
300 
302  defined = true;
303  pf = (void*)p;
304  pb = (void*)p;
305 
306  n = box.length(0);
307  howmany = AMREX_D_TERM(1, *box.length(1), *box.length(2));
308 
309 #if defined(AMREX_USE_CUDA)
310  AMREX_CUFFT_SAFE_CALL(cufftCreate(&plan));
311  AMREX_CUFFT_SAFE_CALL(cufftSetAutoAllocation(plan, 0));
312 
313  cufftType t = std::is_same_v<float,T> ? CUFFT_C2C : CUFFT_Z2Z;
314  std::size_t work_size;
316  (cufftMakePlanMany(plan, 1, &n, nullptr, 1, n, nullptr, 1, n, t, howmany, &work_size));
317 
318 #elif defined(AMREX_USE_HIP)
319 
320  auto prec = std::is_same_v<float,T> ? rocfft_precision_single
321  : rocfft_precision_double;
322  auto dir= (D == Direction::forward) ? rocfft_transform_type_complex_forward
323  : rocfft_transform_type_complex_inverse;
324  const std::size_t length = n;
325  AMREX_ROCFFT_SAFE_CALL
326  (rocfft_plan_create(&plan, rocfft_placement_inplace, dir, prec, 1,
327  &length, howmany, nullptr));
328 
329 #elif defined(AMREX_USE_SYCL)
330 
331  auto* pp = new mkl_desc_c(n);
332 #ifndef AMREX_USE_MKL_DFTI_2024
333  pp->set_value(oneapi::mkl::dft::config_param::PLACEMENT,
334  oneapi::mkl::dft::config_value::INPLACE);
335 #else
336  pp->set_value(oneapi::mkl::dft::config_param::PLACEMENT, DFTI_INPLACE);
337 #endif
338  pp->set_value(oneapi::mkl::dft::config_param::NUMBER_OF_TRANSFORMS, howmany);
339  pp->set_value(oneapi::mkl::dft::config_param::FWD_DISTANCE, n);
340  pp->set_value(oneapi::mkl::dft::config_param::BWD_DISTANCE, n);
341  std::vector<std::int64_t> strides = {0,1};
342 #ifndef AMREX_USE_MKL_DFTI_2024
343  pp->set_value(oneapi::mkl::dft::config_param::FWD_STRIDES, strides);
344  pp->set_value(oneapi::mkl::dft::config_param::BWD_STRIDES, strides);
345 #else
346  pp->set_value(oneapi::mkl::dft::config_param::FWD_STRIDES, strides.data());
347  pp->set_value(oneapi::mkl::dft::config_param::BWD_STRIDES, strides.data());
348 #endif
349  pp->set_value(oneapi::mkl::dft::config_param::WORKSPACE,
350  oneapi::mkl::dft::config_value::WORKSPACE_EXTERNAL);
351  pp->commit(amrex::Gpu::Device::streamQueue());
352  plan = pp;
353 
354 #else /* FFTW */
355 
356  if constexpr (std::is_same_v<float,T>) {
357  if constexpr (D == Direction::forward) {
358  plan = fftwf_plan_many_dft
359  (1, &n, howmany, p, nullptr, 1, n, p, nullptr, 1, n, -1,
360  FFTW_ESTIMATE);
361  } else {
362  plan = fftwf_plan_many_dft
363  (1, &n, howmany, p, nullptr, 1, n, p, nullptr, 1, n, +1,
364  FFTW_ESTIMATE);
365  }
366  } else {
367  if constexpr (D == Direction::forward) {
368  plan = fftw_plan_many_dft
369  (1, &n, howmany, p, nullptr, 1, n, p, nullptr, 1, n, -1,
370  FFTW_ESTIMATE);
371  } else {
372  plan = fftw_plan_many_dft
373  (1, &n, howmany, p, nullptr, 1, n, p, nullptr, 1, n, +1,
374  FFTW_ESTIMATE);
375  }
376  }
377 #endif
378  }
379 
380 #ifndef AMREX_USE_GPU
381  template <Direction D>
382  fftw_r2r_kind get_fftw_kind (std::pair<Boundary,Boundary> const& bc)
383  {
384  if (bc.first == Boundary::even && bc.second == Boundary::even)
385  {
386  return (D == Direction::forward) ? FFTW_REDFT10 : FFTW_REDFT01;
387  }
388  else if (bc.first == Boundary::even && bc.second == Boundary::odd)
389  {
390  return FFTW_REDFT11;
391  }
392  else if (bc.first == Boundary::odd && bc.second == Boundary::even)
393  {
394  return FFTW_RODFT11;
395  }
396  else if (bc.first == Boundary::odd && bc.second == Boundary::odd)
397  {
398  return (D == Direction::forward) ? FFTW_RODFT10 : FFTW_RODFT01;
399  }
400  else {
401  amrex::Abort("FFT: unsupported BC");
402  return fftw_r2r_kind{};
403  }
404 
405  }
406 #endif
407 
408  template <Direction D>
409  Kind get_r2r_kind (std::pair<Boundary,Boundary> const& bc)
410  {
411  if (bc.first == Boundary::even && bc.second == Boundary::even)
412  {
414  }
415  else if (bc.first == Boundary::even && bc.second == Boundary::odd)
416  {
417  return Kind::r2r_eo;
418  }
419  else if (bc.first == Boundary::odd && bc.second == Boundary::even)
420  {
421  return Kind::r2r_oe;
422  }
423  else if (bc.first == Boundary::odd && bc.second == Boundary::odd)
424  {
426  }
427  else {
428  amrex::Abort("FFT: unsupported BC");
429  return Kind::none;
430  }
431 
432  }
433 
434  template <Direction D>
435  void init_r2r (Box const& box, T* p, std::pair<Boundary,Boundary> const& bc,
436  int howmany_initval = 1)
437  {
438  static_assert(D == Direction::forward || D == Direction::backward);
439 
440  kind = get_r2r_kind<D>(bc);
441  defined = true;
442  pf = (void*)p;
443  pb = (void*)p;
444 
445  n = box.length(0);
446  howmany = AMREX_D_TERM(howmany_initval, *box.length(1), *box.length(2));
447 
448 #if defined(AMREX_USE_GPU)
449  int nex=0;
450  if (bc.first == Boundary::odd && bc.second == Boundary::odd &&
451  Direction::forward == D) {
452  nex = 2*n;
453  } else if (bc.first == Boundary::odd && bc.second == Boundary::odd &&
454  Direction::backward == D) {
455  nex = 4*n;
456  } else if (bc.first == Boundary::even && bc.second == Boundary::even &&
457  Direction::forward == D) {
458  nex = 2*n;
459  } else if (bc.first == Boundary::even && bc.second == Boundary::even &&
460  Direction::backward == D) {
461  nex = 4*n;
462  } else if ((bc.first == Boundary::even && bc.second == Boundary::odd) ||
463  (bc.first == Boundary::odd && bc.second == Boundary::even)) {
464  nex = 4*n;
465  } else {
466  amrex::Abort("FFT: unsupported BC");
467  }
468  int nc = (nex/2) + 1;
469 
470 #if defined (AMREX_USE_CUDA)
471 
472  AMREX_CUFFT_SAFE_CALL(cufftCreate(&plan));
473  AMREX_CUFFT_SAFE_CALL(cufftSetAutoAllocation(plan, 0));
474  cufftType fwd_type = std::is_same_v<float,T> ? CUFFT_R2C : CUFFT_D2Z;
475  std::size_t work_size;
477  (cufftMakePlanMany(plan, 1, &nex, nullptr, 1, nc*2, nullptr, 1, nc, fwd_type, howmany, &work_size));
478 
479 #elif defined(AMREX_USE_HIP)
480 
482  auto prec = std::is_same_v<float,T> ? rocfft_precision_single : rocfft_precision_double;
483  const std::size_t length = nex;
484  AMREX_ROCFFT_SAFE_CALL
485  (rocfft_plan_create(&plan, rocfft_placement_inplace,
486  rocfft_transform_type_real_forward, prec, 1,
487  &length, howmany, nullptr));
488 
489 #elif defined(AMREX_USE_SYCL)
490 
491  auto* pp = new mkl_desc_r(nex);
492 #ifndef AMREX_USE_MKL_DFTI_2024
493  pp->set_value(oneapi::mkl::dft::config_param::PLACEMENT,
494  oneapi::mkl::dft::config_value::INPLACE);
495 #else
496  pp->set_value(oneapi::mkl::dft::config_param::PLACEMENT, DFTI_INPLACE);
497 #endif
498  pp->set_value(oneapi::mkl::dft::config_param::NUMBER_OF_TRANSFORMS, howmany);
499  pp->set_value(oneapi::mkl::dft::config_param::FWD_DISTANCE, nc*2);
500  pp->set_value(oneapi::mkl::dft::config_param::BWD_DISTANCE, nc);
501  std::vector<std::int64_t> strides = {0,1};
502 #ifndef AMREX_USE_MKL_DFTI_2024
503  pp->set_value(oneapi::mkl::dft::config_param::FWD_STRIDES, strides);
504  pp->set_value(oneapi::mkl::dft::config_param::BWD_STRIDES, strides);
505 #else
506  pp->set_value(oneapi::mkl::dft::config_param::FWD_STRIDES, strides.data());
507  pp->set_value(oneapi::mkl::dft::config_param::BWD_STRIDES, strides.data());
508 #endif
509  pp->set_value(oneapi::mkl::dft::config_param::WORKSPACE,
510  oneapi::mkl::dft::config_value::WORKSPACE_EXTERNAL);
511  pp->commit(amrex::Gpu::Device::streamQueue());
512  plan = pp;
513 
514 #endif
515 
516 #else /* FFTW */
517  auto fftw_kind = get_fftw_kind<D>(bc);
518  if constexpr (std::is_same_v<float,T>) {
519  plan = fftwf_plan_many_r2r
520  (1, &n, howmany, p, nullptr, 1, n, p, nullptr, 1, n, &fftw_kind,
521  FFTW_ESTIMATE);
522  } else {
523  plan = fftw_plan_many_r2r
524  (1, &n, howmany, p, nullptr, 1, n, p, nullptr, 1, n, &fftw_kind,
525  FFTW_ESTIMATE);
526  }
527 #endif
528  }
529 
530  template <Direction D>
531  void init_r2r (Box const& box, VendorComplex* pc,
532  std::pair<Boundary,Boundary> const& bc)
533  {
534  static_assert(D == Direction::forward || D == Direction::backward);
535 
536  auto* p = (T*)pc;
537 
538 #if defined(AMREX_USE_GPU)
539 
540  init_r2r<D>(box, p, bc, 2);
541  r2r_data_is_complex = true;
542 
543 #else
544 
545  kind = get_r2r_kind<D>(bc);
546  defined = true;
547  pf = (void*)p;
548  pb = (void*)p;
549 
550  n = box.length(0);
551  howmany = AMREX_D_TERM(1, *box.length(1), *box.length(2));
552 
553  defined2 = true;
554  auto fftw_kind = get_fftw_kind<D>(bc);
555  if constexpr (std::is_same_v<float,T>) {
556  plan = fftwf_plan_many_r2r
557  (1, &n, howmany, p, nullptr, 2, n*2, p, nullptr, 2, n*2, &fftw_kind,
558  FFTW_ESTIMATE);
559  plan2 = fftwf_plan_many_r2r
560  (1, &n, howmany, p+1, nullptr, 2, n*2, p+1, nullptr, 2, n*2, &fftw_kind,
561  FFTW_ESTIMATE);
562  } else {
563  plan = fftw_plan_many_r2r
564  (1, &n, howmany, p, nullptr, 2, n*2, p, nullptr, 2, n*2, &fftw_kind,
565  FFTW_ESTIMATE);
566  plan2 = fftw_plan_many_r2r
567  (1, &n, howmany, p+1, nullptr, 2, n*2, p+1, nullptr, 2, n*2, &fftw_kind,
568  FFTW_ESTIMATE);
569  }
570 #endif
571  }
572 
573  template <Direction D>
574  void compute_r2c ()
575  {
576  static_assert(D == Direction::forward || D == Direction::backward);
577  if (!defined) { return; }
578 
579  using TI = std::conditional_t<(D == Direction::forward), T, VendorComplex>;
580  using TO = std::conditional_t<(D == Direction::backward), T, VendorComplex>;
581  auto* pi = (TI*)((D == Direction::forward) ? pf : pb);
582  auto* po = (TO*)((D == Direction::forward) ? pb : pf);
583 
584 #if defined(AMREX_USE_CUDA)
585  AMREX_CUFFT_SAFE_CALL(cufftSetStream(plan, Gpu::gpuStream()));
586 
587  std::size_t work_size = 0;
588  AMREX_CUFFT_SAFE_CALL(cufftGetSize(plan, &work_size));
589 
590  auto* work_area = The_Arena()->alloc(work_size);
591  AMREX_CUFFT_SAFE_CALL(cufftSetWorkArea(plan, work_area));
592 
593  if constexpr (D == Direction::forward) {
594  if constexpr (std::is_same_v<float,T>) {
595  AMREX_CUFFT_SAFE_CALL(cufftExecR2C(plan, pi, po));
596  } else {
597  AMREX_CUFFT_SAFE_CALL(cufftExecD2Z(plan, pi, po));
598  }
599  } else {
600  if constexpr (std::is_same_v<float,T>) {
601  AMREX_CUFFT_SAFE_CALL(cufftExecC2R(plan, pi, po));
602  } else {
603  AMREX_CUFFT_SAFE_CALL(cufftExecZ2D(plan, pi, po));
604  }
605  }
607  The_Arena()->free(work_area);
608 #elif defined(AMREX_USE_HIP)
609  detail::hip_execute(plan, (void**)&pi, (void**)&po);
610 #elif defined(AMREX_USE_SYCL)
611  detail::sycl_execute<T,D>(std::get<0>(plan), pi, po);
612 #else
614  if constexpr (std::is_same_v<float,T>) {
615  fftwf_execute(plan);
616  } else {
617  fftw_execute(plan);
618  }
619 #endif
620  }
621 
622  template <Direction D>
623  void compute_c2c ()
624  {
625  static_assert(D == Direction::forward || D == Direction::backward);
626  if (!defined) { return; }
627 
628  auto* p = (VendorComplex*)pf;
629 
630 #if defined(AMREX_USE_CUDA)
631  AMREX_CUFFT_SAFE_CALL(cufftSetStream(plan, Gpu::gpuStream()));
632 
633  std::size_t work_size = 0;
634  AMREX_CUFFT_SAFE_CALL(cufftGetSize(plan, &work_size));
635 
636  auto* work_area = The_Arena()->alloc(work_size);
637  AMREX_CUFFT_SAFE_CALL(cufftSetWorkArea(plan, work_area));
638 
639  auto dir = (D == Direction::forward) ? CUFFT_FORWARD : CUFFT_INVERSE;
640  if constexpr (std::is_same_v<float,T>) {
641  AMREX_CUFFT_SAFE_CALL(cufftExecC2C(plan, p, p, dir));
642  } else {
643  AMREX_CUFFT_SAFE_CALL(cufftExecZ2Z(plan, p, p, dir));
644  }
646  The_Arena()->free(work_area);
647 #elif defined(AMREX_USE_HIP)
648  detail::hip_execute(plan, (void**)&p, (void**)&p);
649 #elif defined(AMREX_USE_SYCL)
650  detail::sycl_execute<T,D>(std::get<1>(plan), p, p);
651 #else
653  if constexpr (std::is_same_v<float,T>) {
654  fftwf_execute(plan);
655  } else {
656  fftw_execute(plan);
657  }
658 #endif
659  }
660 
661 #ifdef AMREX_USE_GPU
662  [[nodiscard]] void* alloc_scratch_space () const
663  {
664  int nc = 0;
665  if (kind == Kind::r2r_oo_f || kind == Kind::r2r_ee_f) {
666  nc = n + 1;
667  } else if (kind == Kind::r2r_oo_b || kind == Kind::r2r_ee_b ||
668  kind == Kind::r2r_oe || kind == Kind::r2r_eo) {
669  nc = 2*n+1;
670  } else {
671  amrex::Abort("FFT: alloc_scratch_space: unsupported kind");
672  }
673  return The_Arena()->alloc(sizeof(GpuComplex<T>)*nc*howmany);
674  }
675 
676  static void free_scratch_space (void* p) { The_Arena()->free(p); }
677 
678  void pack_r2r_buffer (void* pbuf, T const* psrc) const
679  {
680  auto* pdst = (T*) pbuf;
681  if (kind == Kind::r2r_oo_f || kind == Kind::r2r_ee_f) {
682  T sign = (kind == Kind::r2r_oo_f) ? T(-1) : T(1);
683  int ostride = (n+1)*2;
684  int istride = n;
685  int nex = 2*n;
686  int norig = n;
687  Long nelems = Long(nex)*howmany;
688  if (r2r_data_is_complex) {
689  ParallelFor(nelems/2, [=] AMREX_GPU_DEVICE (Long ielem)
690  {
691  auto batch = ielem / Long(nex);
692  auto i = int(ielem - batch*nex);
693  for (int ir = 0; ir < 2; ++ir) {
694  auto* po = pdst + (2*batch+ir)*ostride + i;
695  auto const* pi = psrc + 2*batch*istride + ir;
696  if (i < norig) {
697  *po = pi[i*2];
698  } else {
699  *po = sign * pi[(2*norig-1-i)*2];
700  }
701  }
702  });
703  } else {
704  ParallelFor(nelems, [=] AMREX_GPU_DEVICE (Long ielem)
705  {
706  auto batch = ielem / Long(nex);
707  auto i = int(ielem - batch*nex);
708  auto* po = pdst + batch*ostride + i;
709  auto const* pi = psrc + batch*istride;
710  if (i < norig) {
711  *po = pi[i];
712  } else {
713  *po = sign * pi[2*norig-1-i];
714  }
715  });
716  }
717  } else if (kind == Kind::r2r_oo_b) {
718  int ostride = (2*n+1)*2;
719  int istride = n;
720  int nex = 4*n;
721  int norig = n;
722  Long nelems = Long(nex)*howmany;
723  if (r2r_data_is_complex) {
724  ParallelFor(nelems/2, [=] AMREX_GPU_DEVICE (Long ielem)
725  {
726  auto batch = ielem / Long(nex);
727  auto i = int(ielem - batch*nex);
728  for (int ir = 0; ir < 2; ++ir) {
729  auto* po = pdst + (2*batch+ir)*ostride + i;
730  auto const* pi = psrc + 2*batch*istride + ir;
731  if (i < norig) {
732  *po = pi[i*2];
733  } else if (i < (2*norig-1)) {
734  *po = pi[(2*norig-2-i)*2];
735  } else if (i == (2*norig-1)) {
736  *po = T(0);
737  } else if (i < (3*norig)) {
738  *po = -pi[(i-2*norig)*2];
739  } else if (i < (4*norig-1)) {
740  *po = -pi[(4*norig-2-i)*2];
741  } else {
742  *po = T(0);
743  }
744  }
745  });
746  } else {
747  ParallelFor(nelems, [=] AMREX_GPU_DEVICE (Long ielem)
748  {
749  auto batch = ielem / Long(nex);
750  auto i = int(ielem - batch*nex);
751  auto* po = pdst + batch*ostride + i;
752  auto const* pi = psrc + batch*istride;
753  if (i < norig) {
754  *po = pi[i];
755  } else if (i < (2*norig-1)) {
756  *po = pi[2*norig-2-i];
757  } else if (i == (2*norig-1)) {
758  *po = T(0);
759  } else if (i < (3*norig)) {
760  *po = -pi[i-2*norig];
761  } else if (i < (4*norig-1)) {
762  *po = -pi[4*norig-2-i];
763  } else {
764  *po = T(0);
765  }
766  });
767  }
768  } else if (kind == Kind::r2r_ee_b) {
769  int ostride = (2*n+1)*2;
770  int istride = n;
771  int nex = 4*n;
772  int norig = n;
773  Long nelems = Long(nex)*howmany;
774  if (r2r_data_is_complex) {
775  ParallelFor(nelems/2, [=] AMREX_GPU_DEVICE (Long ielem)
776  {
777  auto batch = ielem / Long(nex);
778  auto i = int(ielem - batch*nex);
779  for (int ir = 0; ir < 2; ++ir) {
780  auto* po = pdst + (2*batch+ir)*ostride + i;
781  auto const* pi = psrc + 2*batch*istride + ir;
782  if (i < norig) {
783  *po = pi[i*2];
784  } else if (i == norig) {
785  *po = T(0);
786  } else if (i < (2*norig+1)) {
787  *po = -pi[(2*norig-i)*2];
788  } else if (i < (3*norig)) {
789  *po = -pi[(i-2*norig)*2];
790  } else if (i == 3*norig) {
791  *po = T(0);
792  } else {
793  *po = pi[(4*norig-i)*2];
794  }
795  }
796  });
797  } else {
798  ParallelFor(nelems, [=] AMREX_GPU_DEVICE (Long ielem)
799  {
800  auto batch = ielem / Long(nex);
801  auto i = int(ielem - batch*nex);
802  auto* po = pdst + batch*ostride + i;
803  auto const* pi = psrc + batch*istride;
804  if (i < norig) {
805  *po = pi[i];
806  } else if (i == norig) {
807  *po = T(0);
808  } else if (i < (2*norig+1)) {
809  *po = -pi[2*norig-i];
810  } else if (i < (3*norig)) {
811  *po = -pi[i-2*norig];
812  } else if (i == 3*norig) {
813  *po = T(0);
814  } else {
815  *po = pi[4*norig-i];
816  }
817  });
818  }
819  } else if (kind == Kind::r2r_eo) {
820  int ostride = (2*n+1)*2;
821  int istride = n;
822  int nex = 4*n;
823  int norig = n;
824  Long nelems = Long(nex)*howmany;
825  if (r2r_data_is_complex) {
826  ParallelFor(nelems/2, [=] AMREX_GPU_DEVICE (Long ielem)
827  {
828  auto batch = ielem / Long(nex);
829  auto i = int(ielem - batch*nex);
830  for (int ir = 0; ir < 2; ++ir) {
831  auto* po = pdst + (2*batch+ir)*ostride + i;
832  auto const* pi = psrc + 2*batch*istride + ir;
833  if (i < norig) {
834  *po = pi[i*2];
835  } else if (i < (2*norig)) {
836  *po = -pi[(2*norig-1-i)*2];
837  } else if (i < (3*norig)) {
838  *po = -pi[(i-2*norig)*2];
839  } else {
840  *po = pi[(4*norig-1-i)*2];
841  }
842  }
843  });
844  } else {
845  ParallelFor(nelems, [=] AMREX_GPU_DEVICE (Long ielem)
846  {
847  auto batch = ielem / Long(nex);
848  auto i = int(ielem - batch*nex);
849  auto* po = pdst + batch*ostride + i;
850  auto const* pi = psrc + batch*istride;
851  if (i < norig) {
852  *po = pi[i];
853  } else if (i < (2*norig)) {
854  *po = -pi[2*norig-1-i];
855  } else if (i < (3*norig)) {
856  *po = -pi[i-2*norig];
857  } else {
858  *po = pi[4*norig-1-i];
859  }
860  });
861  }
862  } else if (kind == Kind::r2r_oe) {
863  int ostride = (2*n+1)*2;
864  int istride = n;
865  int nex = 4*n;
866  int norig = n;
867  Long nelems = Long(nex)*howmany;
868  if (r2r_data_is_complex) {
869  ParallelFor(nelems/2, [=] AMREX_GPU_DEVICE (Long ielem)
870  {
871  auto batch = ielem / Long(nex);
872  auto i = int(ielem - batch*nex);
873  for (int ir = 0; ir < 2; ++ir) {
874  auto* po = pdst + (2*batch+ir)*ostride + i;
875  auto const* pi = psrc + 2*batch*istride + ir;
876  if (i < norig) {
877  *po = pi[i*2];
878  } else if (i < (2*norig)) {
879  *po = pi[(2*norig-1-i)*2];
880  } else if (i < (3*norig)) {
881  *po = -pi[(i-2*norig)*2];
882  } else {
883  *po = -pi[(4*norig-1-i)*2];
884  }
885  }
886  });
887  } else {
888  ParallelFor(nelems, [=] AMREX_GPU_DEVICE (Long ielem)
889  {
890  auto batch = ielem / Long(nex);
891  auto i = int(ielem - batch*nex);
892  auto* po = pdst + batch*ostride + i;
893  auto const* pi = psrc + batch*istride;
894  if (i < norig) {
895  *po = pi[i];
896  } else if (i < (2*norig)) {
897  *po = pi[2*norig-1-i];
898  } else if (i < (3*norig)) {
899  *po = -pi[i-2*norig];
900  } else {
901  *po = -pi[4*norig-1-i];
902  }
903  });
904  }
905  } else {
906  amrex::Abort("FFT: pack_r2r_buffer: unsupported kind");
907  }
908  }
909 
910  void unpack_r2r_buffer (T* pdst, void const* pbuf) const
911  {
912  auto const* psrc = (GpuComplex<T> const*) pbuf;
913  int norig = n;
914  Long nelems = Long(norig)*howmany;
915  int ostride = n;
916 
917  if (kind == Kind::r2r_oo_f) {
918  int istride = n+1;
919  if (r2r_data_is_complex) {
920  ParallelFor(nelems/2, [=] AMREX_GPU_DEVICE (Long ielem)
921  {
922  auto batch = ielem / Long(norig);
923  auto k = int(ielem - batch*norig);
924  auto [s, c] = Math::sincospi(T(k+1)/T(2*norig));
925  for (int ir = 0; ir < 2; ++ir) {
926  auto const& yk = psrc[(2*batch+ir)*istride+k+1];
927  pdst[2*batch*ostride+ir+k*2] = s * yk.real() - c * yk.imag();
928  }
929  });
930  } else {
931  ParallelFor(nelems, [=] AMREX_GPU_DEVICE (Long ielem)
932  {
933  auto batch = ielem / Long(norig);
934  auto k = int(ielem - batch*norig);
935  auto [s, c] = Math::sincospi(T(k+1)/T(2*norig));
936  auto const& yk = psrc[batch*istride+k+1];
937  pdst[batch*ostride+k] = s * yk.real() - c * yk.imag();
938  });
939  }
940  } else if (kind == Kind::r2r_oo_b) {
941  int istride = 2*n+1;
942  if (r2r_data_is_complex) {
943  ParallelFor(nelems/2, [=] AMREX_GPU_DEVICE (Long ielem)
944  {
945  auto batch = ielem / Long(norig);
946  auto k = int(ielem - batch*norig);
947  auto [s, c] = Math::sincospi(T(2*k+1)/T(2*norig));
948  for (int ir = 0; ir < 2; ++ir) {
949  auto const& yk = psrc[(2*batch+ir)*istride+2*k+1];
950  pdst[2*batch*ostride+ir+k*2] = T(0.5)*(s * yk.real() - c * yk.imag());
951  }
952  });
953  } else {
954  ParallelFor(nelems, [=] AMREX_GPU_DEVICE (Long ielem)
955  {
956  auto batch = ielem / Long(norig);
957  auto k = int(ielem - batch*norig);
958  auto [s, c] = Math::sincospi(T(2*k+1)/T(2*norig));
959  auto const& yk = psrc[batch*istride+2*k+1];
960  pdst[batch*ostride+k] = T(0.5)*(s * yk.real() - c * yk.imag());
961  });
962  }
963  } else if (kind == Kind::r2r_ee_f) {
964  int istride = n+1;
965  if (r2r_data_is_complex) {
966  ParallelFor(nelems/2, [=] AMREX_GPU_DEVICE (Long ielem)
967  {
968  auto batch = ielem / Long(norig);
969  auto k = int(ielem - batch*norig);
970  auto [s, c] = Math::sincospi(T(k)/T(2*norig));
971  for (int ir = 0; ir < 2; ++ir) {
972  auto const& yk = psrc[(2*batch+ir)*istride+k];
973  pdst[2*batch*ostride+ir+k*2] = c * yk.real() + s * yk.imag();
974  }
975  });
976  } else {
977  ParallelFor(nelems, [=] AMREX_GPU_DEVICE (Long ielem)
978  {
979  auto batch = ielem / Long(norig);
980  auto k = int(ielem - batch*norig);
981  auto [s, c] = Math::sincospi(T(k)/T(2*norig));
982  auto const& yk = psrc[batch*istride+k];
983  pdst[batch*ostride+k] = c * yk.real() + s * yk.imag();
984  });
985  }
986  } else if (kind == Kind::r2r_ee_b) {
987  int istride = 2*n+1;
988  if (r2r_data_is_complex) {
989  ParallelFor(nelems/2, [=] AMREX_GPU_DEVICE (Long ielem)
990  {
991  auto batch = ielem / Long(norig);
992  auto k = int(ielem - batch*norig);
993  for (int ir = 0; ir < 2; ++ir) {
994  auto const& yk = psrc[(2*batch+ir)*istride+2*k+1];
995  pdst[2*batch*ostride+ir+k*2] = T(0.5) * yk.real();
996  }
997  });
998  } else {
999  ParallelFor(nelems, [=] AMREX_GPU_DEVICE (Long ielem)
1000  {
1001  auto batch = ielem / Long(norig);
1002  auto k = int(ielem - batch*norig);
1003  auto const& yk = psrc[batch*istride+2*k+1];
1004  pdst[batch*ostride+k] = T(0.5) * yk.real();
1005  });
1006  }
1007  } else if (kind == Kind::r2r_eo) {
1008  int istride = 2*n+1;
1009  if (r2r_data_is_complex) {
1010  ParallelFor(nelems/2, [=] AMREX_GPU_DEVICE (Long ielem)
1011  {
1012  auto batch = ielem / Long(norig);
1013  auto k = int(ielem - batch*norig);
1014  auto [s, c] = Math::sincospi((k+T(0.5))/T(2*norig));
1015  for (int ir = 0; ir < 2; ++ir) {
1016  auto const& yk = psrc[(2*batch+ir)*istride+2*k+1];
1017  pdst[2*batch*ostride+ir+k*2] = T(0.5) * (c * yk.real() + s * yk.imag());
1018  }
1019  });
1020  } else {
1021  ParallelFor(nelems, [=] AMREX_GPU_DEVICE (Long ielem)
1022  {
1023  auto batch = ielem / Long(norig);
1024  auto k = int(ielem - batch*norig);
1025  auto [s, c] = Math::sincospi((k+T(0.5))/T(2*norig));
1026  auto const& yk = psrc[batch*istride+2*k+1];
1027  pdst[batch*ostride+k] = T(0.5) * (c * yk.real() + s * yk.imag());
1028  });
1029  }
1030  } else if (kind == Kind::r2r_oe) {
1031  int istride = 2*n+1;
1032  if (r2r_data_is_complex) {
1033  ParallelFor(nelems/2, [=] AMREX_GPU_DEVICE (Long ielem)
1034  {
1035  auto batch = ielem / Long(norig);
1036  auto k = int(ielem - batch*norig);
1037  auto [s, c] = Math::sincospi((k+T(0.5))/T(2*norig));
1038  for (int ir = 0; ir < 2; ++ir) {
1039  auto const& yk = psrc[(2*batch+ir)*istride+2*k+1];
1040  pdst[2*batch*ostride+ir+k*2] = T(0.5) * (s * yk.real() - c * yk.imag());
1041  }
1042  });
1043  } else {
1044  ParallelFor(nelems, [=] AMREX_GPU_DEVICE (Long ielem)
1045  {
1046  auto batch = ielem / Long(norig);
1047  auto k = int(ielem - batch*norig);
1048  auto [s, c] = Math::sincospi((k+T(0.5))/T(2*norig));
1049  auto const& yk = psrc[batch*istride+2*k+1];
1050  pdst[batch*ostride+k] = T(0.5) * (s * yk.real() - c * yk.imag());
1051  });
1052  }
1053  } else {
1054  amrex::Abort("FFT: unpack_r2r_buffer: unsupported kind");
1055  }
1056  }
1057 #endif
1058 
1059  template <Direction D>
1060  void compute_r2r ()
1061  {
1062  static_assert(D == Direction::forward || D == Direction::backward);
1063  if (!defined) { return; }
1064 
1065 #if defined(AMREX_USE_GPU)
1066 
1067  auto* pscratch = alloc_scratch_space();
1068 
1069  pack_r2r_buffer(pscratch, (T*)((D == Direction::forward) ? pf : pb));
1070 
1071 #if defined(AMREX_USE_CUDA)
1072 
1073  AMREX_CUFFT_SAFE_CALL(cufftSetStream(plan, Gpu::gpuStream()));
1074 
1075  std::size_t work_size = 0;
1076  AMREX_CUFFT_SAFE_CALL(cufftGetSize(plan, &work_size));
1077 
1078  auto* work_area = The_Arena()->alloc(work_size);
1079  AMREX_CUFFT_SAFE_CALL(cufftSetWorkArea(plan, work_area));
1080 
1081  if constexpr (std::is_same_v<float,T>) {
1082  AMREX_CUFFT_SAFE_CALL(cufftExecR2C(plan, (T*)pscratch, (VendorComplex*)pscratch));
1083  } else {
1084  AMREX_CUFFT_SAFE_CALL(cufftExecD2Z(plan, (T*)pscratch, (VendorComplex*)pscratch));
1085  }
1086 
1087 #elif defined(AMREX_USE_HIP)
1088  detail::hip_execute(plan, (void**)&pscratch, (void**)&pscratch);
1089 #elif defined(AMREX_USE_SYCL)
1090  detail::sycl_execute<T,Direction::forward>(std::get<0>(plan), (T*)pscratch, (VendorComplex*)pscratch);
1091 #endif
1092 
1093  unpack_r2r_buffer((T*)((D == Direction::forward) ? pb : pf), pscratch);
1094 
1096  free_scratch_space(pscratch);
1097 #if defined(AMREX_USE_CUDA)
1098  The_Arena()->free(work_area);
1099 #endif
1100 
1101 #else /* FFTW */
1102 
1103  if constexpr (std::is_same_v<float,T>) {
1104  fftwf_execute(plan);
1105  if (defined2) { fftwf_execute(plan2); }
1106  } else {
1107  fftw_execute(plan);
1108  if (defined2) { fftw_execute(plan2); }
1109  }
1110 
1111 #endif
1112  }
1113 
1115  {
1116 #if defined(AMREX_USE_CUDA)
1117  AMREX_CUFFT_SAFE_CALL(cufftDestroy(plan));
1118 #elif defined(AMREX_USE_HIP)
1119  AMREX_ROCFFT_SAFE_CALL(rocfft_plan_destroy(plan));
1120 #elif defined(AMREX_USE_SYCL)
1121  std::visit([](auto&& p) { delete p; }, plan);
1122 #else
1123  if constexpr (std::is_same_v<float,T>) {
1124  fftwf_destroy_plan(plan);
1125  } else {
1126  fftw_destroy_plan(plan);
1127  }
1128 #endif
1129  }
1130 };
1131 
1132 using Key = std::tuple<IntVectND<3>,Direction,Kind>;
1135 
1136 PlanD* get_vendor_plan_d (Key const& key);
1137 PlanF* get_vendor_plan_f (Key const& key);
1138 
1139 void add_vendor_plan_d (Key const& key, PlanD plan);
1140 void add_vendor_plan_f (Key const& key, PlanF plan);
1141 
1142 template <typename T>
1143 template <Direction D, int M>
1144 void Plan<T>::init_r2c (IntVectND<M> const& fft_size, void* pbf, void* pbb, bool cache)
1145 {
1146  static_assert(D == Direction::forward || D == Direction::backward);
1147 
1148  kind = (D == Direction::forward) ? Kind::r2c_f : Kind::r2c_b;
1149  defined = true;
1150  pf = pbf;
1151  pb = pbb;
1152 
1153  n = 1;
1154  for (auto s : fft_size) { n *= s; }
1155  howmany = 1;
1156 
1157 #if defined(AMREX_USE_GPU)
1158  Key key = {fft_size.template expand<3>(), D, kind};
1159  if (cache) {
1160  VendorPlan* cached_plan = nullptr;
1161  if constexpr (std::is_same_v<float,T>) {
1162  cached_plan = get_vendor_plan_f(key);
1163  } else {
1164  cached_plan = get_vendor_plan_d(key);
1165  }
1166  if (cached_plan) {
1167  plan = *cached_plan;
1168  return;
1169  }
1170  }
1171 #else
1172  amrex::ignore_unused(cache);
1173 #endif
1174 
1175 #if defined(AMREX_USE_CUDA)
1176 
1177  AMREX_CUFFT_SAFE_CALL(cufftCreate(&plan));
1178  AMREX_CUFFT_SAFE_CALL(cufftSetAutoAllocation(plan, 0));
1179  cufftType type;
1180  if constexpr (D == Direction::forward) {
1181  type = std::is_same_v<float,T> ? CUFFT_R2C : CUFFT_D2Z;
1182  } else {
1183  type = std::is_same_v<float,T> ? CUFFT_C2R : CUFFT_Z2D;
1184  }
1185  std::size_t work_size;
1186  if constexpr (M == 1) {
1188  (cufftMakePlan1d(plan, fft_size[0], type, howmany, &work_size));
1189  } else if constexpr (M == 2) {
1191  (cufftMakePlan2d(plan, fft_size[1], fft_size[0], type, &work_size));
1192  } else if constexpr (M == 3) {
1194  (cufftMakePlan3d(plan, fft_size[2], fft_size[1], fft_size[0], type, &work_size));
1195  }
1196 
1197 #elif defined(AMREX_USE_HIP)
1198 
1199  auto prec = std::is_same_v<float,T> ? rocfft_precision_single : rocfft_precision_double;
1200  std::size_t length[M];
1201  for (int idim = 0; idim < M; ++idim) { length[idim] = fft_size[idim]; }
1202  if constexpr (D == Direction::forward) {
1203  AMREX_ROCFFT_SAFE_CALL
1204  (rocfft_plan_create(&plan, rocfft_placement_notinplace,
1205  rocfft_transform_type_real_forward, prec, M,
1206  length, howmany, nullptr));
1207  } else {
1208  AMREX_ROCFFT_SAFE_CALL
1209  (rocfft_plan_create(&plan, rocfft_placement_notinplace,
1210  rocfft_transform_type_real_inverse, prec, M,
1211  length, howmany, nullptr));
1212  }
1213 
1214 #elif defined(AMREX_USE_SYCL)
1215 
1216  mkl_desc_r* pp;
1217  if (M == 1) {
1218  pp = new mkl_desc_r(fft_size[0]);
1219  } else {
1220  std::vector<std::int64_t> len(M);
1221  for (int idim = 0; idim < M; ++idim) {
1222  len[idim] = fft_size[M-1-idim];
1223  }
1224  pp = new mkl_desc_r(len);
1225  }
1226 #ifndef AMREX_USE_MKL_DFTI_2024
1227  pp->set_value(oneapi::mkl::dft::config_param::PLACEMENT,
1228  oneapi::mkl::dft::config_value::NOT_INPLACE);
1229 #else
1230  pp->set_value(oneapi::mkl::dft::config_param::PLACEMENT, DFTI_NOT_INPLACE);
1231 #endif
1232 
1233  std::vector<std::int64_t> strides(M+1);
1234  strides[0] = 0;
1235  strides[M] = 1;
1236  for (int i = M-1; i >= 1; --i) {
1237  strides[i] = strides[i+1] * fft_size[M-1-i];
1238  }
1239 
1240 #ifndef AMREX_USE_MKL_DFTI_2024
1241  pp->set_value(oneapi::mkl::dft::config_param::FWD_STRIDES, strides);
1242  // Do not set BWD_STRIDES
1243 #else
1244  pp->set_value(oneapi::mkl::dft::config_param::FWD_STRIDES, strides.data());
1245  // Do not set BWD_STRIDES
1246 #endif
1247  pp->set_value(oneapi::mkl::dft::config_param::WORKSPACE,
1248  oneapi::mkl::dft::config_value::WORKSPACE_EXTERNAL);
1249  pp->commit(amrex::Gpu::Device::streamQueue());
1250  plan = pp;
1251 
1252 #else /* FFTW */
1253 
1254  if (pf == nullptr || pb == nullptr) {
1255  defined = false;
1256  return;
1257  }
1258 
1259  int size_for_row_major[M];
1260  for (int idim = 0; idim < M; ++idim) {
1261  size_for_row_major[idim] = fft_size[M-1-idim];
1262  }
1263 
1264  if constexpr (std::is_same_v<float,T>) {
1265  if constexpr (D == Direction::forward) {
1266  plan = fftwf_plan_dft_r2c
1267  (M, size_for_row_major, (float*)pf, (fftwf_complex*)pb,
1268  FFTW_ESTIMATE);
1269  } else {
1270  plan = fftwf_plan_dft_c2r
1271  (M, size_for_row_major, (fftwf_complex*)pb, (float*)pf,
1272  FFTW_ESTIMATE);
1273  }
1274  } else {
1275  if constexpr (D == Direction::forward) {
1276  plan = fftw_plan_dft_r2c
1277  (M, size_for_row_major, (double*)pf, (fftw_complex*)pb,
1278  FFTW_ESTIMATE);
1279  } else {
1280  plan = fftw_plan_dft_c2r
1281  (M, size_for_row_major, (fftw_complex*)pb, (double*)pf,
1282  FFTW_ESTIMATE);
1283  }
1284  }
1285 #endif
1286 
1287 #if defined(AMREX_USE_GPU)
1288  if (cache) {
1289  if constexpr (std::is_same_v<float,T>) {
1290  add_vendor_plan_f(key, plan);
1291  } else {
1292  add_vendor_plan_d(key, plan);
1293  }
1294  }
1295 #endif
1296 }
1297 
1298 namespace detail
1299 {
1301 
1302  template <typename FA>
1303  typename FA::FABType::value_type * get_fab (FA& fa)
1304  {
1305  auto myproc = ParallelContext::MyProcSub();
1306  if (myproc < fa.size()) {
1307  return fa.fabPtr(myproc);
1308  } else {
1309  return nullptr;
1310  }
1311  }
1312 
1313  template <typename FA1, typename FA2>
1314  std::unique_ptr<char,DataDeleter> make_mfs_share (FA1& fa1, FA2& fa2)
1315  {
1316  bool not_same_fa = true;
1317  if constexpr (std::is_same_v<FA1,FA2>) {
1318  not_same_fa = (&fa1 != &fa2);
1319  }
1320  using FAB1 = typename FA1::FABType::value_type;
1321  using FAB2 = typename FA2::FABType::value_type;
1322  using T1 = typename FAB1::value_type;
1323  using T2 = typename FAB2::value_type;
1324  auto myproc = ParallelContext::MyProcSub();
1325  bool alloc_1 = (myproc < fa1.size());
1326  bool alloc_2 = (myproc < fa2.size()) && not_same_fa;
1327  void* p = nullptr;
1328  if (alloc_1 && alloc_2) {
1329  Box const& box1 = fa1.fabbox(myproc);
1330  Box const& box2 = fa2.fabbox(myproc);
1331  int ncomp1 = fa1.nComp();
1332  int ncomp2 = fa2.nComp();
1333  p = The_Arena()->alloc(std::max(sizeof(T1)*box1.numPts()*ncomp1,
1334  sizeof(T2)*box2.numPts()*ncomp2));
1335  fa1.setFab(myproc, FAB1(box1, ncomp1, (T1*)p));
1336  fa2.setFab(myproc, FAB2(box2, ncomp2, (T2*)p));
1337  } else if (alloc_1) {
1338  Box const& box1 = fa1.fabbox(myproc);
1339  int ncomp1 = fa1.nComp();
1340  p = The_Arena()->alloc(sizeof(T1)*box1.numPts()*ncomp1);
1341  fa1.setFab(myproc, FAB1(box1, ncomp1, (T1*)p));
1342  } else if (alloc_2) {
1343  Box const& box2 = fa2.fabbox(myproc);
1344  int ncomp2 = fa2.nComp();
1345  p = The_Arena()->alloc(sizeof(T2)*box2.numPts()*ncomp2);
1346  fa2.setFab(myproc, FAB2(box2, ncomp2, (T2*)p));
1347  } else {
1348  return nullptr;
1349  }
1350  return std::unique_ptr<char,DataDeleter>((char*)p, DataDeleter{The_Arena()});
1351  }
1352 }
1353 
1354 struct Swap01
1355 {
1356  [[nodiscard]] constexpr Dim3 operator() (Dim3 i) const noexcept
1357  {
1358  return {i.y, i.x, i.z};
1359  }
1360 
1361  static constexpr Dim3 Inverse (Dim3 i)
1362  {
1363  return {i.y, i.x, i.z};
1364  }
1365 
1366  [[nodiscard]] constexpr IndexType operator() (IndexType it) const noexcept
1367  {
1368  return it;
1369  }
1370 
1371  static constexpr IndexType Inverse (IndexType it)
1372  {
1373  return it;
1374  }
1375 };
1376 
1377 struct Swap02
1378 {
1379  [[nodiscard]] constexpr Dim3 operator() (Dim3 i) const noexcept
1380  {
1381  return {i.z, i.y, i.x};
1382  }
1383 
1384  static constexpr Dim3 Inverse (Dim3 i)
1385  {
1386  return {i.z, i.y, i.x};
1387  }
1388 
1389  [[nodiscard]] constexpr IndexType operator() (IndexType it) const noexcept
1390  {
1391  return it;
1392  }
1393 
1394  static constexpr IndexType Inverse (IndexType it)
1395  {
1396  return it;
1397  }
1398 };
1399 
1401 {
1402  // dest -> src: (x,y,z) -> (y,z,x)
1403  [[nodiscard]] constexpr Dim3 operator() (Dim3 i) const noexcept
1404  {
1405  return {i.y, i.z, i.x};
1406  }
1407 
1408  // src -> dest: (x,y,z) -> (z,x,y)
1409  static constexpr Dim3 Inverse (Dim3 i)
1410  {
1411  return {i.z, i.x, i.y};
1412  }
1413 
1414  [[nodiscard]] constexpr IndexType operator() (IndexType it) const noexcept
1415  {
1416  return it;
1417  }
1418 
1419  static constexpr IndexType Inverse (IndexType it)
1420  {
1421  return it;
1422  }
1423 };
1424 
1426 {
1427  // dest -> src: (x,y,z) -> (z,x,y)
1428  [[nodiscard]] constexpr Dim3 operator() (Dim3 i) const noexcept
1429  {
1430  return {i.z, i.x, i.y};
1431  }
1432 
1433  // src -> dest: (x,y,z) -> (y,z,x)
1434  static constexpr Dim3 Inverse (Dim3 i)
1435  {
1436  return {i.y, i.z, i.x};
1437  }
1438 
1439  [[nodiscard]] constexpr IndexType operator() (IndexType it) const noexcept
1440  {
1441  return it;
1442  }
1443 
1444  static constexpr IndexType Inverse (IndexType it)
1445  {
1446  return it;
1447  }
1448 };
1449 
1450 }
1451 
1452 #endif
#define AMREX_CUFFT_SAFE_CALL(call)
Definition: AMReX_GpuError.H:88
#define AMREX_GPU_DEVICE
Definition: AMReX_GpuQualifiers.H:18
amrex::ParmParse pp
Input file parser instance for the given namespace.
Definition: AMReX_HypreIJIface.cpp:15
Real * pdst
Definition: AMReX_HypreMLABecLap.cpp:1090
#define AMREX_D_TERM(a, b, c)
Definition: AMReX_SPACE.H:129
virtual void free(void *pt)=0
A pure virtual function for deleting the arena pointed to by pt.
virtual void * alloc(std::size_t sz)=0
AMREX_GPU_HOST_DEVICE IntVectND< dim > length() const noexcept
Return the length of the BoxND.
Definition: AMReX_Box.H:146
AMREX_GPU_HOST_DEVICE Long numPts() const noexcept
Returns the number of points contained in the BoxND.
Definition: AMReX_Box.H:346
Calculates the distribution of FABs to MPI processes.
Definition: AMReX_DistributionMapping.H:41
Definition: AMReX_IntVect.H:48
FA::FABType::value_type * get_fab(FA &fa)
Definition: AMReX_FFT_Helper.H:1303
std::unique_ptr< char, DataDeleter > make_mfs_share(FA1 &fa1, FA2 &fa2)
Definition: AMReX_FFT_Helper.H:1314
DistributionMapping make_iota_distromap(Long n)
Definition: AMReX_FFT.cpp:88
Definition: AMReX_FFT.cpp:7
std::tuple< IntVectND< 3 >, Direction, Kind > Key
Definition: AMReX_FFT_Helper.H:1132
Direction
Definition: AMReX_FFT_Helper.H:46
void add_vendor_plan_f(Key const &key, PlanF plan)
Definition: AMReX_FFT.cpp:78
DomainStrategy
Definition: AMReX_FFT_Helper.H:48
typename Plan< float >::VendorPlan PlanF
Definition: AMReX_FFT_Helper.H:1134
AMREX_ENUM(Boundary, periodic, even, odd)
typename Plan< double >::VendorPlan PlanD
Definition: AMReX_FFT_Helper.H:1133
void add_vendor_plan_d(Key const &key, PlanD plan)
Definition: AMReX_FFT.cpp:73
Kind
Definition: AMReX_FFT_Helper.H:52
PlanF * get_vendor_plan_f(Key const &key)
Definition: AMReX_FFT.cpp:64
PlanD * get_vendor_plan_d(Key const &key)
Definition: AMReX_FFT.cpp:55
void streamSynchronize() noexcept
Definition: AMReX_GpuDevice.H:237
gpuStream_t gpuStream() noexcept
Definition: AMReX_GpuDevice.H:218
constexpr std::enable_if_t< std::is_floating_point_v< T >, T > pi()
Definition: AMReX_Math.H:62
AMREX_GPU_HOST_DEVICE AMREX_FORCE_INLINE std::pair< double, double > sincospi(double x)
Return sin(pi*x) and cos(pi*x) given x.
Definition: AMReX_Math.H:165
int MyProcSub() noexcept
my sub-rank in current frame
Definition: AMReX_ParallelContext.H:76
@ max
Definition: AMReX_ParallelReduce.H:17
static constexpr int M
Definition: AMReX_OpenBC.H:13
static constexpr int P
Definition: AMReX_OpenBC.H:14
std::enable_if_t< std::is_integral_v< T > > ParallelFor(TypeList< CTOs... > ctos, std::array< int, sizeof...(CTOs)> const &runtime_options, T N, F &&f)
Definition: AMReX_CTOParallelForImpl.H:200
AMREX_GPU_HOST_DEVICE AMREX_FORCE_INLINE void ignore_unused(const Ts &...)
This shuts up the compiler about unused variables.
Definition: AMReX.H:111
AMREX_GPU_HOST_DEVICE AMREX_FORCE_INLINE Dim3 length(Array4< T > const &a) noexcept
Definition: AMReX_Array4.H:322
void Abort(const std::string &msg)
Print out message to cerr and exit via abort().
Definition: AMReX.cpp:225
const int[]
Definition: AMReX_BLProfiler.cpp:1664
Arena * The_Arena()
Definition: AMReX_Arena.cpp:609
Definition: AMReX_FabArrayCommI.H:841
Definition: AMReX_DataAllocator.H:29
Definition: AMReX_Dim3.H:12
int x
Definition: AMReX_Dim3.H:12
int z
Definition: AMReX_Dim3.H:12
int y
Definition: AMReX_Dim3.H:12
Definition: AMReX_FFT_Helper.H:56
bool batch_mode
Definition: AMReX_FFT_Helper.H:60
Info & setBatchMode(bool x)
Definition: AMReX_FFT_Helper.H:65
int nprocs
Max number of processes to use.
Definition: AMReX_FFT_Helper.H:63
Info & setNumProcs(int n)
Definition: AMReX_FFT_Helper.H:66
Definition: AMReX_FFT_Helper.H:111
void * pf
Definition: AMReX_FFT_Helper.H:146
void init_r2c(Box const &box, T *pr, VendorComplex *pc, bool is_2d_transform=false)
Definition: AMReX_FFT_Helper.H:171
void unpack_r2r_buffer(T *pdst, void const *pbuf) const
Definition: AMReX_FFT_Helper.H:910
std::conditional_t< std::is_same_v< float, T >, cuComplex, cuDoubleComplex > VendorComplex
Definition: AMReX_FFT_Helper.H:115
VendorPlan plan2
Definition: AMReX_FFT_Helper.H:145
int n
Definition: AMReX_FFT_Helper.H:138
void destroy()
Definition: AMReX_FFT_Helper.H:156
bool defined2
Definition: AMReX_FFT_Helper.H:143
void init_r2r(Box const &box, VendorComplex *pc, std::pair< Boundary, Boundary > const &bc)
Definition: AMReX_FFT_Helper.H:531
void pack_r2r_buffer(void *pbuf, T const *psrc) const
Definition: AMReX_FFT_Helper.H:678
static void free_scratch_space(void *p)
Definition: AMReX_FFT_Helper.H:676
static void destroy_vendor_plan(VendorPlan plan)
Definition: AMReX_FFT_Helper.H:1114
Kind get_r2r_kind(std::pair< Boundary, Boundary > const &bc)
Definition: AMReX_FFT_Helper.H:409
cufftHandle VendorPlan
Definition: AMReX_FFT_Helper.H:113
Kind kind
Definition: AMReX_FFT_Helper.H:140
int howmany
Definition: AMReX_FFT_Helper.H:139
void init_r2c(IntVectND< M > const &fft_size, void *, void *, bool cache)
Definition: AMReX_FFT_Helper.H:1144
void * pb
Definition: AMReX_FFT_Helper.H:147
void * alloc_scratch_space() const
Definition: AMReX_FFT_Helper.H:662
void compute_r2r()
Definition: AMReX_FFT_Helper.H:1060
void compute_c2c()
Definition: AMReX_FFT_Helper.H:623
bool r2r_data_is_complex
Definition: AMReX_FFT_Helper.H:141
VendorPlan plan
Definition: AMReX_FFT_Helper.H:144
void compute_r2c()
Definition: AMReX_FFT_Helper.H:574
bool defined
Definition: AMReX_FFT_Helper.H:142
void set_ptrs(void *p0, void *p1)
Definition: AMReX_FFT_Helper.H:150
void init_r2r(Box const &box, T *p, std::pair< Boundary, Boundary > const &bc, int howmany_initval=1)
Definition: AMReX_FFT_Helper.H:435
void init_c2c(Box const &box, VendorComplex *p)
Definition: AMReX_FFT_Helper.H:297
Definition: AMReX_FFT_Helper.H:1426
constexpr Dim3 operator()(Dim3 i) const noexcept
Definition: AMReX_FFT_Helper.H:1428
static constexpr IndexType Inverse(IndexType it)
Definition: AMReX_FFT_Helper.H:1444
static constexpr Dim3 Inverse(Dim3 i)
Definition: AMReX_FFT_Helper.H:1434
Definition: AMReX_FFT_Helper.H:1401
static constexpr IndexType Inverse(IndexType it)
Definition: AMReX_FFT_Helper.H:1419
constexpr Dim3 operator()(Dim3 i) const noexcept
Definition: AMReX_FFT_Helper.H:1403
static constexpr Dim3 Inverse(Dim3 i)
Definition: AMReX_FFT_Helper.H:1409
Definition: AMReX_FFT_Helper.H:1355
constexpr Dim3 operator()(Dim3 i) const noexcept
Definition: AMReX_FFT_Helper.H:1356
static constexpr IndexType Inverse(IndexType it)
Definition: AMReX_FFT_Helper.H:1371
static constexpr Dim3 Inverse(Dim3 i)
Definition: AMReX_FFT_Helper.H:1361
Definition: AMReX_FFT_Helper.H:1378
static constexpr Dim3 Inverse(Dim3 i)
Definition: AMReX_FFT_Helper.H:1384
static constexpr IndexType Inverse(IndexType it)
Definition: AMReX_FFT_Helper.H:1394
constexpr Dim3 operator()(Dim3 i) const noexcept
Definition: AMReX_FFT_Helper.H:1379
A host / device complex number type, because std::complex doesn't work in device code with Cuda yet.
Definition: AMReX_GpuComplex.H:29