Block-Structured AMR Software Framework
Loading...
Searching...
No Matches
AMReX_FFT_OpenBCSolver.H
Go to the documentation of this file.
1#ifndef AMREX_FFT_OPENBC_SOLVER_H_
2#define AMREX_FFT_OPENBC_SOLVER_H_
3
4#include <AMReX_FFT_R2C.H>
5
6namespace amrex::FFT
7{
8
24template <typename T = Real>
26{
27public:
28 using MF = typename R2C<T>::MF;
29 using cMF = typename R2C<T>::cMF;
30
37 explicit OpenBCSolver (Box const& domain, Info const& info = Info{});
38
45 template <class F>
46 void setGreensFunction (F const& greens_function);
47
54 void solve (MF& phi, MF const& rho);
55
61 [[nodiscard]] Box const& Domain () const { return m_domain; }
62
68 [[nodiscard]] IntVect const& PaddedLength () const { return m_padded_length; }
69
70private:
71 static IntVect make_padded_length (Box const& domain, Info const& info);
72 static Box make_grown_domain (Box const& domain, IntVect const& padded_len,
73 Info const& info);
74
75 Box m_domain;
76 Info m_info;
77 IntVect m_padded_length;
78 R2C<T> m_r2c;
79 cMF m_G_fft;
80 std::unique_ptr<R2C<T>> m_r2c_green;
81};
82
83template <typename T>
84IntVect OpenBCSolver<T>::make_padded_length (Box const& domain, Info const& info)
85{
86 IntVect len = domain.length();
87 int ndims = AMREX_SPACEDIM;
88#if (AMREX_SPACEDIM == 3)
89 if (info.twod_mode) { ndims = 2; }
90#else
92#endif
93 if (info.openbc_padding) {
94 for (int idim = 0; idim < ndims; ++idim) {
95 len[idim] = FFT::nextFastLen(len[idim], info.openbc_padding_nfactors);
96 }
97 }
98 return len;
99}
100
101template <typename T>
102Box OpenBCSolver<T>::make_grown_domain (Box const& domain, IntVect const& padded_len,
103 Info const& info)
104{
105 IntVect len = padded_len;
106 int ndims = AMREX_SPACEDIM;
107#if (AMREX_SPACEDIM == 3)
108 if (info.twod_mode) { ndims = 2; }
109#else
111#endif
112 for (int idim = 0; idim < ndims; ++idim) {
114 len[idim] <= std::numeric_limits<int>::max()/2,
115 "FFT::OpenBCSolver: padded domain length exceeds int range");
116 len[idim] *= 2;
117 }
118 return Box(domain.smallEnd(), domain.smallEnd()+len-IntVect(1), domain.ixType());
119}
120
121template <typename T>
122OpenBCSolver<T>::OpenBCSolver (Box const& domain, Info const& info)
123 : m_domain(domain),
124 m_info(info),
125 m_padded_length(OpenBCSolver<T>::make_padded_length(domain, info)),
126 m_r2c(OpenBCSolver<T>::make_grown_domain(domain, m_padded_length, info),
127 m_info.setDomainStrategy(FFT::DomainStrategy::slab))
128{
130 "FFT::OpenBCSolver does not support FFT::Info::batch_size > 1");
131
132#if (AMREX_SPACEDIM == 3)
133 if (m_info.twod_mode) {
134 auto gdom = make_grown_domain(domain, m_padded_length, m_info);
135 gdom.enclosedCells(2);
136 gdom.setSmall(2, 0);
137 int nprocs = std::min({ParallelContext::NProcsSub(),
138 m_info.nprocs,
139 m_domain.length(2)});
140 gdom.setBig(2, nprocs-1);
141 m_r2c_green = std::make_unique<R2C<T>>(gdom,m_info);
142 auto [sd, ord] = m_r2c_green->getSpectralData();
143 m_G_fft = cMF(*sd, amrex::make_alias, 0, 1);
144 } else
145#endif
146 {
147 amrex::ignore_unused(m_r2c_green);
148 auto [sd, ord] = m_r2c.getSpectralData();
150 m_G_fft.define(sd->boxArray(), sd->DistributionMap(), 1, 0);
151 }
152}
153
154template <typename T>
155template <class F>
156void OpenBCSolver<T>::setGreensFunction (F const& greens_function)
157{
158 BL_PROFILE("OpenBCSolver::setGreensFunction");
159
160 auto* infab = m_info.twod_mode ? detail::get_fab(m_r2c_green->m_rx)
161 : detail::get_fab(m_r2c.m_rx);
162 auto const& lo = m_domain.smallEnd();
163 auto const& lo3 = lo.dim3();
164 auto const len3d = m_padded_length.dim3();
165 GpuArray<int,3> len{len3d.x, len3d.y, len3d.z};
166 if (infab) {
167 auto const& a = infab->array();
168 auto box = infab->box();
169 GpuArray<int,3> nimages{1,1,1};
170 int ndims = m_info.twod_mode ? AMREX_SPACEDIM-1 : AMREX_SPACEDIM;
171 for (int idim = 0; idim < ndims; ++idim) {
172 if (box.smallEnd(idim) == lo[idim] && box.length(idim) == 2*len[idim]) {
173 box.growHi(idim, -len[idim]+1); // +1 to include the middle plane
174 nimages[idim] = 2;
175 }
176 }
177 AMREX_ASSERT(nimages[0] == 2);
178 box.shift(-lo);
179 amrex::ParallelForOMP(box, [=] AMREX_GPU_DEVICE (int i, int j, int k)
180 {
181 T G;
182 if (i == len[0] || j == len[1] || k == len[2]) {
183 G = 0;
184 } else {
185 auto ii = i;
186 auto jj = (j > len[1]) ? 2*len[1]-j : j;
187 auto kk = (k > len[2]) ? 2*len[2]-k : k;
188 G = greens_function(ii+lo3.x,jj+lo3.y,kk+lo3.z);
189 }
190 for (int koff = 0; koff < nimages[2]; ++koff) {
191 int k2 = (koff == 0) ? k : 2*len[2]-k;
192 if ((k2 == 2*len[2]) || (koff == 1 && k == len[2])) {
193 continue;
194 }
195 for (int joff = 0; joff < nimages[1]; ++joff) {
196 int j2 = (joff == 0) ? j : 2*len[1]-j;
197 if ((j2 == 2*len[1]) || (joff == 1 && j == len[1])) {
198 continue;
199 }
200 for (int ioff = 0; ioff < nimages[0]; ++ioff) {
201 int i2 = (ioff == 0) ? i : 2*len[0]-i;
202 if ((i2 == 2*len[0]) || (ioff == 1 && i == len[0])) {
203 continue;
204 }
205 a(i2+lo3.x,j2+lo3.y,k2+lo3.z) = G;
206 }
207 }
208 }
209 });
210 }
211
212 if (m_info.twod_mode) {
213 m_r2c_green->forward(m_r2c_green->m_rx);
214 } else {
215 m_r2c.forward(m_r2c.m_rx);
216 }
217
218 if (!m_info.twod_mode) {
219 auto [sd, ord] = m_r2c.getSpectralData();
221 auto const* srcfab = detail::get_fab(*sd);
222 if (srcfab) {
223 auto* dstfab = detail::get_fab(m_G_fft);
224 if (dstfab) {
225 Gpu::dtod_memcpy_async(dstfab->dataPtr(), srcfab->dataPtr(), dstfab->nBytes());
226 } else {
227 amrex::Abort("FFT::OpenBCSolver: how did this happen");
228 }
229 }
230
231 m_r2c.prepare_openbc();
232 }
233}
234
235template <typename T>
236void OpenBCSolver<T>::solve (MF& phi, MF const& rho)
237{
238 BL_PROFILE("OpenBCSolver::solve");
239
240 auto& inmf = m_r2c.m_rx;
241 inmf.setVal(T(0));
242 inmf.ParallelCopy(rho, 0, 0, 1);
243
244 m_r2c.m_openbc_half = !m_info.twod_mode;
245 m_r2c.forward(inmf);
246 m_r2c.m_openbc_half = false;
247
248 auto scaling_factor = m_r2c.scalingFactor();
249
250 auto const* gfab = detail::get_fab(m_G_fft);
251 if (gfab) {
252 auto [sd, ord] = m_r2c.getSpectralData();
254 auto* rhofab = detail::get_fab(*sd);
255 if (rhofab) {
256 auto* pdst = rhofab->dataPtr();
257 auto const* psrc = gfab->dataPtr();
258 Box const& rhobox = rhofab->box();
259#if (AMREX_SPACEDIM == 3)
260 Long leng = gfab->box().numPts();
261 if (m_info.twod_mode) {
262 AMREX_ASSERT(gfab->box().length(2) == 1 &&
263 leng == (rhobox.length(0) * rhobox.length(1)));
264 } else {
265 AMREX_ASSERT(leng == rhobox.numPts());
266 }
267#endif
269 {
270#if (AMREX_SPACEDIM == 3)
271 Long isrc = i % leng;
272#else
273 Long isrc = i;
274#endif
275 pdst[i] *= psrc[isrc] * scaling_factor;
276 });
277 } else {
278 amrex::Abort("FFT::OpenBCSolver::solve: how did this happen?");
279 }
280 }
281
282 m_r2c.m_openbc_half = !m_info.twod_mode;
283 m_r2c.backward_doit(phi, phi.nGrowVect());
284 m_r2c.m_openbc_half = false;
285}
286
287}
288
289#endif
#define BL_PROFILE(a)
Definition AMReX_BLProfiler.H:551
#define AMREX_ALWAYS_ASSERT_WITH_MESSAGE(EX, MSG)
Definition AMReX_BLassert.H:49
#define AMREX_ASSERT(EX)
Definition AMReX_BLassert.H:38
#define AMREX_GPU_DEVICE
Definition AMReX_GpuQualifiers.H:18
Real * pdst
Definition AMReX_HypreMLABecLap.cpp:1140
__host__ __device__ Long numPts() const noexcept
Return the number of points contained in the BoxND.
Definition AMReX_Box.H:364
__host__ __device__ IntVectND< dim > length() const noexcept
Return the length of the BoxND.
Definition AMReX_Box.H:155
Convolution-based solver for open boundary conditions using Green's functions.
Definition AMReX_FFT_OpenBCSolver.H:26
Box const & Domain() const
Access the physical domain this solver was built for.
Definition AMReX_FFT_OpenBCSolver.H:61
typename R2C< T >::MF MF
Definition AMReX_FFT_OpenBCSolver.H:28
void solve(MF &phi, MF const &rho)
Solve for phi given right-hand side rho.
Definition AMReX_FFT_OpenBCSolver.H:236
void setGreensFunction(F const &greens_function)
Populate the spectral Green's function used by subsequent solves.
Definition AMReX_FFT_OpenBCSolver.H:156
IntVect const & PaddedLength() const
Access the one-sided padded length used to build the internal FFT domain.
Definition AMReX_FFT_OpenBCSolver.H:68
typename R2C< T >::cMF cMF
Definition AMReX_FFT_OpenBCSolver.H:29
OpenBCSolver(Box const &domain, Info const &info=Info{})
Build a solver over domain using the FFT Info settings in info.
Definition AMReX_FFT_OpenBCSolver.H:122
Parallel Discrete Fourier Transform.
Definition AMReX_FFT_R2C.H:48
std::conditional_t< C, cMF, std::conditional_t< std::is_same_v< T, Real >, MultiFab, FabArray< BaseFab< T > > > > MF
Definition AMReX_FFT_R2C.H:53
Open Boundary Poisson Solver.
Definition AMReX_OpenBC.H:70
OpenBCSolver()=default
Construct an empty solver; call define() before solving.
amrex_long Long
Definition AMReX_INT.H:30
void ParallelForOMP(T n, L const &f) noexcept
Performance-portable kernel launch function with optional OpenMP threading.
Definition AMReX_GpuLaunch.H:328
Definition AMReX_FFT_Helper.H:53
int nextFastLen(int target, int nfactors=FastNumPrimeFactors())
Return the smallest fast FFT length greater than or equal to target.
Definition AMReX_FFT_Helper.H:286
DomainStrategy
Definition AMReX_FFT_Helper.H:57
void dtod_memcpy_async(void *p_d_dst, const void *p_d_src, const std::size_t sz) noexcept
Definition AMReX_GpuDevice.H:449
int NProcsSub() noexcept
number of ranks in current frame
Definition AMReX_ParallelContext.H:74
@ make_alias
Definition AMReX_MakeType.H:7
__host__ __device__ void ignore_unused(const Ts &...)
This shuts up the compiler about unused variables.
Definition AMReX.H:139
BoxND< 3 > Box
Box is an alias for amrex::BoxND instantiated with AMREX_SPACEDIM.
Definition AMReX_BaseFwd.H:30
IntVectND< 3 > IntVect
IntVect is an alias for amrex::IntVectND instantiated with AMREX_SPACEDIM.
Definition AMReX_BaseFwd.H:33
void Abort(const std::string &msg)
Print out message to cerr and exit via abort().
Definition AMReX.cpp:241
Definition AMReX_FFT_Helper.H:83
bool twod_mode
Definition AMReX_FFT_Helper.H:94
int batch_size
Batched FFT size. Only support in R2C, not R2X.
Definition AMReX_FFT_Helper.H:106
int nprocs
Max number of processes to use.
Definition AMReX_FFT_Helper.H:109
Fixed-size array that can be used on GPU.
Definition AMReX_Array.H:43