Block-Structured AMR Software Framework
 
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages
Loading...
Searching...
No Matches
AMReX_FFT_LocalR2C.H
Go to the documentation of this file.
1#ifndef AMREX_FFT_LOCAL_R2C_H_
2#define AMREX_FFT_LOCAL_R2C_H_
3#include <AMReX_Config.H>
4
5#include <AMReX_Arena.H>
6#include <AMReX_FFT_Helper.H>
7
8namespace amrex::FFT
9{
10
25template <typename T, FFT::Direction D = FFT::Direction::both,
26 int M = AMREX_SPACEDIM>
28{
29public:
49 explicit LocalR2C (IntVectND<M> const& fft_size,
50 T* p_fwd = nullptr,
51 GpuComplex<T>* p_bwd = nullptr,
52#ifdef AMREX_USE_GPU
53 bool cache_plan = true);
54#else
55 bool cache_plan = false);
56#endif
57
58 ~LocalR2C ();
59
60 LocalR2C () = default;
61 LocalR2C (LocalR2C &&) noexcept;
62 LocalR2C& operator= (LocalR2C &&) noexcept;
63
64 LocalR2C (LocalR2C const&) = delete;
65 LocalR2C& operator= (LocalR2C const&) = delete;
66
77 template <Direction DIR=D, std::enable_if_t<DIR == Direction::forward ||
78 DIR == Direction::both, int> = 0>
79 void forward (T const* indata, GpuComplex<T>* outdata);
80
81 void clear ();
82
93 template <Direction DIR=D, std::enable_if_t<DIR == Direction::backward ||
94 DIR == Direction::both, int> = 0>
95 void backward (GpuComplex<T> const* indata, T* outdata);
96
100 [[nodiscard]] T scalingFactor () const;
101
103 [[nodiscard]] IntVectND<M> const& spectralSize () const {
104 return m_spectral_size;
105 }
106
107private:
108
111
112 T* m_p_fwd = nullptr;
114
115#if defined(AMREX_USE_SYCL)
116 gpuStream_t m_gpu_stream{};
117#endif
118
121
122 bool m_cache_plan = false;
123};
124
125template <typename T, FFT::Direction D, int M>
126LocalR2C<T,D,M>::LocalR2C (IntVectND<M> const& fft_size, T* p_fwd,
127 GpuComplex<T>* p_bwd, bool cache_plan)
128 : m_p_fwd(p_fwd),
129 m_p_bwd(p_bwd),
130 m_real_size(fft_size),
131 m_spectral_size(fft_size)
132#if defined(AMREX_USE_GPU)
133 , m_cache_plan(cache_plan)
134#endif
135{
136#if !defined(AMREX_USE_GPU)
137 amrex::ignore_unused(cache_plan);
138#endif
139
140 BL_PROFILE("FFT::LocalR2C");
141 m_spectral_size[0] = m_real_size[0]/2 + 1;
142
143#if defined(AMREX_USE_SYCL)
144
145 auto current_stream = Gpu::gpuStream();
147 m_gpu_stream = Gpu::gpuStream();
148
149#endif
150
151 auto* pf = (void*)m_p_fwd;
152 auto* pb = (void*)m_p_bwd;
153
154#ifdef AMREX_USE_SYCL
155 m_fft_fwd.template init_r2c<Direction::forward,M>(m_real_size, pf, pb, m_cache_plan);
157#else
158 if constexpr (D == Direction::both || D == Direction::forward) {
159 m_fft_fwd.template init_r2c<Direction::forward,M>(m_real_size, pf, pb, m_cache_plan);
160 }
161 if constexpr (D == Direction::both || D == Direction::backward) {
162 m_fft_bwd.template init_r2c<Direction::backward,M>(m_real_size, pf, pb, m_cache_plan);
163 }
164#endif
165
166#if defined(AMREX_USE_SYCL)
167 Gpu::Device::setStream(current_stream);
168#endif
169}
170
171template <typename T, FFT::Direction D, int M>
173{
174 if (!m_cache_plan) {
175 if (m_fft_bwd.plan != m_fft_fwd.plan) {
176 m_fft_bwd.destroy();
177 }
178 m_fft_fwd.destroy();
179 }
180
181 m_fft_fwd = Plan<T>{};
182 m_fft_bwd = Plan<T>{};
183}
184
185template <typename T, FFT::Direction D, int M>
187{
188 static_assert(M >= 1 && M <= 3);
189 clear();
190}
191
192template <typename T, FFT::Direction D, int M>
194 : m_p_fwd(rhs.m_p_fwd),
195 m_p_bwd(rhs.m_p_bwd),
196 m_fft_fwd(rhs.m_fft_fwd),
197 m_fft_bwd(rhs.m_fft_bwd),
198#if defined(AMREX_USE_SYCL)
199 m_gpu_stream(rhs.m_gpu_stream),
200#endif
201 m_real_size(rhs.m_real_size),
202 m_spectral_size(rhs.m_spectral_size),
203 m_cache_plan(rhs.m_cache_plan)
204{
205 rhs.m_cache_plan = true; // So that plans in rhs are not destroyed.
206}
207
208template <typename T, FFT::Direction D, int M>
210{
211 if (this == &rhs) { return *this; }
212
213 this->clear();
214
215 m_p_fwd = rhs.m_p_fwd;
216 m_p_bwd = rhs.m_p_bwd;
217 m_fft_fwd = rhs.m_fft_fwd;
218 m_fft_bwd = rhs.m_fft_bwd;
219#if defined(AMREX_USE_SYCL)
220 m_gpu_stream = rhs.m_gpu_stream;
221#endif
222 m_real_size = rhs.m_real_size;
223 m_spectral_size = rhs.m_spectral_size;
224 m_cache_plan = rhs.m_cache_plan;
225
226 rhs.m_cache_plan = true; // So that plans in rhs are not destroyed.
227
228 return *this;
229}
230
231template <typename T, FFT::Direction D, int M>
232template <Direction DIR, std::enable_if_t<DIR == Direction::forward ||
233 DIR == Direction::both, int> >
234void LocalR2C<T,D,M>::forward (T const* indata, GpuComplex<T>* outdata)
235{
236 BL_PROFILE("FFT::LocalR2C::forward");
237
238#if defined(AMREX_USE_GPU)
239
240 m_fft_fwd.set_ptrs((void*)indata, (void*)outdata);
241
242#if defined(AMREX_USE_SYCL)
243 auto current_stream = Gpu::gpuStream();
244 if (current_stream != m_gpu_stream) {
246 Gpu::Device::setStream(m_gpu_stream);
247 }
248#endif
249
250#else /* FFTW */
251
252 if (((T*)indata != m_p_fwd) || (outdata != m_p_bwd)) {
253 m_p_fwd = (T*)indata;
254 m_p_bwd = outdata;
255 auto* pf = (void*)m_p_fwd;
256 auto* pb = (void*)m_p_bwd;
257 m_fft_fwd.destroy();
258 m_fft_fwd.template init_r2c<Direction::forward,M>(m_real_size, pf, pb, false);
259 if constexpr (D == Direction::both) {
260 m_fft_bwd.destroy();
261 m_fft_bwd.template init_r2c<Direction::backward,M>(m_real_size, pf, pb, false);
262 }
263 }
264
265#endif
266
267 m_fft_fwd.template compute_r2c<Direction::forward>();
268
269#if defined(AMREX_USE_SYCL)
270 if (current_stream != m_gpu_stream) {
271 Gpu::Device::setStream(current_stream);
272 }
273#endif
274}
275
276template <typename T, FFT::Direction D, int M>
277template <Direction DIR, std::enable_if_t<DIR == Direction::backward ||
278 DIR == Direction::both, int> >
279void LocalR2C<T,D,M>::backward (GpuComplex<T> const* indata, T* outdata)
280{
281 BL_PROFILE("FFT::LocalR2C::backward");
282
283#if defined(AMREX_USE_GPU)
284
285 m_fft_bwd.set_ptrs((void*)outdata, (void*)indata);
286
287#if defined(AMREX_USE_SYCL)
288 auto current_stream = Gpu::gpuStream();
289 if (current_stream != m_gpu_stream) {
291 Gpu::Device::setStream(m_gpu_stream);
292 }
293#endif
294
295#else /* FFTW */
296
297 if (((GpuComplex<T>*)indata != m_p_bwd) || (outdata != m_p_fwd)) {
298 m_p_fwd = outdata;
299 m_p_bwd = (GpuComplex<T>*)indata;
300 auto* pf = (void*)m_p_fwd;
301 auto* pb = (void*)m_p_bwd;
302 m_fft_bwd.destroy();
303 m_fft_bwd.template init_r2c<Direction::backward,M>(m_real_size, pf, pb, false);
304 if constexpr (D == Direction::both) {
305 m_fft_fwd.destroy();
306 m_fft_fwd.template init_r2c<Direction::forward,M>(m_real_size, pf, pb, false);
307 }
308 }
309
310#endif
311
312 m_fft_bwd.template compute_r2c<Direction::backward>();
313
314#if defined(AMREX_USE_SYCL)
315 if (current_stream != m_gpu_stream) {
316 Gpu::Device::setStream(current_stream);
317 }
318#endif
319}
320
321template <typename T, FFT::Direction D, int M>
323{
324 T r = 1;
325 for (auto s : m_real_size) {
326 r *= T(s);
327 }
328 return T(1)/r;
329}
330
331}
332
333#endif
#define BL_PROFILE(a)
Definition AMReX_BLProfiler.H:551
if(!(yy_init))
Definition amrex_iparser.lex.nolint.H:935
Local Discrete Fourier Transform.
Definition AMReX_FFT_LocalR2C.H:28
void backward(GpuComplex< T > const *indata, T *outdata)
Backward transform.
Definition AMReX_FFT_LocalR2C.H:279
T scalingFactor() const
Definition AMReX_FFT_LocalR2C.H:322
void forward(T const *indata, GpuComplex< T > *outdata)
Forward transform.
Definition AMReX_FFT_LocalR2C.H:234
Plan< T > m_fft_bwd
Definition AMReX_FFT_LocalR2C.H:110
void clear()
Definition AMReX_FFT_LocalR2C.H:172
LocalR2C & operator=(LocalR2C &&) noexcept
Definition AMReX_FFT_LocalR2C.H:209
GpuComplex< T > * m_p_bwd
Definition AMReX_FFT_LocalR2C.H:113
Plan< T > m_fft_fwd
Definition AMReX_FFT_LocalR2C.H:109
bool m_cache_plan
Definition AMReX_FFT_LocalR2C.H:122
IntVectND< M > m_spectral_size
Definition AMReX_FFT_LocalR2C.H:120
T * m_p_fwd
Definition AMReX_FFT_LocalR2C.H:112
~LocalR2C()
Definition AMReX_FFT_LocalR2C.H:186
IntVectND< M > m_real_size
Definition AMReX_FFT_LocalR2C.H:119
IntVectND< M > const & spectralSize() const
Spectral domain size.
Definition AMReX_FFT_LocalR2C.H:103
static gpuStream_t setStream(gpuStream_t s) noexcept
Definition AMReX_GpuDevice.cpp:655
static void resetStreamIndex() noexcept
Definition AMReX_GpuDevice.H:76
Definition AMReX_IntVect.H:48
Definition AMReX_FFT.cpp:7
Direction
Definition AMReX_FFT_Helper.H:48
void streamSynchronize() noexcept
Definition AMReX_GpuDevice.H:237
gpuStream_t gpuStream() noexcept
Definition AMReX_GpuDevice.H:218
cudaStream_t gpuStream_t
Definition AMReX_GpuControl.H:77
AMREX_GPU_HOST_DEVICE AMREX_FORCE_INLINE void ignore_unused(const Ts &...)
This shuts up the compiler about unused variables.
Definition AMReX.H:127
Definition AMReX_FFT_Helper.H:126
A host / device complex number type, because std::complex doesn't work in device code with Cuda yet.
Definition AMReX_GpuComplex.H:29