Block-Structured AMR Software Framework
AMReX_RungeKutta.H
Go to the documentation of this file.
1 #ifndef AMREX_RUNGE_KUTTA_H_
2 #define AMREX_RUNGE_KUTTA_H_
3 #include <AMReX_Config.H>
4 
5 #include <AMReX_FabArray.H>
6 
49 namespace amrex::RungeKutta {
50 
51 struct PostStageNoOp {
52  template <typename MF>
53  std::enable_if_t<IsFabArray<MF>::value> operator() (int, MF&) const {}
54 };
55 
56 namespace detail {
58 template <typename MF>
59 void rk_update (MF& Unew, MF const& Uold, MF const& dUdt, Real dt)
60 {
61  auto const& snew = Unew.arrays();
62  auto const& sold = Uold.const_arrays();
63  auto const& sdot = dUdt.const_arrays();
64  amrex::ParallelFor(Unew, IntVect(0), Unew.nComp(), [=] AMREX_GPU_DEVICE
65  (int bi, int i, int j, int k, int n) noexcept
66  {
67  snew[bi](i,j,k,n) = sold[bi](i,j,k,n) + dt*sdot[bi](i,j,k,n);
68  });
70 }
71 
73 template <typename MF>
74 void rk_update (MF& Unew, MF const& Uold, MF const& dUdt1, MF const& dUdt2, Real dt)
75 {
76  auto const& snew = Unew.arrays();
77  auto const& sold = Uold.const_arrays();
78  auto const& sdot1 = dUdt1.const_arrays();
79  auto const& sdot2 = dUdt2.const_arrays();
80  amrex::ParallelFor(Unew, IntVect(0), Unew.nComp(), [=] AMREX_GPU_DEVICE
81  (int bi, int i, int j, int k, int n) noexcept
82  {
83  snew[bi](i,j,k,n) = sold[bi](i,j,k,n) + dt*(sdot1[bi](i,j,k,n) +
84  sdot2[bi](i,j,k,n));
85  });
87 }
88 
90 template <typename MF>
91 void rk2_update_2 (MF& Unew, MF const& Uold, MF const& dUdt, Real dt)
92 {
93  auto const& snew = Unew.arrays();
94  auto const& sold = Uold.const_arrays();
95  auto const& sdot = dUdt.const_arrays();
96  amrex::ParallelFor(Unew, IntVect(0), Unew.nComp(), [=] AMREX_GPU_DEVICE
97  (int bi, int i, int j, int k, int n) noexcept
98  {
99  snew[bi](i,j,k,n) = Real(0.5)*(snew[bi](i,j,k,n) +
100  sold[bi](i,j,k,n) +
101  sdot[bi](i,j,k,n) * dt);
102  });
104 }
105 
107 template <typename MF>
108 void rk3_update_3 (MF& Unew, MF const& Uold, Array<MF,3> const& rkk, Real dt6)
109 {
110  auto const& snew = Unew.arrays();
111  auto const& sold = Uold.const_arrays();
112  auto const& k1 = rkk[0].const_arrays();
113  auto const& k2 = rkk[1].const_arrays();
114  auto const& k3 = rkk[2].const_arrays();
115  amrex::ParallelFor(Unew, IntVect(0), Unew.nComp(), [=] AMREX_GPU_DEVICE
116  (int bi, int i, int j, int k, int n) noexcept
117  {
118  snew[bi](i,j,k,n) = sold[bi](i,j,k,n)
119  + dt6 * (k1[bi](i,j,k,n) + k2[bi](i,j,k,n)
120  + Real(4.) * k3[bi](i,j,k,n));
121  });
123 }
124 
126 template <typename MF>
127 void rk4_update_4 (MF& Unew, MF const& Uold, Array<MF,4> const& rkk, Real dt6)
128 {
129  auto const& snew = Unew.arrays();
130  auto const& sold = Uold.const_arrays();
131  auto const& k1 = rkk[0].const_arrays();
132  auto const& k2 = rkk[1].const_arrays();
133  auto const& k3 = rkk[2].const_arrays();
134  auto const& k4 = rkk[3].const_arrays();
135  amrex::ParallelFor(Unew, IntVect(0), Unew.nComp(), [=] AMREX_GPU_DEVICE
136  (int bi, int i, int j, int k, int n) noexcept
137  {
138  snew[bi](i,j,k,n) = sold[bi](i,j,k,n)
139  + dt6 * ( k1[bi](i,j,k,n) + k4[bi](i,j,k,n)
140  + Real(2.)*(k2[bi](i,j,k,n) + k3[bi](i,j,k,n)));
141  });
143 }
144 }
145 
157 template <typename MF, typename F, typename FB, typename P = PostStageNoOp>
158 void RK2 (MF& Uold, MF& Unew, Real time, Real dt, F const& frhs, FB const& fillbndry,
159  P const& post_stage = PostStageNoOp())
160 {
161  BL_PROFILE("RungeKutta2");
162 
163  MF dUdt(Unew.boxArray(), Unew.DistributionMap(), Unew.nComp(), 0,
164  MFInfo(), Unew.Factory());
165 
166  // RK2 stage 1
167  fillbndry(1, Uold, time);
168  frhs(1, dUdt, Uold, time, Real(0.5)*dt);
169  // Unew = Uold + dt * dUdt
170  detail::rk_update(Unew, Uold, dUdt, dt);
171  post_stage(1, Unew);
172 
173  // RK2 stage 2
174  fillbndry(2, Unew, time+dt);
175  frhs(2, dUdt, Unew, time, Real(0.5)*dt);
176  // Unew = (Uold+Unew)/2 + dUdt_2 * dt/2,
177  // which is Unew = Uold + dt/2 * (dUdt_1 + dUdt_2)
178  detail::rk2_update_2(Unew, Uold, dUdt, dt);
179  post_stage(2, Unew);
180 }
181 
194 template <typename MF, typename F, typename FB, typename R,
195  typename P = PostStageNoOp>
196 void RK3 (MF& Uold, MF& Unew, Real time, Real dt, F const& frhs, FB const& fillbndry,
197  R const& store_crse_data, P const& post_stage = PostStageNoOp())
198 {
199  BL_PROFILE("RungeKutta3");
200 
201  Array<MF,3> rkk;
202  for (auto& mf : rkk) {
203  mf.define(Unew.boxArray(), Unew.DistributionMap(), Unew.nComp(), 0,
204  MFInfo(), Unew.Factory());
205  }
206 
207  // RK3 stage 1
208  fillbndry(1, Uold, time);
209  frhs(1, rkk[0], Uold, time, dt/Real(6.));
210  // Unew = Uold + k1 * dt
211  detail::rk_update(Unew, Uold, rkk[0], dt);
212  post_stage(1, Unew);
213 
214  // RK3 stage 2
215  fillbndry(2, Unew, time+dt);
216  frhs(2, rkk[1], Unew, time+dt, dt/Real(6.));
217  // Unew = Uold + (k1+k2) * dt/4
218  detail::rk_update(Unew, Uold, rkk[0], rkk[1], Real(0.25)*dt);
219  post_stage(2, Unew);
220 
221  // RK3 stage 3
222  Real t_half = time + Real(0.5)*dt;
223  fillbndry(3, Unew, t_half);
224  frhs(3, rkk[2], Unew, t_half, dt*Real(2./3.));
225  // Unew = Uold + (k1/6 + k2/6 + k3*(2/3)) * dt
226  detail::rk3_update_3(Unew, Uold, rkk, Real(1./6.)*dt);
227  post_stage(3, Unew);
228 
229  store_crse_data(rkk);
230 }
231 
244 template <typename MF, typename F, typename FB, typename R,
245  typename P = PostStageNoOp>
246 void RK4 (MF& Uold, MF& Unew, Real time, Real dt, F const& frhs, FB const& fillbndry,
247  R const& store_crse_data, P const& post_stage = PostStageNoOp())
248 {
249  BL_PROFILE("RungeKutta4");
250 
251  Array<MF,4> rkk;
252  for (auto& mf : rkk) {
253  mf.define(Unew.boxArray(), Unew.DistributionMap(), Unew.nComp(), 0,
254  MFInfo(), Unew.Factory());
255  }
256 
257  // RK4 stage 1
258  fillbndry(1, Uold, time);
259  frhs(1, rkk[0], Uold, time, dt/Real(6.));
260  // Unew = Uold + k1 * dt/2
261  detail::rk_update(Unew, Uold, rkk[0], Real(0.5)*dt);
262  post_stage(1, Unew);
263 
264  // RK4 stage 2
265  Real t_half = time + Real(0.5)*dt;
266  fillbndry(2, Unew, t_half);
267  frhs(2, rkk[1], Unew, t_half, dt/Real(3.));
268  // Unew = Uold + k2 * dt/2
269  detail::rk_update(Unew, Uold, rkk[1], Real(0.5)*dt);
270  post_stage(2, Unew);
271 
272  // RK4 stage 3
273  fillbndry(3, Unew, t_half);
274  frhs(3, rkk[2], Unew, t_half, dt/Real(3.));
275  // Unew = Uold + k3 * dt;
276  detail::rk_update(Unew, Uold, rkk[2], dt);
277  post_stage(3, Unew);
278 
279  // RK4 stage 4
280  fillbndry(4, Unew, time+dt);
281  frhs(4, rkk[3], Unew, time+dt, dt/Real(6.));
282  // Unew = Uold + (k1/6 + k2/3 + k3/3 + k4/6) * dt
283  detail::rk4_update_4(Unew, Uold, rkk, Real(1./6.)*dt);
284  post_stage(4, Unew);
285 
286  store_crse_data(rkk);
287 }
288 
289 }
290 
291 #endif
#define BL_PROFILE(a)
Definition: AMReX_BLProfiler.H:551
#define AMREX_GPU_DEVICE
Definition: AMReX_GpuQualifiers.H:18
void streamSynchronize() noexcept
Definition: AMReX_GpuDevice.H:237
void rk3_update_3(MF &Unew, MF const &Uold, Array< MF, 3 > const &rkk, Real dt6)
Unew = Uold + (k1 + k2 + 4*k3) * dt6, where dt6 = dt/6.
Definition: AMReX_RungeKutta.H:108
void rk2_update_2(MF &Unew, MF const &Uold, MF const &dUdt, Real dt)
Unew = (Uold+Unew)/2 + dUdt * dt/2.
Definition: AMReX_RungeKutta.H:91
void rk4_update_4(MF &Unew, MF const &Uold, Array< MF, 4 > const &rkk, Real dt6)
Unew = Uold + (k1+k4+2*(k2+k3))*dt6, where dt6 = dt/6.
Definition: AMReX_RungeKutta.H:127
void rk_update(MF &Unew, MF const &Uold, MF const &dUdt, Real dt)
Unew = Uold + dUdt * dt.
Definition: AMReX_RungeKutta.H:59
Functions for Runge-Kutta methods.
Definition: AMReX_RungeKutta.H:49
void RK2(MF &Uold, MF &Unew, Real time, Real dt, F const &frhs, FB const &fillbndry, P const &post_stage=PostStageNoOp())
Time stepping with RK2.
Definition: AMReX_RungeKutta.H:158
void RK4(MF &Uold, MF &Unew, Real time, Real dt, F const &frhs, FB const &fillbndry, R const &store_crse_data, P const &post_stage=PostStageNoOp())
Time stepping with RK4.
Definition: AMReX_RungeKutta.H:246
void RK3(MF &Uold, MF &Unew, Real time, Real dt, F const &frhs, FB const &fillbndry, R const &store_crse_data, P const &post_stage=PostStageNoOp())
Time stepping with RK3.
Definition: AMReX_RungeKutta.H:196
static int post_stage(amrex::Real t, N_Vector y_data, void *user_data)
Definition: AMReX_SundialsIntegrator.H:64
static constexpr int P
Definition: AMReX_OpenBC.H:14
std::enable_if_t< std::is_integral_v< T > > ParallelFor(TypeList< CTOs... > ctos, std::array< int, sizeof...(CTOs)> const &runtime_options, T N, F &&f)
Definition: AMReX_CTOParallelForImpl.H:200
IntVectND< AMREX_SPACEDIM > IntVect
Definition: AMReX_BaseFwd.H:30
std::array< T, N > Array
Definition: AMReX_Array.H:24
Definition: AMReX_FabArrayCommI.H:896
FabArray memory allocation information.
Definition: AMReX_FabArray.H:66
Definition: AMReX_RungeKutta.H:51
std::enable_if_t< IsFabArray< MF >::value > operator()(int, MF &) const
Definition: AMReX_RungeKutta.H:53