Block-Structured AMR Software Framework
Loading...
Searching...
No Matches
AMReX_FFT_Stokes.H
Go to the documentation of this file.
1#ifndef AMREX_FFT_STOKES_H_
2#define AMREX_FFT_STOKES_H_
3
4#include <AMReX_FFT.H>
5#include <AMReX_Geometry.H>
6
7namespace amrex::FFT
8{
9
14template <typename MF = MultiFab>
15class Stokes
16{
17public:
18
19 static_assert(AMREX_SPACEDIM >= 2, "FFT::Stokes requires 2D or 3D");
20
22
31 Stokes (Geometry const& geom,
32 Array<std::pair<Boundary,Boundary>,AMREX_SPACEDIM> const& bc)
33 requires (IsFabArray_v<MF>)
34 : m_domain_lo(geom.Domain().smallEnd()),
35 m_geom(detail::shift_geom(geom)),
36 m_bc(bc)
37 {
38 bool all_periodic = true;
39 for (int idim = 0; idim < AMREX_SPACEDIM; ++idim) {
40 all_periodic = all_periodic
41 && (bc[idim].first == Boundary::periodic)
42 && (bc[idim].second == Boundary::periodic);
43 }
44 if (all_periodic) {
45 m_r2c = std::make_unique<R2C<typename MF::value_type>>(m_geom.Domain());
46 } else {
47 amrex::Abort("FFT::Stokes: only supports periodic BC");
48 }
49 }
50
56 explicit Stokes (Geometry const& geom)
57 requires (IsFabArray_v<MF>)
58 : m_domain_lo(geom.Domain().smallEnd()),
59 m_geom(detail::shift_geom(geom)),
60 m_bc{AMREX_D_DECL(std::make_pair(Boundary::periodic,Boundary::periodic),
61 std::make_pair(Boundary::periodic,Boundary::periodic),
62 std::make_pair(Boundary::periodic,Boundary::periodic))}
63 {
64 if (m_geom.isAllPeriodic()) {
65 m_r2c = std::make_unique<R2C<typename MF::value_type>>(m_geom.Domain());
66 } else {
67 amrex::Abort("FFT::Stokes: only supports periodic BC");
68 }
69 }
70
92 void solve (AMREX_D_DECL(MF& U, MF& V, MF& W), MF& p,
93 AMREX_D_DECL(MF const& rhsx, MF const& rhsy, MF const& rhsz),
94 typename MF::value_type alpha, typename MF::value_type eta);
95
96private:
97 IntVect m_domain_lo;
98 Geometry m_geom;
99 Array<std::pair<Boundary,Boundary>,AMREX_SPACEDIM> m_bc;
100 std::unique_ptr<R2C<typename MF::value_type>> m_r2c;
101};
102
103
104template <typename MF>
106 MF& V,
107#if (BL_SPACEDIM == 3)
108 MF& W,
109#endif
110 MF& p,
111 MF const& rhsx,
112 MF const& rhsy,
113#if (BL_SPACEDIM == 3)
114 MF const& rhsz,
115#endif
116 typename MF::value_type alpha,
117 typename MF::value_type eta)
118{
119 BL_PROFILE("FFT::Stokes::solve");
120
121 using T = typename MF::value_type;
122
123 AMREX_ASSERT(p.ixType() == IndexType::TheCellType() &&
125 rhsx.ixType() == U.ixType(),
126 && V.ixType() == IndexType(IntVect::TheDimensionVector(1)) &&
127 rhsy.ixType() == V.ixType(),
128 && W.ixType() == IndexType(IntVect::TheDimensionVector(2)) &&
129 rhsz.ixType() == W.ixType()));
130
131 MF* Umf = &U;
132 MF* Vmf = &V;
133#if (BL_SPACEDIM == 3)
134 MF* Wmf = &W;
135#endif
136 MF* pmf = &p;
137 MF const* rhsxmf = &rhsx;
138 MF const* rhsymf = &rhsy;
139#if (BL_SPACEDIM == 3)
140 MF const* rhszmf = &rhsz;
141#endif
142 MF Utmp, rhsxtmp;
143 MF Vtmp, rhsytmp;
144#if (BL_SPACEDIM == 3)
145 MF Wtmp, rhsztmp;
146#endif
147 MF ptmp;
148 if (m_domain_lo != 0) {
149 detail::shift_mfs(m_domain_lo, U, rhsx, Utmp, rhsxtmp);
150 detail::shift_mfs(m_domain_lo, V, rhsy, Vtmp, rhsytmp);
151#if (BL_SPACEDIM == 3)
152 detail::shift_mfs(m_domain_lo, W, rhsz, Wtmp, rhsztmp);
153#endif
154 detail::shift_mf(m_domain_lo, p, ptmp);
155 Umf = &Utmp;
156 Vmf = &Vtmp;
157#if (BL_SPACEDIM == 3)
158 Wmf = &Wtmp;
159#endif
160 pmf = &ptmp;
161 rhsxmf = &rhsxtmp;
162 rhsymf = &rhsytmp;
163#if (BL_SPACEDIM == 3)
164 rhszmf = &rhsztmp;
165#endif
166 }
167
168 auto& r2c = *m_r2c;
169 auto const& dxinv = m_geom.InvCellSizeArray();
170 auto const scaling = r2c.scalingFactor();
171 auto const& [cba, cdm] = r2c.getSpectralDataLayout();
172
173 cMF phat(cba, cdm, 1, 0);
174
175 cMF rxhat(cba,cdm,1,0);
176 r2c.forward(*rhsxmf, rxhat);
177
178 cMF ryhat(cba,cdm,1,0);
179 r2c.forward(*rhsymf, ryhat);
180
181#if (BL_SPACEDIM == 3)
182 cMF rzhat(cba,cdm,1,0);
183 r2c.forward(*rhszmf, rzhat);
184#endif
185
186 using Complex = GpuComplex<T>;
187 T constexpr tol = std::numeric_limits<T>::epsilon() * T(10);
188 int const nx = m_geom.Domain().length(0);
189#if (AMREX_SPACEDIM >= 2)
190 int const ny = m_geom.Domain().length(1);
191#endif
192#if (AMREX_SPACEDIM == 3)
193 int const nz = m_geom.Domain().length(2);
194#endif
195 for (MFIter mfi(phat); mfi.isValid(); ++mfi) {
196 auto const& pb = phat[mfi].box();
197 auto const& parr = phat[mfi].array();
198 auto const& rx = rxhat[mfi].array();
199 auto const& ry = ryhat[mfi].array();
200#if (BL_SPACEDIM == 3)
201 auto const& rz = rzhat[mfi].array();
202#endif
203 AMREX_D_TERM(T kwx = T(2)*Math::pi<T>()/T(nx);,
204 T kwy = T(2)*Math::pi<T>()/T(ny);,
205 T kwz = T(2)*Math::pi<T>()/T(nz);)
206 ParallelFor(pb, [=] AMREX_GPU_DEVICE (int i, int j, int k) noexcept
207 {
208 AMREX_D_TERM(int ik = (i <= nx/2) ? i : i - nx;,
209 int jk = (j <= ny/2) ? j : j - ny;,
210 int kk = (k <= nz/2) ? k : k - nz);
211
212 GpuArray<T,AMREX_SPACEDIM> kwave{AMREX_D_DECL(ik*kwx,jk*kwy,kk*kwz)};
213
214 T delsqk = AMREX_D_TERM(T(2)*(std::cos(kwave[0])-T(1))*(dxinv[0]*dxinv[0]),
215 + T(2)*(std::cos(kwave[1])-T(1))*(dxinv[1]*dxinv[1]),
216 + T(2)*(std::cos(kwave[2])-T(1))*(dxinv[2]*dxinv[2]));
217
219 {AMREX_D_DECL(Complex((std::cos(kwave[0])-T(1))*dxinv[0],
220 std::sin(kwave[0]) *dxinv[0]),
221 Complex((std::cos(kwave[1])-T(1))*dxinv[1],
222 std::sin(kwave[1]) *dxinv[1]),
223 Complex((std::cos(kwave[2])-T(1))*dxinv[2],
224 std::sin(kwave[2]) *dxinv[2]))};
225
227 {AMREX_D_DECL(Complex((T(1)-std::cos(kwave[0]))*dxinv[0],
228 std::sin(kwave[0]) *dxinv[0]),
229 Complex((T(1)-std::cos(kwave[1]))*dxinv[1],
230 std::sin(kwave[1]) *dxinv[1]),
231 Complex((T(1)-std::cos(kwave[2]))*dxinv[2],
232 std::sin(kwave[2]) *dxinv[2]))};
233
234 AMREX_D_TERM(Complex const rxk = rx(i,j,k);,
235 Complex const ryk = ry(i,j,k);,
236 Complex const rzk = rz(i,j,k);)
237
238 Complex rhsdotdp = scaling * (AMREX_D_TERM(rxk*delkp[0],
239 +ryk*delkp[1],
240 +rzk*delkp[2]));
241
242 if (std::abs(delsqk) > tol) {
243 parr(i,j,k)= rhsdotdp / delsqk;
244 } else {
245 parr(i,j,k)= T(0.);
246 }
247
248 T diffop = alpha - eta*delsqk;
249 if (diffop > tol) {
250 rx(i,j,k) = (scaling*rxk - parr(i,j,k)*delkm[0])/diffop;
251 ry(i,j,k) = (scaling*ryk - parr(i,j,k)*delkm[1])/diffop;
252#if (BL_SPACEDIM == 3)
253 rz(i,j,k) = (scaling*rzk - parr(i,j,k)*delkm[2])/diffop;
254#endif
255 } else {
256 rx(i,j,k) = T(0.);
257 ry(i,j,k) = T(0.);
258#if (BL_SPACEDIM == 3)
259 rz(i,j,k) = T(0.);
260#endif
261 }
262
263 });
264 }
265
266 r2c.backward(phat, *pmf);
267 r2c.backward(rxhat, *Umf);
268 r2c.backward(ryhat, *Vmf);
269#if (BL_SPACEDIM == 3)
270 r2c.backward(rzhat, *Wmf);
271#endif
272}
273
274} // namespace amrex::FFT
275
276#endif
#define BL_PROFILE(a)
Definition AMReX_BLProfiler.H:551
#define AMREX_ASSERT(EX)
Definition AMReX_BLassert.H:38
#define AMREX_GPU_DEVICE
Definition AMReX_GpuQualifiers.H:18
#define AMREX_D_TERM(a, b, c)
Definition AMReX_SPACE.H:172
#define AMREX_D_DECL(a, b, c)
Definition AMReX_SPACE.H:171
#define BL_SPACEDIM
Definition AMReX_SPACE.H:15
Stokes solver for periodic domains using FFT.
Definition AMReX_FFT_Stokes.H:16
Stokes(Geometry const &geom, Array< std::pair< Boundary, Boundary >, 3 > const &bc)
Construct a Stokes solver with explicit boundary types.
Definition AMReX_FFT_Stokes.H:31
Stokes(Geometry const &geom)
Construct a purely periodic Stokes solver.
Definition AMReX_FFT_Stokes.H:56
void solve(MF &U, MF &V, MF &W, MF &p, MF const &rhsx, MF const &rhsy, MF const &rhsz, typename MF::value_type alpha, typename MF::value_type eta)
Solve the generalized Stokes problem in spectral space.
Definition AMReX_FFT_Stokes.H:105
Box box(int K) const noexcept
Return the Kth Box in the BoxArray. That is, the valid region of the Kth grid.
Definition AMReX_FabArrayBase.H:101
Array4< typename FabArray< FAB >::value_type const > array(const MFIter &mfi) const noexcept
Definition AMReX_FabArray.H:561
Rectangular problem domain geometry.
Definition AMReX_Geometry.H:75
const Box & Domain() const noexcept
Returns our rectangular domain.
Definition AMReX_Geometry.H:216
bool isAllPeriodic() const noexcept
Is domain periodic in all directions?
Definition AMReX_Geometry.H:344
__host__ __device__ constexpr CellIndex ixType(int dir) const noexcept
Returns the CellIndex in direction dir.
Definition AMReX_IndexType.H:117
__host__ static __device__ constexpr IndexTypeND< dim > TheCellType() noexcept
This static member function returns an IndexTypeND object of value IndexTypeND::CELL....
Definition AMReX_IndexType.H:150
__host__ static __device__ constexpr IntVectND< dim > TheDimensionVector(int d) noexcept
This static member function returns a reference to a constant IntVectND object, all of whose dim argu...
Definition AMReX_IntVect.H:790
Iterator for looping ever tiles and boxes of amrex::FabArray based containers.
Definition AMReX_MFIter.H:88
bool isValid() const noexcept
Is the iterator valid i.e. is it associated with a FAB?
Definition AMReX_MFIter.H:172
std::array< T, N > Array
Definition AMReX_Array.H:26
Definition AMReX_FFT_Helper.H:53
Boundary
Definition AMReX_FFT_Helper.H:59
void ParallelFor(TypeList< CTOs... > ctos, std::array< int, sizeof...(CTOs)> const &runtime_options, T N, F &&f)
Definition AMReX_CTOParallelForImpl.H:202
double second() noexcept
Definition AMReX_Utility.cpp:940
void Abort(const std::string &msg)
Print out message to cerr and exit via abort().
Definition AMReX.cpp:241
Fixed-size array that can be used on GPU.
Definition AMReX_Array.H:43
A host / device complex number type, because std::complex doesn't work in device code with Cuda yet.
Definition AMReX_GpuComplex.H:30