1#ifndef AMREX_FFT_HELPER_H_
2#define AMREX_FFT_HELPER_H_
3#include <AMReX_Config.H>
16#if defined(AMREX_USE_CUDA)
18# include <cuComplex.h>
19#elif defined(AMREX_USE_HIP)
20# if __has_include(<rocfft/rocfft.h>)
21# include <rocfft/rocfft.h>
25# include <hip/hip_complex.h>
26#elif defined(AMREX_USE_SYCL)
27# if __has_include(<oneapi/mkl/dft.hpp>)
28# include <oneapi/mkl/dft.hpp>
30# define AMREX_USE_MKL_DFTI_2024 1
31# include <oneapi/mkl/dfti.hpp>
75 int nprocs = std::numeric_limits<int>::max();
85namespace detail {
void hip_execute (rocfft_plan plan,
void **in,
void **out); }
91template <
typename T, Direction direction,
typename P,
typename TI,
typename TO>
92void sycl_execute (P* plan, TI* in, TO* out)
94#ifndef AMREX_USE_MKL_DFTI_2024
95 std::int64_t workspaceSize = 0;
97 std::size_t workspaceSize = 0;
99 plan->get_value(oneapi::mkl::dft::config_param::WORKSPACE_BYTES,
102 plan->set_workspace(buffer);
104 if (std::is_same_v<TI,TO>) {
107 r = oneapi::mkl::dft::compute_forward(*plan, out);
109 r = oneapi::mkl::dft::compute_backward(*plan, out);
113 r = oneapi::mkl::dft::compute_forward(*plan, in, out);
115 r = oneapi::mkl::dft::compute_backward(*plan, in, out);
127#if defined(AMREX_USE_CUDA)
130 cuComplex, cuDoubleComplex>;
131#elif defined(AMREX_USE_HIP)
133 using VendorComplex = std::conditional_t<std::is_same_v<float,T>,
135#elif defined(AMREX_USE_SYCL)
136 using mkl_desc_r = oneapi::mkl::dft::descriptor<std::is_same_v<float,T>
137 ? oneapi::mkl::dft::precision::SINGLE
138 : oneapi::mkl::dft::precision::DOUBLE,
139 oneapi::mkl::dft::domain::REAL>;
140 using mkl_desc_c = oneapi::mkl::dft::descriptor<std::is_same_v<float,T>
141 ? oneapi::mkl::dft::precision::SINGLE
142 : oneapi::mkl::dft::precision::DOUBLE,
143 oneapi::mkl::dft::domain::COMPLEX>;
144 using VendorPlan = std::variant<mkl_desc_r*,mkl_desc_c*>;
147 using VendorPlan = std::conditional_t<std::is_same_v<float,T>,
148 fftwf_plan, fftw_plan>;
149 using VendorComplex = std::conditional_t<std::is_same_v<float,T>,
150 fftwf_complex, fftw_complex>;
177#if !defined(AMREX_USE_GPU)
185 template <Direction D>
190 int rank = is_2d_transform ? 2 : 1;
205 int nr = (rank == 1) ? len[0] : len[0]*len[1];
207 int nc = (rank == 1) ? (len[0]/2+1) : (len[1]/2+1)*len[0];
208#if (AMREX_SPACEDIM == 1)
218#if defined(AMREX_USE_CUDA)
222 std::size_t work_size;
224 cufftType fwd_type = std::is_same_v<float,T> ? CUFFT_R2C : CUFFT_D2Z;
226 (cufftMakePlanMany(
plan, rank, len,
nullptr, 1, nr,
nullptr, 1, nc, fwd_type,
howmany, &work_size));
228 cufftType bwd_type = std::is_same_v<float,T> ? CUFFT_C2R : CUFFT_Z2D;
230 (cufftMakePlanMany(
plan, rank, len,
nullptr, 1, nc,
nullptr, 1, nr, bwd_type,
howmany, &work_size));
233#elif defined(AMREX_USE_HIP)
235 auto prec = std::is_same_v<float,T> ? rocfft_precision_single : rocfft_precision_double;
237 std::size_t
length[2] = {std::size_t(len[1]), std::size_t(len[0])};
239 AMREX_ROCFFT_SAFE_CALL
240 (rocfft_plan_create(&
plan, rocfft_placement_notinplace,
241 rocfft_transform_type_real_forward, prec, rank,
244 AMREX_ROCFFT_SAFE_CALL
245 (rocfft_plan_create(&
plan, rocfft_placement_notinplace,
246 rocfft_transform_type_real_inverse, prec, rank,
250#elif defined(AMREX_USE_SYCL)
254 pp =
new mkl_desc_r(len[0]);
256 pp =
new mkl_desc_r({std::int64_t(len[0]), std::int64_t(len[1])});
258#ifndef AMREX_USE_MKL_DFTI_2024
259 pp->set_value(oneapi::mkl::dft::config_param::PLACEMENT,
260 oneapi::mkl::dft::config_value::NOT_INPLACE);
262 pp->set_value(oneapi::mkl::dft::config_param::PLACEMENT, DFTI_NOT_INPLACE);
264 pp->set_value(oneapi::mkl::dft::config_param::NUMBER_OF_TRANSFORMS,
howmany);
265 pp->set_value(oneapi::mkl::dft::config_param::FWD_DISTANCE, nr);
266 pp->set_value(oneapi::mkl::dft::config_param::BWD_DISTANCE, nc);
267 std::vector<std::int64_t> strides;
268 strides.push_back(0);
269 if (rank == 2) { strides.push_back(len[1]); }
270 strides.push_back(1);
271#ifndef AMREX_USE_MKL_DFTI_2024
272 pp->set_value(oneapi::mkl::dft::config_param::FWD_STRIDES, strides);
275 pp->set_value(oneapi::mkl::dft::config_param::FWD_STRIDES, strides.data());
278 pp->set_value(oneapi::mkl::dft::config_param::WORKSPACE,
279 oneapi::mkl::dft::config_value::WORKSPACE_EXTERNAL);
280 pp->commit(amrex::Gpu::Device::streamQueue());
285 if constexpr (std::is_same_v<float,T>) {
287 plan = fftwf_plan_many_dft_r2c
288 (rank, len,
howmany, pr,
nullptr, 1, nr, pc,
nullptr, 1, nc,
289 FFTW_ESTIMATE | FFTW_DESTROY_INPUT);
291 plan = fftwf_plan_many_dft_c2r
292 (rank, len,
howmany, pc,
nullptr, 1, nc, pr,
nullptr, 1, nr,
293 FFTW_ESTIMATE | FFTW_DESTROY_INPUT);
297 plan = fftw_plan_many_dft_r2c
298 (rank, len,
howmany, pr,
nullptr, 1, nr, pc,
nullptr, 1, nc,
299 FFTW_ESTIMATE | FFTW_DESTROY_INPUT);
301 plan = fftw_plan_many_dft_c2r
302 (rank, len,
howmany, pc,
nullptr, 1, nc, pr,
nullptr, 1, nr,
303 FFTW_ESTIMATE | FFTW_DESTROY_INPUT);
309 template <Direction D,
int M>
312 template <Direction D>
330#if (AMREX_SPACEDIM >= 2)
331 else if (ndims == 2) {
333#if (AMREX_SPACEDIM == 2)
341#if (AMREX_SPACEDIM == 3)
342 else if (ndims == 3) {
352#if defined(AMREX_USE_CUDA)
356 cufftType t = std::is_same_v<float,T> ? CUFFT_C2C : CUFFT_Z2Z;
357 std::size_t work_size;
359 (cufftMakePlanMany(
plan, ndims, len,
nullptr, 1,
n,
nullptr, 1,
n, t,
howmany, &work_size));
361#elif defined(AMREX_USE_HIP)
363 auto prec = std::is_same_v<float,T> ? rocfft_precision_single
364 : rocfft_precision_double;
366 : rocfft_transform_type_complex_inverse;
370 }
else if (ndims == 2) {
378 AMREX_ROCFFT_SAFE_CALL
379 (rocfft_plan_create(&
plan, rocfft_placement_inplace, dir, prec, ndims,
382#elif defined(AMREX_USE_SYCL)
386 pp =
new mkl_desc_c(
n);
387 }
else if (ndims == 2) {
388 pp =
new mkl_desc_c({std::int64_t(len[0]), std::int64_t(len[1])});
390 pp =
new mkl_desc_c({std::int64_t(len[0]), std::int64_t(len[1]), std::int64_t(len[2])});
392#ifndef AMREX_USE_MKL_DFTI_2024
393 pp->set_value(oneapi::mkl::dft::config_param::PLACEMENT,
394 oneapi::mkl::dft::config_value::INPLACE);
396 pp->set_value(oneapi::mkl::dft::config_param::PLACEMENT, DFTI_INPLACE);
398 pp->set_value(oneapi::mkl::dft::config_param::NUMBER_OF_TRANSFORMS,
howmany);
399 pp->set_value(oneapi::mkl::dft::config_param::FWD_DISTANCE,
n);
400 pp->set_value(oneapi::mkl::dft::config_param::BWD_DISTANCE,
n);
401 std::vector<std::int64_t> strides(ndims+1);
404 for (
int i = ndims-1; i >= 1; --i) {
405 strides[i] = strides[i+1] * len[ndims-1-i];
407#ifndef AMREX_USE_MKL_DFTI_2024
408 pp->set_value(oneapi::mkl::dft::config_param::FWD_STRIDES, strides);
409 pp->set_value(oneapi::mkl::dft::config_param::BWD_STRIDES, strides);
411 pp->set_value(oneapi::mkl::dft::config_param::FWD_STRIDES, strides.data());
412 pp->set_value(oneapi::mkl::dft::config_param::BWD_STRIDES, strides.data());
414 pp->set_value(oneapi::mkl::dft::config_param::WORKSPACE,
415 oneapi::mkl::dft::config_value::WORKSPACE_EXTERNAL);
416 pp->commit(amrex::Gpu::Device::streamQueue());
421 if constexpr (std::is_same_v<float,T>) {
423 plan = fftwf_plan_many_dft
424 (ndims, len,
howmany, p,
nullptr, 1,
n, p,
nullptr, 1,
n, -1,
427 plan = fftwf_plan_many_dft
428 (ndims, len,
howmany, p,
nullptr, 1,
n, p,
nullptr, 1,
n, +1,
433 plan = fftw_plan_many_dft
434 (ndims, len,
howmany, p,
nullptr, 1,
n, p,
nullptr, 1,
n, -1,
437 plan = fftw_plan_many_dft
438 (ndims, len,
howmany, p,
nullptr, 1,
n, p,
nullptr, 1,
n, +1,
446 template <Direction D>
447 fftw_r2r_kind get_fftw_kind (std::pair<Boundary,Boundary>
const& bc)
449 if (bc.first == Boundary::even && bc.second == Boundary::even)
453 else if (bc.first == Boundary::even && bc.second == Boundary::odd)
457 else if (bc.first == Boundary::odd && bc.second == Boundary::even)
461 else if (bc.first == Boundary::odd && bc.second == Boundary::odd)
467 return fftw_r2r_kind{};
473 template <Direction D>
476 if (bc.first == Boundary::even && bc.second == Boundary::even)
480 else if (bc.first == Boundary::even && bc.second == Boundary::odd)
484 else if (bc.first == Boundary::odd && bc.second == Boundary::even)
488 else if (bc.first == Boundary::odd && bc.second == Boundary::odd)
499 template <Direction D>
500 void init_r2r (
Box const& box, T* p, std::pair<Boundary,Boundary>
const& bc,
501 int howmany_initval = 1)
505 kind = get_r2r_kind<D>(bc);
513#if defined(AMREX_USE_GPU)
515 if (bc.first == Boundary::odd && bc.second == Boundary::odd &&
518 }
else if (bc.first == Boundary::odd && bc.second == Boundary::odd &&
521 }
else if (bc.first == Boundary::even && bc.second == Boundary::even &&
524 }
else if (bc.first == Boundary::even && bc.second == Boundary::even &&
527 }
else if ((bc.first == Boundary::even && bc.second == Boundary::odd) ||
528 (bc.first == Boundary::odd && bc.second == Boundary::even)) {
533 int nc = (nex/2) + 1;
535#if defined (AMREX_USE_CUDA)
539 cufftType fwd_type = std::is_same_v<float,T> ? CUFFT_R2C : CUFFT_D2Z;
540 std::size_t work_size;
542 (cufftMakePlanMany(
plan, 1, &nex,
nullptr, 1, nc*2,
nullptr, 1, nc, fwd_type,
howmany, &work_size));
544#elif defined(AMREX_USE_HIP)
547 auto prec = std::is_same_v<float,T> ? rocfft_precision_single : rocfft_precision_double;
548 const std::size_t
length = nex;
549 AMREX_ROCFFT_SAFE_CALL
550 (rocfft_plan_create(&
plan, rocfft_placement_inplace,
551 rocfft_transform_type_real_forward, prec, 1,
554#elif defined(AMREX_USE_SYCL)
556 auto*
pp =
new mkl_desc_r(nex);
557#ifndef AMREX_USE_MKL_DFTI_2024
558 pp->set_value(oneapi::mkl::dft::config_param::PLACEMENT,
559 oneapi::mkl::dft::config_value::INPLACE);
561 pp->set_value(oneapi::mkl::dft::config_param::PLACEMENT, DFTI_INPLACE);
563 pp->set_value(oneapi::mkl::dft::config_param::NUMBER_OF_TRANSFORMS,
howmany);
564 pp->set_value(oneapi::mkl::dft::config_param::FWD_DISTANCE, nc*2);
565 pp->set_value(oneapi::mkl::dft::config_param::BWD_DISTANCE, nc);
566 std::vector<std::int64_t> strides = {0,1};
567#ifndef AMREX_USE_MKL_DFTI_2024
568 pp->set_value(oneapi::mkl::dft::config_param::FWD_STRIDES, strides);
569 pp->set_value(oneapi::mkl::dft::config_param::BWD_STRIDES, strides);
571 pp->set_value(oneapi::mkl::dft::config_param::FWD_STRIDES, strides.data());
572 pp->set_value(oneapi::mkl::dft::config_param::BWD_STRIDES, strides.data());
574 pp->set_value(oneapi::mkl::dft::config_param::WORKSPACE,
575 oneapi::mkl::dft::config_value::WORKSPACE_EXTERNAL);
576 pp->commit(amrex::Gpu::Device::streamQueue());
582 auto fftw_kind = get_fftw_kind<D>(bc);
583 if constexpr (std::is_same_v<float,T>) {
584 plan = fftwf_plan_many_r2r
585 (1, &
n,
howmany, p,
nullptr, 1,
n, p,
nullptr, 1,
n, &fftw_kind,
588 plan = fftw_plan_many_r2r
589 (1, &
n,
howmany, p,
nullptr, 1,
n, p,
nullptr, 1,
n, &fftw_kind,
595 template <Direction D>
597 std::pair<Boundary,Boundary>
const& bc)
603#if defined(AMREX_USE_GPU)
605 init_r2r<D>(box, p, bc, 2);
610 kind = get_r2r_kind<D>(bc);
619 auto fftw_kind = get_fftw_kind<D>(bc);
620 if constexpr (std::is_same_v<float,T>) {
621 plan = fftwf_plan_many_r2r
622 (1, &
n,
howmany, p,
nullptr, 2,
n*2, p,
nullptr, 2,
n*2, &fftw_kind,
624 plan2 = fftwf_plan_many_r2r
625 (1, &
n,
howmany, p+1,
nullptr, 2,
n*2, p+1,
nullptr, 2,
n*2, &fftw_kind,
628 plan = fftw_plan_many_r2r
629 (1, &
n,
howmany, p,
nullptr, 2,
n*2, p,
nullptr, 2,
n*2, &fftw_kind,
631 plan2 = fftw_plan_many_r2r
632 (1, &
n,
howmany, p+1,
nullptr, 2,
n*2, p+1,
nullptr, 2,
n*2, &fftw_kind,
638 template <Direction D>
649#if defined(AMREX_USE_CUDA)
652 std::size_t work_size = 0;
659 if constexpr (std::is_same_v<float,T>) {
665 if constexpr (std::is_same_v<float,T>) {
673#elif defined(AMREX_USE_HIP)
674 detail::hip_execute(
plan, (
void**)&pi, (
void**)&po);
675#elif defined(AMREX_USE_SYCL)
676 detail::sycl_execute<T,D>(std::get<0>(
plan), pi, po);
679 if constexpr (std::is_same_v<float,T>) {
687 template <Direction D>
695#if defined(AMREX_USE_CUDA)
698 std::size_t work_size = 0;
705 if constexpr (std::is_same_v<float,T>) {
712#elif defined(AMREX_USE_HIP)
713 detail::hip_execute(
plan, (
void**)&p, (
void**)&p);
714#elif defined(AMREX_USE_SYCL)
715 detail::sycl_execute<T,D>(std::get<1>(
plan), p, p);
718 if constexpr (std::is_same_v<float,T>) {
736 amrex::Abort(
"FFT: alloc_scratch_space: unsupported kind");
745 auto*
pdst = (T*) pbuf;
748 int ostride = (
n+1)*2;
752 Long nelems = Long(nex)*
howmany;
756 auto batch = ielem / Long(nex);
757 auto i =
int(ielem - batch*nex);
758 for (
int ir = 0; ir < 2; ++ir) {
759 auto* po =
pdst + (2*batch+ir)*ostride + i;
760 auto const* pi = psrc + 2*batch*istride + ir;
764 *po = sign * pi[(2*norig-1-i)*2];
771 auto batch = ielem / Long(nex);
772 auto i =
int(ielem - batch*nex);
773 auto* po =
pdst + batch*ostride + i;
774 auto const* pi = psrc + batch*istride;
778 *po = sign * pi[2*norig-1-i];
783 int ostride = (2*
n+1)*2;
787 Long nelems = Long(nex)*
howmany;
791 auto batch = ielem / Long(nex);
792 auto i =
int(ielem - batch*nex);
793 for (
int ir = 0; ir < 2; ++ir) {
794 auto* po =
pdst + (2*batch+ir)*ostride + i;
795 auto const* pi = psrc + 2*batch*istride + ir;
798 }
else if (i < (2*norig-1)) {
799 *po = pi[(2*norig-2-i)*2];
800 }
else if (i == (2*norig-1)) {
802 }
else if (i < (3*norig)) {
803 *po = -pi[(i-2*norig)*2];
804 }
else if (i < (4*norig-1)) {
805 *po = -pi[(4*norig-2-i)*2];
814 auto batch = ielem / Long(nex);
815 auto i =
int(ielem - batch*nex);
816 auto* po =
pdst + batch*ostride + i;
817 auto const* pi = psrc + batch*istride;
820 }
else if (i < (2*norig-1)) {
821 *po = pi[2*norig-2-i];
822 }
else if (i == (2*norig-1)) {
824 }
else if (i < (3*norig)) {
825 *po = -pi[i-2*norig];
826 }
else if (i < (4*norig-1)) {
827 *po = -pi[4*norig-2-i];
834 int ostride = (2*
n+1)*2;
838 Long nelems = Long(nex)*
howmany;
842 auto batch = ielem / Long(nex);
843 auto i =
int(ielem - batch*nex);
844 for (
int ir = 0; ir < 2; ++ir) {
845 auto* po =
pdst + (2*batch+ir)*ostride + i;
846 auto const* pi = psrc + 2*batch*istride + ir;
849 }
else if (i == norig) {
851 }
else if (i < (2*norig+1)) {
852 *po = -pi[(2*norig-i)*2];
853 }
else if (i < (3*norig)) {
854 *po = -pi[(i-2*norig)*2];
855 }
else if (i == 3*norig) {
858 *po = pi[(4*norig-i)*2];
865 auto batch = ielem / Long(nex);
866 auto i =
int(ielem - batch*nex);
867 auto* po =
pdst + batch*ostride + i;
868 auto const* pi = psrc + batch*istride;
871 }
else if (i == norig) {
873 }
else if (i < (2*norig+1)) {
874 *po = -pi[2*norig-i];
875 }
else if (i < (3*norig)) {
876 *po = -pi[i-2*norig];
877 }
else if (i == 3*norig) {
885 int ostride = (2*
n+1)*2;
889 Long nelems = Long(nex)*
howmany;
893 auto batch = ielem / Long(nex);
894 auto i =
int(ielem - batch*nex);
895 for (
int ir = 0; ir < 2; ++ir) {
896 auto* po =
pdst + (2*batch+ir)*ostride + i;
897 auto const* pi = psrc + 2*batch*istride + ir;
900 }
else if (i < (2*norig)) {
901 *po = -pi[(2*norig-1-i)*2];
902 }
else if (i < (3*norig)) {
903 *po = -pi[(i-2*norig)*2];
905 *po = pi[(4*norig-1-i)*2];
912 auto batch = ielem / Long(nex);
913 auto i =
int(ielem - batch*nex);
914 auto* po =
pdst + batch*ostride + i;
915 auto const* pi = psrc + batch*istride;
918 }
else if (i < (2*norig)) {
919 *po = -pi[2*norig-1-i];
920 }
else if (i < (3*norig)) {
921 *po = -pi[i-2*norig];
923 *po = pi[4*norig-1-i];
928 int ostride = (2*
n+1)*2;
932 Long nelems = Long(nex)*
howmany;
936 auto batch = ielem / Long(nex);
937 auto i =
int(ielem - batch*nex);
938 for (
int ir = 0; ir < 2; ++ir) {
939 auto* po =
pdst + (2*batch+ir)*ostride + i;
940 auto const* pi = psrc + 2*batch*istride + ir;
943 }
else if (i < (2*norig)) {
944 *po = pi[(2*norig-1-i)*2];
945 }
else if (i < (3*norig)) {
946 *po = -pi[(i-2*norig)*2];
948 *po = -pi[(4*norig-1-i)*2];
955 auto batch = ielem / Long(nex);
956 auto i =
int(ielem - batch*nex);
957 auto* po =
pdst + batch*ostride + i;
958 auto const* pi = psrc + batch*istride;
961 }
else if (i < (2*norig)) {
962 *po = pi[2*norig-1-i];
963 }
else if (i < (3*norig)) {
964 *po = -pi[i-2*norig];
966 *po = -pi[4*norig-1-i];
979 Long nelems = Long(norig)*
howmany;
987 auto batch = ielem / Long(norig);
988 auto k =
int(ielem - batch*norig);
990 for (
int ir = 0; ir < 2; ++ir) {
991 auto const& yk = psrc[(2*batch+ir)*istride+k+1];
992 pdst[2*batch*ostride+ir+k*2] = s * yk.real() - c * yk.imag();
998 auto batch = ielem / Long(norig);
999 auto k =
int(ielem - batch*norig);
1001 auto const& yk = psrc[batch*istride+k+1];
1002 pdst[batch*ostride+k] = s * yk.real() - c * yk.imag();
1006 int istride = 2*
n+1;
1010 auto batch = ielem / Long(norig);
1011 auto k =
int(ielem - batch*norig);
1013 for (
int ir = 0; ir < 2; ++ir) {
1014 auto const& yk = psrc[(2*batch+ir)*istride+2*k+1];
1015 pdst[2*batch*ostride+ir+k*2] = T(0.5)*(s * yk.real() - c * yk.imag());
1021 auto batch = ielem / Long(norig);
1022 auto k =
int(ielem - batch*norig);
1024 auto const& yk = psrc[batch*istride+2*k+1];
1025 pdst[batch*ostride+k] = T(0.5)*(s * yk.real() - c * yk.imag());
1033 auto batch = ielem / Long(norig);
1034 auto k =
int(ielem - batch*norig);
1036 for (
int ir = 0; ir < 2; ++ir) {
1037 auto const& yk = psrc[(2*batch+ir)*istride+k];
1038 pdst[2*batch*ostride+ir+k*2] = c * yk.real() + s * yk.imag();
1044 auto batch = ielem / Long(norig);
1045 auto k =
int(ielem - batch*norig);
1047 auto const& yk = psrc[batch*istride+k];
1048 pdst[batch*ostride+k] = c * yk.real() + s * yk.imag();
1052 int istride = 2*
n+1;
1056 auto batch = ielem / Long(norig);
1057 auto k =
int(ielem - batch*norig);
1058 for (
int ir = 0; ir < 2; ++ir) {
1059 auto const& yk = psrc[(2*batch+ir)*istride+2*k+1];
1060 pdst[2*batch*ostride+ir+k*2] = T(0.5) * yk.real();
1066 auto batch = ielem / Long(norig);
1067 auto k =
int(ielem - batch*norig);
1068 auto const& yk = psrc[batch*istride+2*k+1];
1069 pdst[batch*ostride+k] = T(0.5) * yk.real();
1073 int istride = 2*
n+1;
1077 auto batch = ielem / Long(norig);
1078 auto k =
int(ielem - batch*norig);
1080 for (
int ir = 0; ir < 2; ++ir) {
1081 auto const& yk = psrc[(2*batch+ir)*istride+2*k+1];
1082 pdst[2*batch*ostride+ir+k*2] = T(0.5) * (c * yk.real() + s * yk.imag());
1088 auto batch = ielem / Long(norig);
1089 auto k =
int(ielem - batch*norig);
1091 auto const& yk = psrc[batch*istride+2*k+1];
1092 pdst[batch*ostride+k] = T(0.5) * (c * yk.real() + s * yk.imag());
1096 int istride = 2*
n+1;
1100 auto batch = ielem / Long(norig);
1101 auto k =
int(ielem - batch*norig);
1103 for (
int ir = 0; ir < 2; ++ir) {
1104 auto const& yk = psrc[(2*batch+ir)*istride+2*k+1];
1105 pdst[2*batch*ostride+ir+k*2] = T(0.5) * (s * yk.real() - c * yk.imag());
1111 auto batch = ielem / Long(norig);
1112 auto k =
int(ielem - batch*norig);
1114 auto const& yk = psrc[batch*istride+2*k+1];
1115 pdst[batch*ostride+k] = T(0.5) * (s * yk.real() - c * yk.imag());
1119 amrex::Abort(
"FFT: unpack_r2r_buffer: unsupported kind");
1124 template <Direction D>
1130#if defined(AMREX_USE_GPU)
1136#if defined(AMREX_USE_CUDA)
1140 std::size_t work_size = 0;
1146 if constexpr (std::is_same_v<float,T>) {
1152#elif defined(AMREX_USE_HIP)
1153 detail::hip_execute(
plan, (
void**)&pscratch, (
void**)&pscratch);
1154#elif defined(AMREX_USE_SYCL)
1155 detail::sycl_execute<T,Direction::forward>(std::get<0>(
plan), (T*)pscratch, (
VendorComplex*)pscratch);
1162#if defined(AMREX_USE_CUDA)
1168 if constexpr (std::is_same_v<float,T>) {
1169 fftwf_execute(
plan);
1181#if defined(AMREX_USE_CUDA)
1183#elif defined(AMREX_USE_HIP)
1184 AMREX_ROCFFT_SAFE_CALL(rocfft_plan_destroy(
plan));
1185#elif defined(AMREX_USE_SYCL)
1186 std::visit([](
auto&& p) {
delete p; },
plan);
1188 if constexpr (std::is_same_v<float,T>) {
1189 fftwf_destroy_plan(
plan);
1191 fftw_destroy_plan(
plan);
1207template <
typename T>
1208template <Direction D,
int M>
1219 for (
auto s : fft_size) { n *= s; }
1222#if defined(AMREX_USE_GPU)
1223 Key key = {fft_size.template expand<3>(), ncomp, D, kind};
1226 if constexpr (std::is_same_v<float,T>) {
1232 plan = *cached_plan;
1241 for (
int i = 0; i < M; ++i) {
1242 len[i] = fft_size[M-1-i];
1245 int nc = fft_size[0]/2+1;
1246 for (
int i = 1; i < M; ++i) {
1250#if defined(AMREX_USE_CUDA)
1257 type = std::is_same_v<float,T> ? CUFFT_R2C : CUFFT_D2Z;
1261 type = std::is_same_v<float,T> ? CUFFT_C2R : CUFFT_Z2D;
1265 std::size_t work_size;
1267 (cufftMakePlanMany(plan, M, len,
nullptr, 1, n_in,
nullptr, 1, n_out, type, howmany, &work_size));
1269#elif defined(AMREX_USE_HIP)
1271 auto prec = std::is_same_v<float,T> ? rocfft_precision_single : rocfft_precision_double;
1273 for (
int idim = 0; idim < M; ++idim) {
length[idim] = fft_size[idim]; }
1275 AMREX_ROCFFT_SAFE_CALL
1276 (rocfft_plan_create(&plan, rocfft_placement_notinplace,
1277 rocfft_transform_type_real_forward, prec, M,
1278 length, howmany,
nullptr));
1280 AMREX_ROCFFT_SAFE_CALL
1281 (rocfft_plan_create(&plan, rocfft_placement_notinplace,
1282 rocfft_transform_type_real_inverse, prec, M,
1283 length, howmany,
nullptr));
1286#elif defined(AMREX_USE_SYCL)
1290 pp =
new mkl_desc_r(fft_size[0]);
1292 std::vector<std::int64_t> len64(M);
1293 for (
int idim = 0; idim < M; ++idim) {
1294 len64[idim] = len[idim];
1296 pp =
new mkl_desc_r(len64);
1298#ifndef AMREX_USE_MKL_DFTI_2024
1299 pp->set_value(oneapi::mkl::dft::config_param::PLACEMENT,
1300 oneapi::mkl::dft::config_value::NOT_INPLACE);
1302 pp->set_value(oneapi::mkl::dft::config_param::PLACEMENT, DFTI_NOT_INPLACE);
1304 pp->set_value(oneapi::mkl::dft::config_param::NUMBER_OF_TRANSFORMS, howmany);
1305 pp->set_value(oneapi::mkl::dft::config_param::FWD_DISTANCE, n);
1306 pp->set_value(oneapi::mkl::dft::config_param::BWD_DISTANCE, nc);
1307 std::vector<std::int64_t> strides(M+1);
1310 for (
int i = M-1; i >= 1; --i) {
1311 strides[i] = strides[i+1] * fft_size[M-1-i];
1314#ifndef AMREX_USE_MKL_DFTI_2024
1315 pp->set_value(oneapi::mkl::dft::config_param::FWD_STRIDES, strides);
1318 pp->set_value(oneapi::mkl::dft::config_param::FWD_STRIDES, strides.data());
1321 pp->set_value(oneapi::mkl::dft::config_param::WORKSPACE,
1322 oneapi::mkl::dft::config_value::WORKSPACE_EXTERNAL);
1323 pp->commit(amrex::Gpu::Device::streamQueue());
1328 if (pf ==
nullptr || pb ==
nullptr) {
1333 if constexpr (std::is_same_v<float,T>) {
1335 plan = fftwf_plan_many_dft_r2c
1336 (M, len, howmany, (
float*)pf,
nullptr, 1, n, (fftwf_complex*)pb,
nullptr, 1, nc,
1339 plan = fftwf_plan_many_dft_c2r
1340 (M, len, howmany, (fftwf_complex*)pb,
nullptr, 1, nc, (
float*)pf,
nullptr, 1, n,
1345 plan = fftw_plan_many_dft_r2c
1346 (M, len, howmany, (
double*)pf,
nullptr, 1, n, (fftw_complex*)pb,
nullptr, 1, nc,
1349 plan = fftw_plan_many_dft_c2r
1350 (M, len, howmany, (fftw_complex*)pb,
nullptr, 1, nc, (
double*)pf,
nullptr, 1, n,
1356#if defined(AMREX_USE_GPU)
1358 if constexpr (std::is_same_v<float,T>) {
1371 template <
typename FA>
1372 typename FA::FABType::value_type *
get_fab (FA& fa)
1375 if (myproc < fa.size()) {
1376 return fa.fabPtr(myproc);
1382 template <
typename FA1,
typename FA2>
1385 bool not_same_fa =
true;
1386 if constexpr (std::is_same_v<FA1,FA2>) {
1387 not_same_fa = (&fa1 != &fa2);
1389 using FAB1 =
typename FA1::FABType::value_type;
1390 using FAB2 =
typename FA2::FABType::value_type;
1391 using T1 =
typename FAB1::value_type;
1392 using T2 =
typename FAB2::value_type;
1394 bool alloc_1 = (myproc < fa1.size());
1395 bool alloc_2 = (myproc < fa2.size()) && not_same_fa;
1397 if (alloc_1 && alloc_2) {
1398 Box const& box1 = fa1.fabbox(myproc);
1399 Box const& box2 = fa2.fabbox(myproc);
1400 int ncomp1 = fa1.nComp();
1401 int ncomp2 = fa2.nComp();
1403 sizeof(T2)*box2.
numPts()*ncomp2));
1404 fa1.setFab(myproc, FAB1(box1, ncomp1, (T1*)p));
1405 fa2.setFab(myproc, FAB2(box2, ncomp2, (T2*)p));
1406 }
else if (alloc_1) {
1407 Box const& box1 = fa1.fabbox(myproc);
1408 int ncomp1 = fa1.nComp();
1410 fa1.setFab(myproc, FAB1(box1, ncomp1, (T1*)p));
1411 }
else if (alloc_2) {
1412 Box const& box2 = fa2.fabbox(myproc);
1413 int ncomp2 = fa2.nComp();
1415 fa2.setFab(myproc, FAB2(box2, ncomp2, (T2*)p));
1427 return {i.
y, i.x, i.z};
1432 return {i.
y, i.
x, i.
z};
1450 return {i.
z, i.y, i.x};
1455 return {i.
z, i.
y, i.
x};
1474 return {i.
y, i.z, i.x};
1480 return {i.
z, i.
x, i.
y};
1499 return {i.
z, i.x, i.y};
1505 return {i.
y, i.
z, i.
x};
1541 template <
typename T>
1544#if (AMREX_SPACEDIM == 1)
1547#elif (AMREX_SPACEDIM == 2)
1548 if (m_case == case_1n) {
1549 return T{a[1],a[0]};
1554 if (m_case == case_11n) {
1555 return T{a[2],a[0],a[1]};
1556 }
else if (m_case == case_1n1) {
1557 return T{a[1],a[0],a[2]};
1558 }
else if (m_case == case_1nn) {
1559 return T{a[1],a[2],a[0]};
1560 }
else if (m_case == case_n1n) {
1561 return T{a[0],a[2],a[1]};
1570 template <
typename FA>
1573 BoxList bl = mf.boxArray().boxList();
1574 for (
auto&
b : bl) {
1577 auto const& ng =
make_iv(mf.nGrowVect());
1579 using FAB =
typename FA::fab_type;
1581 submf.setFab(mfi, FAB(mfi.fabbox(), mf.nComp(), mf[mfi].dataPtr()));
1586#if (AMREX_SPACEDIM == 2)
1587 enum Case { case_1n, case_other };
1588 int m_case = case_other;
1589#elif (AMREX_SPACEDIM == 3)
1590 enum Case { case_11n, case_1n1, case_1nn, case_n1n, case_other };
1591 int m_case = case_other;
#define AMREX_ENUM(CLASS,...)
Definition AMReX_Enum.H:133
#define AMREX_CUFFT_SAFE_CALL(call)
Definition AMReX_GpuError.H:92
#define AMREX_GPU_DEVICE
Definition AMReX_GpuQualifiers.H:18
amrex::ParmParse pp
Input file parser instance for the given namespace.
Definition AMReX_HypreIJIface.cpp:15
Real * pdst
Definition AMReX_HypreMLABecLap.cpp:1090
#define AMREX_D_TERM(a, b, c)
Definition AMReX_SPACE.H:129
virtual void free(void *pt)=0
A pure virtual function for deleting the arena pointed to by pt.
virtual void * alloc(std::size_t sz)=0
A collection of Boxes stored in an Array.
Definition AMReX_BoxArray.H:550
A class for managing a List of Boxes that share a common IndexType. This class implements operations ...
Definition AMReX_BoxList.H:52
AMREX_GPU_HOST_DEVICE IntVectND< dim > length() const noexcept
Return the length of the BoxND.
Definition AMReX_Box.H:146
AMREX_GPU_HOST_DEVICE Long numPts() const noexcept
Returns the number of points contained in the BoxND.
Definition AMReX_Box.H:346
Calculates the distribution of FABs to MPI processes.
Definition AMReX_DistributionMapping.H:41
Definition AMReX_IntVect.H:48
Definition AMReX_MFIter.H:57
bool isValid() const noexcept
Is the iterator valid i.e. is it associated with a FAB?
Definition AMReX_MFIter.H:141
This provides length of period for periodic domains. 0 means it is not periodic in that direction....
Definition AMReX_Periodicity.H:17
std::unique_ptr< char, DataDeleter > make_mfs_share(FA1 &fa1, FA2 &fa2)
Definition AMReX_FFT_Helper.H:1383
FA::FABType::value_type * get_fab(FA &fa)
Definition AMReX_FFT_Helper.H:1372
DistributionMapping make_iota_distromap(Long n)
Definition AMReX_FFT.cpp:88
Definition AMReX_FFT.cpp:7
Direction
Definition AMReX_FFT_Helper.H:48
void add_vendor_plan_f(Key const &key, PlanF plan)
Definition AMReX_FFT.cpp:78
DomainStrategy
Definition AMReX_FFT_Helper.H:50
typename Plan< float >::VendorPlan PlanF
Definition AMReX_FFT_Helper.H:1199
typename Plan< double >::VendorPlan PlanD
Definition AMReX_FFT_Helper.H:1198
void add_vendor_plan_d(Key const &key, PlanD plan)
Definition AMReX_FFT.cpp:73
Kind
Definition AMReX_FFT_Helper.H:54
PlanF * get_vendor_plan_f(Key const &key)
Definition AMReX_FFT.cpp:64
PlanD * get_vendor_plan_d(Key const &key)
Definition AMReX_FFT.cpp:55
std::tuple< IntVectND< 3 >, int, Direction, Kind > Key
Definition AMReX_FFT_Helper.H:1197
void streamSynchronize() noexcept
Definition AMReX_GpuDevice.H:237
gpuStream_t gpuStream() noexcept
Definition AMReX_GpuDevice.H:218
AMREX_GPU_HOST_DEVICE AMREX_FORCE_INLINE std::pair< double, double > sincospi(double x)
Return sin(pi*x) and cos(pi*x) given x.
Definition AMReX_Math.H:165
int MyProcSub() noexcept
my sub-rank in current frame
Definition AMReX_ParallelContext.H:76
std::enable_if_t< std::is_integral_v< T > > ParallelFor(TypeList< CTOs... > ctos, std::array< int, sizeof...(CTOs)> const &runtime_options, T N, F &&f)
Definition AMReX_CTOParallelForImpl.H:191
AMREX_GPU_HOST_DEVICE AMREX_FORCE_INLINE void ignore_unused(const Ts &...)
This shuts up the compiler about unused variables.
Definition AMReX.H:127
AMREX_GPU_HOST_DEVICE AMREX_FORCE_INLINE Dim3 length(Array4< T > const &a) noexcept
Definition AMReX_Array4.H:322
void Abort(const std::string &msg)
Print out message to cerr and exit via abort().
Definition AMReX.cpp:230
const int[]
Definition AMReX_BLProfiler.cpp:1664
Arena * The_Arena()
Definition AMReX_Arena.cpp:616
Definition AMReX_FabArrayCommI.H:896
Definition AMReX_DataAllocator.H:29
Definition AMReX_Dim3.H:12
int x
Definition AMReX_Dim3.H:12
int z
Definition AMReX_Dim3.H:12
int y
Definition AMReX_Dim3.H:12
Definition AMReX_FFT_Helper.H:58
bool twod_mode
Definition AMReX_FFT_Helper.H:69
Info & setNumProcs(int n)
Definition AMReX_FFT_Helper.H:81
int batch_size
Batched FFT size. Only support in R2C, not R2X.
Definition AMReX_FFT_Helper.H:72
Info & setDomainStrategy(DomainStrategy s)
Definition AMReX_FFT_Helper.H:77
DomainStrategy domain_strategy
Domain composition strategy.
Definition AMReX_FFT_Helper.H:60
int nprocs
Max number of processes to use.
Definition AMReX_FFT_Helper.H:75
int pencil_threshold
Definition AMReX_FFT_Helper.H:64
Info & setBatchSize(int bsize)
Definition AMReX_FFT_Helper.H:80
Info & setPencilThreshold(int t)
Definition AMReX_FFT_Helper.H:78
Info & setTwoDMode(bool x)
Definition AMReX_FFT_Helper.H:79
Definition AMReX_FFT_Helper.H:126
void * pf
Definition AMReX_FFT_Helper.H:161
void unpack_r2r_buffer(T *pdst, void const *pbuf) const
Definition AMReX_FFT_Helper.H:975
std::conditional_t< std::is_same_v< float, T >, cuComplex, cuDoubleComplex > VendorComplex
Definition AMReX_FFT_Helper.H:130
VendorPlan plan2
Definition AMReX_FFT_Helper.H:160
int n
Definition AMReX_FFT_Helper.H:153
void destroy()
Definition AMReX_FFT_Helper.H:171
bool defined2
Definition AMReX_FFT_Helper.H:158
void init_r2r(Box const &box, VendorComplex *pc, std::pair< Boundary, Boundary > const &bc)
Definition AMReX_FFT_Helper.H:596
void pack_r2r_buffer(void *pbuf, T const *psrc) const
Definition AMReX_FFT_Helper.H:743
static void free_scratch_space(void *p)
Definition AMReX_FFT_Helper.H:741
static void destroy_vendor_plan(VendorPlan plan)
Definition AMReX_FFT_Helper.H:1179
Kind get_r2r_kind(std::pair< Boundary, Boundary > const &bc)
Definition AMReX_FFT_Helper.H:474
cufftHandle VendorPlan
Definition AMReX_FFT_Helper.H:128
Kind kind
Definition AMReX_FFT_Helper.H:155
void init_c2c(Box const &box, VendorComplex *p, int ncomp=1, int ndims=1)
Definition AMReX_FFT_Helper.H:313
int howmany
Definition AMReX_FFT_Helper.H:154
void * pb
Definition AMReX_FFT_Helper.H:162
void init_r2c(Box const &box, T *pr, VendorComplex *pc, bool is_2d_transform=false, int ncomp=1)
Definition AMReX_FFT_Helper.H:186
void compute_r2r()
Definition AMReX_FFT_Helper.H:1125
void compute_c2c()
Definition AMReX_FFT_Helper.H:688
bool r2r_data_is_complex
Definition AMReX_FFT_Helper.H:156
void * alloc_scratch_space() const
Definition AMReX_FFT_Helper.H:727
VendorPlan plan
Definition AMReX_FFT_Helper.H:159
void compute_r2c()
Definition AMReX_FFT_Helper.H:639
bool defined
Definition AMReX_FFT_Helper.H:157
void set_ptrs(void *p0, void *p1)
Definition AMReX_FFT_Helper.H:165
void init_r2r(Box const &box, T *p, std::pair< Boundary, Boundary > const &bc, int howmany_initval=1)
Definition AMReX_FFT_Helper.H:500
void init_r2c(IntVectND< M > const &fft_size, void *, void *, bool cache, int ncomp=1)
Definition AMReX_FFT_Helper.H:1209
Definition AMReX_FFT_Helper.H:1495
constexpr Dim3 operator()(Dim3 i) const noexcept
Definition AMReX_FFT_Helper.H:1497
static constexpr IndexType Inverse(IndexType it)
Definition AMReX_FFT_Helper.H:1513
static constexpr Dim3 Inverse(Dim3 i)
Definition AMReX_FFT_Helper.H:1503
Definition AMReX_FFT_Helper.H:1470
static constexpr IndexType Inverse(IndexType it)
Definition AMReX_FFT_Helper.H:1488
constexpr Dim3 operator()(Dim3 i) const noexcept
Definition AMReX_FFT_Helper.H:1472
static constexpr Dim3 Inverse(Dim3 i)
Definition AMReX_FFT_Helper.H:1478
Definition AMReX_FFT_Helper.H:1424
constexpr Dim3 operator()(Dim3 i) const noexcept
Definition AMReX_FFT_Helper.H:1425
static constexpr IndexType Inverse(IndexType it)
Definition AMReX_FFT_Helper.H:1440
static constexpr Dim3 Inverse(Dim3 i)
Definition AMReX_FFT_Helper.H:1430
Definition AMReX_FFT_Helper.H:1447
static constexpr Dim3 Inverse(Dim3 i)
Definition AMReX_FFT_Helper.H:1453
static constexpr IndexType Inverse(IndexType it)
Definition AMReX_FFT_Helper.H:1463
constexpr Dim3 operator()(Dim3 i) const noexcept
Definition AMReX_FFT_Helper.H:1448
Definition AMReX_FFT_Helper.H:1522
T make_array(T const &a) const
Definition AMReX_FFT_Helper.H:1542
Box make_box(Box const &box) const
Definition AMReX_FFT.cpp:142
BoxArray inverse_boxarray(BoxArray const &ba) const
Definition AMReX_FFT.cpp:209
bool ghost_safe(IntVect const &ng) const
Definition AMReX_FFT.cpp:152
GpuArray< int, 3 > xyz_order() const
Definition AMReX_FFT.cpp:326
IntVect inverse_order(IntVect const &order) const
Definition AMReX_FFT.cpp:266
IntVect make_iv(IntVect const &iv) const
Definition AMReX_FFT.cpp:178
FA make_alias_mf(FA const &mf)
Definition AMReX_FFT_Helper.H:1571
Periodicity make_periodicity(Periodicity const &period) const
Definition AMReX_FFT.cpp:147
IntVect make_safe_ghost(IntVect const &ng) const
Definition AMReX_FFT.cpp:183
Definition AMReX_Array.H:34
A host / device complex number type, because std::complex doesn't work in device code with Cuda yet.
Definition AMReX_GpuComplex.H:29
FabArray memory allocation information.
Definition AMReX_FabArray.H:66
MFInfo & SetAlloc(bool a) noexcept
Definition AMReX_FabArray.H:73
Definition AMReX_MFIter.H:20