1 #ifndef AMREX_FFT_HELPER_H_
2 #define AMREX_FFT_HELPER_H_
3 #include <AMReX_Config.H>
16 #if defined(AMREX_USE_CUDA)
18 # include <cuComplex.h>
19 #elif defined(AMREX_USE_HIP)
20 # if __has_include(<rocfft/rocfft.h>)
21 # include <rocfft/rocfft.h>
25 # include <hip/hip_complex.h>
26 #elif defined(AMREX_USE_SYCL)
27 # if __has_include(<oneapi/mkl/dft.hpp>)
28 # include <oneapi/mkl/dft.hpp>
30 # define AMREX_USE_MKL_DFTI_2024 1
31 # include <oneapi/mkl/dfti.hpp>
72 namespace detail {
void hip_execute (rocfft_plan plan,
void **in,
void **out); }
78 template <
typename T, Direction direction,
typename P,
typename TI,
typename TO>
79 void sycl_execute (
P* plan, TI* in, TO* out)
81 #ifndef AMREX_USE_MKL_DFTI_2024
82 std::int64_t workspaceSize = 0;
84 std::size_t workspaceSize = 0;
86 plan->get_value(oneapi::mkl::dft::config_param::WORKSPACE_BYTES,
89 plan->set_workspace(buffer);
91 if (std::is_same_v<TI,TO>) {
94 r = oneapi::mkl::dft::compute_forward(*plan, out);
96 r = oneapi::mkl::dft::compute_backward(*plan, out);
100 r = oneapi::mkl::dft::compute_forward(*plan, in, out);
102 r = oneapi::mkl::dft::compute_backward(*plan, in, out);
111 template <
typename T>
114 #if defined(AMREX_USE_CUDA)
117 cuComplex, cuDoubleComplex>;
118 #elif defined(AMREX_USE_HIP)
120 using VendorComplex = std::conditional_t<std::is_same_v<float,T>,
122 #elif defined(AMREX_USE_SYCL)
123 using mkl_desc_r = oneapi::mkl::dft::descriptor<std::is_same_v<float,T>
124 ? oneapi::mkl::dft::precision::SINGLE
125 : oneapi::mkl::dft::precision::DOUBLE,
126 oneapi::mkl::dft::domain::REAL>;
127 using mkl_desc_c = oneapi::mkl::dft::descriptor<std::is_same_v<float,T>
128 ? oneapi::mkl::dft::precision::SINGLE
129 : oneapi::mkl::dft::precision::DOUBLE,
130 oneapi::mkl::dft::domain::COMPLEX>;
131 using VendorPlan = std::variant<mkl_desc_r*,mkl_desc_c*>;
134 using VendorPlan = std::conditional_t<std::is_same_v<float,T>,
135 fftwf_plan, fftw_plan>;
136 using VendorComplex = std::conditional_t<std::is_same_v<float,T>,
137 fftwf_complex, fftw_complex>;
164 #if !defined(AMREX_USE_GPU)
172 template <Direction D>
177 int rank = is_2d_transform ? 2 : 1;
192 int nr = (rank == 1) ? len[0] : len[0]*len[1];
194 int nc = (rank == 1) ? (len[0]/2+1) : (len[1]/2+1)*len[0];
195 #if (AMREX_SPACEDIM == 1)
204 #if defined(AMREX_USE_CUDA)
208 std::size_t work_size;
210 cufftType fwd_type = std::is_same_v<float,T> ? CUFFT_R2C : CUFFT_D2Z;
212 (cufftMakePlanMany(
plan, rank, len,
nullptr, 1, nr,
nullptr, 1, nc, fwd_type,
howmany, &work_size));
214 cufftType bwd_type = std::is_same_v<float,T> ? CUFFT_C2R : CUFFT_Z2D;
216 (cufftMakePlanMany(
plan, rank, len,
nullptr, 1, nc,
nullptr, 1, nr, bwd_type,
howmany, &work_size));
219 #elif defined(AMREX_USE_HIP)
221 auto prec = std::is_same_v<float,T> ? rocfft_precision_single : rocfft_precision_double;
223 std::size_t
length[2] = {std::size_t(len[1]), std::size_t(len[0])};
225 AMREX_ROCFFT_SAFE_CALL
226 (rocfft_plan_create(&
plan, rocfft_placement_notinplace,
227 rocfft_transform_type_real_forward, prec, rank,
230 AMREX_ROCFFT_SAFE_CALL
231 (rocfft_plan_create(&
plan, rocfft_placement_notinplace,
232 rocfft_transform_type_real_inverse, prec, rank,
236 #elif defined(AMREX_USE_SYCL)
240 pp =
new mkl_desc_r(len[0]);
242 pp =
new mkl_desc_r({std::int64_t(len[0]), std::int64_t(len[1])});
244 #ifndef AMREX_USE_MKL_DFTI_2024
245 pp->set_value(oneapi::mkl::dft::config_param::PLACEMENT,
246 oneapi::mkl::dft::config_value::NOT_INPLACE);
248 pp->set_value(oneapi::mkl::dft::config_param::PLACEMENT, DFTI_NOT_INPLACE);
250 pp->set_value(oneapi::mkl::dft::config_param::NUMBER_OF_TRANSFORMS,
howmany);
251 pp->set_value(oneapi::mkl::dft::config_param::FWD_DISTANCE, nr);
252 pp->set_value(oneapi::mkl::dft::config_param::BWD_DISTANCE, nc);
253 std::vector<std::int64_t> strides;
254 strides.push_back(0);
255 if (rank == 2) { strides.push_back(len[1]); }
256 strides.push_back(1);
257 #ifndef AMREX_USE_MKL_DFTI_2024
258 pp->set_value(oneapi::mkl::dft::config_param::FWD_STRIDES, strides);
261 pp->set_value(oneapi::mkl::dft::config_param::FWD_STRIDES, strides.data());
264 pp->set_value(oneapi::mkl::dft::config_param::WORKSPACE,
265 oneapi::mkl::dft::config_value::WORKSPACE_EXTERNAL);
266 pp->commit(amrex::Gpu::Device::streamQueue());
271 if constexpr (std::is_same_v<float,T>) {
273 plan = fftwf_plan_many_dft_r2c
274 (rank, len,
howmany, pr,
nullptr, 1, nr, pc,
nullptr, 1, nc,
275 FFTW_ESTIMATE | FFTW_DESTROY_INPUT);
277 plan = fftwf_plan_many_dft_c2r
278 (rank, len,
howmany, pc,
nullptr, 1, nc, pr,
nullptr, 1, nr,
279 FFTW_ESTIMATE | FFTW_DESTROY_INPUT);
283 plan = fftw_plan_many_dft_r2c
284 (rank, len,
howmany, pr,
nullptr, 1, nr, pc,
nullptr, 1, nc,
285 FFTW_ESTIMATE | FFTW_DESTROY_INPUT);
287 plan = fftw_plan_many_dft_c2r
288 (rank, len,
howmany, pc,
nullptr, 1, nc, pr,
nullptr, 1, nr,
289 FFTW_ESTIMATE | FFTW_DESTROY_INPUT);
295 template <Direction D,
int M>
298 template <Direction D>
311 #if defined(AMREX_USE_CUDA)
315 cufftType t = std::is_same_v<float,T> ? CUFFT_C2C : CUFFT_Z2Z;
316 std::size_t work_size;
318 (cufftMakePlanMany(
plan, 1, &
n,
nullptr, 1,
n,
nullptr, 1,
n, t,
howmany, &work_size));
320 #elif defined(AMREX_USE_HIP)
322 auto prec = std::is_same_v<float,T> ? rocfft_precision_single
323 : rocfft_precision_double;
325 : rocfft_transform_type_complex_inverse;
327 AMREX_ROCFFT_SAFE_CALL
328 (rocfft_plan_create(&
plan, rocfft_placement_inplace, dir, prec, 1,
331 #elif defined(AMREX_USE_SYCL)
333 auto*
pp =
new mkl_desc_c(
n);
334 #ifndef AMREX_USE_MKL_DFTI_2024
335 pp->set_value(oneapi::mkl::dft::config_param::PLACEMENT,
336 oneapi::mkl::dft::config_value::INPLACE);
338 pp->set_value(oneapi::mkl::dft::config_param::PLACEMENT, DFTI_INPLACE);
340 pp->set_value(oneapi::mkl::dft::config_param::NUMBER_OF_TRANSFORMS,
howmany);
341 pp->set_value(oneapi::mkl::dft::config_param::FWD_DISTANCE,
n);
342 pp->set_value(oneapi::mkl::dft::config_param::BWD_DISTANCE,
n);
343 std::vector<std::int64_t> strides = {0,1};
344 #ifndef AMREX_USE_MKL_DFTI_2024
345 pp->set_value(oneapi::mkl::dft::config_param::FWD_STRIDES, strides);
346 pp->set_value(oneapi::mkl::dft::config_param::BWD_STRIDES, strides);
348 pp->set_value(oneapi::mkl::dft::config_param::FWD_STRIDES, strides.data());
349 pp->set_value(oneapi::mkl::dft::config_param::BWD_STRIDES, strides.data());
351 pp->set_value(oneapi::mkl::dft::config_param::WORKSPACE,
352 oneapi::mkl::dft::config_value::WORKSPACE_EXTERNAL);
353 pp->commit(amrex::Gpu::Device::streamQueue());
358 if constexpr (std::is_same_v<float,T>) {
360 plan = fftwf_plan_many_dft
361 (1, &
n,
howmany, p,
nullptr, 1,
n, p,
nullptr, 1,
n, -1,
364 plan = fftwf_plan_many_dft
365 (1, &
n,
howmany, p,
nullptr, 1,
n, p,
nullptr, 1,
n, +1,
370 plan = fftw_plan_many_dft
371 (1, &
n,
howmany, p,
nullptr, 1,
n, p,
nullptr, 1,
n, -1,
374 plan = fftw_plan_many_dft
375 (1, &
n,
howmany, p,
nullptr, 1,
n, p,
nullptr, 1,
n, +1,
382 #ifndef AMREX_USE_GPU
383 template <Direction D>
384 fftw_r2r_kind get_fftw_kind (std::pair<Boundary,Boundary>
const& bc)
386 if (bc.first == Boundary::even && bc.second == Boundary::even)
390 else if (bc.first == Boundary::even && bc.second == Boundary::odd)
394 else if (bc.first == Boundary::odd && bc.second == Boundary::even)
398 else if (bc.first == Boundary::odd && bc.second == Boundary::odd)
404 return fftw_r2r_kind{};
410 template <Direction D>
413 if (bc.first == Boundary::even && bc.second == Boundary::even)
417 else if (bc.first == Boundary::even && bc.second == Boundary::odd)
421 else if (bc.first == Boundary::odd && bc.second == Boundary::even)
425 else if (bc.first == Boundary::odd && bc.second == Boundary::odd)
436 template <Direction D>
437 void init_r2r (
Box const& box, T* p, std::pair<Boundary,Boundary>
const& bc,
438 int howmany_initval = 1)
442 kind = get_r2r_kind<D>(bc);
450 #if defined(AMREX_USE_GPU)
452 if (bc.first == Boundary::odd && bc.second == Boundary::odd &&
455 }
else if (bc.first == Boundary::odd && bc.second == Boundary::odd &&
458 }
else if (bc.first == Boundary::even && bc.second == Boundary::even &&
461 }
else if (bc.first == Boundary::even && bc.second == Boundary::even &&
464 }
else if ((bc.first == Boundary::even && bc.second == Boundary::odd) ||
465 (bc.first == Boundary::odd && bc.second == Boundary::even)) {
470 int nc = (nex/2) + 1;
472 #if defined (AMREX_USE_CUDA)
476 cufftType fwd_type = std::is_same_v<float,T> ? CUFFT_R2C : CUFFT_D2Z;
477 std::size_t work_size;
479 (cufftMakePlanMany(
plan, 1, &nex,
nullptr, 1, nc*2,
nullptr, 1, nc, fwd_type,
howmany, &work_size));
481 #elif defined(AMREX_USE_HIP)
484 auto prec = std::is_same_v<float,T> ? rocfft_precision_single : rocfft_precision_double;
485 const std::size_t
length = nex;
486 AMREX_ROCFFT_SAFE_CALL
487 (rocfft_plan_create(&
plan, rocfft_placement_inplace,
488 rocfft_transform_type_real_forward, prec, 1,
491 #elif defined(AMREX_USE_SYCL)
493 auto*
pp =
new mkl_desc_r(nex);
494 #ifndef AMREX_USE_MKL_DFTI_2024
495 pp->set_value(oneapi::mkl::dft::config_param::PLACEMENT,
496 oneapi::mkl::dft::config_value::INPLACE);
498 pp->set_value(oneapi::mkl::dft::config_param::PLACEMENT, DFTI_INPLACE);
500 pp->set_value(oneapi::mkl::dft::config_param::NUMBER_OF_TRANSFORMS,
howmany);
501 pp->set_value(oneapi::mkl::dft::config_param::FWD_DISTANCE, nc*2);
502 pp->set_value(oneapi::mkl::dft::config_param::BWD_DISTANCE, nc);
503 std::vector<std::int64_t> strides = {0,1};
504 #ifndef AMREX_USE_MKL_DFTI_2024
505 pp->set_value(oneapi::mkl::dft::config_param::FWD_STRIDES, strides);
506 pp->set_value(oneapi::mkl::dft::config_param::BWD_STRIDES, strides);
508 pp->set_value(oneapi::mkl::dft::config_param::FWD_STRIDES, strides.data());
509 pp->set_value(oneapi::mkl::dft::config_param::BWD_STRIDES, strides.data());
511 pp->set_value(oneapi::mkl::dft::config_param::WORKSPACE,
512 oneapi::mkl::dft::config_value::WORKSPACE_EXTERNAL);
513 pp->commit(amrex::Gpu::Device::streamQueue());
519 auto fftw_kind = get_fftw_kind<D>(bc);
520 if constexpr (std::is_same_v<float,T>) {
521 plan = fftwf_plan_many_r2r
522 (1, &
n,
howmany, p,
nullptr, 1,
n, p,
nullptr, 1,
n, &fftw_kind,
525 plan = fftw_plan_many_r2r
526 (1, &
n,
howmany, p,
nullptr, 1,
n, p,
nullptr, 1,
n, &fftw_kind,
532 template <Direction D>
534 std::pair<Boundary,Boundary>
const& bc)
540 #if defined(AMREX_USE_GPU)
542 init_r2r<D>(box, p, bc, 2);
547 kind = get_r2r_kind<D>(bc);
556 auto fftw_kind = get_fftw_kind<D>(bc);
557 if constexpr (std::is_same_v<float,T>) {
558 plan = fftwf_plan_many_r2r
559 (1, &
n,
howmany, p,
nullptr, 2,
n*2, p,
nullptr, 2,
n*2, &fftw_kind,
561 plan2 = fftwf_plan_many_r2r
562 (1, &
n,
howmany, p+1,
nullptr, 2,
n*2, p+1,
nullptr, 2,
n*2, &fftw_kind,
565 plan = fftw_plan_many_r2r
566 (1, &
n,
howmany, p,
nullptr, 2,
n*2, p,
nullptr, 2,
n*2, &fftw_kind,
568 plan2 = fftw_plan_many_r2r
569 (1, &
n,
howmany, p+1,
nullptr, 2,
n*2, p+1,
nullptr, 2,
n*2, &fftw_kind,
575 template <Direction D>
586 #if defined(AMREX_USE_CUDA)
589 std::size_t work_size = 0;
596 if constexpr (std::is_same_v<float,T>) {
602 if constexpr (std::is_same_v<float,T>) {
610 #elif defined(AMREX_USE_HIP)
611 detail::hip_execute(
plan, (
void**)&
pi, (
void**)&po);
612 #elif defined(AMREX_USE_SYCL)
613 detail::sycl_execute<T,D>(std::get<0>(
plan),
pi, po);
616 if constexpr (std::is_same_v<float,T>) {
624 template <Direction D>
632 #if defined(AMREX_USE_CUDA)
635 std::size_t work_size = 0;
642 if constexpr (std::is_same_v<float,T>) {
649 #elif defined(AMREX_USE_HIP)
650 detail::hip_execute(
plan, (
void**)&p, (
void**)&p);
651 #elif defined(AMREX_USE_SYCL)
652 detail::sycl_execute<T,D>(std::get<1>(
plan), p, p);
655 if constexpr (std::is_same_v<float,T>) {
673 amrex::Abort(
"FFT: alloc_scratch_space: unsupported kind");
682 auto*
pdst = (T*) pbuf;
685 int ostride = (
n+1)*2;
689 Long nelems = Long(nex)*
howmany;
693 auto batch = ielem / Long(nex);
694 auto i =
int(ielem - batch*nex);
695 for (
int ir = 0; ir < 2; ++ir) {
696 auto* po =
pdst + (2*batch+ir)*ostride + i;
697 auto const*
pi = psrc + 2*batch*istride + ir;
701 *po = sign *
pi[(2*norig-1-i)*2];
708 auto batch = ielem / Long(nex);
709 auto i =
int(ielem - batch*nex);
710 auto* po =
pdst + batch*ostride + i;
711 auto const*
pi = psrc + batch*istride;
715 *po = sign *
pi[2*norig-1-i];
720 int ostride = (2*
n+1)*2;
724 Long nelems = Long(nex)*
howmany;
728 auto batch = ielem / Long(nex);
729 auto i =
int(ielem - batch*nex);
730 for (
int ir = 0; ir < 2; ++ir) {
731 auto* po =
pdst + (2*batch+ir)*ostride + i;
732 auto const*
pi = psrc + 2*batch*istride + ir;
735 }
else if (i < (2*norig-1)) {
736 *po =
pi[(2*norig-2-i)*2];
737 }
else if (i == (2*norig-1)) {
739 }
else if (i < (3*norig)) {
740 *po = -
pi[(i-2*norig)*2];
741 }
else if (i < (4*norig-1)) {
742 *po = -
pi[(4*norig-2-i)*2];
751 auto batch = ielem / Long(nex);
752 auto i =
int(ielem - batch*nex);
753 auto* po =
pdst + batch*ostride + i;
754 auto const*
pi = psrc + batch*istride;
757 }
else if (i < (2*norig-1)) {
758 *po =
pi[2*norig-2-i];
759 }
else if (i == (2*norig-1)) {
761 }
else if (i < (3*norig)) {
762 *po = -
pi[i-2*norig];
763 }
else if (i < (4*norig-1)) {
764 *po = -
pi[4*norig-2-i];
771 int ostride = (2*
n+1)*2;
775 Long nelems = Long(nex)*
howmany;
779 auto batch = ielem / Long(nex);
780 auto i =
int(ielem - batch*nex);
781 for (
int ir = 0; ir < 2; ++ir) {
782 auto* po =
pdst + (2*batch+ir)*ostride + i;
783 auto const*
pi = psrc + 2*batch*istride + ir;
786 }
else if (i == norig) {
788 }
else if (i < (2*norig+1)) {
789 *po = -
pi[(2*norig-i)*2];
790 }
else if (i < (3*norig)) {
791 *po = -
pi[(i-2*norig)*2];
792 }
else if (i == 3*norig) {
795 *po =
pi[(4*norig-i)*2];
802 auto batch = ielem / Long(nex);
803 auto i =
int(ielem - batch*nex);
804 auto* po =
pdst + batch*ostride + i;
805 auto const*
pi = psrc + batch*istride;
808 }
else if (i == norig) {
810 }
else if (i < (2*norig+1)) {
811 *po = -
pi[2*norig-i];
812 }
else if (i < (3*norig)) {
813 *po = -
pi[i-2*norig];
814 }
else if (i == 3*norig) {
822 int ostride = (2*
n+1)*2;
826 Long nelems = Long(nex)*
howmany;
830 auto batch = ielem / Long(nex);
831 auto i =
int(ielem - batch*nex);
832 for (
int ir = 0; ir < 2; ++ir) {
833 auto* po =
pdst + (2*batch+ir)*ostride + i;
834 auto const*
pi = psrc + 2*batch*istride + ir;
837 }
else if (i < (2*norig)) {
838 *po = -
pi[(2*norig-1-i)*2];
839 }
else if (i < (3*norig)) {
840 *po = -
pi[(i-2*norig)*2];
842 *po =
pi[(4*norig-1-i)*2];
849 auto batch = ielem / Long(nex);
850 auto i =
int(ielem - batch*nex);
851 auto* po =
pdst + batch*ostride + i;
852 auto const*
pi = psrc + batch*istride;
855 }
else if (i < (2*norig)) {
856 *po = -
pi[2*norig-1-i];
857 }
else if (i < (3*norig)) {
858 *po = -
pi[i-2*norig];
860 *po =
pi[4*norig-1-i];
865 int ostride = (2*
n+1)*2;
869 Long nelems = Long(nex)*
howmany;
873 auto batch = ielem / Long(nex);
874 auto i =
int(ielem - batch*nex);
875 for (
int ir = 0; ir < 2; ++ir) {
876 auto* po =
pdst + (2*batch+ir)*ostride + i;
877 auto const*
pi = psrc + 2*batch*istride + ir;
880 }
else if (i < (2*norig)) {
881 *po =
pi[(2*norig-1-i)*2];
882 }
else if (i < (3*norig)) {
883 *po = -
pi[(i-2*norig)*2];
885 *po = -
pi[(4*norig-1-i)*2];
892 auto batch = ielem / Long(nex);
893 auto i =
int(ielem - batch*nex);
894 auto* po =
pdst + batch*ostride + i;
895 auto const*
pi = psrc + batch*istride;
898 }
else if (i < (2*norig)) {
899 *po =
pi[2*norig-1-i];
900 }
else if (i < (3*norig)) {
901 *po = -
pi[i-2*norig];
903 *po = -
pi[4*norig-1-i];
916 Long nelems = Long(norig)*
howmany;
924 auto batch = ielem / Long(norig);
925 auto k =
int(ielem - batch*norig);
927 for (
int ir = 0; ir < 2; ++ir) {
928 auto const& yk = psrc[(2*batch+ir)*istride+k+1];
929 pdst[2*batch*ostride+ir+k*2] = s * yk.real() - c * yk.imag();
935 auto batch = ielem / Long(norig);
936 auto k =
int(ielem - batch*norig);
938 auto const& yk = psrc[batch*istride+k+1];
939 pdst[batch*ostride+k] = s * yk.real() - c * yk.imag();
947 auto batch = ielem / Long(norig);
948 auto k =
int(ielem - batch*norig);
950 for (
int ir = 0; ir < 2; ++ir) {
951 auto const& yk = psrc[(2*batch+ir)*istride+2*k+1];
952 pdst[2*batch*ostride+ir+k*2] = T(0.5)*(s * yk.real() - c * yk.imag());
958 auto batch = ielem / Long(norig);
959 auto k =
int(ielem - batch*norig);
961 auto const& yk = psrc[batch*istride+2*k+1];
962 pdst[batch*ostride+k] = T(0.5)*(s * yk.real() - c * yk.imag());
970 auto batch = ielem / Long(norig);
971 auto k =
int(ielem - batch*norig);
973 for (
int ir = 0; ir < 2; ++ir) {
974 auto const& yk = psrc[(2*batch+ir)*istride+k];
975 pdst[2*batch*ostride+ir+k*2] = c * yk.real() + s * yk.imag();
981 auto batch = ielem / Long(norig);
982 auto k =
int(ielem - batch*norig);
984 auto const& yk = psrc[batch*istride+k];
985 pdst[batch*ostride+k] = c * yk.real() + s * yk.imag();
993 auto batch = ielem / Long(norig);
994 auto k =
int(ielem - batch*norig);
995 for (
int ir = 0; ir < 2; ++ir) {
996 auto const& yk = psrc[(2*batch+ir)*istride+2*k+1];
997 pdst[2*batch*ostride+ir+k*2] = T(0.5) * yk.real();
1003 auto batch = ielem / Long(norig);
1004 auto k =
int(ielem - batch*norig);
1005 auto const& yk = psrc[batch*istride+2*k+1];
1006 pdst[batch*ostride+k] = T(0.5) * yk.real();
1010 int istride = 2*
n+1;
1014 auto batch = ielem / Long(norig);
1015 auto k =
int(ielem - batch*norig);
1017 for (
int ir = 0; ir < 2; ++ir) {
1018 auto const& yk = psrc[(2*batch+ir)*istride+2*k+1];
1019 pdst[2*batch*ostride+ir+k*2] = T(0.5) * (c * yk.real() + s * yk.imag());
1025 auto batch = ielem / Long(norig);
1026 auto k =
int(ielem - batch*norig);
1028 auto const& yk = psrc[batch*istride+2*k+1];
1029 pdst[batch*ostride+k] = T(0.5) * (c * yk.real() + s * yk.imag());
1033 int istride = 2*
n+1;
1037 auto batch = ielem / Long(norig);
1038 auto k =
int(ielem - batch*norig);
1040 for (
int ir = 0; ir < 2; ++ir) {
1041 auto const& yk = psrc[(2*batch+ir)*istride+2*k+1];
1042 pdst[2*batch*ostride+ir+k*2] = T(0.5) * (s * yk.real() - c * yk.imag());
1048 auto batch = ielem / Long(norig);
1049 auto k =
int(ielem - batch*norig);
1051 auto const& yk = psrc[batch*istride+2*k+1];
1052 pdst[batch*ostride+k] = T(0.5) * (s * yk.real() - c * yk.imag());
1056 amrex::Abort(
"FFT: unpack_r2r_buffer: unsupported kind");
1061 template <Direction D>
1067 #if defined(AMREX_USE_GPU)
1073 #if defined(AMREX_USE_CUDA)
1077 std::size_t work_size = 0;
1083 if constexpr (std::is_same_v<float,T>) {
1089 #elif defined(AMREX_USE_HIP)
1090 detail::hip_execute(
plan, (
void**)&pscratch, (
void**)&pscratch);
1091 #elif defined(AMREX_USE_SYCL)
1092 detail::sycl_execute<T,Direction::forward>(std::get<0>(
plan), (T*)pscratch, (
VendorComplex*)pscratch);
1099 #if defined(AMREX_USE_CUDA)
1105 if constexpr (std::is_same_v<float,T>) {
1106 fftwf_execute(
plan);
1118 #if defined(AMREX_USE_CUDA)
1120 #elif defined(AMREX_USE_HIP)
1121 AMREX_ROCFFT_SAFE_CALL(rocfft_plan_destroy(
plan));
1122 #elif defined(AMREX_USE_SYCL)
1123 std::visit([](
auto&& p) {
delete p; },
plan);
1125 if constexpr (std::is_same_v<float,T>) {
1126 fftwf_destroy_plan(
plan);
1128 fftw_destroy_plan(
plan);
1144 template <
typename T>
1145 template <Direction D,
int M>
1156 for (
auto s : fft_size) { n *= s; }
1159 #if defined(AMREX_USE_GPU)
1160 Key key = {fft_size.template expand<3>(), D, kind};
1163 if constexpr (std::is_same_v<float,T>) {
1169 plan = *cached_plan;
1177 #if defined(AMREX_USE_CUDA)
1183 type = std::is_same_v<float,T> ? CUFFT_R2C : CUFFT_D2Z;
1185 type = std::is_same_v<float,T> ? CUFFT_C2R : CUFFT_Z2D;
1187 std::size_t work_size;
1188 if constexpr (
M == 1) {
1190 (cufftMakePlan1d(plan, fft_size[0], type, howmany, &work_size));
1191 }
else if constexpr (
M == 2) {
1193 (cufftMakePlan2d(plan, fft_size[1], fft_size[0], type, &work_size));
1194 }
else if constexpr (
M == 3) {
1196 (cufftMakePlan3d(plan, fft_size[2], fft_size[1], fft_size[0], type, &work_size));
1199 #elif defined(AMREX_USE_HIP)
1201 auto prec = std::is_same_v<float,T> ? rocfft_precision_single : rocfft_precision_double;
1203 for (
int idim = 0; idim <
M; ++idim) {
length[idim] = fft_size[idim]; }
1205 AMREX_ROCFFT_SAFE_CALL
1206 (rocfft_plan_create(&plan, rocfft_placement_notinplace,
1207 rocfft_transform_type_real_forward, prec,
M,
1208 length, howmany,
nullptr));
1210 AMREX_ROCFFT_SAFE_CALL
1211 (rocfft_plan_create(&plan, rocfft_placement_notinplace,
1212 rocfft_transform_type_real_inverse, prec,
M,
1213 length, howmany,
nullptr));
1216 #elif defined(AMREX_USE_SYCL)
1220 pp =
new mkl_desc_r(fft_size[0]);
1222 std::vector<std::int64_t> len(
M);
1223 for (
int idim = 0; idim <
M; ++idim) {
1224 len[idim] = fft_size[
M-1-idim];
1226 pp =
new mkl_desc_r(len);
1228 #ifndef AMREX_USE_MKL_DFTI_2024
1229 pp->set_value(oneapi::mkl::dft::config_param::PLACEMENT,
1230 oneapi::mkl::dft::config_value::NOT_INPLACE);
1232 pp->set_value(oneapi::mkl::dft::config_param::PLACEMENT, DFTI_NOT_INPLACE);
1235 std::vector<std::int64_t> strides(
M+1);
1238 for (
int i =
M-1; i >= 1; --i) {
1239 strides[i] = strides[i+1] * fft_size[
M-1-i];
1242 #ifndef AMREX_USE_MKL_DFTI_2024
1243 pp->set_value(oneapi::mkl::dft::config_param::FWD_STRIDES, strides);
1246 pp->set_value(oneapi::mkl::dft::config_param::FWD_STRIDES, strides.data());
1249 pp->set_value(oneapi::mkl::dft::config_param::WORKSPACE,
1250 oneapi::mkl::dft::config_value::WORKSPACE_EXTERNAL);
1251 pp->commit(amrex::Gpu::Device::streamQueue());
1256 if (pf ==
nullptr || pb ==
nullptr) {
1261 int size_for_row_major[
M];
1262 for (
int idim = 0; idim <
M; ++idim) {
1263 size_for_row_major[idim] = fft_size[
M-1-idim];
1266 if constexpr (std::is_same_v<float,T>) {
1268 plan = fftwf_plan_dft_r2c
1269 (
M, size_for_row_major, (
float*)pf, (fftwf_complex*)pb,
1272 plan = fftwf_plan_dft_c2r
1273 (
M, size_for_row_major, (fftwf_complex*)pb, (
float*)pf,
1278 plan = fftw_plan_dft_r2c
1279 (
M, size_for_row_major, (
double*)pf, (fftw_complex*)pb,
1282 plan = fftw_plan_dft_c2r
1283 (
M, size_for_row_major, (fftw_complex*)pb, (
double*)pf,
1289 #if defined(AMREX_USE_GPU)
1291 if constexpr (std::is_same_v<float,T>) {
1304 template <
typename FA>
1305 typename FA::FABType::value_type *
get_fab (FA& fa)
1308 if (myproc < fa.size()) {
1309 return fa.fabPtr(myproc);
1315 template <
typename FA1,
typename FA2>
1318 bool not_same_fa =
true;
1319 if constexpr (std::is_same_v<FA1,FA2>) {
1320 not_same_fa = (&fa1 != &fa2);
1322 using FAB1 =
typename FA1::FABType::value_type;
1323 using FAB2 =
typename FA2::FABType::value_type;
1324 using T1 =
typename FAB1::value_type;
1325 using T2 =
typename FAB2::value_type;
1327 bool alloc_1 = (myproc < fa1.size());
1328 bool alloc_2 = (myproc < fa2.size()) && not_same_fa;
1330 if (alloc_1 && alloc_2) {
1331 Box const& box1 = fa1.fabbox(myproc);
1332 Box const& box2 = fa2.fabbox(myproc);
1333 int ncomp1 = fa1.nComp();
1334 int ncomp2 = fa2.nComp();
1336 sizeof(T2)*box2.
numPts()*ncomp2));
1337 fa1.setFab(myproc, FAB1(box1, ncomp1, (T1*)p));
1338 fa2.setFab(myproc, FAB2(box2, ncomp2, (T2*)p));
1339 }
else if (alloc_1) {
1340 Box const& box1 = fa1.fabbox(myproc);
1341 int ncomp1 = fa1.nComp();
1343 fa1.setFab(myproc, FAB1(box1, ncomp1, (T1*)p));
1344 }
else if (alloc_2) {
1345 Box const& box2 = fa2.fabbox(myproc);
1346 int ncomp2 = fa2.nComp();
1348 fa2.setFab(myproc, FAB2(box2, ncomp2, (T2*)p));
1360 return {i.
y, i.x, i.z};
1365 return {i.
y, i.
x, i.
z};
1383 return {i.
z, i.y, i.x};
1388 return {i.
z, i.
y, i.
x};
1407 return {i.
y, i.z, i.x};
1413 return {i.
z, i.
x, i.
y};
1432 return {i.
z, i.x, i.y};
1438 return {i.
y, i.
z, i.
x};
1474 template <
typename T>
1477 #if (AMREX_SPACEDIM == 1)
1480 #elif (AMREX_SPACEDIM == 2)
1481 if (m_case == case_1n) {
1482 return T{a[1],a[0]};
1487 if (m_case == case_11n) {
1488 return T{a[2],a[0],a[1]};
1489 }
else if (m_case == case_1n1) {
1490 return T{a[1],a[0],a[2]};
1491 }
else if (m_case == case_1nn) {
1492 return T{a[1],a[2],a[0]};
1493 }
else if (m_case == case_n1n) {
1494 return T{a[0],a[2],a[1]};
1503 template <
typename FA>
1506 BoxList bl = mf.boxArray().boxList();
1507 for (
auto&
b : bl) {
1510 auto const& ng =
make_iv(mf.nGrowVect());
1512 using FAB =
typename FA::fab_type;
1514 submf.setFab(mfi,
FAB(mfi.fabbox(), 1, mf[mfi].dataPtr()));
1519 #if (AMREX_SPACEDIM == 2)
1520 enum Case { case_1n, case_other };
1521 int m_case = case_other;
1522 #elif (AMREX_SPACEDIM == 3)
1523 enum Case { case_11n, case_1n1, case_1nn, case_n1n, case_other };
1524 int m_case = case_other;
#define AMREX_CUFFT_SAFE_CALL(call)
Definition: AMReX_GpuError.H:88
#define AMREX_GPU_DEVICE
Definition: AMReX_GpuQualifiers.H:18
amrex::ParmParse pp
Input file parser instance for the given namespace.
Definition: AMReX_HypreIJIface.cpp:15
Real * pdst
Definition: AMReX_HypreMLABecLap.cpp:1090
#define AMREX_D_TERM(a, b, c)
Definition: AMReX_SPACE.H:129
virtual void free(void *pt)=0
A pure virtual function for deleting the arena pointed to by pt.
virtual void * alloc(std::size_t sz)=0
A collection of Boxes stored in an Array.
Definition: AMReX_BoxArray.H:550
A class for managing a List of Boxes that share a common IndexType. This class implements operations ...
Definition: AMReX_BoxList.H:52
AMREX_GPU_HOST_DEVICE IntVectND< dim > length() const noexcept
Return the length of the BoxND.
Definition: AMReX_Box.H:146
AMREX_GPU_HOST_DEVICE Long numPts() const noexcept
Returns the number of points contained in the BoxND.
Definition: AMReX_Box.H:346
Calculates the distribution of FABs to MPI processes.
Definition: AMReX_DistributionMapping.H:41
Definition: AMReX_IntVect.H:48
Definition: AMReX_MFIter.H:57
bool isValid() const noexcept
Is the iterator valid i.e. is it associated with a FAB?
Definition: AMReX_MFIter.H:141
This provides length of period for periodic domains. 0 means it is not periodic in that direction....
Definition: AMReX_Periodicity.H:17
@ FAB
Definition: AMReX_AmrvisConstants.H:86
FA::FABType::value_type * get_fab(FA &fa)
Definition: AMReX_FFT_Helper.H:1305
std::unique_ptr< char, DataDeleter > make_mfs_share(FA1 &fa1, FA2 &fa2)
Definition: AMReX_FFT_Helper.H:1316
DistributionMapping make_iota_distromap(Long n)
Definition: AMReX_FFT.cpp:88
Definition: AMReX_FFT.cpp:7
std::tuple< IntVectND< 3 >, Direction, Kind > Key
Definition: AMReX_FFT_Helper.H:1134
Direction
Definition: AMReX_FFT_Helper.H:48
void add_vendor_plan_f(Key const &key, PlanF plan)
Definition: AMReX_FFT.cpp:78
DomainStrategy
Definition: AMReX_FFT_Helper.H:50
typename Plan< float >::VendorPlan PlanF
Definition: AMReX_FFT_Helper.H:1136
AMREX_ENUM(Boundary, periodic, even, odd)
typename Plan< double >::VendorPlan PlanD
Definition: AMReX_FFT_Helper.H:1135
void add_vendor_plan_d(Key const &key, PlanD plan)
Definition: AMReX_FFT.cpp:73
Kind
Definition: AMReX_FFT_Helper.H:54
PlanF * get_vendor_plan_f(Key const &key)
Definition: AMReX_FFT.cpp:64
PlanD * get_vendor_plan_d(Key const &key)
Definition: AMReX_FFT.cpp:55
void streamSynchronize() noexcept
Definition: AMReX_GpuDevice.H:237
gpuStream_t gpuStream() noexcept
Definition: AMReX_GpuDevice.H:218
constexpr std::enable_if_t< std::is_floating_point_v< T >, T > pi()
Definition: AMReX_Math.H:62
AMREX_GPU_HOST_DEVICE AMREX_FORCE_INLINE std::pair< double, double > sincospi(double x)
Return sin(pi*x) and cos(pi*x) given x.
Definition: AMReX_Math.H:165
int MyProcSub() noexcept
my sub-rank in current frame
Definition: AMReX_ParallelContext.H:76
@ max
Definition: AMReX_ParallelReduce.H:17
static constexpr int M
Definition: AMReX_OpenBC.H:13
static constexpr int P
Definition: AMReX_OpenBC.H:14
std::enable_if_t< std::is_integral_v< T > > ParallelFor(TypeList< CTOs... > ctos, std::array< int, sizeof...(CTOs)> const &runtime_options, T N, F &&f)
Definition: AMReX_CTOParallelForImpl.H:200
AMREX_GPU_HOST_DEVICE AMREX_FORCE_INLINE void ignore_unused(const Ts &...)
This shuts up the compiler about unused variables.
Definition: AMReX.H:111
AMREX_GPU_HOST_DEVICE AMREX_FORCE_INLINE Dim3 length(Array4< T > const &a) noexcept
Definition: AMReX_Array4.H:322
void Abort(const std::string &msg)
Print out message to cerr and exit via abort().
Definition: AMReX.cpp:225
const int[]
Definition: AMReX_BLProfiler.cpp:1664
Arena * The_Arena()
Definition: AMReX_Arena.cpp:609
Definition: AMReX_FabArrayCommI.H:896
Definition: AMReX_DataAllocator.H:29
Definition: AMReX_Dim3.H:12
int x
Definition: AMReX_Dim3.H:12
int z
Definition: AMReX_Dim3.H:12
int y
Definition: AMReX_Dim3.H:12
Definition: AMReX_FFT_Helper.H:58
bool batch_mode
Definition: AMReX_FFT_Helper.H:62
Info & setBatchMode(bool x)
Definition: AMReX_FFT_Helper.H:67
int nprocs
Max number of processes to use.
Definition: AMReX_FFT_Helper.H:65
Info & setNumProcs(int n)
Definition: AMReX_FFT_Helper.H:68
Definition: AMReX_FFT_Helper.H:113
void * pf
Definition: AMReX_FFT_Helper.H:148
void init_r2c(Box const &box, T *pr, VendorComplex *pc, bool is_2d_transform=false)
Definition: AMReX_FFT_Helper.H:173
void unpack_r2r_buffer(T *pdst, void const *pbuf) const
Definition: AMReX_FFT_Helper.H:912
std::conditional_t< std::is_same_v< float, T >, cuComplex, cuDoubleComplex > VendorComplex
Definition: AMReX_FFT_Helper.H:117
VendorPlan plan2
Definition: AMReX_FFT_Helper.H:147
int n
Definition: AMReX_FFT_Helper.H:140
void destroy()
Definition: AMReX_FFT_Helper.H:158
bool defined2
Definition: AMReX_FFT_Helper.H:145
void init_r2r(Box const &box, VendorComplex *pc, std::pair< Boundary, Boundary > const &bc)
Definition: AMReX_FFT_Helper.H:533
void pack_r2r_buffer(void *pbuf, T const *psrc) const
Definition: AMReX_FFT_Helper.H:680
static void free_scratch_space(void *p)
Definition: AMReX_FFT_Helper.H:678
static void destroy_vendor_plan(VendorPlan plan)
Definition: AMReX_FFT_Helper.H:1116
Kind get_r2r_kind(std::pair< Boundary, Boundary > const &bc)
Definition: AMReX_FFT_Helper.H:411
cufftHandle VendorPlan
Definition: AMReX_FFT_Helper.H:115
Kind kind
Definition: AMReX_FFT_Helper.H:142
int howmany
Definition: AMReX_FFT_Helper.H:141
void init_r2c(IntVectND< M > const &fft_size, void *, void *, bool cache)
Definition: AMReX_FFT_Helper.H:1146
void * pb
Definition: AMReX_FFT_Helper.H:149
void * alloc_scratch_space() const
Definition: AMReX_FFT_Helper.H:664
void compute_r2r()
Definition: AMReX_FFT_Helper.H:1062
void compute_c2c()
Definition: AMReX_FFT_Helper.H:625
bool r2r_data_is_complex
Definition: AMReX_FFT_Helper.H:143
VendorPlan plan
Definition: AMReX_FFT_Helper.H:146
void compute_r2c()
Definition: AMReX_FFT_Helper.H:576
bool defined
Definition: AMReX_FFT_Helper.H:144
void set_ptrs(void *p0, void *p1)
Definition: AMReX_FFT_Helper.H:152
void init_r2r(Box const &box, T *p, std::pair< Boundary, Boundary > const &bc, int howmany_initval=1)
Definition: AMReX_FFT_Helper.H:437
void init_c2c(Box const &box, VendorComplex *p)
Definition: AMReX_FFT_Helper.H:299
Definition: AMReX_FFT_Helper.H:1428
constexpr Dim3 operator()(Dim3 i) const noexcept
Definition: AMReX_FFT_Helper.H:1430
static constexpr IndexType Inverse(IndexType it)
Definition: AMReX_FFT_Helper.H:1446
static constexpr Dim3 Inverse(Dim3 i)
Definition: AMReX_FFT_Helper.H:1436
Definition: AMReX_FFT_Helper.H:1403
static constexpr IndexType Inverse(IndexType it)
Definition: AMReX_FFT_Helper.H:1421
constexpr Dim3 operator()(Dim3 i) const noexcept
Definition: AMReX_FFT_Helper.H:1405
static constexpr Dim3 Inverse(Dim3 i)
Definition: AMReX_FFT_Helper.H:1411
Definition: AMReX_FFT_Helper.H:1357
constexpr Dim3 operator()(Dim3 i) const noexcept
Definition: AMReX_FFT_Helper.H:1358
static constexpr IndexType Inverse(IndexType it)
Definition: AMReX_FFT_Helper.H:1373
static constexpr Dim3 Inverse(Dim3 i)
Definition: AMReX_FFT_Helper.H:1363
Definition: AMReX_FFT_Helper.H:1380
static constexpr Dim3 Inverse(Dim3 i)
Definition: AMReX_FFT_Helper.H:1386
static constexpr IndexType Inverse(IndexType it)
Definition: AMReX_FFT_Helper.H:1396
constexpr Dim3 operator()(Dim3 i) const noexcept
Definition: AMReX_FFT_Helper.H:1381
Definition: AMReX_FFT_Helper.H:1455
SubHelper(Box const &domain)
Definition: AMReX_FFT.cpp:121
T make_array(T const &a) const
Definition: AMReX_FFT_Helper.H:1475
Box make_box(Box const &box) const
Definition: AMReX_FFT.cpp:142
BoxArray inverse_boxarray(BoxArray const &ba) const
Definition: AMReX_FFT.cpp:209
bool ghost_safe(IntVect const &ng) const
Definition: AMReX_FFT.cpp:152
GpuArray< int, 3 > xyz_order() const
Definition: AMReX_FFT.cpp:326
IntVect inverse_order(IntVect const &order) const
Definition: AMReX_FFT.cpp:266
IntVect make_iv(IntVect const &iv) const
Definition: AMReX_FFT.cpp:178
FA make_alias_mf(FA const &mf)
Definition: AMReX_FFT_Helper.H:1504
Periodicity make_periodicity(Periodicity const &period) const
Definition: AMReX_FFT.cpp:147
IntVect make_safe_ghost(IntVect const &ng) const
Definition: AMReX_FFT.cpp:183
A host / device complex number type, because std::complex doesn't work in device code with Cuda yet.
Definition: AMReX_GpuComplex.H:29
FabArray memory allocation information.
Definition: AMReX_FabArray.H:66
MFInfo & SetAlloc(bool a) noexcept
Definition: AMReX_FabArray.H:73
Definition: AMReX_MFIter.H:20