1 #ifndef AMREX_FFT_HELPER_H_
2 #define AMREX_FFT_HELPER_H_
3 #include <AMReX_Config.H>
14 #if defined(AMREX_USE_CUDA)
16 # include <cuComplex.h>
17 #elif defined(AMREX_USE_HIP)
18 # if __has_include(<rocfft/rocfft.h>)
19 # include <rocfft/rocfft.h>
23 # include <hip/hip_complex.h>
24 #elif defined(AMREX_USE_SYCL)
25 # if __has_include(<oneapi/mkl/dft.hpp>)
26 # include <oneapi/mkl/dft.hpp>
28 # define AMREX_USE_MKL_DFTI_2024 1
29 # include <oneapi/mkl/dfti.hpp>
70 namespace detail {
void hip_execute (rocfft_plan plan,
void **in,
void **out); }
76 template <
typename T, Direction direction,
typename P,
typename TI,
typename TO>
77 void sycl_execute (
P* plan, TI* in, TO* out)
79 #ifndef AMREX_USE_MKL_DFTI_2024
80 std::int64_t workspaceSize = 0;
82 std::size_t workspaceSize = 0;
84 plan->get_value(oneapi::mkl::dft::config_param::WORKSPACE_BYTES,
87 plan->set_workspace(buffer);
89 if (std::is_same_v<TI,TO>) {
92 r = oneapi::mkl::dft::compute_forward(*plan, out);
94 r = oneapi::mkl::dft::compute_backward(*plan, out);
98 r = oneapi::mkl::dft::compute_forward(*plan, in, out);
100 r = oneapi::mkl::dft::compute_backward(*plan, in, out);
109 template <
typename T>
112 #if defined(AMREX_USE_CUDA)
115 cuComplex, cuDoubleComplex>;
116 #elif defined(AMREX_USE_HIP)
118 using VendorComplex = std::conditional_t<std::is_same_v<float,T>,
120 #elif defined(AMREX_USE_SYCL)
121 using mkl_desc_r = oneapi::mkl::dft::descriptor<std::is_same_v<float,T>
122 ? oneapi::mkl::dft::precision::SINGLE
123 : oneapi::mkl::dft::precision::DOUBLE,
124 oneapi::mkl::dft::domain::REAL>;
125 using mkl_desc_c = oneapi::mkl::dft::descriptor<std::is_same_v<float,T>
126 ? oneapi::mkl::dft::precision::SINGLE
127 : oneapi::mkl::dft::precision::DOUBLE,
128 oneapi::mkl::dft::domain::COMPLEX>;
129 using VendorPlan = std::variant<mkl_desc_r*,mkl_desc_c*>;
132 using VendorPlan = std::conditional_t<std::is_same_v<float,T>,
133 fftwf_plan, fftw_plan>;
134 using VendorComplex = std::conditional_t<std::is_same_v<float,T>,
135 fftwf_complex, fftw_complex>;
162 #if !defined(AMREX_USE_GPU)
170 template <Direction D>
175 int rank = is_2d_transform ? 2 : 1;
190 int nr = (rank == 1) ? len[0] : len[0]*len[1];
192 int nc = (rank == 1) ? (len[0]/2+1) : (len[1]/2+1)*len[0];
193 #if (AMREX_SPACEDIM == 1)
202 #if defined(AMREX_USE_CUDA)
206 std::size_t work_size;
208 cufftType fwd_type = std::is_same_v<float,T> ? CUFFT_R2C : CUFFT_D2Z;
210 (cufftMakePlanMany(
plan, rank, len,
nullptr, 1, nr,
nullptr, 1, nc, fwd_type,
howmany, &work_size));
212 cufftType bwd_type = std::is_same_v<float,T> ? CUFFT_C2R : CUFFT_Z2D;
214 (cufftMakePlanMany(
plan, rank, len,
nullptr, 1, nc,
nullptr, 1, nr, bwd_type,
howmany, &work_size));
217 #elif defined(AMREX_USE_HIP)
219 auto prec = std::is_same_v<float,T> ? rocfft_precision_single : rocfft_precision_double;
221 std::size_t
length[2] = {std::size_t(len[1]), std::size_t(len[0])};
223 AMREX_ROCFFT_SAFE_CALL
224 (rocfft_plan_create(&
plan, rocfft_placement_notinplace,
225 rocfft_transform_type_real_forward, prec, rank,
228 AMREX_ROCFFT_SAFE_CALL
229 (rocfft_plan_create(&
plan, rocfft_placement_notinplace,
230 rocfft_transform_type_real_inverse, prec, rank,
234 #elif defined(AMREX_USE_SYCL)
238 pp =
new mkl_desc_r(len[0]);
240 pp =
new mkl_desc_r({std::int64_t(len[0]), std::int64_t(len[1])});
242 #ifndef AMREX_USE_MKL_DFTI_2024
243 pp->set_value(oneapi::mkl::dft::config_param::PLACEMENT,
244 oneapi::mkl::dft::config_value::NOT_INPLACE);
246 pp->set_value(oneapi::mkl::dft::config_param::PLACEMENT, DFTI_NOT_INPLACE);
248 pp->set_value(oneapi::mkl::dft::config_param::NUMBER_OF_TRANSFORMS,
howmany);
249 pp->set_value(oneapi::mkl::dft::config_param::FWD_DISTANCE, nr);
250 pp->set_value(oneapi::mkl::dft::config_param::BWD_DISTANCE, nc);
251 std::vector<std::int64_t> strides;
252 strides.push_back(0);
253 if (rank == 2) { strides.push_back(len[1]); }
254 strides.push_back(1);
255 #ifndef AMREX_USE_MKL_DFTI_2024
256 pp->set_value(oneapi::mkl::dft::config_param::FWD_STRIDES, strides);
259 pp->set_value(oneapi::mkl::dft::config_param::FWD_STRIDES, strides.data());
262 pp->set_value(oneapi::mkl::dft::config_param::WORKSPACE,
263 oneapi::mkl::dft::config_value::WORKSPACE_EXTERNAL);
264 pp->commit(amrex::Gpu::Device::streamQueue());
269 if constexpr (std::is_same_v<float,T>) {
271 plan = fftwf_plan_many_dft_r2c
272 (rank, len,
howmany, pr,
nullptr, 1, nr, pc,
nullptr, 1, nc,
273 FFTW_ESTIMATE | FFTW_DESTROY_INPUT);
275 plan = fftwf_plan_many_dft_c2r
276 (rank, len,
howmany, pc,
nullptr, 1, nc, pr,
nullptr, 1, nr,
277 FFTW_ESTIMATE | FFTW_DESTROY_INPUT);
281 plan = fftw_plan_many_dft_r2c
282 (rank, len,
howmany, pr,
nullptr, 1, nr, pc,
nullptr, 1, nc,
283 FFTW_ESTIMATE | FFTW_DESTROY_INPUT);
285 plan = fftw_plan_many_dft_c2r
286 (rank, len,
howmany, pc,
nullptr, 1, nc, pr,
nullptr, 1, nr,
287 FFTW_ESTIMATE | FFTW_DESTROY_INPUT);
293 template <Direction D,
int M>
296 template <Direction D>
309 #if defined(AMREX_USE_CUDA)
313 cufftType t = std::is_same_v<float,T> ? CUFFT_C2C : CUFFT_Z2Z;
314 std::size_t work_size;
316 (cufftMakePlanMany(
plan, 1, &
n,
nullptr, 1,
n,
nullptr, 1,
n, t,
howmany, &work_size));
318 #elif defined(AMREX_USE_HIP)
320 auto prec = std::is_same_v<float,T> ? rocfft_precision_single
321 : rocfft_precision_double;
323 : rocfft_transform_type_complex_inverse;
325 AMREX_ROCFFT_SAFE_CALL
326 (rocfft_plan_create(&
plan, rocfft_placement_inplace, dir, prec, 1,
329 #elif defined(AMREX_USE_SYCL)
331 auto*
pp =
new mkl_desc_c(
n);
332 #ifndef AMREX_USE_MKL_DFTI_2024
333 pp->set_value(oneapi::mkl::dft::config_param::PLACEMENT,
334 oneapi::mkl::dft::config_value::INPLACE);
336 pp->set_value(oneapi::mkl::dft::config_param::PLACEMENT, DFTI_INPLACE);
338 pp->set_value(oneapi::mkl::dft::config_param::NUMBER_OF_TRANSFORMS,
howmany);
339 pp->set_value(oneapi::mkl::dft::config_param::FWD_DISTANCE,
n);
340 pp->set_value(oneapi::mkl::dft::config_param::BWD_DISTANCE,
n);
341 std::vector<std::int64_t> strides = {0,1};
342 #ifndef AMREX_USE_MKL_DFTI_2024
343 pp->set_value(oneapi::mkl::dft::config_param::FWD_STRIDES, strides);
344 pp->set_value(oneapi::mkl::dft::config_param::BWD_STRIDES, strides);
346 pp->set_value(oneapi::mkl::dft::config_param::FWD_STRIDES, strides.data());
347 pp->set_value(oneapi::mkl::dft::config_param::BWD_STRIDES, strides.data());
349 pp->set_value(oneapi::mkl::dft::config_param::WORKSPACE,
350 oneapi::mkl::dft::config_value::WORKSPACE_EXTERNAL);
351 pp->commit(amrex::Gpu::Device::streamQueue());
356 if constexpr (std::is_same_v<float,T>) {
358 plan = fftwf_plan_many_dft
359 (1, &
n,
howmany, p,
nullptr, 1,
n, p,
nullptr, 1,
n, -1,
362 plan = fftwf_plan_many_dft
363 (1, &
n,
howmany, p,
nullptr, 1,
n, p,
nullptr, 1,
n, +1,
368 plan = fftw_plan_many_dft
369 (1, &
n,
howmany, p,
nullptr, 1,
n, p,
nullptr, 1,
n, -1,
372 plan = fftw_plan_many_dft
373 (1, &
n,
howmany, p,
nullptr, 1,
n, p,
nullptr, 1,
n, +1,
380 #ifndef AMREX_USE_GPU
381 template <Direction D>
382 fftw_r2r_kind get_fftw_kind (std::pair<Boundary,Boundary>
const& bc)
384 if (bc.first == Boundary::even && bc.second == Boundary::even)
388 else if (bc.first == Boundary::even && bc.second == Boundary::odd)
392 else if (bc.first == Boundary::odd && bc.second == Boundary::even)
396 else if (bc.first == Boundary::odd && bc.second == Boundary::odd)
402 return fftw_r2r_kind{};
408 template <Direction D>
411 if (bc.first == Boundary::even && bc.second == Boundary::even)
415 else if (bc.first == Boundary::even && bc.second == Boundary::odd)
419 else if (bc.first == Boundary::odd && bc.second == Boundary::even)
423 else if (bc.first == Boundary::odd && bc.second == Boundary::odd)
434 template <Direction D>
435 void init_r2r (
Box const& box, T* p, std::pair<Boundary,Boundary>
const& bc,
436 int howmany_initval = 1)
440 kind = get_r2r_kind<D>(bc);
448 #if defined(AMREX_USE_GPU)
450 if (bc.first == Boundary::odd && bc.second == Boundary::odd &&
453 }
else if (bc.first == Boundary::odd && bc.second == Boundary::odd &&
456 }
else if (bc.first == Boundary::even && bc.second == Boundary::even &&
459 }
else if (bc.first == Boundary::even && bc.second == Boundary::even &&
462 }
else if ((bc.first == Boundary::even && bc.second == Boundary::odd) ||
463 (bc.first == Boundary::odd && bc.second == Boundary::even)) {
468 int nc = (nex/2) + 1;
470 #if defined (AMREX_USE_CUDA)
474 cufftType fwd_type = std::is_same_v<float,T> ? CUFFT_R2C : CUFFT_D2Z;
475 std::size_t work_size;
477 (cufftMakePlanMany(
plan, 1, &nex,
nullptr, 1, nc*2,
nullptr, 1, nc, fwd_type,
howmany, &work_size));
479 #elif defined(AMREX_USE_HIP)
482 auto prec = std::is_same_v<float,T> ? rocfft_precision_single : rocfft_precision_double;
483 const std::size_t
length = nex;
484 AMREX_ROCFFT_SAFE_CALL
485 (rocfft_plan_create(&
plan, rocfft_placement_inplace,
486 rocfft_transform_type_real_forward, prec, 1,
489 #elif defined(AMREX_USE_SYCL)
491 auto*
pp =
new mkl_desc_r(nex);
492 #ifndef AMREX_USE_MKL_DFTI_2024
493 pp->set_value(oneapi::mkl::dft::config_param::PLACEMENT,
494 oneapi::mkl::dft::config_value::INPLACE);
496 pp->set_value(oneapi::mkl::dft::config_param::PLACEMENT, DFTI_INPLACE);
498 pp->set_value(oneapi::mkl::dft::config_param::NUMBER_OF_TRANSFORMS,
howmany);
499 pp->set_value(oneapi::mkl::dft::config_param::FWD_DISTANCE, nc*2);
500 pp->set_value(oneapi::mkl::dft::config_param::BWD_DISTANCE, nc);
501 std::vector<std::int64_t> strides = {0,1};
502 #ifndef AMREX_USE_MKL_DFTI_2024
503 pp->set_value(oneapi::mkl::dft::config_param::FWD_STRIDES, strides);
504 pp->set_value(oneapi::mkl::dft::config_param::BWD_STRIDES, strides);
506 pp->set_value(oneapi::mkl::dft::config_param::FWD_STRIDES, strides.data());
507 pp->set_value(oneapi::mkl::dft::config_param::BWD_STRIDES, strides.data());
509 pp->set_value(oneapi::mkl::dft::config_param::WORKSPACE,
510 oneapi::mkl::dft::config_value::WORKSPACE_EXTERNAL);
511 pp->commit(amrex::Gpu::Device::streamQueue());
517 auto fftw_kind = get_fftw_kind<D>(bc);
518 if constexpr (std::is_same_v<float,T>) {
519 plan = fftwf_plan_many_r2r
520 (1, &
n,
howmany, p,
nullptr, 1,
n, p,
nullptr, 1,
n, &fftw_kind,
523 plan = fftw_plan_many_r2r
524 (1, &
n,
howmany, p,
nullptr, 1,
n, p,
nullptr, 1,
n, &fftw_kind,
530 template <Direction D>
532 std::pair<Boundary,Boundary>
const& bc)
538 #if defined(AMREX_USE_GPU)
540 init_r2r<D>(box, p, bc, 2);
545 kind = get_r2r_kind<D>(bc);
554 auto fftw_kind = get_fftw_kind<D>(bc);
555 if constexpr (std::is_same_v<float,T>) {
556 plan = fftwf_plan_many_r2r
557 (1, &
n,
howmany, p,
nullptr, 2,
n*2, p,
nullptr, 2,
n*2, &fftw_kind,
559 plan2 = fftwf_plan_many_r2r
560 (1, &
n,
howmany, p+1,
nullptr, 2,
n*2, p+1,
nullptr, 2,
n*2, &fftw_kind,
563 plan = fftw_plan_many_r2r
564 (1, &
n,
howmany, p,
nullptr, 2,
n*2, p,
nullptr, 2,
n*2, &fftw_kind,
566 plan2 = fftw_plan_many_r2r
567 (1, &
n,
howmany, p+1,
nullptr, 2,
n*2, p+1,
nullptr, 2,
n*2, &fftw_kind,
573 template <Direction D>
584 #if defined(AMREX_USE_CUDA)
587 std::size_t work_size = 0;
594 if constexpr (std::is_same_v<float,T>) {
600 if constexpr (std::is_same_v<float,T>) {
608 #elif defined(AMREX_USE_HIP)
609 detail::hip_execute(
plan, (
void**)&
pi, (
void**)&po);
610 #elif defined(AMREX_USE_SYCL)
611 detail::sycl_execute<T,D>(std::get<0>(
plan),
pi, po);
614 if constexpr (std::is_same_v<float,T>) {
622 template <Direction D>
630 #if defined(AMREX_USE_CUDA)
633 std::size_t work_size = 0;
640 if constexpr (std::is_same_v<float,T>) {
647 #elif defined(AMREX_USE_HIP)
648 detail::hip_execute(
plan, (
void**)&p, (
void**)&p);
649 #elif defined(AMREX_USE_SYCL)
650 detail::sycl_execute<T,D>(std::get<1>(
plan), p, p);
653 if constexpr (std::is_same_v<float,T>) {
671 amrex::Abort(
"FFT: alloc_scratch_space: unsupported kind");
680 auto*
pdst = (T*) pbuf;
683 int ostride = (
n+1)*2;
687 Long nelems = Long(nex)*
howmany;
691 auto batch = ielem / Long(nex);
692 auto i =
int(ielem - batch*nex);
693 for (
int ir = 0; ir < 2; ++ir) {
694 auto* po =
pdst + (2*batch+ir)*ostride + i;
695 auto const*
pi = psrc + 2*batch*istride + ir;
699 *po = sign *
pi[(2*norig-1-i)*2];
706 auto batch = ielem / Long(nex);
707 auto i =
int(ielem - batch*nex);
708 auto* po =
pdst + batch*ostride + i;
709 auto const*
pi = psrc + batch*istride;
713 *po = sign *
pi[2*norig-1-i];
718 int ostride = (2*
n+1)*2;
722 Long nelems = Long(nex)*
howmany;
726 auto batch = ielem / Long(nex);
727 auto i =
int(ielem - batch*nex);
728 for (
int ir = 0; ir < 2; ++ir) {
729 auto* po =
pdst + (2*batch+ir)*ostride + i;
730 auto const*
pi = psrc + 2*batch*istride + ir;
733 }
else if (i < (2*norig-1)) {
734 *po =
pi[(2*norig-2-i)*2];
735 }
else if (i == (2*norig-1)) {
737 }
else if (i < (3*norig)) {
738 *po = -
pi[(i-2*norig)*2];
739 }
else if (i < (4*norig-1)) {
740 *po = -
pi[(4*norig-2-i)*2];
749 auto batch = ielem / Long(nex);
750 auto i =
int(ielem - batch*nex);
751 auto* po =
pdst + batch*ostride + i;
752 auto const*
pi = psrc + batch*istride;
755 }
else if (i < (2*norig-1)) {
756 *po =
pi[2*norig-2-i];
757 }
else if (i == (2*norig-1)) {
759 }
else if (i < (3*norig)) {
760 *po = -
pi[i-2*norig];
761 }
else if (i < (4*norig-1)) {
762 *po = -
pi[4*norig-2-i];
769 int ostride = (2*
n+1)*2;
773 Long nelems = Long(nex)*
howmany;
777 auto batch = ielem / Long(nex);
778 auto i =
int(ielem - batch*nex);
779 for (
int ir = 0; ir < 2; ++ir) {
780 auto* po =
pdst + (2*batch+ir)*ostride + i;
781 auto const*
pi = psrc + 2*batch*istride + ir;
784 }
else if (i == norig) {
786 }
else if (i < (2*norig+1)) {
787 *po = -
pi[(2*norig-i)*2];
788 }
else if (i < (3*norig)) {
789 *po = -
pi[(i-2*norig)*2];
790 }
else if (i == 3*norig) {
793 *po =
pi[(4*norig-i)*2];
800 auto batch = ielem / Long(nex);
801 auto i =
int(ielem - batch*nex);
802 auto* po =
pdst + batch*ostride + i;
803 auto const*
pi = psrc + batch*istride;
806 }
else if (i == norig) {
808 }
else if (i < (2*norig+1)) {
809 *po = -
pi[2*norig-i];
810 }
else if (i < (3*norig)) {
811 *po = -
pi[i-2*norig];
812 }
else if (i == 3*norig) {
820 int ostride = (2*
n+1)*2;
824 Long nelems = Long(nex)*
howmany;
828 auto batch = ielem / Long(nex);
829 auto i =
int(ielem - batch*nex);
830 for (
int ir = 0; ir < 2; ++ir) {
831 auto* po =
pdst + (2*batch+ir)*ostride + i;
832 auto const*
pi = psrc + 2*batch*istride + ir;
835 }
else if (i < (2*norig)) {
836 *po = -
pi[(2*norig-1-i)*2];
837 }
else if (i < (3*norig)) {
838 *po = -
pi[(i-2*norig)*2];
840 *po =
pi[(4*norig-1-i)*2];
847 auto batch = ielem / Long(nex);
848 auto i =
int(ielem - batch*nex);
849 auto* po =
pdst + batch*ostride + i;
850 auto const*
pi = psrc + batch*istride;
853 }
else if (i < (2*norig)) {
854 *po = -
pi[2*norig-1-i];
855 }
else if (i < (3*norig)) {
856 *po = -
pi[i-2*norig];
858 *po =
pi[4*norig-1-i];
863 int ostride = (2*
n+1)*2;
867 Long nelems = Long(nex)*
howmany;
871 auto batch = ielem / Long(nex);
872 auto i =
int(ielem - batch*nex);
873 for (
int ir = 0; ir < 2; ++ir) {
874 auto* po =
pdst + (2*batch+ir)*ostride + i;
875 auto const*
pi = psrc + 2*batch*istride + ir;
878 }
else if (i < (2*norig)) {
879 *po =
pi[(2*norig-1-i)*2];
880 }
else if (i < (3*norig)) {
881 *po = -
pi[(i-2*norig)*2];
883 *po = -
pi[(4*norig-1-i)*2];
890 auto batch = ielem / Long(nex);
891 auto i =
int(ielem - batch*nex);
892 auto* po =
pdst + batch*ostride + i;
893 auto const*
pi = psrc + batch*istride;
896 }
else if (i < (2*norig)) {
897 *po =
pi[2*norig-1-i];
898 }
else if (i < (3*norig)) {
899 *po = -
pi[i-2*norig];
901 *po = -
pi[4*norig-1-i];
914 Long nelems = Long(norig)*
howmany;
922 auto batch = ielem / Long(norig);
923 auto k =
int(ielem - batch*norig);
925 for (
int ir = 0; ir < 2; ++ir) {
926 auto const& yk = psrc[(2*batch+ir)*istride+k+1];
927 pdst[2*batch*ostride+ir+k*2] = s * yk.real() - c * yk.imag();
933 auto batch = ielem / Long(norig);
934 auto k =
int(ielem - batch*norig);
936 auto const& yk = psrc[batch*istride+k+1];
937 pdst[batch*ostride+k] = s * yk.real() - c * yk.imag();
945 auto batch = ielem / Long(norig);
946 auto k =
int(ielem - batch*norig);
948 for (
int ir = 0; ir < 2; ++ir) {
949 auto const& yk = psrc[(2*batch+ir)*istride+2*k+1];
950 pdst[2*batch*ostride+ir+k*2] = T(0.5)*(s * yk.real() - c * yk.imag());
956 auto batch = ielem / Long(norig);
957 auto k =
int(ielem - batch*norig);
959 auto const& yk = psrc[batch*istride+2*k+1];
960 pdst[batch*ostride+k] = T(0.5)*(s * yk.real() - c * yk.imag());
968 auto batch = ielem / Long(norig);
969 auto k =
int(ielem - batch*norig);
971 for (
int ir = 0; ir < 2; ++ir) {
972 auto const& yk = psrc[(2*batch+ir)*istride+k];
973 pdst[2*batch*ostride+ir+k*2] = c * yk.real() + s * yk.imag();
979 auto batch = ielem / Long(norig);
980 auto k =
int(ielem - batch*norig);
982 auto const& yk = psrc[batch*istride+k];
983 pdst[batch*ostride+k] = c * yk.real() + s * yk.imag();
991 auto batch = ielem / Long(norig);
992 auto k =
int(ielem - batch*norig);
993 for (
int ir = 0; ir < 2; ++ir) {
994 auto const& yk = psrc[(2*batch+ir)*istride+2*k+1];
995 pdst[2*batch*ostride+ir+k*2] = T(0.5) * yk.real();
1001 auto batch = ielem / Long(norig);
1002 auto k =
int(ielem - batch*norig);
1003 auto const& yk = psrc[batch*istride+2*k+1];
1004 pdst[batch*ostride+k] = T(0.5) * yk.real();
1008 int istride = 2*
n+1;
1012 auto batch = ielem / Long(norig);
1013 auto k =
int(ielem - batch*norig);
1015 for (
int ir = 0; ir < 2; ++ir) {
1016 auto const& yk = psrc[(2*batch+ir)*istride+2*k+1];
1017 pdst[2*batch*ostride+ir+k*2] = T(0.5) * (c * yk.real() + s * yk.imag());
1023 auto batch = ielem / Long(norig);
1024 auto k =
int(ielem - batch*norig);
1026 auto const& yk = psrc[batch*istride+2*k+1];
1027 pdst[batch*ostride+k] = T(0.5) * (c * yk.real() + s * yk.imag());
1031 int istride = 2*
n+1;
1035 auto batch = ielem / Long(norig);
1036 auto k =
int(ielem - batch*norig);
1038 for (
int ir = 0; ir < 2; ++ir) {
1039 auto const& yk = psrc[(2*batch+ir)*istride+2*k+1];
1040 pdst[2*batch*ostride+ir+k*2] = T(0.5) * (s * yk.real() - c * yk.imag());
1046 auto batch = ielem / Long(norig);
1047 auto k =
int(ielem - batch*norig);
1049 auto const& yk = psrc[batch*istride+2*k+1];
1050 pdst[batch*ostride+k] = T(0.5) * (s * yk.real() - c * yk.imag());
1054 amrex::Abort(
"FFT: unpack_r2r_buffer: unsupported kind");
1059 template <Direction D>
1065 #if defined(AMREX_USE_GPU)
1071 #if defined(AMREX_USE_CUDA)
1075 std::size_t work_size = 0;
1081 if constexpr (std::is_same_v<float,T>) {
1087 #elif defined(AMREX_USE_HIP)
1088 detail::hip_execute(
plan, (
void**)&pscratch, (
void**)&pscratch);
1089 #elif defined(AMREX_USE_SYCL)
1090 detail::sycl_execute<T,Direction::forward>(std::get<0>(
plan), (T*)pscratch, (
VendorComplex*)pscratch);
1097 #if defined(AMREX_USE_CUDA)
1103 if constexpr (std::is_same_v<float,T>) {
1104 fftwf_execute(
plan);
1116 #if defined(AMREX_USE_CUDA)
1118 #elif defined(AMREX_USE_HIP)
1119 AMREX_ROCFFT_SAFE_CALL(rocfft_plan_destroy(
plan));
1120 #elif defined(AMREX_USE_SYCL)
1121 std::visit([](
auto&& p) {
delete p; },
plan);
1123 if constexpr (std::is_same_v<float,T>) {
1124 fftwf_destroy_plan(
plan);
1126 fftw_destroy_plan(
plan);
1142 template <
typename T>
1143 template <Direction D,
int M>
1154 for (
auto s : fft_size) { n *= s; }
1157 #if defined(AMREX_USE_GPU)
1158 Key key = {fft_size.template expand<3>(), D, kind};
1161 if constexpr (std::is_same_v<float,T>) {
1167 plan = *cached_plan;
1175 #if defined(AMREX_USE_CUDA)
1181 type = std::is_same_v<float,T> ? CUFFT_R2C : CUFFT_D2Z;
1183 type = std::is_same_v<float,T> ? CUFFT_C2R : CUFFT_Z2D;
1185 std::size_t work_size;
1186 if constexpr (
M == 1) {
1188 (cufftMakePlan1d(plan, fft_size[0], type, howmany, &work_size));
1189 }
else if constexpr (
M == 2) {
1191 (cufftMakePlan2d(plan, fft_size[1], fft_size[0], type, &work_size));
1192 }
else if constexpr (
M == 3) {
1194 (cufftMakePlan3d(plan, fft_size[2], fft_size[1], fft_size[0], type, &work_size));
1197 #elif defined(AMREX_USE_HIP)
1199 auto prec = std::is_same_v<float,T> ? rocfft_precision_single : rocfft_precision_double;
1201 for (
int idim = 0; idim <
M; ++idim) {
length[idim] = fft_size[idim]; }
1203 AMREX_ROCFFT_SAFE_CALL
1204 (rocfft_plan_create(&plan, rocfft_placement_notinplace,
1205 rocfft_transform_type_real_forward, prec,
M,
1206 length, howmany,
nullptr));
1208 AMREX_ROCFFT_SAFE_CALL
1209 (rocfft_plan_create(&plan, rocfft_placement_notinplace,
1210 rocfft_transform_type_real_inverse, prec,
M,
1211 length, howmany,
nullptr));
1214 #elif defined(AMREX_USE_SYCL)
1218 pp =
new mkl_desc_r(fft_size[0]);
1220 std::vector<std::int64_t> len(
M);
1221 for (
int idim = 0; idim <
M; ++idim) {
1222 len[idim] = fft_size[
M-1-idim];
1224 pp =
new mkl_desc_r(len);
1226 #ifndef AMREX_USE_MKL_DFTI_2024
1227 pp->set_value(oneapi::mkl::dft::config_param::PLACEMENT,
1228 oneapi::mkl::dft::config_value::NOT_INPLACE);
1230 pp->set_value(oneapi::mkl::dft::config_param::PLACEMENT, DFTI_NOT_INPLACE);
1233 std::vector<std::int64_t> strides(
M+1);
1236 for (
int i =
M-1; i >= 1; --i) {
1237 strides[i] = strides[i+1] * fft_size[
M-1-i];
1240 #ifndef AMREX_USE_MKL_DFTI_2024
1241 pp->set_value(oneapi::mkl::dft::config_param::FWD_STRIDES, strides);
1244 pp->set_value(oneapi::mkl::dft::config_param::FWD_STRIDES, strides.data());
1247 pp->set_value(oneapi::mkl::dft::config_param::WORKSPACE,
1248 oneapi::mkl::dft::config_value::WORKSPACE_EXTERNAL);
1249 pp->commit(amrex::Gpu::Device::streamQueue());
1254 if (pf ==
nullptr || pb ==
nullptr) {
1259 int size_for_row_major[
M];
1260 for (
int idim = 0; idim <
M; ++idim) {
1261 size_for_row_major[idim] = fft_size[
M-1-idim];
1264 if constexpr (std::is_same_v<float,T>) {
1266 plan = fftwf_plan_dft_r2c
1267 (
M, size_for_row_major, (
float*)pf, (fftwf_complex*)pb,
1270 plan = fftwf_plan_dft_c2r
1271 (
M, size_for_row_major, (fftwf_complex*)pb, (
float*)pf,
1276 plan = fftw_plan_dft_r2c
1277 (
M, size_for_row_major, (
double*)pf, (fftw_complex*)pb,
1280 plan = fftw_plan_dft_c2r
1281 (
M, size_for_row_major, (fftw_complex*)pb, (
double*)pf,
1287 #if defined(AMREX_USE_GPU)
1289 if constexpr (std::is_same_v<float,T>) {
1302 template <
typename FA>
1303 typename FA::FABType::value_type *
get_fab (FA& fa)
1306 if (myproc < fa.size()) {
1307 return fa.fabPtr(myproc);
1313 template <
typename FA1,
typename FA2>
1316 bool not_same_fa =
true;
1317 if constexpr (std::is_same_v<FA1,FA2>) {
1318 not_same_fa = (&fa1 != &fa2);
1320 using FAB1 =
typename FA1::FABType::value_type;
1321 using FAB2 =
typename FA2::FABType::value_type;
1322 using T1 =
typename FAB1::value_type;
1323 using T2 =
typename FAB2::value_type;
1325 bool alloc_1 = (myproc < fa1.size());
1326 bool alloc_2 = (myproc < fa2.size()) && not_same_fa;
1328 if (alloc_1 && alloc_2) {
1329 Box const& box1 = fa1.fabbox(myproc);
1330 Box const& box2 = fa2.fabbox(myproc);
1331 int ncomp1 = fa1.nComp();
1332 int ncomp2 = fa2.nComp();
1334 sizeof(T2)*box2.
numPts()*ncomp2));
1335 fa1.setFab(myproc, FAB1(box1, ncomp1, (T1*)p));
1336 fa2.setFab(myproc, FAB2(box2, ncomp2, (T2*)p));
1337 }
else if (alloc_1) {
1338 Box const& box1 = fa1.fabbox(myproc);
1339 int ncomp1 = fa1.nComp();
1341 fa1.setFab(myproc, FAB1(box1, ncomp1, (T1*)p));
1342 }
else if (alloc_2) {
1343 Box const& box2 = fa2.fabbox(myproc);
1344 int ncomp2 = fa2.nComp();
1346 fa2.setFab(myproc, FAB2(box2, ncomp2, (T2*)p));
1358 return {i.
y, i.x, i.z};
1363 return {i.
y, i.
x, i.
z};
1381 return {i.
z, i.y, i.x};
1386 return {i.
z, i.
y, i.
x};
1405 return {i.
y, i.z, i.x};
1411 return {i.
z, i.
x, i.
y};
1430 return {i.
z, i.x, i.y};
1436 return {i.
y, i.
z, i.
x};
#define AMREX_CUFFT_SAFE_CALL(call)
Definition: AMReX_GpuError.H:88
#define AMREX_GPU_DEVICE
Definition: AMReX_GpuQualifiers.H:18
amrex::ParmParse pp
Input file parser instance for the given namespace.
Definition: AMReX_HypreIJIface.cpp:15
Real * pdst
Definition: AMReX_HypreMLABecLap.cpp:1090
#define AMREX_D_TERM(a, b, c)
Definition: AMReX_SPACE.H:129
virtual void free(void *pt)=0
A pure virtual function for deleting the arena pointed to by pt.
virtual void * alloc(std::size_t sz)=0
AMREX_GPU_HOST_DEVICE IntVectND< dim > length() const noexcept
Return the length of the BoxND.
Definition: AMReX_Box.H:146
AMREX_GPU_HOST_DEVICE Long numPts() const noexcept
Returns the number of points contained in the BoxND.
Definition: AMReX_Box.H:346
Calculates the distribution of FABs to MPI processes.
Definition: AMReX_DistributionMapping.H:41
Definition: AMReX_IntVect.H:48
FA::FABType::value_type * get_fab(FA &fa)
Definition: AMReX_FFT_Helper.H:1303
std::unique_ptr< char, DataDeleter > make_mfs_share(FA1 &fa1, FA2 &fa2)
Definition: AMReX_FFT_Helper.H:1314
DistributionMapping make_iota_distromap(Long n)
Definition: AMReX_FFT.cpp:88
Definition: AMReX_FFT.cpp:7
std::tuple< IntVectND< 3 >, Direction, Kind > Key
Definition: AMReX_FFT_Helper.H:1132
Direction
Definition: AMReX_FFT_Helper.H:46
void add_vendor_plan_f(Key const &key, PlanF plan)
Definition: AMReX_FFT.cpp:78
DomainStrategy
Definition: AMReX_FFT_Helper.H:48
typename Plan< float >::VendorPlan PlanF
Definition: AMReX_FFT_Helper.H:1134
AMREX_ENUM(Boundary, periodic, even, odd)
typename Plan< double >::VendorPlan PlanD
Definition: AMReX_FFT_Helper.H:1133
void add_vendor_plan_d(Key const &key, PlanD plan)
Definition: AMReX_FFT.cpp:73
Kind
Definition: AMReX_FFT_Helper.H:52
PlanF * get_vendor_plan_f(Key const &key)
Definition: AMReX_FFT.cpp:64
PlanD * get_vendor_plan_d(Key const &key)
Definition: AMReX_FFT.cpp:55
void streamSynchronize() noexcept
Definition: AMReX_GpuDevice.H:237
gpuStream_t gpuStream() noexcept
Definition: AMReX_GpuDevice.H:218
constexpr std::enable_if_t< std::is_floating_point_v< T >, T > pi()
Definition: AMReX_Math.H:62
AMREX_GPU_HOST_DEVICE AMREX_FORCE_INLINE std::pair< double, double > sincospi(double x)
Return sin(pi*x) and cos(pi*x) given x.
Definition: AMReX_Math.H:165
int MyProcSub() noexcept
my sub-rank in current frame
Definition: AMReX_ParallelContext.H:76
@ max
Definition: AMReX_ParallelReduce.H:17
static constexpr int M
Definition: AMReX_OpenBC.H:13
static constexpr int P
Definition: AMReX_OpenBC.H:14
std::enable_if_t< std::is_integral_v< T > > ParallelFor(TypeList< CTOs... > ctos, std::array< int, sizeof...(CTOs)> const &runtime_options, T N, F &&f)
Definition: AMReX_CTOParallelForImpl.H:200
AMREX_GPU_HOST_DEVICE AMREX_FORCE_INLINE void ignore_unused(const Ts &...)
This shuts up the compiler about unused variables.
Definition: AMReX.H:111
AMREX_GPU_HOST_DEVICE AMREX_FORCE_INLINE Dim3 length(Array4< T > const &a) noexcept
Definition: AMReX_Array4.H:322
void Abort(const std::string &msg)
Print out message to cerr and exit via abort().
Definition: AMReX.cpp:225
const int[]
Definition: AMReX_BLProfiler.cpp:1664
Arena * The_Arena()
Definition: AMReX_Arena.cpp:609
Definition: AMReX_FabArrayCommI.H:841
Definition: AMReX_DataAllocator.H:29
Definition: AMReX_Dim3.H:12
int x
Definition: AMReX_Dim3.H:12
int z
Definition: AMReX_Dim3.H:12
int y
Definition: AMReX_Dim3.H:12
Definition: AMReX_FFT_Helper.H:56
bool batch_mode
Definition: AMReX_FFT_Helper.H:60
Info & setBatchMode(bool x)
Definition: AMReX_FFT_Helper.H:65
int nprocs
Max number of processes to use.
Definition: AMReX_FFT_Helper.H:63
Info & setNumProcs(int n)
Definition: AMReX_FFT_Helper.H:66
Definition: AMReX_FFT_Helper.H:111
void * pf
Definition: AMReX_FFT_Helper.H:146
void init_r2c(Box const &box, T *pr, VendorComplex *pc, bool is_2d_transform=false)
Definition: AMReX_FFT_Helper.H:171
void unpack_r2r_buffer(T *pdst, void const *pbuf) const
Definition: AMReX_FFT_Helper.H:910
std::conditional_t< std::is_same_v< float, T >, cuComplex, cuDoubleComplex > VendorComplex
Definition: AMReX_FFT_Helper.H:115
VendorPlan plan2
Definition: AMReX_FFT_Helper.H:145
int n
Definition: AMReX_FFT_Helper.H:138
void destroy()
Definition: AMReX_FFT_Helper.H:156
bool defined2
Definition: AMReX_FFT_Helper.H:143
void init_r2r(Box const &box, VendorComplex *pc, std::pair< Boundary, Boundary > const &bc)
Definition: AMReX_FFT_Helper.H:531
void pack_r2r_buffer(void *pbuf, T const *psrc) const
Definition: AMReX_FFT_Helper.H:678
static void free_scratch_space(void *p)
Definition: AMReX_FFT_Helper.H:676
static void destroy_vendor_plan(VendorPlan plan)
Definition: AMReX_FFT_Helper.H:1114
Kind get_r2r_kind(std::pair< Boundary, Boundary > const &bc)
Definition: AMReX_FFT_Helper.H:409
cufftHandle VendorPlan
Definition: AMReX_FFT_Helper.H:113
Kind kind
Definition: AMReX_FFT_Helper.H:140
int howmany
Definition: AMReX_FFT_Helper.H:139
void init_r2c(IntVectND< M > const &fft_size, void *, void *, bool cache)
Definition: AMReX_FFT_Helper.H:1144
void * pb
Definition: AMReX_FFT_Helper.H:147
void * alloc_scratch_space() const
Definition: AMReX_FFT_Helper.H:662
void compute_r2r()
Definition: AMReX_FFT_Helper.H:1060
void compute_c2c()
Definition: AMReX_FFT_Helper.H:623
bool r2r_data_is_complex
Definition: AMReX_FFT_Helper.H:141
VendorPlan plan
Definition: AMReX_FFT_Helper.H:144
void compute_r2c()
Definition: AMReX_FFT_Helper.H:574
bool defined
Definition: AMReX_FFT_Helper.H:142
void set_ptrs(void *p0, void *p1)
Definition: AMReX_FFT_Helper.H:150
void init_r2r(Box const &box, T *p, std::pair< Boundary, Boundary > const &bc, int howmany_initval=1)
Definition: AMReX_FFT_Helper.H:435
void init_c2c(Box const &box, VendorComplex *p)
Definition: AMReX_FFT_Helper.H:297
Definition: AMReX_FFT_Helper.H:1426
constexpr Dim3 operator()(Dim3 i) const noexcept
Definition: AMReX_FFT_Helper.H:1428
static constexpr IndexType Inverse(IndexType it)
Definition: AMReX_FFT_Helper.H:1444
static constexpr Dim3 Inverse(Dim3 i)
Definition: AMReX_FFT_Helper.H:1434
Definition: AMReX_FFT_Helper.H:1401
static constexpr IndexType Inverse(IndexType it)
Definition: AMReX_FFT_Helper.H:1419
constexpr Dim3 operator()(Dim3 i) const noexcept
Definition: AMReX_FFT_Helper.H:1403
static constexpr Dim3 Inverse(Dim3 i)
Definition: AMReX_FFT_Helper.H:1409
Definition: AMReX_FFT_Helper.H:1355
constexpr Dim3 operator()(Dim3 i) const noexcept
Definition: AMReX_FFT_Helper.H:1356
static constexpr IndexType Inverse(IndexType it)
Definition: AMReX_FFT_Helper.H:1371
static constexpr Dim3 Inverse(Dim3 i)
Definition: AMReX_FFT_Helper.H:1361
Definition: AMReX_FFT_Helper.H:1378
static constexpr Dim3 Inverse(Dim3 i)
Definition: AMReX_FFT_Helper.H:1384
static constexpr IndexType Inverse(IndexType it)
Definition: AMReX_FFT_Helper.H:1394
constexpr Dim3 operator()(Dim3 i) const noexcept
Definition: AMReX_FFT_Helper.H:1379
A host / device complex number type, because std::complex doesn't work in device code with Cuda yet.
Definition: AMReX_GpuComplex.H:29