1#ifndef AMREX_FFT_HELPER_H_
2#define AMREX_FFT_HELPER_H_
3#include <AMReX_Config.H>
16#if defined(AMREX_USE_CUDA)
18# include <cuComplex.h>
19#elif defined(AMREX_USE_HIP)
20# if __has_include(<rocfft/rocfft.h>)
21# include <rocfft/rocfft.h>
25# include <hip/hip_complex.h>
26#elif defined(AMREX_USE_SYCL)
27# if __has_include(<oneapi/mkl/dft.hpp>)
28# include <oneapi/mkl/dft.hpp>
30# define AMREX_USE_MKL_DFTI_2024 1
31# include <oneapi/mkl/dfti.hpp>
78 int nprocs = std::numeric_limits<int>::max();
89namespace detail {
void hip_execute (rocfft_plan plan,
void **in,
void **out); }
95template <
typename T, Direction direction,
typename P,
typename TI,
typename TO>
96void sycl_execute (P* plan, TI* in, TO* out)
98#ifndef AMREX_USE_MKL_DFTI_2024
99 std::int64_t workspaceSize = 0;
101 std::size_t workspaceSize = 0;
103 plan->get_value(oneapi::mkl::dft::config_param::WORKSPACE_BYTES,
106 plan->set_workspace(buffer);
108 if (std::is_same_v<TI,TO>) {
111 r = oneapi::mkl::dft::compute_forward(*plan, out);
113 r = oneapi::mkl::dft::compute_backward(*plan, out);
117 r = oneapi::mkl::dft::compute_forward(*plan, in, out);
119 r = oneapi::mkl::dft::compute_backward(*plan, in, out);
131#if defined(AMREX_USE_CUDA)
134 cuComplex, cuDoubleComplex>;
135#elif defined(AMREX_USE_HIP)
137 using VendorComplex = std::conditional_t<std::is_same_v<float,T>,
139#elif defined(AMREX_USE_SYCL)
140 using mkl_desc_r = oneapi::mkl::dft::descriptor<std::is_same_v<float,T>
141 ? oneapi::mkl::dft::precision::SINGLE
142 : oneapi::mkl::dft::precision::DOUBLE,
143 oneapi::mkl::dft::domain::REAL>;
144 using mkl_desc_c = oneapi::mkl::dft::descriptor<std::is_same_v<float,T>
145 ? oneapi::mkl::dft::precision::SINGLE
146 : oneapi::mkl::dft::precision::DOUBLE,
147 oneapi::mkl::dft::domain::COMPLEX>;
148 using VendorPlan = std::variant<mkl_desc_r*,mkl_desc_c*>;
151 using VendorPlan = std::conditional_t<std::is_same_v<float,T>,
152 fftwf_plan, fftw_plan>;
153 using VendorComplex = std::conditional_t<std::is_same_v<float,T>,
154 fftwf_complex, fftw_complex>;
181#if !defined(AMREX_USE_GPU)
189 template <Direction D>
194 int rank = is_2d_transform ? 2 : 1;
209 int nr = (rank == 1) ? len[0] : len[0]*len[1];
211 int nc = (rank == 1) ? (len[0]/2+1) : (len[1]/2+1)*len[0];
212#if (AMREX_SPACEDIM == 1)
222#if defined(AMREX_USE_CUDA)
226 std::size_t work_size;
228 cufftType fwd_type = std::is_same_v<float,T> ? CUFFT_R2C : CUFFT_D2Z;
230 (cufftMakePlanMany(
plan, rank, len,
nullptr, 1, nr,
nullptr, 1, nc, fwd_type,
howmany, &work_size));
232 cufftType bwd_type = std::is_same_v<float,T> ? CUFFT_C2R : CUFFT_Z2D;
234 (cufftMakePlanMany(
plan, rank, len,
nullptr, 1, nc,
nullptr, 1, nr, bwd_type,
howmany, &work_size));
237#elif defined(AMREX_USE_HIP)
239 auto prec = std::is_same_v<float,T> ? rocfft_precision_single : rocfft_precision_double;
241 std::size_t
length[2] = {std::size_t(len[1]), std::size_t(len[0])};
243 AMREX_ROCFFT_SAFE_CALL
244 (rocfft_plan_create(&
plan, rocfft_placement_notinplace,
245 rocfft_transform_type_real_forward, prec, rank,
248 AMREX_ROCFFT_SAFE_CALL
249 (rocfft_plan_create(&
plan, rocfft_placement_notinplace,
250 rocfft_transform_type_real_inverse, prec, rank,
254#elif defined(AMREX_USE_SYCL)
258 pp =
new mkl_desc_r(len[0]);
260 pp =
new mkl_desc_r({std::int64_t(len[0]), std::int64_t(len[1])});
262#ifndef AMREX_USE_MKL_DFTI_2024
263 pp->set_value(oneapi::mkl::dft::config_param::PLACEMENT,
264 oneapi::mkl::dft::config_value::NOT_INPLACE);
266 pp->set_value(oneapi::mkl::dft::config_param::PLACEMENT, DFTI_NOT_INPLACE);
268 pp->set_value(oneapi::mkl::dft::config_param::NUMBER_OF_TRANSFORMS,
howmany);
269 pp->set_value(oneapi::mkl::dft::config_param::FWD_DISTANCE, nr);
270 pp->set_value(oneapi::mkl::dft::config_param::BWD_DISTANCE, nc);
271 std::vector<std::int64_t> strides;
272 strides.push_back(0);
273 if (rank == 2) { strides.push_back(len[1]); }
274 strides.push_back(1);
275#ifndef AMREX_USE_MKL_DFTI_2024
276 pp->set_value(oneapi::mkl::dft::config_param::FWD_STRIDES, strides);
279 pp->set_value(oneapi::mkl::dft::config_param::FWD_STRIDES, strides.data());
282 pp->set_value(oneapi::mkl::dft::config_param::WORKSPACE,
283 oneapi::mkl::dft::config_value::WORKSPACE_EXTERNAL);
284 pp->commit(amrex::Gpu::Device::streamQueue());
289 if constexpr (std::is_same_v<float,T>) {
291 plan = fftwf_plan_many_dft_r2c
292 (rank, len,
howmany, pr,
nullptr, 1, nr, pc,
nullptr, 1, nc,
293 FFTW_ESTIMATE | FFTW_DESTROY_INPUT);
295 plan = fftwf_plan_many_dft_c2r
296 (rank, len,
howmany, pc,
nullptr, 1, nc, pr,
nullptr, 1, nr,
297 FFTW_ESTIMATE | FFTW_DESTROY_INPUT);
301 plan = fftw_plan_many_dft_r2c
302 (rank, len,
howmany, pr,
nullptr, 1, nr, pc,
nullptr, 1, nc,
303 FFTW_ESTIMATE | FFTW_DESTROY_INPUT);
305 plan = fftw_plan_many_dft_c2r
306 (rank, len,
howmany, pc,
nullptr, 1, nc, pr,
nullptr, 1, nr,
307 FFTW_ESTIMATE | FFTW_DESTROY_INPUT);
313 template <Direction D,
int M>
316 template <Direction D>
334#if (AMREX_SPACEDIM >= 2)
335 else if (ndims == 2) {
337#if (AMREX_SPACEDIM == 2)
345#if (AMREX_SPACEDIM == 3)
346 else if (ndims == 3) {
356#if defined(AMREX_USE_CUDA)
360 cufftType t = std::is_same_v<float,T> ? CUFFT_C2C : CUFFT_Z2Z;
361 std::size_t work_size;
363 (cufftMakePlanMany(
plan, ndims, len,
nullptr, 1,
n,
nullptr, 1,
n, t,
howmany, &work_size));
365#elif defined(AMREX_USE_HIP)
367 auto prec = std::is_same_v<float,T> ? rocfft_precision_single
368 : rocfft_precision_double;
370 : rocfft_transform_type_complex_inverse;
374 }
else if (ndims == 2) {
382 AMREX_ROCFFT_SAFE_CALL
383 (rocfft_plan_create(&
plan, rocfft_placement_inplace, dir, prec, ndims,
386#elif defined(AMREX_USE_SYCL)
390 pp =
new mkl_desc_c(
n);
391 }
else if (ndims == 2) {
392 pp =
new mkl_desc_c({std::int64_t(len[0]), std::int64_t(len[1])});
394 pp =
new mkl_desc_c({std::int64_t(len[0]), std::int64_t(len[1]), std::int64_t(len[2])});
396#ifndef AMREX_USE_MKL_DFTI_2024
397 pp->set_value(oneapi::mkl::dft::config_param::PLACEMENT,
398 oneapi::mkl::dft::config_value::INPLACE);
400 pp->set_value(oneapi::mkl::dft::config_param::PLACEMENT, DFTI_INPLACE);
402 pp->set_value(oneapi::mkl::dft::config_param::NUMBER_OF_TRANSFORMS,
howmany);
403 pp->set_value(oneapi::mkl::dft::config_param::FWD_DISTANCE,
n);
404 pp->set_value(oneapi::mkl::dft::config_param::BWD_DISTANCE,
n);
405 std::vector<std::int64_t> strides(ndims+1);
408 for (
int i = ndims-1; i >= 1; --i) {
409 strides[i] = strides[i+1] * len[i];
411#ifndef AMREX_USE_MKL_DFTI_2024
412 pp->set_value(oneapi::mkl::dft::config_param::FWD_STRIDES, strides);
413 pp->set_value(oneapi::mkl::dft::config_param::BWD_STRIDES, strides);
415 pp->set_value(oneapi::mkl::dft::config_param::FWD_STRIDES, strides.data());
416 pp->set_value(oneapi::mkl::dft::config_param::BWD_STRIDES, strides.data());
418 pp->set_value(oneapi::mkl::dft::config_param::WORKSPACE,
419 oneapi::mkl::dft::config_value::WORKSPACE_EXTERNAL);
420 pp->commit(amrex::Gpu::Device::streamQueue());
425 if constexpr (std::is_same_v<float,T>) {
427 plan = fftwf_plan_many_dft
428 (ndims, len,
howmany, p,
nullptr, 1,
n, p,
nullptr, 1,
n, -1,
431 plan = fftwf_plan_many_dft
432 (ndims, len,
howmany, p,
nullptr, 1,
n, p,
nullptr, 1,
n, +1,
437 plan = fftw_plan_many_dft
438 (ndims, len,
howmany, p,
nullptr, 1,
n, p,
nullptr, 1,
n, -1,
441 plan = fftw_plan_many_dft
442 (ndims, len,
howmany, p,
nullptr, 1,
n, p,
nullptr, 1,
n, +1,
450 template <Direction D>
451 fftw_r2r_kind get_fftw_kind (std::pair<Boundary,Boundary>
const& bc)
471 return fftw_r2r_kind{};
477 template <Direction D>
503 template <Direction D>
504 void init_r2r (
Box const& box, T* p, std::pair<Boundary,Boundary>
const& bc,
505 int howmany_initval = 1)
509 kind = get_r2r_kind<D>(bc);
517#if defined(AMREX_USE_GPU)
537 int nc = (nex/2) + 1;
539#if defined (AMREX_USE_CUDA)
543 cufftType fwd_type = std::is_same_v<float,T> ? CUFFT_R2C : CUFFT_D2Z;
544 std::size_t work_size;
546 (cufftMakePlanMany(
plan, 1, &nex,
nullptr, 1, nc*2,
nullptr, 1, nc, fwd_type,
howmany, &work_size));
548#elif defined(AMREX_USE_HIP)
551 auto prec = std::is_same_v<float,T> ? rocfft_precision_single : rocfft_precision_double;
552 const std::size_t
length = nex;
553 AMREX_ROCFFT_SAFE_CALL
554 (rocfft_plan_create(&
plan, rocfft_placement_inplace,
555 rocfft_transform_type_real_forward, prec, 1,
558#elif defined(AMREX_USE_SYCL)
560 auto*
pp =
new mkl_desc_r(nex);
561#ifndef AMREX_USE_MKL_DFTI_2024
562 pp->set_value(oneapi::mkl::dft::config_param::PLACEMENT,
563 oneapi::mkl::dft::config_value::INPLACE);
565 pp->set_value(oneapi::mkl::dft::config_param::PLACEMENT, DFTI_INPLACE);
567 pp->set_value(oneapi::mkl::dft::config_param::NUMBER_OF_TRANSFORMS,
howmany);
568 pp->set_value(oneapi::mkl::dft::config_param::FWD_DISTANCE, nc*2);
569 pp->set_value(oneapi::mkl::dft::config_param::BWD_DISTANCE, nc);
570 std::vector<std::int64_t> strides = {0,1};
571#ifndef AMREX_USE_MKL_DFTI_2024
572 pp->set_value(oneapi::mkl::dft::config_param::FWD_STRIDES, strides);
573 pp->set_value(oneapi::mkl::dft::config_param::BWD_STRIDES, strides);
575 pp->set_value(oneapi::mkl::dft::config_param::FWD_STRIDES, strides.data());
576 pp->set_value(oneapi::mkl::dft::config_param::BWD_STRIDES, strides.data());
578 pp->set_value(oneapi::mkl::dft::config_param::WORKSPACE,
579 oneapi::mkl::dft::config_value::WORKSPACE_EXTERNAL);
580 pp->commit(amrex::Gpu::Device::streamQueue());
586 auto fftw_kind = get_fftw_kind<D>(bc);
587 if constexpr (std::is_same_v<float,T>) {
588 plan = fftwf_plan_many_r2r
589 (1, &
n,
howmany, p,
nullptr, 1,
n, p,
nullptr, 1,
n, &fftw_kind,
592 plan = fftw_plan_many_r2r
593 (1, &
n,
howmany, p,
nullptr, 1,
n, p,
nullptr, 1,
n, &fftw_kind,
599 template <Direction D>
601 std::pair<Boundary,Boundary>
const& bc)
607#if defined(AMREX_USE_GPU)
609 init_r2r<D>(box, p, bc, 2);
614 kind = get_r2r_kind<D>(bc);
623 auto fftw_kind = get_fftw_kind<D>(bc);
624 if constexpr (std::is_same_v<float,T>) {
625 plan = fftwf_plan_many_r2r
626 (1, &
n,
howmany, p,
nullptr, 2,
n*2, p,
nullptr, 2,
n*2, &fftw_kind,
628 plan2 = fftwf_plan_many_r2r
629 (1, &
n,
howmany, p+1,
nullptr, 2,
n*2, p+1,
nullptr, 2,
n*2, &fftw_kind,
632 plan = fftw_plan_many_r2r
633 (1, &
n,
howmany, p,
nullptr, 2,
n*2, p,
nullptr, 2,
n*2, &fftw_kind,
635 plan2 = fftw_plan_many_r2r
636 (1, &
n,
howmany, p+1,
nullptr, 2,
n*2, p+1,
nullptr, 2,
n*2, &fftw_kind,
642 template <Direction D>
653#if defined(AMREX_USE_CUDA)
656 std::size_t work_size = 0;
663 if constexpr (std::is_same_v<float,T>) {
669 if constexpr (std::is_same_v<float,T>) {
677#elif defined(AMREX_USE_HIP)
678 detail::hip_execute(
plan, (
void**)&pi, (
void**)&po);
679#elif defined(AMREX_USE_SYCL)
680 detail::sycl_execute<T,D>(std::get<0>(
plan), pi, po);
683 if constexpr (std::is_same_v<float,T>) {
691 template <Direction D>
699#if defined(AMREX_USE_CUDA)
702 std::size_t work_size = 0;
709 if constexpr (std::is_same_v<float,T>) {
716#elif defined(AMREX_USE_HIP)
717 detail::hip_execute(
plan, (
void**)&p, (
void**)&p);
718#elif defined(AMREX_USE_SYCL)
719 detail::sycl_execute<T,D>(std::get<1>(
plan), p, p);
722 if constexpr (std::is_same_v<float,T>) {
740 amrex::Abort(
"FFT: alloc_scratch_space: unsupported kind");
749 auto*
pdst = (T*) pbuf;
752 int ostride = (
n+1)*2;
756 Long nelems = Long(nex)*
howmany;
760 auto batch = ielem / Long(nex);
761 auto i = int(ielem - batch*nex);
762 for (
int ir = 0; ir < 2; ++ir) {
763 auto* po =
pdst + (2*batch+ir)*ostride + i;
764 auto const* pi = psrc + 2*batch*istride + ir;
768 *po = sign * pi[(2*norig-1-i)*2];
775 auto batch = ielem / Long(nex);
776 auto i = int(ielem - batch*nex);
777 auto* po =
pdst + batch*ostride + i;
778 auto const* pi = psrc + batch*istride;
782 *po = sign * pi[2*norig-1-i];
787 int ostride = (2*
n+1)*2;
791 Long nelems = Long(nex)*
howmany;
795 auto batch = ielem / Long(nex);
796 auto i = int(ielem - batch*nex);
797 for (
int ir = 0; ir < 2; ++ir) {
798 auto* po =
pdst + (2*batch+ir)*ostride + i;
799 auto const* pi = psrc + 2*batch*istride + ir;
802 }
else if (i < (2*norig-1)) {
803 *po = pi[(2*norig-2-i)*2];
804 }
else if (i == (2*norig-1)) {
806 }
else if (i < (3*norig)) {
807 *po = -pi[(i-2*norig)*2];
808 }
else if (i < (4*norig-1)) {
809 *po = -pi[(4*norig-2-i)*2];
818 auto batch = ielem / Long(nex);
819 auto i = int(ielem - batch*nex);
820 auto* po =
pdst + batch*ostride + i;
821 auto const* pi = psrc + batch*istride;
824 }
else if (i < (2*norig-1)) {
825 *po = pi[2*norig-2-i];
826 }
else if (i == (2*norig-1)) {
828 }
else if (i < (3*norig)) {
829 *po = -pi[i-2*norig];
830 }
else if (i < (4*norig-1)) {
831 *po = -pi[4*norig-2-i];
838 int ostride = (2*
n+1)*2;
842 Long nelems = Long(nex)*
howmany;
846 auto batch = ielem / Long(nex);
847 auto i = int(ielem - batch*nex);
848 for (
int ir = 0; ir < 2; ++ir) {
849 auto* po =
pdst + (2*batch+ir)*ostride + i;
850 auto const* pi = psrc + 2*batch*istride + ir;
853 }
else if (i == norig) {
855 }
else if (i < (2*norig+1)) {
856 *po = -pi[(2*norig-i)*2];
857 }
else if (i < (3*norig)) {
858 *po = -pi[(i-2*norig)*2];
859 }
else if (i == 3*norig) {
862 *po = pi[(4*norig-i)*2];
869 auto batch = ielem / Long(nex);
870 auto i = int(ielem - batch*nex);
871 auto* po =
pdst + batch*ostride + i;
872 auto const* pi = psrc + batch*istride;
875 }
else if (i == norig) {
877 }
else if (i < (2*norig+1)) {
878 *po = -pi[2*norig-i];
879 }
else if (i < (3*norig)) {
880 *po = -pi[i-2*norig];
881 }
else if (i == 3*norig) {
889 int ostride = (2*
n+1)*2;
893 Long nelems = Long(nex)*
howmany;
897 auto batch = ielem / Long(nex);
898 auto i = int(ielem - batch*nex);
899 for (
int ir = 0; ir < 2; ++ir) {
900 auto* po =
pdst + (2*batch+ir)*ostride + i;
901 auto const* pi = psrc + 2*batch*istride + ir;
904 }
else if (i < (2*norig)) {
905 *po = -pi[(2*norig-1-i)*2];
906 }
else if (i < (3*norig)) {
907 *po = -pi[(i-2*norig)*2];
909 *po = pi[(4*norig-1-i)*2];
916 auto batch = ielem / Long(nex);
917 auto i = int(ielem - batch*nex);
918 auto* po =
pdst + batch*ostride + i;
919 auto const* pi = psrc + batch*istride;
922 }
else if (i < (2*norig)) {
923 *po = -pi[2*norig-1-i];
924 }
else if (i < (3*norig)) {
925 *po = -pi[i-2*norig];
927 *po = pi[4*norig-1-i];
932 int ostride = (2*
n+1)*2;
936 Long nelems = Long(nex)*
howmany;
940 auto batch = ielem / Long(nex);
941 auto i = int(ielem - batch*nex);
942 for (
int ir = 0; ir < 2; ++ir) {
943 auto* po =
pdst + (2*batch+ir)*ostride + i;
944 auto const* pi = psrc + 2*batch*istride + ir;
947 }
else if (i < (2*norig)) {
948 *po = pi[(2*norig-1-i)*2];
949 }
else if (i < (3*norig)) {
950 *po = -pi[(i-2*norig)*2];
952 *po = -pi[(4*norig-1-i)*2];
959 auto batch = ielem / Long(nex);
960 auto i = int(ielem - batch*nex);
961 auto* po =
pdst + batch*ostride + i;
962 auto const* pi = psrc + batch*istride;
965 }
else if (i < (2*norig)) {
966 *po = pi[2*norig-1-i];
967 }
else if (i < (3*norig)) {
968 *po = -pi[i-2*norig];
970 *po = -pi[4*norig-1-i];
983 Long nelems = Long(norig)*
howmany;
991 auto batch = ielem / Long(norig);
992 auto k = int(ielem - batch*norig);
994 for (
int ir = 0; ir < 2; ++ir) {
995 auto const& yk = psrc[(2*batch+ir)*istride+k+1];
996 pdst[2*batch*ostride+ir+k*2] = s * yk.real() - c * yk.imag();
1002 auto batch = ielem / Long(norig);
1003 auto k = int(ielem - batch*norig);
1005 auto const& yk = psrc[batch*istride+k+1];
1006 pdst[batch*ostride+k] = s * yk.real() - c * yk.imag();
1010 int istride = 2*
n+1;
1014 auto batch = ielem / Long(norig);
1015 auto k = int(ielem - batch*norig);
1017 for (
int ir = 0; ir < 2; ++ir) {
1018 auto const& yk = psrc[(2*batch+ir)*istride+2*k+1];
1019 pdst[2*batch*ostride+ir+k*2] = T(0.5)*(s * yk.real() - c * yk.imag());
1025 auto batch = ielem / Long(norig);
1026 auto k = int(ielem - batch*norig);
1028 auto const& yk = psrc[batch*istride+2*k+1];
1029 pdst[batch*ostride+k] = T(0.5)*(s * yk.real() - c * yk.imag());
1037 auto batch = ielem / Long(norig);
1038 auto k = int(ielem - batch*norig);
1040 for (
int ir = 0; ir < 2; ++ir) {
1041 auto const& yk = psrc[(2*batch+ir)*istride+k];
1042 pdst[2*batch*ostride+ir+k*2] = c * yk.real() + s * yk.imag();
1048 auto batch = ielem / Long(norig);
1049 auto k = int(ielem - batch*norig);
1051 auto const& yk = psrc[batch*istride+k];
1052 pdst[batch*ostride+k] = c * yk.real() + s * yk.imag();
1056 int istride = 2*
n+1;
1060 auto batch = ielem / Long(norig);
1061 auto k = int(ielem - batch*norig);
1062 for (
int ir = 0; ir < 2; ++ir) {
1063 auto const& yk = psrc[(2*batch+ir)*istride+2*k+1];
1064 pdst[2*batch*ostride+ir+k*2] = T(0.5) * yk.real();
1070 auto batch = ielem / Long(norig);
1071 auto k = int(ielem - batch*norig);
1072 auto const& yk = psrc[batch*istride+2*k+1];
1073 pdst[batch*ostride+k] = T(0.5) * yk.real();
1077 int istride = 2*
n+1;
1081 auto batch = ielem / Long(norig);
1082 auto k = int(ielem - batch*norig);
1084 for (
int ir = 0; ir < 2; ++ir) {
1085 auto const& yk = psrc[(2*batch+ir)*istride+2*k+1];
1086 pdst[2*batch*ostride+ir+k*2] = T(0.5) * (c * yk.real() + s * yk.imag());
1092 auto batch = ielem / Long(norig);
1093 auto k = int(ielem - batch*norig);
1095 auto const& yk = psrc[batch*istride+2*k+1];
1096 pdst[batch*ostride+k] = T(0.5) * (c * yk.real() + s * yk.imag());
1100 int istride = 2*
n+1;
1104 auto batch = ielem / Long(norig);
1105 auto k = int(ielem - batch*norig);
1107 for (
int ir = 0; ir < 2; ++ir) {
1108 auto const& yk = psrc[(2*batch+ir)*istride+2*k+1];
1109 pdst[2*batch*ostride+ir+k*2] = T(0.5) * (s * yk.real() - c * yk.imag());
1115 auto batch = ielem / Long(norig);
1116 auto k = int(ielem - batch*norig);
1118 auto const& yk = psrc[batch*istride+2*k+1];
1119 pdst[batch*ostride+k] = T(0.5) * (s * yk.real() - c * yk.imag());
1123 amrex::Abort(
"FFT: unpack_r2r_buffer: unsupported kind");
1128 template <Direction D>
1134#if defined(AMREX_USE_GPU)
1140#if defined(AMREX_USE_CUDA)
1144 std::size_t work_size = 0;
1150 if constexpr (std::is_same_v<float,T>) {
1156#elif defined(AMREX_USE_HIP)
1157 detail::hip_execute(
plan, (
void**)&pscratch, (
void**)&pscratch);
1158#elif defined(AMREX_USE_SYCL)
1159 detail::sycl_execute<T,Direction::forward>(std::get<0>(
plan), (T*)pscratch, (
VendorComplex*)pscratch);
1166#if defined(AMREX_USE_CUDA)
1172 if constexpr (std::is_same_v<float,T>) {
1173 fftwf_execute(
plan);
1185#if defined(AMREX_USE_CUDA)
1187#elif defined(AMREX_USE_HIP)
1188 AMREX_ROCFFT_SAFE_CALL(rocfft_plan_destroy(
plan));
1189#elif defined(AMREX_USE_SYCL)
1190 std::visit([](
auto&& p) {
delete p; },
plan);
1192 if constexpr (std::is_same_v<float,T>) {
1193 fftwf_destroy_plan(
plan);
1195 fftw_destroy_plan(
plan);
1211template <
typename T>
1212template <Direction D,
int M>
1223 for (
auto s : fft_size) { n *= s; }
1226#if defined(AMREX_USE_GPU)
1227 Key key = {fft_size.template expand<3>(), ncomp, D, kind};
1230 if constexpr (std::is_same_v<float,T>) {
1236 plan = *cached_plan;
1245 for (
int i = 0; i < M; ++i) {
1246 len[i] = fft_size[M-1-i];
1249 int nc = fft_size[0]/2+1;
1250 for (
int i = 1; i < M; ++i) {
1254#if defined(AMREX_USE_CUDA)
1261 type = std::is_same_v<float,T> ? CUFFT_R2C : CUFFT_D2Z;
1265 type = std::is_same_v<float,T> ? CUFFT_C2R : CUFFT_Z2D;
1269 std::size_t work_size;
1271 (cufftMakePlanMany(plan, M, len,
nullptr, 1, n_in,
nullptr, 1, n_out, type, howmany, &work_size));
1273#elif defined(AMREX_USE_HIP)
1275 auto prec = std::is_same_v<float,T> ? rocfft_precision_single : rocfft_precision_double;
1277 for (
int idim = 0; idim < M; ++idim) {
length[idim] = fft_size[idim]; }
1279 AMREX_ROCFFT_SAFE_CALL
1280 (rocfft_plan_create(&plan, rocfft_placement_notinplace,
1281 rocfft_transform_type_real_forward, prec, M,
1282 length, howmany,
nullptr));
1284 AMREX_ROCFFT_SAFE_CALL
1285 (rocfft_plan_create(&plan, rocfft_placement_notinplace,
1286 rocfft_transform_type_real_inverse, prec, M,
1287 length, howmany,
nullptr));
1290#elif defined(AMREX_USE_SYCL)
1294 pp =
new mkl_desc_r(fft_size[0]);
1296 std::vector<std::int64_t> len64(M);
1297 for (
int idim = 0; idim < M; ++idim) {
1298 len64[idim] = len[idim];
1300 pp =
new mkl_desc_r(len64);
1302#ifndef AMREX_USE_MKL_DFTI_2024
1303 pp->set_value(oneapi::mkl::dft::config_param::PLACEMENT,
1304 oneapi::mkl::dft::config_value::NOT_INPLACE);
1306 pp->set_value(oneapi::mkl::dft::config_param::PLACEMENT, DFTI_NOT_INPLACE);
1308 pp->set_value(oneapi::mkl::dft::config_param::NUMBER_OF_TRANSFORMS, howmany);
1309 pp->set_value(oneapi::mkl::dft::config_param::FWD_DISTANCE, n);
1310 pp->set_value(oneapi::mkl::dft::config_param::BWD_DISTANCE, nc);
1311 std::vector<std::int64_t> strides(M+1);
1314 for (
int i = M-1; i >= 1; --i) {
1315 strides[i] = strides[i+1] * fft_size[M-1-i];
1318#ifndef AMREX_USE_MKL_DFTI_2024
1319 pp->set_value(oneapi::mkl::dft::config_param::FWD_STRIDES, strides);
1322 pp->set_value(oneapi::mkl::dft::config_param::FWD_STRIDES, strides.data());
1325 pp->set_value(oneapi::mkl::dft::config_param::WORKSPACE,
1326 oneapi::mkl::dft::config_value::WORKSPACE_EXTERNAL);
1327 pp->commit(amrex::Gpu::Device::streamQueue());
1332 if (pf ==
nullptr || pb ==
nullptr) {
1337 if constexpr (std::is_same_v<float,T>) {
1339 plan = fftwf_plan_many_dft_r2c
1340 (M, len, howmany, (
float*)pf,
nullptr, 1, n, (fftwf_complex*)pb,
nullptr, 1, nc,
1343 plan = fftwf_plan_many_dft_c2r
1344 (M, len, howmany, (fftwf_complex*)pb,
nullptr, 1, nc, (
float*)pf,
nullptr, 1, n,
1349 plan = fftw_plan_many_dft_r2c
1350 (M, len, howmany, (
double*)pf,
nullptr, 1, n, (fftw_complex*)pb,
nullptr, 1, nc,
1353 plan = fftw_plan_many_dft_c2r
1354 (M, len, howmany, (fftw_complex*)pb,
nullptr, 1, nc, (
double*)pf,
nullptr, 1, n,
1360#if defined(AMREX_USE_GPU)
1362 if constexpr (std::is_same_v<float,T>) {
1375 template <
typename FA>
1376 typename FA::FABType::value_type *
get_fab (FA& fa)
1379 if (myproc < fa.size()) {
1380 return fa.fabPtr(myproc);
1386 template <
typename FA1,
typename FA2>
1389 bool not_same_fa =
true;
1390 if constexpr (std::is_same_v<FA1,FA2>) {
1391 not_same_fa = (&fa1 != &fa2);
1393 using FAB1 =
typename FA1::FABType::value_type;
1394 using FAB2 =
typename FA2::FABType::value_type;
1395 using T1 =
typename FAB1::value_type;
1396 using T2 =
typename FAB2::value_type;
1398 bool alloc_1 = (myproc < fa1.size());
1399 bool alloc_2 = (myproc < fa2.size()) && not_same_fa;
1401 if (alloc_1 && alloc_2) {
1402 Box const& box1 = fa1.fabbox(myproc);
1403 Box const& box2 = fa2.fabbox(myproc);
1404 int ncomp1 = fa1.nComp();
1405 int ncomp2 = fa2.nComp();
1407 sizeof(T2)*box2.
numPts()*ncomp2));
1408 fa1.setFab(myproc, FAB1(box1, ncomp1, (T1*)p));
1409 fa2.setFab(myproc, FAB2(box2, ncomp2, (T2*)p));
1410 }
else if (alloc_1) {
1411 Box const& box1 = fa1.fabbox(myproc);
1412 int ncomp1 = fa1.nComp();
1414 fa1.setFab(myproc, FAB1(box1, ncomp1, (T1*)p));
1415 }
else if (alloc_2) {
1416 Box const& box2 = fa2.fabbox(myproc);
1417 int ncomp2 = fa2.nComp();
1419 fa2.setFab(myproc, FAB2(box2, ncomp2, (T2*)p));
1431 return {i.
y, i.x, i.z};
1436 return {i.
y, i.
x, i.
z};
1454 return {i.
z, i.y, i.x};
1459 return {i.
z, i.
y, i.
x};
1478 return {i.
y, i.z, i.x};
1484 return {i.
z, i.
x, i.
y};
1503 return {i.
z, i.x, i.y};
1509 return {i.
y, i.
z, i.
x};
1545 template <
typename T>
1548#if (AMREX_SPACEDIM == 1)
1551#elif (AMREX_SPACEDIM == 2)
1553 return T{a[1],a[0]};
1559 return T{a[2],a[0],a[1]};
1561 return T{a[1],a[0],a[2]};
1563 return T{a[1],a[2],a[0]};
1565 return T{a[0],a[2],a[1]};
1574 template <
typename FA>
1577 BoxList bl = mf.boxArray().boxList();
1578 for (
auto&
b : bl) {
1581 auto const& ng =
make_iv(mf.nGrowVect());
1583 using FAB =
typename FA::fab_type;
1585 submf.setFab(mfi, FAB(mfi.fabbox(), mf.nComp(), mf[mfi].dataPtr()));
1590#if (AMREX_SPACEDIM == 2)
1593#elif (AMREX_SPACEDIM == 3)
#define AMREX_ENUM(CLASS,...)
Definition AMReX_Enum.H:206
#define AMREX_CUFFT_SAFE_CALL(call)
Definition AMReX_GpuError.H:92
#define AMREX_GPU_DEVICE
Definition AMReX_GpuQualifiers.H:18
amrex::ParmParse pp
Input file parser instance for the given namespace.
Definition AMReX_HypreIJIface.cpp:15
Real * pdst
Definition AMReX_HypreMLABecLap.cpp:1090
#define AMREX_D_TERM(a, b, c)
Definition AMReX_SPACE.H:172
virtual void free(void *pt)=0
A pure virtual function for deleting the arena pointed to by pt.
virtual void * alloc(std::size_t sz)=0
A collection of Boxes stored in an Array.
Definition AMReX_BoxArray.H:551
A class for managing a List of Boxes that share a common IndexType. This class implements operations ...
Definition AMReX_BoxList.H:52
__host__ __device__ Long numPts() const noexcept
Returns the number of points contained in the BoxND.
Definition AMReX_Box.H:349
__host__ __device__ IntVectND< dim > length() const noexcept
Return the length of the BoxND.
Definition AMReX_Box.H:149
Calculates the distribution of FABs to MPI processes.
Definition AMReX_DistributionMapping.H:41
Definition AMReX_IntVect.H:55
Definition AMReX_MFIter.H:57
bool isValid() const noexcept
Is the iterator valid i.e. is it associated with a FAB?
Definition AMReX_MFIter.H:141
This provides length of period for periodic domains. 0 means it is not periodic in that direction....
Definition AMReX_Periodicity.H:17
std::unique_ptr< char, DataDeleter > make_mfs_share(FA1 &fa1, FA2 &fa2)
Definition AMReX_FFT_Helper.H:1387
FA::FABType::value_type * get_fab(FA &fa)
Definition AMReX_FFT_Helper.H:1376
DistributionMapping make_iota_distromap(Long n)
Definition AMReX_FFT.cpp:88
Definition AMReX_FFT.cpp:7
Direction
Definition AMReX_FFT_Helper.H:48
Boundary
Definition AMReX_FFT_Helper.H:52
void add_vendor_plan_f(Key const &key, PlanF plan)
Definition AMReX_FFT.cpp:78
DomainStrategy
Definition AMReX_FFT_Helper.H:50
typename Plan< float >::VendorPlan PlanF
Definition AMReX_FFT_Helper.H:1203
typename Plan< double >::VendorPlan PlanD
Definition AMReX_FFT_Helper.H:1202
void add_vendor_plan_d(Key const &key, PlanD plan)
Definition AMReX_FFT.cpp:73
Kind
Definition AMReX_FFT_Helper.H:54
PlanF * get_vendor_plan_f(Key const &key)
Definition AMReX_FFT.cpp:64
PlanD * get_vendor_plan_d(Key const &key)
Definition AMReX_FFT.cpp:55
std::tuple< IntVectND< 3 >, int, Direction, Kind > Key
Definition AMReX_FFT_Helper.H:1201
void streamSynchronize() noexcept
Definition AMReX_GpuDevice.H:260
gpuStream_t gpuStream() noexcept
Definition AMReX_GpuDevice.H:241
__host__ __device__ std::pair< double, double > sincospi(double x)
Return sin(pi*x) and cos(pi*x) given x.
Definition AMReX_Math.H:198
int MyProcSub() noexcept
my sub-rank in current frame
Definition AMReX_ParallelContext.H:76
void ParallelForOMP(T n, L const &f) noexcept
Definition AMReX_GpuLaunch.H:249
__host__ __device__ void ignore_unused(const Ts &...)
This shuts up the compiler about unused variables.
Definition AMReX.H:138
__host__ __device__ Dim3 length(Array4< T > const &a) noexcept
Definition AMReX_Array4.H:326
void Abort(const std::string &msg)
Print out message to cerr and exit via abort().
Definition AMReX.cpp:230
Arena * The_Arena()
Definition AMReX_Arena.cpp:705
Definition AMReX_FabArrayCommI.H:1000
Definition AMReX_DataAllocator.H:29
Definition AMReX_Dim3.H:12
int x
Definition AMReX_Dim3.H:12
int z
Definition AMReX_Dim3.H:12
int y
Definition AMReX_Dim3.H:12
Definition AMReX_FFT_Helper.H:58
bool twod_mode
Definition AMReX_FFT_Helper.H:69
Info & setNumProcs(int n)
Definition AMReX_FFT_Helper.H:85
bool oned_mode
We might have a special twod_mode: nx or ny == 1 && nz > 1.
Definition AMReX_FFT_Helper.H:72
int batch_size
Batched FFT size. Only support in R2C, not R2X.
Definition AMReX_FFT_Helper.H:75
Info & setDomainStrategy(DomainStrategy s)
Definition AMReX_FFT_Helper.H:80
DomainStrategy domain_strategy
Domain composition strategy.
Definition AMReX_FFT_Helper.H:60
int nprocs
Max number of processes to use.
Definition AMReX_FFT_Helper.H:78
int pencil_threshold
Definition AMReX_FFT_Helper.H:64
Info & setOneDMode(bool x)
Definition AMReX_FFT_Helper.H:83
Info & setBatchSize(int bsize)
Definition AMReX_FFT_Helper.H:84
Info & setPencilThreshold(int t)
Definition AMReX_FFT_Helper.H:81
Info & setTwoDMode(bool x)
Definition AMReX_FFT_Helper.H:82
Definition AMReX_FFT_Helper.H:130
void * pf
Definition AMReX_FFT_Helper.H:165
void unpack_r2r_buffer(T *pdst, void const *pbuf) const
Definition AMReX_FFT_Helper.H:979
std::conditional_t< std::is_same_v< float, T >, cuComplex, cuDoubleComplex > VendorComplex
Definition AMReX_FFT_Helper.H:134
VendorPlan plan2
Definition AMReX_FFT_Helper.H:164
int n
Definition AMReX_FFT_Helper.H:157
void destroy()
Definition AMReX_FFT_Helper.H:175
bool defined2
Definition AMReX_FFT_Helper.H:162
void init_r2r(Box const &box, VendorComplex *pc, std::pair< Boundary, Boundary > const &bc)
Definition AMReX_FFT_Helper.H:600
void pack_r2r_buffer(void *pbuf, T const *psrc) const
Definition AMReX_FFT_Helper.H:747
static void free_scratch_space(void *p)
Definition AMReX_FFT_Helper.H:745
static void destroy_vendor_plan(VendorPlan plan)
Definition AMReX_FFT_Helper.H:1183
Kind get_r2r_kind(std::pair< Boundary, Boundary > const &bc)
Definition AMReX_FFT_Helper.H:478
cufftHandle VendorPlan
Definition AMReX_FFT_Helper.H:132
Kind kind
Definition AMReX_FFT_Helper.H:159
void init_c2c(Box const &box, VendorComplex *p, int ncomp=1, int ndims=1)
Definition AMReX_FFT_Helper.H:317
int howmany
Definition AMReX_FFT_Helper.H:158
void * pb
Definition AMReX_FFT_Helper.H:166
void init_r2c(Box const &box, T *pr, VendorComplex *pc, bool is_2d_transform=false, int ncomp=1)
Definition AMReX_FFT_Helper.H:190
void compute_r2r()
Definition AMReX_FFT_Helper.H:1129
void compute_c2c()
Definition AMReX_FFT_Helper.H:692
bool r2r_data_is_complex
Definition AMReX_FFT_Helper.H:160
void * alloc_scratch_space() const
Definition AMReX_FFT_Helper.H:731
VendorPlan plan
Definition AMReX_FFT_Helper.H:163
void compute_r2c()
Definition AMReX_FFT_Helper.H:643
bool defined
Definition AMReX_FFT_Helper.H:161
void set_ptrs(void *p0, void *p1)
Definition AMReX_FFT_Helper.H:169
void init_r2r(Box const &box, T *p, std::pair< Boundary, Boundary > const &bc, int howmany_initval=1)
Definition AMReX_FFT_Helper.H:504
void init_r2c(IntVectND< M > const &fft_size, void *, void *, bool cache, int ncomp=1)
Definition AMReX_FFT_Helper.H:1213
Definition AMReX_FFT_Helper.H:1499
constexpr Dim3 operator()(Dim3 i) const noexcept
Definition AMReX_FFT_Helper.H:1501
static constexpr IndexType Inverse(IndexType it)
Definition AMReX_FFT_Helper.H:1517
static constexpr Dim3 Inverse(Dim3 i)
Definition AMReX_FFT_Helper.H:1507
Definition AMReX_FFT_Helper.H:1474
static constexpr IndexType Inverse(IndexType it)
Definition AMReX_FFT_Helper.H:1492
constexpr Dim3 operator()(Dim3 i) const noexcept
Definition AMReX_FFT_Helper.H:1476
static constexpr Dim3 Inverse(Dim3 i)
Definition AMReX_FFT_Helper.H:1482
Definition AMReX_FFT_Helper.H:1428
constexpr Dim3 operator()(Dim3 i) const noexcept
Definition AMReX_FFT_Helper.H:1429
static constexpr IndexType Inverse(IndexType it)
Definition AMReX_FFT_Helper.H:1444
static constexpr Dim3 Inverse(Dim3 i)
Definition AMReX_FFT_Helper.H:1434
Definition AMReX_FFT_Helper.H:1451
static constexpr Dim3 Inverse(Dim3 i)
Definition AMReX_FFT_Helper.H:1457
static constexpr IndexType Inverse(IndexType it)
Definition AMReX_FFT_Helper.H:1467
constexpr Dim3 operator()(Dim3 i) const noexcept
Definition AMReX_FFT_Helper.H:1452
Definition AMReX_FFT_Helper.H:1526
T make_array(T const &a) const
Definition AMReX_FFT_Helper.H:1546
Box make_box(Box const &box) const
Definition AMReX_FFT.cpp:142
BoxArray inverse_boxarray(BoxArray const &ba) const
Definition AMReX_FFT.cpp:209
bool ghost_safe(IntVect const &ng) const
Definition AMReX_FFT.cpp:152
GpuArray< int, 3 > xyz_order() const
Definition AMReX_FFT.cpp:326
IntVect inverse_order(IntVect const &order) const
Definition AMReX_FFT.cpp:266
int m_case
Definition AMReX_FFT_Helper.H:1595
IntVect make_iv(IntVect const &iv) const
Definition AMReX_FFT.cpp:178
FA make_alias_mf(FA const &mf)
Definition AMReX_FFT_Helper.H:1575
Periodicity make_periodicity(Periodicity const &period) const
Definition AMReX_FFT.cpp:147
IntVect make_safe_ghost(IntVect const &ng) const
Definition AMReX_FFT.cpp:183
Case
Definition AMReX_FFT_Helper.H:1594
@ case_n1n
Definition AMReX_FFT_Helper.H:1594
@ case_11n
Definition AMReX_FFT_Helper.H:1594
@ case_1nn
Definition AMReX_FFT_Helper.H:1594
@ case_1n1
Definition AMReX_FFT_Helper.H:1594
@ case_other
Definition AMReX_FFT_Helper.H:1594
Definition AMReX_Array.H:34
A host / device complex number type, because std::complex doesn't work in device code with Cuda yet.
Definition AMReX_GpuComplex.H:29
FabArray memory allocation information.
Definition AMReX_FabArray.H:66
MFInfo & SetAlloc(bool a) noexcept
Definition AMReX_FabArray.H:73
Definition AMReX_MFIter.H:20