Block-Structured AMR Software Framework
Loading...
Searching...
No Matches
AMReX_FFT_LocalR2C.H
Go to the documentation of this file.
1#ifndef AMREX_FFT_LOCAL_R2C_H_
2#define AMREX_FFT_LOCAL_R2C_H_
3#include <AMReX_Config.H>
4
5#include <AMReX_Arena.H>
6#include <AMReX_FFT_Helper.H>
7
8namespace amrex::FFT
9{
10
32template <typename T, FFT::Direction D = FFT::Direction::both,
33 int M = AMREX_SPACEDIM>
35{
36public:
56 explicit LocalR2C (IntVectND<M> const& fft_size,
57 T* p_fwd = nullptr,
58 GpuComplex<T>* p_bwd = nullptr,
59#ifdef AMREX_USE_GPU
60 bool cache_plan = true);
61#else
62 bool cache_plan = false);
63#endif
64
65 ~LocalR2C ();
66
67 LocalR2C () = default;
68
72 LocalR2C (LocalR2C &&) noexcept;
73
80 LocalR2C& operator= (LocalR2C &&) noexcept;
81
82 LocalR2C (LocalR2C const&) = delete;
83 LocalR2C& operator= (LocalR2C const&) = delete;
84
95 template <Direction DIR=D, std::enable_if_t<DIR == Direction::forward ||
96 DIR == Direction::both, int> = 0>
97 void forward (T const* indata, GpuComplex<T>* outdata);
98
102 void clear ();
103
114 template <Direction DIR=D, std::enable_if_t<DIR == Direction::backward ||
115 DIR == Direction::both, int> = 0>
116 void backward (GpuComplex<T> const* indata, T* outdata);
117
127 [[nodiscard]] T scalingFactor () const;
128
134 [[nodiscard]] IntVectND<M> const& spectralSize () const {
135 return m_spectral_size;
136 }
137
138private:
139
140 Plan<T> m_fft_fwd;
141 Plan<T> m_fft_bwd;
142
143 T* m_p_fwd = nullptr;
144 GpuComplex<T>* m_p_bwd = nullptr;
145
146#if defined(AMREX_USE_SYCL)
147 gpuStream_t m_gpu_stream{};
148#endif
149
150 IntVectND<M> m_real_size;
151 IntVectND<M> m_spectral_size;
152
153 bool m_cache_plan = false;
154};
155
156template <typename T, FFT::Direction D, int M>
157LocalR2C<T,D,M>::LocalR2C (IntVectND<M> const& fft_size, T* p_fwd,
158 GpuComplex<T>* p_bwd, bool cache_plan)
159 : m_p_fwd(p_fwd),
160 m_p_bwd(p_bwd),
161 m_real_size(fft_size),
162 m_spectral_size(fft_size)
163#if defined(AMREX_USE_GPU)
164 , m_cache_plan(cache_plan)
165#endif
166{
167#if !defined(AMREX_USE_GPU)
168 amrex::ignore_unused(cache_plan);
169#endif
170
171 BL_PROFILE("FFT::LocalR2C");
172 m_spectral_size[0] = m_real_size[0]/2 + 1;
173
174#if defined(AMREX_USE_SYCL)
175
176 auto current_stream = Gpu::gpuStream();
178 m_gpu_stream = Gpu::gpuStream();
179
180#endif
181
182 auto* pf = (void*)m_p_fwd;
183 auto* pb = (void*)m_p_bwd;
184
185#ifdef AMREX_USE_SYCL
186 m_fft_fwd.template init_r2c<Direction::forward,M>(m_real_size, pf, pb, m_cache_plan);
187 m_fft_bwd = m_fft_fwd;
188#else
189 if constexpr (D == Direction::both || D == Direction::forward) {
190 m_fft_fwd.template init_r2c<Direction::forward,M>(m_real_size, pf, pb, m_cache_plan);
191 }
192 if constexpr (D == Direction::both || D == Direction::backward) {
193 m_fft_bwd.template init_r2c<Direction::backward,M>(m_real_size, pf, pb, m_cache_plan);
194 }
195#endif
196
197#if defined(AMREX_USE_SYCL)
198 Gpu::Device::setStream(current_stream);
199#endif
200}
201
202template <typename T, FFT::Direction D, int M>
204{
205 if (!m_cache_plan) {
206 if (m_fft_bwd.plan != m_fft_fwd.plan) {
207 m_fft_bwd.destroy();
208 }
209 m_fft_fwd.destroy();
210 }
211
212 m_fft_fwd = Plan<T>{};
213 m_fft_bwd = Plan<T>{};
214}
215
216template <typename T, FFT::Direction D, int M>
218{
219 static_assert(M >= 1 && M <= 3);
220 clear();
221}
222
223template <typename T, FFT::Direction D, int M>
225 : m_p_fwd(rhs.m_p_fwd),
226 m_p_bwd(rhs.m_p_bwd),
227 m_fft_fwd(rhs.m_fft_fwd),
228 m_fft_bwd(rhs.m_fft_bwd),
229#if defined(AMREX_USE_SYCL)
230 m_gpu_stream(rhs.m_gpu_stream),
231#endif
232 m_real_size(rhs.m_real_size),
233 m_spectral_size(rhs.m_spectral_size),
234 m_cache_plan(rhs.m_cache_plan)
235{
236 rhs.m_cache_plan = true; // So that plans in rhs are not destroyed.
237}
238
239template <typename T, FFT::Direction D, int M>
241{
242 if (this == &rhs) { return *this; }
243
244 this->clear();
245
246 m_p_fwd = rhs.m_p_fwd;
247 m_p_bwd = rhs.m_p_bwd;
248 m_fft_fwd = rhs.m_fft_fwd;
249 m_fft_bwd = rhs.m_fft_bwd;
250#if defined(AMREX_USE_SYCL)
251 m_gpu_stream = rhs.m_gpu_stream;
252#endif
253 m_real_size = rhs.m_real_size;
254 m_spectral_size = rhs.m_spectral_size;
255 m_cache_plan = rhs.m_cache_plan;
256
257 rhs.m_cache_plan = true; // So that plans in rhs are not destroyed.
258
259 return *this;
260}
261
262template <typename T, FFT::Direction D, int M>
263template <Direction DIR, std::enable_if_t<DIR == Direction::forward ||
264 DIR == Direction::both, int> >
265void LocalR2C<T,D,M>::forward (T const* indata, GpuComplex<T>* outdata)
266{
267 BL_PROFILE("FFT::LocalR2C::forward");
268
269#if defined(AMREX_USE_GPU)
270
271 m_fft_fwd.set_ptrs((void*)indata, (void*)outdata);
272
273#if defined(AMREX_USE_SYCL)
274 auto current_stream = Gpu::gpuStream();
275 if (current_stream != m_gpu_stream) {
277 Gpu::Device::setStream(m_gpu_stream);
278 }
279#endif
280
281#else /* FFTW */
282
283 if (((T*)indata != m_p_fwd) || (outdata != m_p_bwd)) {
284 m_p_fwd = (T*)indata;
285 m_p_bwd = outdata;
286 auto* pf = (void*)m_p_fwd;
287 auto* pb = (void*)m_p_bwd;
288 m_fft_fwd.destroy();
289 m_fft_fwd.template init_r2c<Direction::forward,M>(m_real_size, pf, pb, false);
290 if constexpr (D == Direction::both) {
291 m_fft_bwd.destroy();
292 m_fft_bwd.template init_r2c<Direction::backward,M>(m_real_size, pf, pb, false);
293 }
294 }
295
296#endif
297
298 m_fft_fwd.template compute_r2c<Direction::forward>();
299
300#if defined(AMREX_USE_SYCL)
301 if (current_stream != m_gpu_stream) {
302 Gpu::Device::setStream(current_stream);
303 }
304#endif
305}
306
307template <typename T, FFT::Direction D, int M>
308template <Direction DIR, std::enable_if_t<DIR == Direction::backward ||
309 DIR == Direction::both, int> >
310void LocalR2C<T,D,M>::backward (GpuComplex<T> const* indata, T* outdata)
311{
312 BL_PROFILE("FFT::LocalR2C::backward");
313
314#if defined(AMREX_USE_GPU)
315
316 m_fft_bwd.set_ptrs((void*)outdata, (void*)indata);
317
318#if defined(AMREX_USE_SYCL)
319 auto current_stream = Gpu::gpuStream();
320 if (current_stream != m_gpu_stream) {
322 Gpu::Device::setStream(m_gpu_stream);
323 }
324#endif
325
326#else /* FFTW */
327
328 if (((GpuComplex<T>*)indata != m_p_bwd) || (outdata != m_p_fwd)) {
329 m_p_fwd = outdata;
330 m_p_bwd = (GpuComplex<T>*)indata;
331 auto* pf = (void*)m_p_fwd;
332 auto* pb = (void*)m_p_bwd;
333 m_fft_bwd.destroy();
334 m_fft_bwd.template init_r2c<Direction::backward,M>(m_real_size, pf, pb, false);
335 if constexpr (D == Direction::both) {
336 m_fft_fwd.destroy();
337 m_fft_fwd.template init_r2c<Direction::forward,M>(m_real_size, pf, pb, false);
338 }
339 }
340
341#endif
342
343 m_fft_bwd.template compute_r2c<Direction::backward>();
344
345#if defined(AMREX_USE_SYCL)
346 if (current_stream != m_gpu_stream) {
347 Gpu::Device::setStream(current_stream);
348 }
349#endif
350}
351
352template <typename T, FFT::Direction D, int M>
354{
355 T r = 1;
356 for (auto s : m_real_size) {
357 r *= T(s);
358 }
359 return T(1)/r;
360}
361
362}
363
364#endif
#define BL_PROFILE(a)
Definition AMReX_BLProfiler.H:551
Local Discrete Fourier Transform.
Definition AMReX_FFT_LocalR2C.H:35
void backward(GpuComplex< T > const *indata, T *outdata)
Backward transform.
Definition AMReX_FFT_LocalR2C.H:310
T scalingFactor() const
Scaling factor for normalization.
Definition AMReX_FFT_LocalR2C.H:353
void forward(T const *indata, GpuComplex< T > *outdata)
Forward transform.
Definition AMReX_FFT_LocalR2C.H:265
void clear()
Destroy any cached FFT plans; leaves the object in an uninitialized state.
Definition AMReX_FFT_LocalR2C.H:203
LocalR2C & operator=(LocalR2C &&) noexcept
Move assignment; releases current plans and adopts those from rhs.
Definition AMReX_FFT_LocalR2C.H:240
~LocalR2C()
Definition AMReX_FFT_LocalR2C.H:217
IntVectND< M > const & spectralSize() const
Spectral domain extents associated with this plan.
Definition AMReX_FFT_LocalR2C.H:134
static gpuStream_t setStream(gpuStream_t s) noexcept
Definition AMReX_GpuDevice.cpp:735
static void resetStreamIndex() noexcept
Definition AMReX_GpuDevice.H:96
An Integer Vector in dim-Dimensional Space.
Definition AMReX_IntVect.H:57
Definition AMReX_FFT_Helper.H:52
Direction
Definition AMReX_FFT_Helper.H:54
void streamSynchronize() noexcept
Definition AMReX_GpuDevice.H:263
gpuStream_t gpuStream() noexcept
Definition AMReX_GpuDevice.H:244
__host__ __device__ void ignore_unused(const Ts &...)
This shuts up the compiler about unused variables.
Definition AMReX.H:139
cudaStream_t gpuStream_t
Definition AMReX_GpuControl.H:83
Definition AMReX_FFT_Helper.H:180
A host / device complex number type, because std::complex doesn't work in device code with Cuda yet.
Definition AMReX_GpuComplex.H:30