1#ifndef AMREX_FFT_HELPER_H_
2#define AMREX_FFT_HELPER_H_
3#include <AMReX_Config.H>
16#if defined(AMREX_USE_CUDA)
18# include <cuComplex.h>
19#elif defined(AMREX_USE_HIP)
20# if __has_include(<rocfft/rocfft.h>)
21# include <rocfft/rocfft.h>
25# include <hip/hip_complex.h>
26#elif defined(AMREX_USE_SYCL)
27# if __has_include(<oneapi/mkl/dft.hpp>)
28# include <oneapi/mkl/dft.hpp>
30# define AMREX_USE_MKL_DFTI_2024 1
31# include <oneapi/mkl/dfti.hpp>
78 int nprocs = std::numeric_limits<int>::max();
90namespace detail {
void hip_execute (rocfft_plan plan,
void **in,
void **out); }
98template <
typename T, Direction direction,
typename P,
typename TI,
typename TO>
99void sycl_execute (P* plan, TI* in, TO* out)
101#ifndef AMREX_USE_MKL_DFTI_2024
102 std::int64_t workspaceSize = 0;
104 std::size_t workspaceSize = 0;
106 plan->get_value(oneapi::mkl::dft::config_param::WORKSPACE_BYTES,
109 plan->set_workspace(buffer);
111 if (std::is_same_v<TI,TO>) {
113 if constexpr (direction == Direction::forward) {
114 r = oneapi::mkl::dft::compute_forward(*plan, out);
116 r = oneapi::mkl::dft::compute_backward(*plan, out);
119 if constexpr (direction == Direction::forward) {
120 r = oneapi::mkl::dft::compute_forward(*plan, in, out);
122 r = oneapi::mkl::dft::compute_backward(*plan, in, out);
135#if defined(AMREX_USE_CUDA)
138 cuComplex, cuDoubleComplex>;
139#elif defined(AMREX_USE_HIP)
141 using VendorComplex = std::conditional_t<std::is_same_v<float,T>,
143#elif defined(AMREX_USE_SYCL)
144 using mkl_desc_r = oneapi::mkl::dft::descriptor<std::is_same_v<float,T>
145 ? oneapi::mkl::dft::precision::SINGLE
146 : oneapi::mkl::dft::precision::DOUBLE,
147 oneapi::mkl::dft::domain::REAL>;
148 using mkl_desc_c = oneapi::mkl::dft::descriptor<std::is_same_v<float,T>
149 ? oneapi::mkl::dft::precision::SINGLE
150 : oneapi::mkl::dft::precision::DOUBLE,
151 oneapi::mkl::dft::domain::COMPLEX>;
152 using VendorPlan = std::variant<mkl_desc_r*,mkl_desc_c*>;
155 using VendorPlan = std::conditional_t<std::is_same_v<float,T>,
156 fftwf_plan, fftw_plan>;
157 using VendorComplex = std::conditional_t<std::is_same_v<float,T>,
158 fftwf_complex, fftw_complex>;
185#if !defined(AMREX_USE_GPU)
193 template <Direction D>
198 int rank = is_2d_transform ? 2 : 1;
213 int nr = (rank == 1) ? len[0] : len[0]*len[1];
215 int nc = (rank == 1) ? (len[0]/2+1) : (len[1]/2+1)*len[0];
216#if (AMREX_SPACEDIM == 1)
226#if defined(AMREX_USE_CUDA)
230 std::size_t work_size;
232 cufftType fwd_type = std::is_same_v<float,T> ? CUFFT_R2C : CUFFT_D2Z;
234 (cufftMakePlanMany(
plan, rank, len,
nullptr, 1, nr,
nullptr, 1, nc, fwd_type,
howmany, &work_size));
236 cufftType bwd_type = std::is_same_v<float,T> ? CUFFT_C2R : CUFFT_Z2D;
238 (cufftMakePlanMany(
plan, rank, len,
nullptr, 1, nc,
nullptr, 1, nr, bwd_type,
howmany, &work_size));
241#elif defined(AMREX_USE_HIP)
243 auto prec = std::is_same_v<float,T> ? rocfft_precision_single : rocfft_precision_double;
245 std::size_t
length[2] = {std::size_t(len[1]), std::size_t(len[0])};
247 AMREX_ROCFFT_SAFE_CALL
248 (rocfft_plan_create(&
plan, rocfft_placement_notinplace,
249 rocfft_transform_type_real_forward, prec, rank,
252 AMREX_ROCFFT_SAFE_CALL
253 (rocfft_plan_create(&
plan, rocfft_placement_notinplace,
254 rocfft_transform_type_real_inverse, prec, rank,
258#elif defined(AMREX_USE_SYCL)
262 pp =
new mkl_desc_r(len[0]);
264 pp =
new mkl_desc_r({std::int64_t(len[0]), std::int64_t(len[1])});
266#ifndef AMREX_USE_MKL_DFTI_2024
267 pp->set_value(oneapi::mkl::dft::config_param::PLACEMENT,
268 oneapi::mkl::dft::config_value::NOT_INPLACE);
270 pp->set_value(oneapi::mkl::dft::config_param::PLACEMENT, DFTI_NOT_INPLACE);
272 pp->set_value(oneapi::mkl::dft::config_param::NUMBER_OF_TRANSFORMS,
howmany);
273 pp->set_value(oneapi::mkl::dft::config_param::FWD_DISTANCE, nr);
274 pp->set_value(oneapi::mkl::dft::config_param::BWD_DISTANCE, nc);
275 std::vector<std::int64_t> strides;
276 strides.push_back(0);
277 if (rank == 2) { strides.push_back(len[1]); }
278 strides.push_back(1);
279#ifndef AMREX_USE_MKL_DFTI_2024
280 pp->set_value(oneapi::mkl::dft::config_param::FWD_STRIDES, strides);
283 pp->set_value(oneapi::mkl::dft::config_param::FWD_STRIDES, strides.data());
286 pp->set_value(oneapi::mkl::dft::config_param::WORKSPACE,
287 oneapi::mkl::dft::config_value::WORKSPACE_EXTERNAL);
288 pp->commit(amrex::Gpu::Device::streamQueue());
293 if constexpr (std::is_same_v<float,T>) {
295 plan = fftwf_plan_many_dft_r2c
296 (rank, len,
howmany, pr,
nullptr, 1, nr, pc,
nullptr, 1, nc,
297 FFTW_ESTIMATE | FFTW_DESTROY_INPUT);
299 plan = fftwf_plan_many_dft_c2r
300 (rank, len,
howmany, pc,
nullptr, 1, nc, pr,
nullptr, 1, nr,
301 FFTW_ESTIMATE | FFTW_DESTROY_INPUT);
305 plan = fftw_plan_many_dft_r2c
306 (rank, len,
howmany, pr,
nullptr, 1, nr, pc,
nullptr, 1, nc,
307 FFTW_ESTIMATE | FFTW_DESTROY_INPUT);
309 plan = fftw_plan_many_dft_c2r
310 (rank, len,
howmany, pc,
nullptr, 1, nc, pr,
nullptr, 1, nr,
311 FFTW_ESTIMATE | FFTW_DESTROY_INPUT);
317 template <Direction D,
int M>
320 template <Direction D>
338#if (AMREX_SPACEDIM >= 2)
339 else if (ndims == 2) {
341#if (AMREX_SPACEDIM == 2)
349#if (AMREX_SPACEDIM == 3)
350 else if (ndims == 3) {
360#if defined(AMREX_USE_CUDA)
364 cufftType t = std::is_same_v<float,T> ? CUFFT_C2C : CUFFT_Z2Z;
365 std::size_t work_size;
367 (cufftMakePlanMany(
plan, ndims, len,
nullptr, 1,
n,
nullptr, 1,
n, t,
howmany, &work_size));
369#elif defined(AMREX_USE_HIP)
371 auto prec = std::is_same_v<float,T> ? rocfft_precision_single
372 : rocfft_precision_double;
374 : rocfft_transform_type_complex_inverse;
378 }
else if (ndims == 2) {
386 AMREX_ROCFFT_SAFE_CALL
387 (rocfft_plan_create(&
plan, rocfft_placement_inplace, dir, prec, ndims,
390#elif defined(AMREX_USE_SYCL)
394 pp =
new mkl_desc_c(
n);
395 }
else if (ndims == 2) {
396 pp =
new mkl_desc_c({std::int64_t(len[0]), std::int64_t(len[1])});
398 pp =
new mkl_desc_c({std::int64_t(len[0]), std::int64_t(len[1]), std::int64_t(len[2])});
400#ifndef AMREX_USE_MKL_DFTI_2024
401 pp->set_value(oneapi::mkl::dft::config_param::PLACEMENT,
402 oneapi::mkl::dft::config_value::INPLACE);
404 pp->set_value(oneapi::mkl::dft::config_param::PLACEMENT, DFTI_INPLACE);
406 pp->set_value(oneapi::mkl::dft::config_param::NUMBER_OF_TRANSFORMS,
howmany);
407 pp->set_value(oneapi::mkl::dft::config_param::FWD_DISTANCE,
n);
408 pp->set_value(oneapi::mkl::dft::config_param::BWD_DISTANCE,
n);
409 std::vector<std::int64_t> strides(ndims+1);
412 for (
int i = ndims-1; i >= 1; --i) {
413 strides[i] = strides[i+1] * len[i];
415#ifndef AMREX_USE_MKL_DFTI_2024
416 pp->set_value(oneapi::mkl::dft::config_param::FWD_STRIDES, strides);
417 pp->set_value(oneapi::mkl::dft::config_param::BWD_STRIDES, strides);
419 pp->set_value(oneapi::mkl::dft::config_param::FWD_STRIDES, strides.data());
420 pp->set_value(oneapi::mkl::dft::config_param::BWD_STRIDES, strides.data());
422 pp->set_value(oneapi::mkl::dft::config_param::WORKSPACE,
423 oneapi::mkl::dft::config_value::WORKSPACE_EXTERNAL);
424 pp->commit(amrex::Gpu::Device::streamQueue());
429 if constexpr (std::is_same_v<float,T>) {
431 plan = fftwf_plan_many_dft
432 (ndims, len,
howmany, p,
nullptr, 1,
n, p,
nullptr, 1,
n, -1,
435 plan = fftwf_plan_many_dft
436 (ndims, len,
howmany, p,
nullptr, 1,
n, p,
nullptr, 1,
n, +1,
441 plan = fftw_plan_many_dft
442 (ndims, len,
howmany, p,
nullptr, 1,
n, p,
nullptr, 1,
n, -1,
445 plan = fftw_plan_many_dft
446 (ndims, len,
howmany, p,
nullptr, 1,
n, p,
nullptr, 1,
n, +1,
454 template <Direction D>
455 fftw_r2r_kind get_fftw_kind (std::pair<Boundary,Boundary>
const& bc)
475 return fftw_r2r_kind{};
481 template <Direction D>
507 template <Direction D>
508 void init_r2r (
Box const& box, T* p, std::pair<Boundary,Boundary>
const& bc,
509 int howmany_initval = 1)
513 kind = get_r2r_kind<D>(bc);
521#if defined(AMREX_USE_GPU)
541 int nc = (nex/2) + 1;
543#if defined (AMREX_USE_CUDA)
547 cufftType fwd_type = std::is_same_v<float,T> ? CUFFT_R2C : CUFFT_D2Z;
548 std::size_t work_size;
550 (cufftMakePlanMany(
plan, 1, &nex,
nullptr, 1, nc*2,
nullptr, 1, nc, fwd_type,
howmany, &work_size));
552#elif defined(AMREX_USE_HIP)
555 auto prec = std::is_same_v<float,T> ? rocfft_precision_single : rocfft_precision_double;
556 const std::size_t
length = nex;
557 AMREX_ROCFFT_SAFE_CALL
558 (rocfft_plan_create(&
plan, rocfft_placement_inplace,
559 rocfft_transform_type_real_forward, prec, 1,
562#elif defined(AMREX_USE_SYCL)
564 auto*
pp =
new mkl_desc_r(nex);
565#ifndef AMREX_USE_MKL_DFTI_2024
566 pp->set_value(oneapi::mkl::dft::config_param::PLACEMENT,
567 oneapi::mkl::dft::config_value::INPLACE);
569 pp->set_value(oneapi::mkl::dft::config_param::PLACEMENT, DFTI_INPLACE);
571 pp->set_value(oneapi::mkl::dft::config_param::NUMBER_OF_TRANSFORMS,
howmany);
572 pp->set_value(oneapi::mkl::dft::config_param::FWD_DISTANCE, nc*2);
573 pp->set_value(oneapi::mkl::dft::config_param::BWD_DISTANCE, nc);
574 std::vector<std::int64_t> strides = {0,1};
575#ifndef AMREX_USE_MKL_DFTI_2024
576 pp->set_value(oneapi::mkl::dft::config_param::FWD_STRIDES, strides);
577 pp->set_value(oneapi::mkl::dft::config_param::BWD_STRIDES, strides);
579 pp->set_value(oneapi::mkl::dft::config_param::FWD_STRIDES, strides.data());
580 pp->set_value(oneapi::mkl::dft::config_param::BWD_STRIDES, strides.data());
582 pp->set_value(oneapi::mkl::dft::config_param::WORKSPACE,
583 oneapi::mkl::dft::config_value::WORKSPACE_EXTERNAL);
584 pp->commit(amrex::Gpu::Device::streamQueue());
590 auto fftw_kind = get_fftw_kind<D>(bc);
591 if constexpr (std::is_same_v<float,T>) {
592 plan = fftwf_plan_many_r2r
593 (1, &
n,
howmany, p,
nullptr, 1,
n, p,
nullptr, 1,
n, &fftw_kind,
596 plan = fftw_plan_many_r2r
597 (1, &
n,
howmany, p,
nullptr, 1,
n, p,
nullptr, 1,
n, &fftw_kind,
603 template <Direction D>
605 std::pair<Boundary,Boundary>
const& bc)
611#if defined(AMREX_USE_GPU)
613 init_r2r<D>(box, p, bc, 2);
618 kind = get_r2r_kind<D>(bc);
627 auto fftw_kind = get_fftw_kind<D>(bc);
628 if constexpr (std::is_same_v<float,T>) {
629 plan = fftwf_plan_many_r2r
630 (1, &
n,
howmany, p,
nullptr, 2,
n*2, p,
nullptr, 2,
n*2, &fftw_kind,
632 plan2 = fftwf_plan_many_r2r
633 (1, &
n,
howmany, p+1,
nullptr, 2,
n*2, p+1,
nullptr, 2,
n*2, &fftw_kind,
636 plan = fftw_plan_many_r2r
637 (1, &
n,
howmany, p,
nullptr, 2,
n*2, p,
nullptr, 2,
n*2, &fftw_kind,
639 plan2 = fftw_plan_many_r2r
640 (1, &
n,
howmany, p+1,
nullptr, 2,
n*2, p+1,
nullptr, 2,
n*2, &fftw_kind,
646 template <Direction D>
657#if defined(AMREX_USE_CUDA)
660 std::size_t work_size = 0;
667 if constexpr (std::is_same_v<float,T>) {
673 if constexpr (std::is_same_v<float,T>) {
681#elif defined(AMREX_USE_HIP)
682 detail::hip_execute(
plan, (
void**)&pi, (
void**)&po);
683#elif defined(AMREX_USE_SYCL)
684 detail::sycl_execute<T,D>(std::get<0>(
plan), pi, po);
687 if constexpr (std::is_same_v<float,T>) {
695 template <Direction D>
703#if defined(AMREX_USE_CUDA)
706 std::size_t work_size = 0;
713 if constexpr (std::is_same_v<float,T>) {
720#elif defined(AMREX_USE_HIP)
721 detail::hip_execute(
plan, (
void**)&p, (
void**)&p);
722#elif defined(AMREX_USE_SYCL)
723 detail::sycl_execute<T,D>(std::get<1>(
plan), p, p);
726 if constexpr (std::is_same_v<float,T>) {
744 amrex::Abort(
"FFT: alloc_scratch_space: unsupported kind");
753 auto*
pdst = (T*) pbuf;
756 int ostride = (
n+1)*2;
764 auto batch = ielem /
Long(nex);
765 auto i = int(ielem - batch*nex);
766 for (
int ir = 0; ir < 2; ++ir) {
767 auto* po =
pdst + (2*batch+ir)*ostride + i;
768 auto const* pi = psrc + 2*batch*istride + ir;
772 *po = sign * pi[(2*norig-1-i)*2];
779 auto batch = ielem /
Long(nex);
780 auto i = int(ielem - batch*nex);
781 auto* po =
pdst + batch*ostride + i;
782 auto const* pi = psrc + batch*istride;
786 *po = sign * pi[2*norig-1-i];
791 int ostride = (2*
n+1)*2;
799 auto batch = ielem /
Long(nex);
800 auto i = int(ielem - batch*nex);
801 for (
int ir = 0; ir < 2; ++ir) {
802 auto* po =
pdst + (2*batch+ir)*ostride + i;
803 auto const* pi = psrc + 2*batch*istride + ir;
806 }
else if (i < (2*norig-1)) {
807 *po = pi[(2*norig-2-i)*2];
808 }
else if (i == (2*norig-1)) {
810 }
else if (i < (3*norig)) {
811 *po = -pi[(i-2*norig)*2];
812 }
else if (i < (4*norig-1)) {
813 *po = -pi[(4*norig-2-i)*2];
822 auto batch = ielem /
Long(nex);
823 auto i = int(ielem - batch*nex);
824 auto* po =
pdst + batch*ostride + i;
825 auto const* pi = psrc + batch*istride;
828 }
else if (i < (2*norig-1)) {
829 *po = pi[2*norig-2-i];
830 }
else if (i == (2*norig-1)) {
832 }
else if (i < (3*norig)) {
833 *po = -pi[i-2*norig];
834 }
else if (i < (4*norig-1)) {
835 *po = -pi[4*norig-2-i];
842 int ostride = (2*
n+1)*2;
850 auto batch = ielem /
Long(nex);
851 auto i = int(ielem - batch*nex);
852 for (
int ir = 0; ir < 2; ++ir) {
853 auto* po =
pdst + (2*batch+ir)*ostride + i;
854 auto const* pi = psrc + 2*batch*istride + ir;
857 }
else if (i == norig) {
859 }
else if (i < (2*norig+1)) {
860 *po = -pi[(2*norig-i)*2];
861 }
else if (i < (3*norig)) {
862 *po = -pi[(i-2*norig)*2];
863 }
else if (i == 3*norig) {
866 *po = pi[(4*norig-i)*2];
873 auto batch = ielem /
Long(nex);
874 auto i = int(ielem - batch*nex);
875 auto* po =
pdst + batch*ostride + i;
876 auto const* pi = psrc + batch*istride;
879 }
else if (i == norig) {
881 }
else if (i < (2*norig+1)) {
882 *po = -pi[2*norig-i];
883 }
else if (i < (3*norig)) {
884 *po = -pi[i-2*norig];
885 }
else if (i == 3*norig) {
893 int ostride = (2*
n+1)*2;
901 auto batch = ielem /
Long(nex);
902 auto i = int(ielem - batch*nex);
903 for (
int ir = 0; ir < 2; ++ir) {
904 auto* po =
pdst + (2*batch+ir)*ostride + i;
905 auto const* pi = psrc + 2*batch*istride + ir;
908 }
else if (i < (2*norig)) {
909 *po = -pi[(2*norig-1-i)*2];
910 }
else if (i < (3*norig)) {
911 *po = -pi[(i-2*norig)*2];
913 *po = pi[(4*norig-1-i)*2];
920 auto batch = ielem /
Long(nex);
921 auto i = int(ielem - batch*nex);
922 auto* po =
pdst + batch*ostride + i;
923 auto const* pi = psrc + batch*istride;
926 }
else if (i < (2*norig)) {
927 *po = -pi[2*norig-1-i];
928 }
else if (i < (3*norig)) {
929 *po = -pi[i-2*norig];
931 *po = pi[4*norig-1-i];
936 int ostride = (2*
n+1)*2;
944 auto batch = ielem /
Long(nex);
945 auto i = int(ielem - batch*nex);
946 for (
int ir = 0; ir < 2; ++ir) {
947 auto* po =
pdst + (2*batch+ir)*ostride + i;
948 auto const* pi = psrc + 2*batch*istride + ir;
951 }
else if (i < (2*norig)) {
952 *po = pi[(2*norig-1-i)*2];
953 }
else if (i < (3*norig)) {
954 *po = -pi[(i-2*norig)*2];
956 *po = -pi[(4*norig-1-i)*2];
963 auto batch = ielem /
Long(nex);
964 auto i = int(ielem - batch*nex);
965 auto* po =
pdst + batch*ostride + i;
966 auto const* pi = psrc + batch*istride;
969 }
else if (i < (2*norig)) {
970 *po = pi[2*norig-1-i];
971 }
else if (i < (3*norig)) {
972 *po = -pi[i-2*norig];
974 *po = -pi[4*norig-1-i];
995 auto batch = ielem /
Long(norig);
996 auto k = int(ielem - batch*norig);
998 for (
int ir = 0; ir < 2; ++ir) {
999 auto const& yk = psrc[(2*batch+ir)*istride+k+1];
1000 pdst[2*batch*ostride+ir+k*2] = s * yk.real() - c * yk.imag();
1006 auto batch = ielem /
Long(norig);
1007 auto k = int(ielem - batch*norig);
1009 auto const& yk = psrc[batch*istride+k+1];
1010 pdst[batch*ostride+k] = s * yk.real() - c * yk.imag();
1014 int istride = 2*
n+1;
1018 auto batch = ielem /
Long(norig);
1019 auto k = int(ielem - batch*norig);
1021 for (
int ir = 0; ir < 2; ++ir) {
1022 auto const& yk = psrc[(2*batch+ir)*istride+2*k+1];
1023 pdst[2*batch*ostride+ir+k*2] = T(0.5)*(s * yk.real() - c * yk.imag());
1029 auto batch = ielem /
Long(norig);
1030 auto k = int(ielem - batch*norig);
1032 auto const& yk = psrc[batch*istride+2*k+1];
1033 pdst[batch*ostride+k] = T(0.5)*(s * yk.real() - c * yk.imag());
1041 auto batch = ielem /
Long(norig);
1042 auto k = int(ielem - batch*norig);
1044 for (
int ir = 0; ir < 2; ++ir) {
1045 auto const& yk = psrc[(2*batch+ir)*istride+k];
1046 pdst[2*batch*ostride+ir+k*2] = c * yk.real() + s * yk.imag();
1052 auto batch = ielem /
Long(norig);
1053 auto k = int(ielem - batch*norig);
1055 auto const& yk = psrc[batch*istride+k];
1056 pdst[batch*ostride+k] = c * yk.real() + s * yk.imag();
1060 int istride = 2*
n+1;
1064 auto batch = ielem /
Long(norig);
1065 auto k = int(ielem - batch*norig);
1066 for (
int ir = 0; ir < 2; ++ir) {
1067 auto const& yk = psrc[(2*batch+ir)*istride+2*k+1];
1068 pdst[2*batch*ostride+ir+k*2] = T(0.5) * yk.real();
1074 auto batch = ielem /
Long(norig);
1075 auto k = int(ielem - batch*norig);
1076 auto const& yk = psrc[batch*istride+2*k+1];
1077 pdst[batch*ostride+k] = T(0.5) * yk.real();
1081 int istride = 2*
n+1;
1085 auto batch = ielem /
Long(norig);
1086 auto k = int(ielem - batch*norig);
1088 for (
int ir = 0; ir < 2; ++ir) {
1089 auto const& yk = psrc[(2*batch+ir)*istride+2*k+1];
1090 pdst[2*batch*ostride+ir+k*2] = T(0.5) * (c * yk.real() + s * yk.imag());
1096 auto batch = ielem /
Long(norig);
1097 auto k = int(ielem - batch*norig);
1099 auto const& yk = psrc[batch*istride+2*k+1];
1100 pdst[batch*ostride+k] = T(0.5) * (c * yk.real() + s * yk.imag());
1104 int istride = 2*
n+1;
1108 auto batch = ielem /
Long(norig);
1109 auto k = int(ielem - batch*norig);
1111 for (
int ir = 0; ir < 2; ++ir) {
1112 auto const& yk = psrc[(2*batch+ir)*istride+2*k+1];
1113 pdst[2*batch*ostride+ir+k*2] = T(0.5) * (s * yk.real() - c * yk.imag());
1119 auto batch = ielem /
Long(norig);
1120 auto k = int(ielem - batch*norig);
1122 auto const& yk = psrc[batch*istride+2*k+1];
1123 pdst[batch*ostride+k] = T(0.5) * (s * yk.real() - c * yk.imag());
1127 amrex::Abort(
"FFT: unpack_r2r_buffer: unsupported kind");
1132 template <Direction D>
1138#if defined(AMREX_USE_GPU)
1144#if defined(AMREX_USE_CUDA)
1148 std::size_t work_size = 0;
1154 if constexpr (std::is_same_v<float,T>) {
1160#elif defined(AMREX_USE_HIP)
1161 detail::hip_execute(
plan, (
void**)&pscratch, (
void**)&pscratch);
1162#elif defined(AMREX_USE_SYCL)
1163 detail::sycl_execute<T,Direction::forward>(std::get<0>(
plan), (T*)pscratch, (
VendorComplex*)pscratch);
1170#if defined(AMREX_USE_CUDA)
1176 if constexpr (std::is_same_v<float,T>) {
1177 fftwf_execute(
plan);
1189#if defined(AMREX_USE_CUDA)
1191#elif defined(AMREX_USE_HIP)
1192 AMREX_ROCFFT_SAFE_CALL(rocfft_plan_destroy(
plan));
1193#elif defined(AMREX_USE_SYCL)
1194 std::visit([](
auto&& p) {
delete p; },
plan);
1196 if constexpr (std::is_same_v<float,T>) {
1197 fftwf_destroy_plan(
plan);
1199 fftw_destroy_plan(
plan);
1212 PlanD* get_vendor_plan_d (Key
const& key);
1213 PlanF* get_vendor_plan_f (Key
const& key);
1215 void add_vendor_plan_d (Key
const& key, PlanD plan);
1216 void add_vendor_plan_f (Key
const& key, PlanF plan);
1220template <
typename T>
1221template <Direction D,
int M>
1232 for (
auto s : fft_size) { n *= s; }
1235#if defined(AMREX_USE_GPU)
1236 Key key = {fft_size.template expand<3>(), ncomp, D, kind};
1239 if constexpr (std::is_same_v<float,T>) {
1240 cached_plan = detail::get_vendor_plan_f(key);
1242 cached_plan = detail::get_vendor_plan_d(key);
1245 plan = *cached_plan;
1254 for (
int i = 0; i < M; ++i) {
1255 len[i] = fft_size[M-1-i];
1258 int nc = fft_size[0]/2+1;
1259 for (
int i = 1; i < M; ++i) {
1263#if defined(AMREX_USE_CUDA)
1270 type = std::is_same_v<float,T> ? CUFFT_R2C : CUFFT_D2Z;
1274 type = std::is_same_v<float,T> ? CUFFT_C2R : CUFFT_Z2D;
1278 std::size_t work_size;
1280 (cufftMakePlanMany(plan, M, len,
nullptr, 1, n_in,
nullptr, 1, n_out, type, howmany, &work_size));
1282#elif defined(AMREX_USE_HIP)
1284 auto prec = std::is_same_v<float,T> ? rocfft_precision_single : rocfft_precision_double;
1286 for (
int idim = 0; idim < M; ++idim) {
length[idim] = fft_size[idim]; }
1288 AMREX_ROCFFT_SAFE_CALL
1289 (rocfft_plan_create(&plan, rocfft_placement_notinplace,
1290 rocfft_transform_type_real_forward, prec, M,
1291 length, howmany,
nullptr));
1293 AMREX_ROCFFT_SAFE_CALL
1294 (rocfft_plan_create(&plan, rocfft_placement_notinplace,
1295 rocfft_transform_type_real_inverse, prec, M,
1296 length, howmany,
nullptr));
1299#elif defined(AMREX_USE_SYCL)
1303 pp =
new mkl_desc_r(fft_size[0]);
1305 std::vector<std::int64_t> len64(M);
1306 for (
int idim = 0; idim < M; ++idim) {
1307 len64[idim] = len[idim];
1309 pp =
new mkl_desc_r(len64);
1311#ifndef AMREX_USE_MKL_DFTI_2024
1312 pp->set_value(oneapi::mkl::dft::config_param::PLACEMENT,
1313 oneapi::mkl::dft::config_value::NOT_INPLACE);
1315 pp->set_value(oneapi::mkl::dft::config_param::PLACEMENT, DFTI_NOT_INPLACE);
1317 pp->set_value(oneapi::mkl::dft::config_param::NUMBER_OF_TRANSFORMS, howmany);
1318 pp->set_value(oneapi::mkl::dft::config_param::FWD_DISTANCE, n);
1319 pp->set_value(oneapi::mkl::dft::config_param::BWD_DISTANCE, nc);
1320 std::vector<std::int64_t> strides(M+1);
1323 for (
int i = M-1; i >= 1; --i) {
1324 strides[i] = strides[i+1] * fft_size[M-1-i];
1327#ifndef AMREX_USE_MKL_DFTI_2024
1328 pp->set_value(oneapi::mkl::dft::config_param::FWD_STRIDES, strides);
1331 pp->set_value(oneapi::mkl::dft::config_param::FWD_STRIDES, strides.data());
1334 pp->set_value(oneapi::mkl::dft::config_param::WORKSPACE,
1335 oneapi::mkl::dft::config_value::WORKSPACE_EXTERNAL);
1336 pp->commit(amrex::Gpu::Device::streamQueue());
1341 if (pf ==
nullptr || pb ==
nullptr) {
1346 if constexpr (std::is_same_v<float,T>) {
1348 plan = fftwf_plan_many_dft_r2c
1349 (M, len, howmany, (
float*)pf,
nullptr, 1, n, (fftwf_complex*)pb,
nullptr, 1, nc,
1352 plan = fftwf_plan_many_dft_c2r
1353 (M, len, howmany, (fftwf_complex*)pb,
nullptr, 1, nc, (
float*)pf,
nullptr, 1, n,
1358 plan = fftw_plan_many_dft_r2c
1359 (M, len, howmany, (
double*)pf,
nullptr, 1, n, (fftw_complex*)pb,
nullptr, 1, nc,
1362 plan = fftw_plan_many_dft_c2r
1363 (M, len, howmany, (fftw_complex*)pb,
nullptr, 1, nc, (
double*)pf,
nullptr, 1, n,
1369#if defined(AMREX_USE_GPU)
1371 if constexpr (std::is_same_v<float,T>) {
1372 detail::add_vendor_plan_f(key, plan);
1374 detail::add_vendor_plan_d(key, plan);
1385 template <
typename FA>
1386 typename FA::FABType::value_type * get_fab (FA& fa)
1388 auto myproc = ParallelContext::MyProcSub();
1389 if (myproc < fa.size()) {
1390 return fa.fabPtr(myproc);
1396 template <
typename FA1,
typename FA2>
1397 std::unique_ptr<char,DataDeleter> make_mfs_share (FA1& fa1, FA2& fa2)
1399 bool not_same_fa =
true;
1400 if constexpr (std::is_same_v<FA1,FA2>) {
1401 not_same_fa = (&fa1 != &fa2);
1403 using FAB1 =
typename FA1::FABType::value_type;
1404 using FAB2 =
typename FA2::FABType::value_type;
1405 using T1 =
typename FAB1::value_type;
1406 using T2 =
typename FAB2::value_type;
1407 auto myproc = ParallelContext::MyProcSub();
1408 bool alloc_1 = (myproc < fa1.size());
1409 bool alloc_2 = (myproc < fa2.size()) && not_same_fa;
1411 if (alloc_1 && alloc_2) {
1412 Box const& box1 = fa1.fabbox(myproc);
1413 Box const& box2 = fa2.fabbox(myproc);
1414 int ncomp1 = fa1.nComp();
1415 int ncomp2 = fa2.nComp();
1417 sizeof(T2)*box2.numPts()*ncomp2));
1418 fa1.setFab(myproc, FAB1(box1, ncomp1, (T1*)p));
1419 fa2.setFab(myproc, FAB2(box2, ncomp2, (T2*)p));
1420 }
else if (alloc_1) {
1421 Box const& box1 = fa1.fabbox(myproc);
1422 int ncomp1 = fa1.nComp();
1424 fa1.setFab(myproc, FAB1(box1, ncomp1, (T1*)p));
1425 }
else if (alloc_2) {
1426 Box const& box2 = fa2.fabbox(myproc);
1427 int ncomp2 = fa2.nComp();
1429 fa2.setFab(myproc, FAB2(box2, ncomp2, (T2*)p));
1433 return std::unique_ptr<char,DataDeleter>((
char*)p, DataDeleter{
The_Arena()});
1442 [[nodiscard]]
constexpr Dim3 operator() (Dim3 i)
const noexcept
1444 return {i.y, i.x, i.z};
1447 static constexpr Dim3 Inverse (Dim3 i)
1449 return {i.y, i.x, i.z};
1452 [[nodiscard]]
constexpr IndexType operator() (IndexType it)
const noexcept
1457 static constexpr IndexType Inverse (IndexType it)
1465 [[nodiscard]]
constexpr Dim3 operator() (Dim3 i)
const noexcept
1467 return {i.z, i.y, i.x};
1470 static constexpr Dim3 Inverse (Dim3 i)
1472 return {i.z, i.y, i.x};
1475 [[nodiscard]]
constexpr IndexType operator() (IndexType it)
const noexcept
1480 static constexpr IndexType Inverse (IndexType it)
1489 [[nodiscard]]
constexpr Dim3 operator() (Dim3 i)
const noexcept
1491 return {i.y, i.z, i.x};
1495 static constexpr Dim3 Inverse (Dim3 i)
1497 return {i.z, i.x, i.y};
1500 [[nodiscard]]
constexpr IndexType operator() (IndexType it)
const noexcept
1505 static constexpr IndexType Inverse (IndexType it)
1514 [[nodiscard]]
constexpr Dim3 operator() (Dim3 i)
const noexcept
1516 return {i.z, i.x, i.y};
1520 static constexpr Dim3 Inverse (Dim3 i)
1522 return {i.y, i.z, i.x};
1525 [[nodiscard]]
constexpr IndexType operator() (IndexType it)
const noexcept
1530 static constexpr IndexType Inverse (IndexType it)
1543 explicit SubHelper (Box
const& domain);
1545 [[nodiscard]]
Box make_box (Box
const& box)
const;
1547 [[nodiscard]] Periodicity make_periodicity (Periodicity
const& period)
const;
1549 [[nodiscard]]
bool ghost_safe (IntVect
const& ng)
const;
1552 [[nodiscard]]
IntVect make_iv (IntVect
const& iv)
const;
1555 [[nodiscard]]
IntVect make_safe_ghost (IntVect
const& ng)
const;
1557 [[nodiscard]] BoxArray inverse_boxarray (BoxArray
const& ba)
const;
1559 [[nodiscard]]
IntVect inverse_order (IntVect
const& order)
const;
1561 template <
typename T>
1562 [[nodiscard]] T make_array (T
const& a)
const
1564#if (AMREX_SPACEDIM == 1)
1567#elif (AMREX_SPACEDIM == 2)
1568 if (m_case == case_1n) {
1569 return T{a[1],a[0]};
1574 if (m_case == case_11n) {
1575 return T{a[2],a[0],a[1]};
1576 }
else if (m_case == case_1n1) {
1577 return T{a[1],a[0],a[2]};
1578 }
else if (m_case == case_1nn) {
1579 return T{a[1],a[2],a[0]};
1580 }
else if (m_case == case_n1n) {
1581 return T{a[0],a[2],a[1]};
1588 [[nodiscard]] GpuArray<int,3> xyz_order ()
const;
1590 template <
typename FA>
1591 FA make_alias_mf (FA
const& mf)
1593 BoxList bl = mf.boxArray().boxList();
1594 for (
auto& b : bl) {
1597 auto const& ng = make_iv(mf.nGrowVect());
1598 FA submf(BoxArray(std::move(bl)), mf.DistributionMap(), mf.nComp(), ng, MFInfo{}.SetAlloc(
false));
1599 using FAB =
typename FA::fab_type;
1600 for (MFIter mfi(submf, MFItInfo().DisableDeviceSync()); mfi.isValid(); ++mfi) {
1601 submf.setFab(mfi, FAB(mfi.fabbox(), mf.nComp(), mf[mfi].dataPtr()));
1606#if (AMREX_SPACEDIM == 2)
1607 enum Case { case_1n, case_other };
1608 int m_case = case_other;
1609#elif (AMREX_SPACEDIM == 3)
1610 enum Case { case_11n, case_1n1, case_1nn, case_n1n, case_other };
1611 int m_case = case_other;
#define AMREX_ENUM(CLASS,...)
Definition AMReX_Enum.H:208
#define AMREX_CUFFT_SAFE_CALL(call)
Definition AMReX_GpuError.H:92
#define AMREX_GPU_DEVICE
Definition AMReX_GpuQualifiers.H:18
amrex::ParmParse pp
Input file parser instance for the given namespace.
Definition AMReX_HypreIJIface.cpp:15
Real * pdst
Definition AMReX_HypreMLABecLap.cpp:1090
#define AMREX_D_TERM(a, b, c)
Definition AMReX_SPACE.H:172
virtual void free(void *pt)=0
A pure virtual function for deleting the arena pointed to by pt.
virtual void * alloc(std::size_t sz)=0
__host__ __device__ IntVectND< dim > length() const noexcept
Return the length of the BoxND.
Definition AMReX_Box.H:154
Calculates the distribution of FABs to MPI processes.
Definition AMReX_DistributionMapping.H:43
An Integer Vector in dim-Dimensional Space.
Definition AMReX_IntVect.H:57
amrex_long Long
Definition AMReX_INT.H:30
void ParallelForOMP(T n, L const &f) noexcept
Performance-portable kernel launch function with optional OpenMP threading.
Definition AMReX_GpuLaunch.H:243
Definition AMReX_FFT_Helper.H:46
Direction
Definition AMReX_FFT_Helper.H:48
Boundary
Definition AMReX_FFT_Helper.H:52
DomainStrategy
Definition AMReX_FFT_Helper.H:50
Kind
Definition AMReX_FFT_Helper.H:54
void streamSynchronize() noexcept
Definition AMReX_GpuDevice.H:263
gpuStream_t gpuStream() noexcept
Definition AMReX_GpuDevice.H:244
__host__ __device__ std::pair< double, double > sincospi(double x)
Return sin(pi*x) and cos(pi*x) given x.
Definition AMReX_Math.H:204
__host__ __device__ void ignore_unused(const Ts &...)
This shuts up the compiler about unused variables.
Definition AMReX.H:138
__host__ __device__ Dim3 length(Array4< T > const &a) noexcept
Definition AMReX_Array4.H:326
BoxND< 3 > Box
Box is an alias for amrex::BoxND instantiated with AMREX_SPACEDIM.
Definition AMReX_BaseFwd.H:27
IndexTypeND< 3 > IndexType
IndexType is an alias for amrex::IndexTypeND instantiated with AMREX_SPACEDIM.
Definition AMReX_BaseFwd.H:33
IntVectND< 3 > IntVect
IntVect is an alias for amrex::IntVectND instantiated with AMREX_SPACEDIM.
Definition AMReX_BaseFwd.H:30
void Abort(const std::string &msg)
Print out message to cerr and exit via abort().
Definition AMReX.cpp:230
Arena * The_Arena()
Definition AMReX_Arena.cpp:783
Definition AMReX_FFT_Helper.H:58
bool twod_mode
Definition AMReX_FFT_Helper.H:69
Info & setNumProcs(int n)
Definition AMReX_FFT_Helper.H:85
bool oned_mode
We might have a special twod_mode: nx or ny == 1 && nz > 1.
Definition AMReX_FFT_Helper.H:72
int batch_size
Batched FFT size. Only support in R2C, not R2X.
Definition AMReX_FFT_Helper.H:75
Info & setDomainStrategy(DomainStrategy s)
Definition AMReX_FFT_Helper.H:80
DomainStrategy domain_strategy
Domain composition strategy.
Definition AMReX_FFT_Helper.H:60
int nprocs
Max number of processes to use.
Definition AMReX_FFT_Helper.H:78
int pencil_threshold
Definition AMReX_FFT_Helper.H:64
Info & setOneDMode(bool x)
Definition AMReX_FFT_Helper.H:83
Info & setBatchSize(int bsize)
Definition AMReX_FFT_Helper.H:84
Info & setPencilThreshold(int t)
Definition AMReX_FFT_Helper.H:81
Info & setTwoDMode(bool x)
Definition AMReX_FFT_Helper.H:82
Definition AMReX_FFT_Helper.H:134
void * pf
Definition AMReX_FFT_Helper.H:169
void unpack_r2r_buffer(T *pdst, void const *pbuf) const
Definition AMReX_FFT_Helper.H:983
std::conditional_t< std::is_same_v< float, T >, cuComplex, cuDoubleComplex > VendorComplex
Definition AMReX_FFT_Helper.H:138
VendorPlan plan2
Definition AMReX_FFT_Helper.H:168
int n
Definition AMReX_FFT_Helper.H:161
void destroy()
Definition AMReX_FFT_Helper.H:179
bool defined2
Definition AMReX_FFT_Helper.H:166
void init_r2r(Box const &box, VendorComplex *pc, std::pair< Boundary, Boundary > const &bc)
Definition AMReX_FFT_Helper.H:604
void pack_r2r_buffer(void *pbuf, T const *psrc) const
Definition AMReX_FFT_Helper.H:751
static void free_scratch_space(void *p)
Definition AMReX_FFT_Helper.H:749
static void destroy_vendor_plan(VendorPlan plan)
Definition AMReX_FFT_Helper.H:1187
Kind get_r2r_kind(std::pair< Boundary, Boundary > const &bc)
Definition AMReX_FFT_Helper.H:482
cufftHandle VendorPlan
Definition AMReX_FFT_Helper.H:136
Kind kind
Definition AMReX_FFT_Helper.H:163
void init_c2c(Box const &box, VendorComplex *p, int ncomp=1, int ndims=1)
Definition AMReX_FFT_Helper.H:321
int howmany
Definition AMReX_FFT_Helper.H:162
void * pb
Definition AMReX_FFT_Helper.H:170
void init_r2c(Box const &box, T *pr, VendorComplex *pc, bool is_2d_transform=false, int ncomp=1)
Definition AMReX_FFT_Helper.H:194
void compute_r2r()
Definition AMReX_FFT_Helper.H:1133
void compute_c2c()
Definition AMReX_FFT_Helper.H:696
bool r2r_data_is_complex
Definition AMReX_FFT_Helper.H:164
void * alloc_scratch_space() const
Definition AMReX_FFT_Helper.H:735
VendorPlan plan
Definition AMReX_FFT_Helper.H:167
void compute_r2c()
Definition AMReX_FFT_Helper.H:647
bool defined
Definition AMReX_FFT_Helper.H:165
void set_ptrs(void *p0, void *p1)
Definition AMReX_FFT_Helper.H:173
void init_r2r(Box const &box, T *p, std::pair< Boundary, Boundary > const &bc, int howmany_initval=1)
Definition AMReX_FFT_Helper.H:508
void init_r2c(IntVectND< M > const &fft_size, void *, void *, bool cache, int ncomp=1)
Definition AMReX_FFT_Helper.H:1222
A host / device complex number type, because std::complex doesn't work in device code with Cuda yet.
Definition AMReX_GpuComplex.H:30