Block-Structured AMR Software Framework
 
Loading...
Searching...
No Matches
AMReX_FFT_LocalR2C.H
Go to the documentation of this file.
1#ifndef AMREX_FFT_LOCAL_R2C_H_
2#define AMREX_FFT_LOCAL_R2C_H_
3#include <AMReX_Config.H>
4
5#include <AMReX_Arena.H>
6#include <AMReX_FFT_Helper.H>
7
8namespace amrex::FFT
9{
10
26template <typename T, FFT::Direction D = FFT::Direction::both,
27 int M = AMREX_SPACEDIM>
29{
30public:
50 explicit LocalR2C (IntVectND<M> const& fft_size,
51 T* p_fwd = nullptr,
52 GpuComplex<T>* p_bwd = nullptr,
53#ifdef AMREX_USE_GPU
54 bool cache_plan = true);
55#else
56 bool cache_plan = false);
57#endif
58
59 ~LocalR2C ();
60
61 LocalR2C () = default;
62 LocalR2C (LocalR2C &&) noexcept;
63 LocalR2C& operator= (LocalR2C &&) noexcept;
64
65 LocalR2C (LocalR2C const&) = delete;
66 LocalR2C& operator= (LocalR2C const&) = delete;
67
78 template <Direction DIR=D, std::enable_if_t<DIR == Direction::forward ||
79 DIR == Direction::both, int> = 0>
80 void forward (T const* indata, GpuComplex<T>* outdata);
81
82 void clear ();
83
94 template <Direction DIR=D, std::enable_if_t<DIR == Direction::backward ||
95 DIR == Direction::both, int> = 0>
96 void backward (GpuComplex<T> const* indata, T* outdata);
97
101 [[nodiscard]] T scalingFactor () const;
102
104 [[nodiscard]] IntVectND<M> const& spectralSize () const {
105 return m_spectral_size;
106 }
107
108private:
109
110 Plan<T> m_fft_fwd;
111 Plan<T> m_fft_bwd;
112
113 T* m_p_fwd = nullptr;
114 GpuComplex<T>* m_p_bwd = nullptr;
115
116#if defined(AMREX_USE_SYCL)
117 gpuStream_t m_gpu_stream{};
118#endif
119
120 IntVectND<M> m_real_size;
121 IntVectND<M> m_spectral_size;
122
123 bool m_cache_plan = false;
124};
125
126template <typename T, FFT::Direction D, int M>
127LocalR2C<T,D,M>::LocalR2C (IntVectND<M> const& fft_size, T* p_fwd,
128 GpuComplex<T>* p_bwd, bool cache_plan)
129 : m_p_fwd(p_fwd),
130 m_p_bwd(p_bwd),
131 m_real_size(fft_size),
132 m_spectral_size(fft_size)
133#if defined(AMREX_USE_GPU)
134 , m_cache_plan(cache_plan)
135#endif
136{
137#if !defined(AMREX_USE_GPU)
138 amrex::ignore_unused(cache_plan);
139#endif
140
141 BL_PROFILE("FFT::LocalR2C");
142 m_spectral_size[0] = m_real_size[0]/2 + 1;
143
144#if defined(AMREX_USE_SYCL)
145
146 auto current_stream = Gpu::gpuStream();
148 m_gpu_stream = Gpu::gpuStream();
149
150#endif
151
152 auto* pf = (void*)m_p_fwd;
153 auto* pb = (void*)m_p_bwd;
154
155#ifdef AMREX_USE_SYCL
156 m_fft_fwd.template init_r2c<Direction::forward,M>(m_real_size, pf, pb, m_cache_plan);
157 m_fft_bwd = m_fft_fwd;
158#else
159 if constexpr (D == Direction::both || D == Direction::forward) {
160 m_fft_fwd.template init_r2c<Direction::forward,M>(m_real_size, pf, pb, m_cache_plan);
161 }
162 if constexpr (D == Direction::both || D == Direction::backward) {
163 m_fft_bwd.template init_r2c<Direction::backward,M>(m_real_size, pf, pb, m_cache_plan);
164 }
165#endif
166
167#if defined(AMREX_USE_SYCL)
168 Gpu::Device::setStream(current_stream);
169#endif
170}
171
172template <typename T, FFT::Direction D, int M>
174{
175 if (!m_cache_plan) {
176 if (m_fft_bwd.plan != m_fft_fwd.plan) {
177 m_fft_bwd.destroy();
178 }
179 m_fft_fwd.destroy();
180 }
181
182 m_fft_fwd = Plan<T>{};
183 m_fft_bwd = Plan<T>{};
184}
185
186template <typename T, FFT::Direction D, int M>
188{
189 static_assert(M >= 1 && M <= 3);
190 clear();
191}
192
193template <typename T, FFT::Direction D, int M>
195 : m_p_fwd(rhs.m_p_fwd),
196 m_p_bwd(rhs.m_p_bwd),
197 m_fft_fwd(rhs.m_fft_fwd),
198 m_fft_bwd(rhs.m_fft_bwd),
199#if defined(AMREX_USE_SYCL)
200 m_gpu_stream(rhs.m_gpu_stream),
201#endif
202 m_real_size(rhs.m_real_size),
203 m_spectral_size(rhs.m_spectral_size),
204 m_cache_plan(rhs.m_cache_plan)
205{
206 rhs.m_cache_plan = true; // So that plans in rhs are not destroyed.
207}
208
209template <typename T, FFT::Direction D, int M>
211{
212 if (this == &rhs) { return *this; }
213
214 this->clear();
215
216 m_p_fwd = rhs.m_p_fwd;
217 m_p_bwd = rhs.m_p_bwd;
218 m_fft_fwd = rhs.m_fft_fwd;
219 m_fft_bwd = rhs.m_fft_bwd;
220#if defined(AMREX_USE_SYCL)
221 m_gpu_stream = rhs.m_gpu_stream;
222#endif
223 m_real_size = rhs.m_real_size;
224 m_spectral_size = rhs.m_spectral_size;
225 m_cache_plan = rhs.m_cache_plan;
226
227 rhs.m_cache_plan = true; // So that plans in rhs are not destroyed.
228
229 return *this;
230}
231
232template <typename T, FFT::Direction D, int M>
233template <Direction DIR, std::enable_if_t<DIR == Direction::forward ||
234 DIR == Direction::both, int> >
235void LocalR2C<T,D,M>::forward (T const* indata, GpuComplex<T>* outdata)
236{
237 BL_PROFILE("FFT::LocalR2C::forward");
238
239#if defined(AMREX_USE_GPU)
240
241 m_fft_fwd.set_ptrs((void*)indata, (void*)outdata);
242
243#if defined(AMREX_USE_SYCL)
244 auto current_stream = Gpu::gpuStream();
245 if (current_stream != m_gpu_stream) {
247 Gpu::Device::setStream(m_gpu_stream);
248 }
249#endif
250
251#else /* FFTW */
252
253 if (((T*)indata != m_p_fwd) || (outdata != m_p_bwd)) {
254 m_p_fwd = (T*)indata;
255 m_p_bwd = outdata;
256 auto* pf = (void*)m_p_fwd;
257 auto* pb = (void*)m_p_bwd;
258 m_fft_fwd.destroy();
259 m_fft_fwd.template init_r2c<Direction::forward,M>(m_real_size, pf, pb, false);
260 if constexpr (D == Direction::both) {
261 m_fft_bwd.destroy();
262 m_fft_bwd.template init_r2c<Direction::backward,M>(m_real_size, pf, pb, false);
263 }
264 }
265
266#endif
267
268 m_fft_fwd.template compute_r2c<Direction::forward>();
269
270#if defined(AMREX_USE_SYCL)
271 if (current_stream != m_gpu_stream) {
272 Gpu::Device::setStream(current_stream);
273 }
274#endif
275}
276
277template <typename T, FFT::Direction D, int M>
278template <Direction DIR, std::enable_if_t<DIR == Direction::backward ||
279 DIR == Direction::both, int> >
280void LocalR2C<T,D,M>::backward (GpuComplex<T> const* indata, T* outdata)
281{
282 BL_PROFILE("FFT::LocalR2C::backward");
283
284#if defined(AMREX_USE_GPU)
285
286 m_fft_bwd.set_ptrs((void*)outdata, (void*)indata);
287
288#if defined(AMREX_USE_SYCL)
289 auto current_stream = Gpu::gpuStream();
290 if (current_stream != m_gpu_stream) {
292 Gpu::Device::setStream(m_gpu_stream);
293 }
294#endif
295
296#else /* FFTW */
297
298 if (((GpuComplex<T>*)indata != m_p_bwd) || (outdata != m_p_fwd)) {
299 m_p_fwd = outdata;
300 m_p_bwd = (GpuComplex<T>*)indata;
301 auto* pf = (void*)m_p_fwd;
302 auto* pb = (void*)m_p_bwd;
303 m_fft_bwd.destroy();
304 m_fft_bwd.template init_r2c<Direction::backward,M>(m_real_size, pf, pb, false);
305 if constexpr (D == Direction::both) {
306 m_fft_fwd.destroy();
307 m_fft_fwd.template init_r2c<Direction::forward,M>(m_real_size, pf, pb, false);
308 }
309 }
310
311#endif
312
313 m_fft_bwd.template compute_r2c<Direction::backward>();
314
315#if defined(AMREX_USE_SYCL)
316 if (current_stream != m_gpu_stream) {
317 Gpu::Device::setStream(current_stream);
318 }
319#endif
320}
321
322template <typename T, FFT::Direction D, int M>
324{
325 T r = 1;
326 for (auto s : m_real_size) {
327 r *= T(s);
328 }
329 return T(1)/r;
330}
331
332}
333
334#endif
#define BL_PROFILE(a)
Definition AMReX_BLProfiler.H:551
Local Discrete Fourier Transform.
Definition AMReX_FFT_LocalR2C.H:29
void backward(GpuComplex< T > const *indata, T *outdata)
Backward transform.
Definition AMReX_FFT_LocalR2C.H:280
T scalingFactor() const
Definition AMReX_FFT_LocalR2C.H:323
void forward(T const *indata, GpuComplex< T > *outdata)
Forward transform.
Definition AMReX_FFT_LocalR2C.H:235
void clear()
Definition AMReX_FFT_LocalR2C.H:173
LocalR2C & operator=(LocalR2C &&) noexcept
Definition AMReX_FFT_LocalR2C.H:210
~LocalR2C()
Definition AMReX_FFT_LocalR2C.H:187
IntVectND< M > const & spectralSize() const
Spectral domain size.
Definition AMReX_FFT_LocalR2C.H:104
static gpuStream_t setStream(gpuStream_t s) noexcept
Definition AMReX_GpuDevice.cpp:731
static void resetStreamIndex() noexcept
Definition AMReX_GpuDevice.H:96
An Integer Vector in dim-Dimensional Space.
Definition AMReX_IntVect.H:57
Definition AMReX_FFT_Helper.H:46
Direction
Definition AMReX_FFT_Helper.H:48
void streamSynchronize() noexcept
Definition AMReX_GpuDevice.H:263
gpuStream_t gpuStream() noexcept
Definition AMReX_GpuDevice.H:244
__host__ __device__ void ignore_unused(const Ts &...)
This shuts up the compiler about unused variables.
Definition AMReX.H:138
cudaStream_t gpuStream_t
Definition AMReX_GpuControl.H:83
Definition AMReX_FFT_Helper.H:134
A host / device complex number type, because std::complex doesn't work in device code with Cuda yet.
Definition AMReX_GpuComplex.H:30