Block-Structured AMR Software Framework
Loading...
Searching...
No Matches
AMReX_RungeKutta.H
Go to the documentation of this file.
1#ifndef AMREX_RUNGE_KUTTA_H_
2#define AMREX_RUNGE_KUTTA_H_
3#include <AMReX_Config.H>
4
5#include <AMReX_Concepts.H>
6#include <AMReX_FabArray.H>
7
51
53 template <FabArrayType MF>
54 void operator() (int, MF&) const {}
55};
56
58namespace detail {
60template <typename MF>
61void rk_update (MF& Unew, MF const& Uold, MF const& dUdt, Real dt)
62{
63 auto const& snew = Unew.arrays();
64 auto const& sold = Uold.const_arrays();
65 auto const& sdot = dUdt.const_arrays();
66 amrex::ParallelFor(Unew, IntVect(0), Unew.nComp(), [=] AMREX_GPU_DEVICE
67 (int bi, int i, int j, int k, int n) noexcept
68 {
69 snew[bi](i,j,k,n) = sold[bi](i,j,k,n) + dt*sdot[bi](i,j,k,n);
70 });
71 if (!Gpu::inNoSyncRegion()) {
72 Gpu::streamSynchronize();
73 }
74}
75
77template <typename MF>
78void rk_update (MF& Unew, MF const& Uold, MF const& dUdt1, MF const& dUdt2, Real dt)
79{
80 auto const& snew = Unew.arrays();
81 auto const& sold = Uold.const_arrays();
82 auto const& sdot1 = dUdt1.const_arrays();
83 auto const& sdot2 = dUdt2.const_arrays();
84 amrex::ParallelFor(Unew, IntVect(0), Unew.nComp(), [=] AMREX_GPU_DEVICE
85 (int bi, int i, int j, int k, int n) noexcept
86 {
87 snew[bi](i,j,k,n) = sold[bi](i,j,k,n) + dt*(sdot1[bi](i,j,k,n) +
88 sdot2[bi](i,j,k,n));
89 });
90 if (!Gpu::inNoSyncRegion()) {
91 Gpu::streamSynchronize();
92 }
93}
94
96template <typename MF>
97void rk2_update_2 (MF& Unew, MF const& Uold, MF const& dUdt, Real dt)
98{
99 auto const& snew = Unew.arrays();
100 auto const& sold = Uold.const_arrays();
101 auto const& sdot = dUdt.const_arrays();
102 amrex::ParallelFor(Unew, IntVect(0), Unew.nComp(), [=] AMREX_GPU_DEVICE
103 (int bi, int i, int j, int k, int n) noexcept
104 {
105 snew[bi](i,j,k,n) = Real(0.5)*(snew[bi](i,j,k,n) +
106 sold[bi](i,j,k,n) +
107 sdot[bi](i,j,k,n) * dt);
108 });
109 if (!Gpu::inNoSyncRegion()) {
110 Gpu::streamSynchronize();
111 }
112}
113
115template <typename MF>
116void rk3_update_3 (MF& Unew, MF const& Uold, Array<MF,3> const& rkk, Real dt6)
117{
118 auto const& snew = Unew.arrays();
119 auto const& sold = Uold.const_arrays();
120 auto const& k1 = rkk[0].const_arrays();
121 auto const& k2 = rkk[1].const_arrays();
122 auto const& k3 = rkk[2].const_arrays();
123 amrex::ParallelFor(Unew, IntVect(0), Unew.nComp(), [=] AMREX_GPU_DEVICE
124 (int bi, int i, int j, int k, int n) noexcept
125 {
126 snew[bi](i,j,k,n) = sold[bi](i,j,k,n)
127 + dt6 * (k1[bi](i,j,k,n) + k2[bi](i,j,k,n)
128 + Real(4.) * k3[bi](i,j,k,n));
129 });
130 if (!Gpu::inNoSyncRegion()) {
131 Gpu::streamSynchronize();
132 }
133}
134
136template <typename MF>
137void rk4_update_4 (MF& Unew, MF const& Uold, Array<MF,4> const& rkk, Real dt6)
138{
139 auto const& snew = Unew.arrays();
140 auto const& sold = Uold.const_arrays();
141 auto const& k1 = rkk[0].const_arrays();
142 auto const& k2 = rkk[1].const_arrays();
143 auto const& k3 = rkk[2].const_arrays();
144 auto const& k4 = rkk[3].const_arrays();
145 amrex::ParallelFor(Unew, IntVect(0), Unew.nComp(), [=] AMREX_GPU_DEVICE
146 (int bi, int i, int j, int k, int n) noexcept
147 {
148 snew[bi](i,j,k,n) = sold[bi](i,j,k,n)
149 + dt6 * ( k1[bi](i,j,k,n) + k4[bi](i,j,k,n)
150 + Real(2.)*(k2[bi](i,j,k,n) + k3[bi](i,j,k,n)));
151 });
152 if (!Gpu::inNoSyncRegion()) {
153 Gpu::streamSynchronize();
154 }
155}
156}
158
170template <typename MF, typename F, typename FB, typename P = PostStageNoOp>
171void RK2 (MF& Uold, MF& Unew, Real time, Real dt, F const& frhs, FB const& fillbndry,
172 P const& post_stage = PostStageNoOp())
173{
174 BL_PROFILE("RungeKutta2");
175
176 MF dUdt(Unew.boxArray(), Unew.DistributionMap(), Unew.nComp(), 0,
177 MFInfo().SetArena(The_Async_Arena()), Unew.Factory());
178
179 // RK2 stage 1
180 fillbndry(1, Uold, time);
181 frhs(1, dUdt, Uold, time, Real(0.5)*dt);
182 // Unew = Uold + dt * dUdt
183 detail::rk_update(Unew, Uold, dUdt, dt);
184 post_stage(1, Unew);
185
186 // RK2 stage 2
187 fillbndry(2, Unew, time+dt);
188 frhs(2, dUdt, Unew, time+dt, Real(0.5)*dt);
189 // Unew = (Uold+Unew)/2 + dUdt_2 * dt/2,
190 // which is Unew = Uold + dt/2 * (dUdt_1 + dUdt_2)
191 detail::rk2_update_2(Unew, Uold, dUdt, dt);
192 post_stage(2, Unew);
193}
194
208template <typename MF, typename F, typename FB, typename R,
209 typename P = PostStageNoOp>
210void RK3 (MF& Uold, MF& Unew, Real time, Real dt, F const& frhs, FB const& fillbndry,
211 R const& store_crse_data, P const& post_stage = PostStageNoOp())
212{
213 BL_PROFILE("RungeKutta3");
214
215 Array<MF,3> rkk;
216 for (auto& mf : rkk) {
217 mf.define(Unew.boxArray(), Unew.DistributionMap(), Unew.nComp(), 0,
218 MFInfo().SetArena(The_Async_Arena()), Unew.Factory());
219 }
220
221 // RK3 stage 1
222 fillbndry(1, Uold, time);
223 frhs(1, rkk[0], Uold, time, dt/Real(6.));
224 // Unew = Uold + k1 * dt
225 detail::rk_update(Unew, Uold, rkk[0], dt);
226 post_stage(1, Unew);
227
228 // RK3 stage 2
229 fillbndry(2, Unew, time+dt);
230 frhs(2, rkk[1], Unew, time+dt, dt/Real(6.));
231 // Unew = Uold + (k1+k2) * dt/4
232 detail::rk_update(Unew, Uold, rkk[0], rkk[1], Real(0.25)*dt);
233 post_stage(2, Unew);
234
235 // RK3 stage 3
236 Real t_half = time + Real(0.5)*dt;
237 fillbndry(3, Unew, t_half);
238 frhs(3, rkk[2], Unew, t_half, dt*Real(2./3.));
239 // Unew = Uold + (k1/6 + k2/6 + k3*(2/3)) * dt
240 detail::rk3_update_3(Unew, Uold, rkk, Real(1./6.)*dt);
241 post_stage(3, Unew);
242
243 store_crse_data(rkk);
244}
245
259template <typename MF, typename F, typename FB, typename R,
260 typename P = PostStageNoOp>
261void RK4 (MF& Uold, MF& Unew, Real time, Real dt, F const& frhs, FB const& fillbndry,
262 R const& store_crse_data, P const& post_stage = PostStageNoOp())
263{
264 BL_PROFILE("RungeKutta4");
265
266 Array<MF,4> rkk;
267 for (auto& mf : rkk) {
268 mf.define(Unew.boxArray(), Unew.DistributionMap(), Unew.nComp(), 0,
269 MFInfo().SetArena(The_Async_Arena()), Unew.Factory());
270 }
271
272 // RK4 stage 1
273 fillbndry(1, Uold, time);
274 frhs(1, rkk[0], Uold, time, dt/Real(6.));
275 // Unew = Uold + k1 * dt/2
276 detail::rk_update(Unew, Uold, rkk[0], Real(0.5)*dt);
277 post_stage(1, Unew);
278
279 // RK4 stage 2
280 Real t_half = time + Real(0.5)*dt;
281 fillbndry(2, Unew, t_half);
282 frhs(2, rkk[1], Unew, t_half, dt/Real(3.));
283 // Unew = Uold + k2 * dt/2
284 detail::rk_update(Unew, Uold, rkk[1], Real(0.5)*dt);
285 post_stage(2, Unew);
286
287 // RK4 stage 3
288 fillbndry(3, Unew, t_half);
289 frhs(3, rkk[2], Unew, t_half, dt/Real(3.));
290 // Unew = Uold + k3 * dt;
291 detail::rk_update(Unew, Uold, rkk[2], dt);
292 post_stage(3, Unew);
293
294 // RK4 stage 4
295 fillbndry(4, Unew, time+dt);
296 frhs(4, rkk[3], Unew, time+dt, dt/Real(6.));
297 // Unew = Uold + (k1/6 + k2/3 + k3/3 + k4/6) * dt
298 detail::rk4_update_4(Unew, Uold, rkk, Real(1./6.)*dt);
299 post_stage(4, Unew);
300
301 store_crse_data(rkk);
302}
303
304}
305
306#endif
#define BL_PROFILE(a)
Definition AMReX_BLProfiler.H:551
#define AMREX_GPU_DEVICE
Definition AMReX_GpuQualifiers.H:18
amrex_real Real
Floating Point Type for Fields.
Definition AMReX_REAL.H:79
std::array< T, N > Array
Definition AMReX_Array.H:26
Arena * The_Async_Arena()
Definition AMReX_Arena.cpp:830
Functions for Runge-Kutta methods.
Definition AMReX_RungeKutta.H:50
void RK2(MF &Uold, MF &Unew, Real time, Real dt, F const &frhs, FB const &fillbndry, P const &post_stage=PostStageNoOp())
Time stepping with RK2.
Definition AMReX_RungeKutta.H:171
void RK3(MF &Uold, MF &Unew, Real time, Real dt, F const &frhs, FB const &fillbndry, R const &store_crse_data, P const &post_stage=PostStageNoOp())
Time stepping with RK3.
Definition AMReX_RungeKutta.H:210
void ParallelFor(TypeList< CTOs... > ctos, std::array< int, sizeof...(CTOs)> const &runtime_options, T N, F &&f)
Definition AMReX_CTOParallelForImpl.H:202
IntVectND< 3 > IntVect
IntVect is an alias for amrex::IntVectND instantiated with AMREX_SPACEDIM.
Definition AMReX_BaseFwd.H:33
FabArray memory allocation information.
Definition AMReX_FabArray.H:68
MFInfo & SetArena(Arena *ar) noexcept
Definition AMReX_FabArray.H:79
Definition AMReX_RungeKutta.H:52
void operator()(int, MF &) const
Definition AMReX_RungeKutta.H:54