Block-Structured AMR Software Framework
AMReX_RKIntegrator.H
Go to the documentation of this file.
1 #ifndef AMREX_RK_INTEGRATOR_H
2 #define AMREX_RK_INTEGRATOR_H
3 #include <AMReX_REAL.H>
4 #include <AMReX_Vector.H>
5 #include <AMReX_ParmParse.H>
6 #include <AMReX_IntegratorBase.H>
7 #include <functional>
8 
9 namespace amrex {
10 
11 enum struct ButcherTableauTypes {
12  User = 0,
14  Trapezoid,
15  SSPRK3,
16  RK4,
17  NumTypes
18 };
19 
20 template<class T>
21 class RKIntegrator : public IntegratorBase<T>
22 {
23 private:
25 
26  // Butcher tableau identifiers
28 
29  // Number of RK stages
31 
32  // A matrix from Butcher tableau
34 
35  // b vector from Butcher tableau
37 
38  // c vector from Butcher tableau
40 
41  // RK embedded method b vector
43 
44  // RK stage right-hand sides
46 
47  // Current (internal) state and time
49  amrex::Real time_current;
50 
52  {
53  switch (tableau_type)
54  {
56  nodes = {0.0};
57  tableau = {{0.0}};
58  weights = {1.0};
59  break;
61  nodes = {0.0,
62  1.0};
63  tableau = {{0.0},
64  {1.0, 0.0}};
65  weights = {0.5, 0.5};
66  break;
68  nodes = {0.0,
69  1.0,
70  0.5};
71  tableau = {{0.0},
72  {1.0, 0.0},
73  {0.25, 0.25, 0.0}};
74  weights = {1./6., 1./6., 2./3.};
75  break;
77  nodes = {0.0,
78  0.5,
79  0.5,
80  1.0};
81  tableau = {{0.0},
82  {0.5, 0.0},
83  {0.0, 0.5, 0.0},
84  {0.0, 0.0, 1.0, 0.0}};
85  weights = {1./6., 1./3., 1./3., 1./6.};
86  break;
87  default:
88  amrex::Error("Invalid RK Integrator tableau type");
89  break;
90  }
91 
93  }
94 
96  {
97  amrex::ParmParse pp("integration.rk");
98 
99  // Read an integrator type, if not recognized, then read weights/nodes/butcher tableau
100  int _tableau_type = 0;
101  pp.get("type", _tableau_type);
102  tableau_type = static_cast<ButcherTableauTypes>(_tableau_type);
103 
104  // By default, define no extended weights
105  extended_weights = {};
106 
108  {
109  // Read weights/nodes/butcher tableau
110  pp.getarr("weights", weights);
111  pp.queryarr("extended_weights", extended_weights);
112  pp.getarr("nodes", nodes);
113 
114  amrex::Vector<amrex::Real> btable; // flattened into row major format
115  pp.getarr("tableau", btable);
116 
117  // Sanity check the inputs
118  if (weights.size() != nodes.size())
119  {
120  amrex::Error("integration.rk.weights should be the same length as integration.rk.nodes");
121  } else {
123  const int nTableau = (number_nodes * (number_nodes + 1)) / 2; // includes diagonal
124  if (btable.size() != nTableau)
125  {
126  amrex::Error("integration.rk.tableau incorrect length - should include the Butcher Tableau diagonal.");
127  }
128  }
129 
130  // Fill tableau from the flattened entries
131  int k = 0;
132  for (int i = 0; i < number_nodes; ++i)
133  {
134  amrex::Vector<amrex::Real> stage_row;
135  for (int j = 0; j <= i; ++j)
136  {
137  stage_row.push_back(btable[k]);
138  ++k;
139  }
140 
141  tableau.push_back(stage_row);
142  }
143 
144  // Check that this is an explicit method
145  for (const auto& astage : tableau)
146  {
147  if (astage.back() != 0.0)
148  {
149  amrex::Error("RKIntegrator currently only supports explicit Butcher tableaus.");
150  }
151  }
153  {
155  } else {
156  amrex::Error("RKIntegrator received invalid input for integration.rk.type");
157  }
158  }
159 
160  void initialize_stages (const T& S_data, const amrex::Real time)
161  {
162  // Create data for stage RHS
163  for (int i = 0; i < number_nodes; ++i)
164  {
166  }
167 
168  // Create and initialize data for current state
170  IntegratorOps<T>::Copy(*S_current[0], S_data);
171 
172  // Set the initial time
173  time_current = time;
174  }
175 
176 public:
178 
179  RKIntegrator (const T& S_data, const amrex::Real time = 0.0)
180  {
181  initialize(S_data, time);
182  }
183 
184  void initialize (const T& S_data, const amrex::Real time = 0.0)
185  {
187  initialize_stages(S_data, time);
188  }
189 
190  virtual ~RKIntegrator () {}
191 
192  amrex::Real advance (T& S_old, T& S_new, amrex::Real time, const amrex::Real dt) override
193  {
194  // Assume before step() that S_old is valid data at the current time ("time" argument)
195  // And that if data is a MultiFab, both S_old and S_new contain ghost cells for evaluating a stencil based RHS
196  // We need this from S_old. This is convenient for S_new to have so we can use it
197  // as scratch space for stage values without creating a new scratch MultiFab with ghost cells.
198 
199  // Fill the RHS F_nodes at each stage
200  for (int i = 0; i < number_nodes; ++i)
201  {
202  // Get current stage time, t = t_old + h * Ci
203  amrex::Real stage_time = time + dt * nodes[i];
204 
205  // Fill S_new with the solution value for evaluating F at the current stage
206  // Copy S_new = S_old
207  IntegratorOps<T>::Copy(S_new, S_old);
208  if (i > 0) {
209  // Saxpy across the tableau row:
210  // S_new += h * Aij * Fj
211  // We should fuse these kernels ...
212  for (int j = 0; j < i; ++j)
213  {
214  IntegratorOps<T>::Saxpy(S_new, dt * tableau[i][j], *F_nodes[j]);
215  }
216 
217  BaseT::post_stage_action(S_new, stage_time);
218  }
219 
220  // Fill F[i], the RHS at the current stage
221  // F[i] = RHS(y, t) at y = stage_value, t = stage_time
222  BaseT::Rhs(*F_nodes[i], S_new, stage_time);
223  }
224 
225  // Fill new State, starting with S_new = S_old.
226  // Then Saxpy S_new += h * Wi * Fi for integration weights Wi
227  // We should fuse these kernels ...
228  IntegratorOps<T>::Copy(S_new, S_old);
229  for (int i = 0; i < number_nodes; ++i)
230  {
231  IntegratorOps<T>::Saxpy(S_new, dt * weights[i], *F_nodes[i]);
232  }
233 
234  BaseT::post_step_action(S_new, time + dt);
235 
236  // If we are working with an extended Butcher tableau, we can estimate the error here,
237  // and then calculate an adaptive time step.
238 
239  // Save last completed step size for time_interpolate
241 
242  // Return time step
243  return dt;
244  }
245 
246  void evolve (T& S_out, const amrex::Real time_out) override
247  {
248  amrex::Real dt = BaseT::time_step;
249  bool stop = false;
250 
251  for (int step_number = 0; step_number < BaseT::max_steps && !stop; ++step_number)
252  {
253  // Adjust step size to reach output time
254  // protect against roundoff
255  if ((time_out-time_current) < dt || almostEqual(time_out-time_current,dt,1000)) {
256  dt = time_out - time_current;
257  stop = true;
258  }
259 
260  // Call the time integrator step
261  advance(*S_current[0], S_out, time_current, dt);
262 
263  // Update current state S_current = S_out
264  IntegratorOps<T>::Copy(*S_current[0], S_out);
265 
266  // Update time
267  time_current += dt;
268 
269  // Save last completed step size for time_interpolate
271 
272  if (step_number == BaseT::max_steps - 1) {
273  Error("Did not reach output time in max steps.");
274  }
275  }
276  }
277 
278  void time_interpolate (const T& /* S_new */, const T& S_old, amrex::Real timestep_fraction, T& data) override
279  {
280  // data = S_old*(1-time_step_fraction) + S_new*(time_step_fraction)
281  /*
282  data.setVal(0);
283 
284  IntegratorOps<T>::Saxpy(data, 1-timestep_fraction, S_old);
285  IntegratorOps<T>::Saxpy(data, timestep_fraction, S_new);
286  */
287 
288  // currently we only do this for 4th order RK
290 
291  // fill data using MC Equation 39 at time + timestep_fraction * dt
292  amrex::Real c = 0;
293 
294  // data = S_old
295  IntegratorOps<T>::Copy(data, S_old);
296 
297  // data += (chi - 3/2 * chi^2 + 2/3 * chi^3) * k1
298  c = timestep_fraction - 1.5 * std::pow(timestep_fraction, 2) + 2./3. * std::pow(timestep_fraction, 3);
300 
301  // data += (chi^2 - 2/3 * chi^3) * k2
302  c = std::pow(timestep_fraction, 2) - 2./3. * std::pow(timestep_fraction, 3);
304 
305  // data += (chi^2 - 2/3 * chi^3) * k3
306  c = std::pow(timestep_fraction, 2) - 2./3. * std::pow(timestep_fraction, 3);
308 
309  // data += (-1/2 * chi^2 + 2/3 * chi^3) * k4
310  c = -0.5 * std::pow(timestep_fraction, 2) + 2./3. * std::pow(timestep_fraction, 3);
312 
313  }
314 
315  void map_data (std::function<void(T&)> Map) override
316  {
317  for (auto& F : F_nodes) {
318  Map(*F);
319  }
320  }
321 
322 };
323 
324 }
325 
326 #endif
#define AMREX_ASSERT(EX)
Definition: AMReX_BLassert.H:38
amrex::ParmParse pp
Input file parser instance for the given namespace.
Definition: AMReX_HypreIJIface.cpp:15
Definition: AMReX_IntegratorBase.H:163
std::function< void(T &, amrex::Real)> post_step_action
The post_step_action function is called by the integrator on the computed state just after it is comp...
Definition: AMReX_IntegratorBase.H:198
int max_steps
Max number of internal steps before an error is returned (Long)
Definition: AMReX_IntegratorBase.H:248
std::function< void(T &rhs, T &state, const amrex::Real time)> Rhs
Rhs is the right-hand-side function the integrator will use.
Definition: AMReX_IntegratorBase.H:168
std::function< void(T &, amrex::Real)> post_stage_action
The post_stage_action function is called by the integrator on the computed stage just after it is com...
Definition: AMReX_IntegratorBase.H:192
amrex::Real time_step
Current integrator time step size (Real)
Definition: AMReX_IntegratorBase.H:221
amrex::Real previous_time_step
Step size of the last completed step (Real)
Definition: AMReX_IntegratorBase.H:226
Parse Parameters From Command Line and Input Files.
Definition: AMReX_ParmParse.H:320
void get(const char *name, bool &ref, int ival=FIRST) const
Same as getkth() but searches for the last occurrence of name.
Definition: AMReX_ParmParse.cpp:1292
int queryarr(const char *name, std::vector< int > &ref, int start_ix=FIRST, int num_val=ALL) const
Same as queryktharr() but searches for last occurrence of name.
Definition: AMReX_ParmParse.cpp:1376
void getarr(const char *name, std::vector< int > &ref, int start_ix=FIRST, int num_val=ALL) const
Same as getktharr() but searches for last occurrence of name.
Definition: AMReX_ParmParse.cpp:1362
Definition: AMReX_RKIntegrator.H:22
virtual ~RKIntegrator()
Definition: AMReX_RKIntegrator.H:190
void time_interpolate(const T &, const T &S_old, amrex::Real timestep_fraction, T &data) override
Definition: AMReX_RKIntegrator.H:278
amrex::Vector< amrex::Real > extended_weights
Definition: AMReX_RKIntegrator.H:42
ButcherTableauTypes tableau_type
Definition: AMReX_RKIntegrator.H:27
amrex::Vector< amrex::Real > nodes
Definition: AMReX_RKIntegrator.H:39
amrex::Vector< std::unique_ptr< T > > F_nodes
Definition: AMReX_RKIntegrator.H:45
RKIntegrator(const T &S_data, const amrex::Real time=0.0)
Definition: AMReX_RKIntegrator.H:179
amrex::Vector< std::unique_ptr< T > > S_current
Definition: AMReX_RKIntegrator.H:48
amrex::Vector< amrex::Real > weights
Definition: AMReX_RKIntegrator.H:36
void initialize_parameters()
Definition: AMReX_RKIntegrator.H:95
void map_data(std::function< void(T &)> Map) override
Definition: AMReX_RKIntegrator.H:315
void initialize_preset_tableau()
Definition: AMReX_RKIntegrator.H:51
void initialize_stages(const T &S_data, const amrex::Real time)
Definition: AMReX_RKIntegrator.H:160
amrex::Real advance(T &S_old, T &S_new, amrex::Real time, const amrex::Real dt) override
Take a single time step from (time, S_old) to (time + dt, S_new) with the given step size.
Definition: AMReX_RKIntegrator.H:192
amrex::Vector< amrex::Vector< amrex::Real > > tableau
Definition: AMReX_RKIntegrator.H:33
int number_nodes
Definition: AMReX_RKIntegrator.H:30
void evolve(T &S_out, const amrex::Real time_out) override
Evolve the current (internal) integrator state to time_out.
Definition: AMReX_RKIntegrator.H:246
void initialize(const T &S_data, const amrex::Real time=0.0)
Definition: AMReX_RKIntegrator.H:184
amrex::Real time_current
Definition: AMReX_RKIntegrator.H:49
RKIntegrator()
Definition: AMReX_RKIntegrator.H:177
This class is a thin wrapper around std::vector. Unlike vector, Vector::operator[] provides bound che...
Definition: AMReX_Vector.H:27
Long size() const noexcept
Definition: AMReX_Vector.H:50
Definition: AMReX_Amr.cpp:49
void Saxpy(MF &dst, typename MF::value_type a, MF const &src, int scomp, int dcomp, int ncomp, IntVect const &nghost)
dst += a * src
Definition: AMReX_FabArrayUtility.H:1646
void Copy(FabArray< DFAB > &dst, FabArray< SFAB > const &src, int srccomp, int dstcomp, int numcomp, int nghost)
Definition: AMReX_FabArray.H:179
AMREX_GPU_HOST_DEVICE AMREX_FORCE_INLINE std::enable_if_t< std::is_floating_point_v< T >, bool > almostEqual(T x, T y, int ulp=2)
Definition: AMReX_Algorithm.H:93
void Error(const std::string &msg)
Print out message to cerr and exit via amrex::Abort().
Definition: AMReX.cpp:219
AMREX_GPU_HOST_DEVICE AMREX_FORCE_INLINE GpuComplex< T > pow(const GpuComplex< T > &a_z, const T &a_y) noexcept
Raise a complex number to a (real) power.
Definition: AMReX_GpuComplex.H:418
ButcherTableauTypes
Definition: AMReX_RKIntegrator.H:11
Definition: AMReX_IntegratorBase.H:17