Block-Structured AMR Software Framework
Loading...
Searching...
No Matches
AMReX_SundialsIntegrator.H
Go to the documentation of this file.
1#ifndef AMREX_SUNDIALS_INTEGRATOR_H
2#define AMREX_SUNDIALS_INTEGRATOR_H
3
4#include <functional>
5
6#include <AMReX_Config.H>
7#include <AMReX_REAL.H>
8#include <AMReX_Vector.H>
9#include <AMReX_ParmParse.H>
12#include <AMReX_Sundials.H>
13
14#include <nvector/nvector_manyvector.h>
15#include <sunnonlinsol/sunnonlinsol_fixedpoint.h>
16#include <sunlinsol/sunlinsol_spgmr.h>
17#include <arkode/arkode_arkstep.h>
18#include <arkode/arkode_mristep.h>
19
27namespace amrex {
28
36 std::function<int(amrex::Real, N_Vector, N_Vector, void*)> f;
37 std::function<int(amrex::Real, N_Vector, N_Vector, void*)> fi;
38 std::function<int(amrex::Real, N_Vector, N_Vector, void*)> fe;
39 std::function<int(amrex::Real, N_Vector, N_Vector, void*)> ff;
40 std::function<int(amrex::Real, N_Vector, void*)> post_stage;
41 std::function<int(amrex::Real, N_Vector, void*)> post_step;
42 std::function<int(amrex::Real, N_Vector, void*)> post_fast_stage;
43 std::function<int(amrex::Real, N_Vector, void*)> post_fast_step;
44};
45
46namespace SundialsUserFun {
47 static int f (amrex::Real t, N_Vector y_data, N_Vector y_rhs, void *user_data) {
48 SundialsUserData* udata = static_cast<SundialsUserData*>(user_data);
49 return udata->f(t, y_data, y_rhs, user_data);
50 }
51
52 static int fi (amrex::Real t, N_Vector y_data, N_Vector y_rhs, void *user_data) {
53 SundialsUserData* udata = static_cast<SundialsUserData*>(user_data);
54 return udata->fi(t, y_data, y_rhs, user_data);
55 }
56
57 static int fe (amrex::Real t, N_Vector y_data, N_Vector y_rhs, void *user_data) {
58 SundialsUserData* udata = static_cast<SundialsUserData*>(user_data);
59 return udata->fe(t, y_data, y_rhs, user_data);
60 }
61
62 static int ff (amrex::Real t, N_Vector y_data, N_Vector y_rhs, void *user_data) {
63 SundialsUserData* udata = static_cast<SundialsUserData*>(user_data);
64 return udata->ff(t, y_data, y_rhs, user_data);
65 }
66
67 static int post_stage (amrex::Real t, N_Vector y_data, void *user_data) {
68 SundialsUserData* udata = static_cast<SundialsUserData*>(user_data);
69 return udata->post_stage(t, y_data, user_data);
70 }
71
72 static int post_step (amrex::Real t, N_Vector y_data, void *user_data) {
73 SundialsUserData* udata = static_cast<SundialsUserData*>(user_data);
74 return udata->post_step(t, y_data, user_data);
75 }
76
77 static int post_fast_stage (amrex::Real t, N_Vector y_data, void *user_data) {
78 SundialsUserData* udata = static_cast<SundialsUserData*>(user_data);
79 return udata->post_fast_stage(t, y_data, user_data);
80 }
81
82 static int post_fast_step (amrex::Real t, N_Vector y_data, void *user_data) {
83 SundialsUserData* udata = static_cast<SundialsUserData*>(user_data);
84 return udata->post_fast_step(t, y_data, user_data);
85 }
86}
87
93template<class T>
95{
96private:
98
99 // Method type: ERK, DIRK, IMEX-RK, EX-MRI, IM-MRI, IMEX-MRI
100 std::string type = "ERK";
101
102 // Use SUNDIALS default methods
103 std::string method = "DEFAULT"; // ERK, DIRK, or slow method with MRI
104 std::string method_e = "DEFAULT"; // ERK in IMEX-RK
105 std::string method_i = "DEFAULT"; // DIRK in IMEX-RK
106
107 // Fast method type (ERK or DIRK) and method
108 std::string fast_type = "ERK";
109 std::string fast_method = "DEFAULT";
110
111 // Nonlinear solver
112 std::string nonlinear_solver = "Newton";
113 int max_nonlinear_iters = 0;
114
115 std::string fast_nonlinear_solver = "Newton";
116 int fast_max_nonlinear_iters = 0;
117
118 // Linear solver
119 std::string linear_solver = "GMRES";
120 int max_linear_iters = 0;
121
122 std::string fast_linear_solver = "GMRES";
123 int fast_max_linear_iters = 0;
124
125 // SUNDIALS package flags, set based on type
126 bool use_ark = false;
127 bool use_mri = false;
128
129 // structure for interfacing with user-supplied functions
130 SundialsUserData udata;
131
132 // SUNDIALS context
133 //
134 // We should probably use context created by amrex:sundials::Initialize but
135 // that context is not MPI-aware
136 ::sundials::Context sunctx;
137
138 // Single rate or slow time scale
139 void *arkode_mem = nullptr;
140 SUNLinearSolver LS = nullptr;
141 SUNNonlinearSolver NLS = nullptr;
142
143 // Fast time scale
144 void *arkode_fast_mem = nullptr;
145 MRIStepInnerStepper fast_stepper = nullptr;
146 SUNLinearSolver fast_LS = nullptr;
147 SUNNonlinearSolver fast_NLS = nullptr;
148
149 // Integrator stop time
150 bool set_stop_time = false;
151 amrex::Real stop_time = 0.0;
152
153 // Max steps between returns
154 amrex::Long max_num_steps = 0;
155
156 void initialize_parameters ()
157 {
158 amrex::ParmParse pp("integration.sundials");
159
160 pp.query("type", type);
161 pp.query("method", method);
162 pp.query("method_e", method_e);
163 pp.query("method_i", method_i);
164
165 pp.query("fast_type", fast_type);
166 pp.query("fast_method", fast_method);
167
168 if (type == "ERK" || type == "DIRK" || type == "IMEX-RK") {
169 use_ark = true;
170 }
171 else if (type == "EX-MRI" || type == "IM-MRI" || type == "IMEX-MRI") {
172 use_mri = true;
173 }
174 else {
175 std::string msg("Unknown method type: ");
176 msg += type;
177 amrex::Error(msg.c_str());
178 }
179
180 pp.query("nonlinear_solver", nonlinear_solver);
181 pp.query("max_nonlinear_iters", max_nonlinear_iters);
182
183 pp.query("fast_nonlinear_solver", fast_nonlinear_solver);
184 pp.query("fast_max_nonlinear_iters", fast_max_nonlinear_iters);
185
186 pp.query("linear_solver", linear_solver);
187 pp.query("max_linear_iters", max_linear_iters);
188
189 pp.query("fast_linear_solver", fast_linear_solver);
190 pp.query("fast_max_linear_iters", fast_max_linear_iters);
191
192 set_stop_time = pp.query("stop_time", stop_time);
193
194 pp.query("max_num_steps", max_num_steps);
195 }
196
197 void SetupRK (amrex::Real time, N_Vector y_data)
198 {
199 if (amrex::Verbose()) { amrex::Print() << "Using SUNDIALS time integrator\n"; }
200 int flag = 0;
201
202 // Create integrator and select method
203 if (type == "ERK") {
204 if (amrex::Verbose()) { amrex::Print() << "ERK method: " << method << "\n"; }
205 arkode_mem = ARKStepCreate(SundialsUserFun::f, nullptr, time, y_data, sunctx);
206 AMREX_ALWAYS_ASSERT(arkode_mem != nullptr);
207 if (method != "DEFAULT") {
208 flag = ARKStepSetTableName(arkode_mem, "ARKODE_DIRK_NONE", method.c_str());
209 AMREX_ALWAYS_ASSERT(flag == 0);
210 }
211 }
212 else if (type == "DIRK") {
213 if (amrex::Verbose()) { amrex::Print() << "DIRK method: " << method << "\n"; }
214 arkode_mem = ARKStepCreate(nullptr, SundialsUserFun::f, time, y_data, sunctx);
215 AMREX_ALWAYS_ASSERT(arkode_mem != nullptr);
216 if (method != "DEFAULT") {
217 flag = ARKStepSetTableName(arkode_mem, method.c_str(), "ARKODE_ERK_NONE");
218 AMREX_ALWAYS_ASSERT(flag == 0);
219 }
220 }
221 else if (type == "IMEX-RK") {
222 if (amrex::Verbose()) { amrex::Print() << "IMEX-RK method: " << method_i << " and "
223 << method_e << "\n"; }
224 arkode_mem = ARKStepCreate(SundialsUserFun::fe, SundialsUserFun::fi, time, y_data, sunctx);
225 AMREX_ALWAYS_ASSERT(arkode_mem != 0);
226 if (method_e != "DEFAULT" && method_i != "DEFAULT")
227 {
228 flag = ARKStepSetTableName(arkode_mem, method_i.c_str(), method_e.c_str());
229 AMREX_ALWAYS_ASSERT(flag == 0);
230 }
231 }
232
233 // Attach structure with user-supplied function wrappers
234 flag = ARKStepSetUserData(arkode_mem, &udata);
235 AMREX_ALWAYS_ASSERT(flag == 0);
236
237 // Set integrator tolerances
238 if (BaseT::use_adaptive_time_step || type == "DIRK" || type == "IMEX-RK") {
239 if (amrex::Verbose()) {
240 amrex::Print() << "Relative tolerance: " << BaseT::rel_tol << "\n";
241 amrex::Print() << "Absolute tolerance: " << BaseT::abs_tol << "\n";
242 }
243 flag = ARKStepSStolerances(arkode_mem, BaseT::rel_tol, BaseT::abs_tol);
244 AMREX_ALWAYS_ASSERT(flag == 0);
245 }
246
247 // Create and attach linear solver for implicit methods
248 if (type == "DIRK" || type == "IMEX-RK") {
249 if (amrex::Verbose()) {
250 amrex::Print() << "Nonlinear solver: " << nonlinear_solver << "\n";
251 amrex::Print() << "Max nonlinear iters: " << max_nonlinear_iters << "\n";
252 }
253 if (nonlinear_solver == "fixed-point") {
254 NLS = SUNNonlinSol_FixedPoint(y_data, 0, sunctx);
255 AMREX_ALWAYS_ASSERT(NLS != nullptr);
256 flag = ARKStepSetNonlinearSolver(arkode_mem, NLS);
257 AMREX_ALWAYS_ASSERT(flag == 0);
258 }
259 flag = ARKStepSetMaxNonlinIters(arkode_mem, max_nonlinear_iters);
260 AMREX_ALWAYS_ASSERT(flag == 0);
261
262 if (nonlinear_solver == "Newton") {
263 if (amrex::Verbose()) {
264 amrex::Print() << "Linear solver: " << linear_solver << "\n";
265 amrex::Print() << "Max linear iters: " << max_linear_iters << "\n";
266 }
267 LS = SUNLinSol_SPGMR(y_data, SUN_PREC_NONE, max_linear_iters, sunctx);
268 AMREX_ALWAYS_ASSERT(LS != nullptr);
269 flag = ARKStepSetLinearSolver(arkode_mem, LS, nullptr);
270 AMREX_ALWAYS_ASSERT(flag == 0);
271 }
272 }
273
274 // Set post stage and step function
275 flag = ARKStepSetPostprocessStageFn(arkode_mem, SundialsUserFun::post_stage);
276 AMREX_ALWAYS_ASSERT(flag == 0);
277 flag = ARKStepSetPostprocessStepFn(arkode_mem, SundialsUserFun::post_step);
278 AMREX_ALWAYS_ASSERT(flag == 0);
279
280 // Set a stop time
281 if (set_stop_time) {
282 if (amrex::Verbose()) { amrex::Print() << "Stop time: " << stop_time << "\n"; }
283 flag = ARKStepSetStopTime(arkode_mem, stop_time);
284 AMREX_ALWAYS_ASSERT(flag == 0);
285 }
286
287 // Set max number of steps between returns
288 flag = ARKStepSetMaxNumSteps(arkode_mem, max_num_steps);
289 AMREX_ALWAYS_ASSERT(flag == 0);
290 }
291
292 void SetupMRI (amrex::Real time, N_Vector y_data)
293 {
294 if (amrex::Verbose()) { amrex::Print() << "Using SUNDIALS multirate time integrator\n"; }
295 int flag = 0;
296
297 // Create the fast integrator and select method
298 if (fast_type == "ERK") {
299 if (amrex::Verbose()) { amrex::Print() << "Fast ERK method: " << fast_method << "\n"; }
300 arkode_fast_mem = ARKStepCreate(SundialsUserFun::ff, nullptr, time, y_data, sunctx);
301 AMREX_ALWAYS_ASSERT(arkode_fast_mem != nullptr);
302 if (fast_method != "DEFAULT") {
303 flag = ARKStepSetTableName(arkode_fast_mem, "ARKODE_DIRK_NONE", fast_method.c_str());
304 AMREX_ALWAYS_ASSERT(flag == 0);
305 }
306 }
307 else if (fast_type == "DIRK") {
308 if (amrex::Verbose()) { amrex::Print() << "Fast DIRK method: " << fast_method << "\n"; }
309 arkode_fast_mem = ARKStepCreate(nullptr, SundialsUserFun::ff, time, y_data, sunctx);
310 AMREX_ALWAYS_ASSERT(arkode_fast_mem != nullptr);
311 if (fast_method != "DEFAULT") {
312 flag = ARKStepSetTableName(arkode_fast_mem, fast_method.c_str(), "ARKODE_ERK_NONE");
313 AMREX_ALWAYS_ASSERT(flag == 0);
314 }
315
316 if (amrex::Verbose()) {
317 amrex::Print() << "Fast nonlinear solver: " << fast_nonlinear_solver << "\n";
318 amrex::Print() << "Fast max nonlinear iters: " << fast_max_nonlinear_iters << "\n";
319 }
320 if (fast_nonlinear_solver == "fixed-point") {
321 fast_NLS = SUNNonlinSol_FixedPoint(y_data, 0, sunctx);
322 AMREX_ALWAYS_ASSERT(fast_NLS != nullptr);
323 flag = ARKStepSetNonlinearSolver(arkode_fast_mem, fast_NLS);
324 AMREX_ALWAYS_ASSERT(flag == 0);
325 }
326 flag = ARKStepSetMaxNonlinIters(arkode_fast_mem, fast_max_nonlinear_iters);
327 AMREX_ALWAYS_ASSERT(flag == 0);
328
329 if (fast_nonlinear_solver == "Newton") {
330 if (amrex::Verbose()) {
331 amrex::Print() << "Linear solver: " << fast_linear_solver << "\n";
332 amrex::Print() << "Max linear iters: " << fast_max_linear_iters << "\n";
333 }
334 fast_LS = SUNLinSol_SPGMR(y_data, SUN_PREC_NONE, fast_max_linear_iters, sunctx);
335 AMREX_ALWAYS_ASSERT(fast_LS != nullptr);
336 flag = ARKStepSetLinearSolver(arkode_fast_mem, fast_LS, nullptr);
337 AMREX_ALWAYS_ASSERT(flag == 0);
338 }
339 }
340
341 // Attach structure with user-supplied function wrappers
342 flag = ARKStepSetUserData(arkode_fast_mem, &udata);
343 AMREX_ALWAYS_ASSERT(flag == 0);
344
345 // Set integrator tolerances
346 if (BaseT::use_adaptive_fast_time_step || fast_type == "DIRK" || fast_type == "IMEX-RK") {
347 if (amrex::Verbose()) {
348 amrex::Print() << "Fast relative tolerance: " << BaseT::fast_rel_tol << "\n";
349 amrex::Print() << "Fast absolute tolerance: " << BaseT::fast_abs_tol << "\n";
350 }
351 flag = ARKStepSStolerances(arkode_fast_mem, BaseT::fast_rel_tol, BaseT::fast_abs_tol);
352 AMREX_ALWAYS_ASSERT(flag == 0);
353 }
354
355 // Set post stage and step function
356 flag = ARKStepSetPostprocessStageFn(arkode_fast_mem, SundialsUserFun::post_fast_stage);
357 AMREX_ALWAYS_ASSERT(flag == 0);
358 flag = ARKStepSetPostprocessStepFn(arkode_fast_mem, SundialsUserFun::post_fast_step);
359 AMREX_ALWAYS_ASSERT(flag == 0);
360
361 // Set max number of steps between returns
362 flag = ARKStepSetMaxNumSteps(arkode_fast_mem, max_num_steps);
363 AMREX_ALWAYS_ASSERT(flag == 0);
364
365 // Wrap fast integrator as an inner stepper
366 flag = ARKStepCreateMRIStepInnerStepper(arkode_fast_mem, &fast_stepper);
367 AMREX_ALWAYS_ASSERT(flag == 0);
368
369 // Create slow integrator
370 if (type == "EX-MRI") {
371 if (amrex::Verbose()) { amrex::Print() << "EX-MRI method: " << method << "\n"; }
372 arkode_mem = MRIStepCreate(SundialsUserFun::f, nullptr, time, y_data,
373 fast_stepper, sunctx);
374 AMREX_ALWAYS_ASSERT(arkode_mem != nullptr);
375 }
376 else if (type == "IM-MRI") {
377 if (amrex::Verbose()) { amrex::Print() << "IM-MRI method: " << method << "\n"; }
378 arkode_mem = MRIStepCreate(nullptr, SundialsUserFun::f, time, y_data,
379 fast_stepper, sunctx);
380 AMREX_ALWAYS_ASSERT(arkode_mem != nullptr);
381 }
382 else if (type == "IMEX-MRI") {
383 if (amrex::Verbose()) { amrex::Print() << "IMEX-MRI method: " << method << "\n"; }
384 arkode_mem = MRIStepCreate(SundialsUserFun::fe, SundialsUserFun::fi,
385 time, y_data, fast_stepper, sunctx);
386 AMREX_ALWAYS_ASSERT(arkode_mem != nullptr);
387 }
388
389 // Set method
390 if (method != "DEFAULT") {
391 MRIStepCoupling MRIC = MRIStepCoupling_LoadTableByName(method.c_str());
392 AMREX_ALWAYS_ASSERT(MRIC != nullptr);
393 flag = MRIStepSetCoupling(arkode_mem, MRIC);
394 AMREX_ALWAYS_ASSERT(flag == 0);
395 MRIStepCoupling_Free(MRIC);
396 }
397
398 // Attach structure with user-supplied function wrappers
399 flag = MRIStepSetUserData(arkode_mem, &udata);
400 AMREX_ALWAYS_ASSERT(flag == 0);
401
402 // Set integrator tolerances
403 if (BaseT::use_adaptive_time_step || fast_type == "IM-MRI" || fast_type == "IMEX-MRI") {
404 if (amrex::Verbose()) {
405 amrex::Print() << "Relative tolerance: " << BaseT::rel_tol << "\n";
406 amrex::Print() << "Absolute tolerance: " << BaseT::abs_tol << "\n";
407 }
408 flag = MRIStepSStolerances(arkode_mem, BaseT::rel_tol, BaseT::abs_tol);
409 AMREX_ALWAYS_ASSERT(flag == 0);
410 }
411
412 // Create and attach linear solver
413 if (type == "IM-MRI" || type == "IMEX-MRI") {
414 if (amrex::Verbose()) {
415 amrex::Print() << "Nonlinear solver: " << nonlinear_solver << "\n";
416 amrex::Print() << "Max nonlinear iters: " << max_nonlinear_iters << "\n";
417 }
418 if (nonlinear_solver == "fixed-point") {
419 NLS = SUNNonlinSol_FixedPoint(y_data, 0, sunctx);
420 AMREX_ALWAYS_ASSERT(NLS != nullptr);
421 flag = MRIStepSetNonlinearSolver(arkode_mem, NLS);
422 AMREX_ALWAYS_ASSERT(flag == 0);
423 }
424 flag = MRIStepSetMaxNonlinIters(arkode_mem, max_nonlinear_iters);
425 AMREX_ALWAYS_ASSERT(flag == 0);
426
427 if (nonlinear_solver == "Newton") {
428 if (amrex::Verbose()) {
429 amrex::Print() << "Linear solver: " << linear_solver << "\n";
430 amrex::Print() << "Max linear iters: " << max_linear_iters << "\n";
431 }
432 LS = SUNLinSol_SPGMR(y_data, SUN_PREC_NONE, max_linear_iters, sunctx);
433 AMREX_ALWAYS_ASSERT(LS != nullptr);
434 flag = MRIStepSetLinearSolver(arkode_mem, LS, nullptr);
435 AMREX_ALWAYS_ASSERT(flag == 0);
436 }
437 }
438
439 // Set post stage and step function
440 flag = MRIStepSetPostprocessStageFn(arkode_mem, SundialsUserFun::post_stage);
441 AMREX_ALWAYS_ASSERT(flag == 0);
442 flag = MRIStepSetPostprocessStepFn(arkode_mem, SundialsUserFun::post_step);
443 AMREX_ALWAYS_ASSERT(flag == 0);
444
445 // Set a stop time
446 if (set_stop_time) {
447 if (amrex::Verbose()) { amrex::Print() << "Stop time: " << stop_time << "\n"; }
448 flag = MRIStepSetStopTime(arkode_mem, stop_time);
449 AMREX_ALWAYS_ASSERT(flag == 0);
450 }
451
452 // Set max number of steps between returns
453 flag = MRIStepSetMaxNumSteps(arkode_mem, max_num_steps);
454 AMREX_ALWAYS_ASSERT(flag == 0);
455 }
456
457 // -------------------------------------
458 // Vector<MultiFab> / N_Vector Utilities
459 // -------------------------------------
460
461 // Utility to unpack a SUNDIALS ManyVector into a vector of MultiFabs
462 void unpack_vector (N_Vector y_data, amrex::Vector<amrex::MultiFab>& S_data)
463 {
464 const int num_vecs = N_VGetNumSubvectors_ManyVector(y_data);
465 S_data.resize(num_vecs);
466
467 for(int i = 0; i < num_vecs; i++)
468 {
469 S_data.at(i) = amrex::MultiFab(*amrex::sundials::getMFptr(N_VGetSubvector_ManyVector(y_data, i)),
471 0,
472 amrex::sundials::getMFptr(N_VGetSubvector_ManyVector(y_data, i))->nComp());
473 }
474 };
475
476 // Utility to wrap vector of MultiFabs as a SUNDIALS ManyVector
477 N_Vector wrap_data (amrex::Vector<amrex::MultiFab>& S_data)
478 {
479 auto get_length = [&](int index) -> sunindextype {
480 auto* p_mf = &S_data[index];
481 return p_mf->nComp() * (p_mf->boxArray()).numPts();
482 };
483
484 sunindextype NV_len = S_data.size();
485 N_Vector* NV_array = new N_Vector[NV_len];
486
487 for (int i = 0; i < NV_len; ++i) {
488 NV_array[i] = amrex::sundials::N_VMake_MultiFab(get_length(i),
489 &S_data[i], &sunctx);
490 }
491
492 N_Vector y_data = N_VNew_ManyVector(NV_len, NV_array, sunctx);
493
494 delete[] NV_array;
495
496 return y_data;
497 };
498
499 // Utility to wrap vector of MultiFabs as a SUNDIALS ManyVector
500 N_Vector copy_data (const amrex::Vector<amrex::MultiFab>& S_data)
501 {
502 auto get_length = [&](int index) -> sunindextype {
503 auto* p_mf = &S_data[index];
504 return p_mf->nComp() * (p_mf->boxArray()).numPts();
505 };
506
507 sunindextype NV_len = S_data.size();
508 N_Vector* NV_array = new N_Vector[NV_len];
509
510 for (int i = 0; i < NV_len; ++i) {
511 NV_array[i] = amrex::sundials::N_VNew_MultiFab(get_length(i),
512 S_data[i].boxArray(),
513 S_data[i].DistributionMap(),
514 S_data[i].nComp(),
515 S_data[i].nGrow(),
516 &sunctx);
517
519 S_data[i],
520 0,
521 0,
522 S_data[i].nComp(),
523 S_data[i].nGrow());
524 }
525
526 N_Vector y_data = N_VNew_ManyVector(NV_len, NV_array, sunctx);
527
528 delete[] NV_array;
529
530 return y_data;
531 };
532
533 // -----------------------------
534 // MultiFab / N_Vector Utilities
535 // -----------------------------
536
537 // Utility to unpack a SUNDIALS Vector into a MultiFab
538 void unpack_vector (N_Vector y_data, amrex::MultiFab& S_data)
539 {
542 0,
544 };
545
546 // Utility to wrap a MultiFab as a SUNDIALS Vector
547 N_Vector wrap_data (amrex::MultiFab& S_data)
548 {
549 return amrex::sundials::N_VMake_MultiFab(S_data.nComp() * S_data.boxArray().numPts(),
550 &S_data, &sunctx);
551 };
552
553 // Utility to wrap a MultiFab as a SUNDIALS Vector
554 N_Vector copy_data (const amrex::MultiFab& S_data)
555 {
556 N_Vector y_data = amrex::sundials::N_VNew_MultiFab(S_data.nComp() * S_data.boxArray().numPts(),
557 S_data.boxArray(),
558 S_data.DistributionMap(),
559 S_data.nComp(),
560 S_data.nGrow(),
561 &sunctx);
562
564 S_data,
565 0,
566 0,
567 S_data.nComp(),
568 S_data.nGrow());
569
570 return y_data;
571 };
572
573public:
578
585 SundialsIntegrator (const T& S_data, const amrex::Real time = 0.0)
586 {
587 initialize(S_data, time);
588 }
589
596 void initialize (const T& S_data, const amrex::Real time = 0.0)
597 {
598 initialize_parameters();
600#if defined(SUNDIALS_VERSION_MAJOR) && (SUNDIALS_VERSION_MAJOR < 7)
601# ifdef AMREX_USE_MPI
602 sunctx = ::sundials::Context(&mpi_comm);
603# else
604 sunctx = ::sundials::Context(nullptr);
605# endif
606#else
607# ifdef AMREX_USE_MPI
608 sunctx = ::sundials::Context(mpi_comm);
609# else
610 sunctx = ::sundials::Context(SUN_COMM_NULL);
611# endif
612#endif
613
614 // Right-hand side function wrappers
615 udata.f = [&](amrex::Real rhs_time, N_Vector y_data, N_Vector y_rhs,
616 void * /* user_data */) -> int {
617
618 T S_data;
619 unpack_vector(y_data, S_data);
620
621 T S_rhs;
622 unpack_vector(y_rhs, S_rhs);
623
624 BaseT::Rhs(S_rhs, S_data, rhs_time);
625
626 return 0;
627 };
628
629 udata.fi = [&](amrex::Real rhs_time, N_Vector y_data, N_Vector y_rhs,
630 void * /* user_data */) -> int {
631
632 T S_data;
633 unpack_vector(y_data, S_data);
634
635 T S_rhs;
636 unpack_vector(y_rhs, S_rhs);
637
638 BaseT::RhsIm(S_rhs, S_data, rhs_time);
639
640 return 0;
641 };
642
643 udata.fe = [&](amrex::Real rhs_time, N_Vector y_data, N_Vector y_rhs,
644 void * /* user_data */) -> int {
645
646 T S_data;
647 unpack_vector(y_data, S_data);
648
649 T S_rhs;
650 unpack_vector(y_rhs, S_rhs);
651
652 BaseT::RhsEx(S_rhs, S_data, rhs_time);
653
654 return 0;
655 };
656
657 udata.ff = [&](amrex::Real rhs_time, N_Vector y_data, N_Vector y_rhs,
658 void * /* user_data */) -> int {
659
660 T S_data;
661 unpack_vector(y_data, S_data);
662
663 T S_rhs;
664 unpack_vector(y_rhs, S_rhs);
665
666 BaseT::RhsFast(S_rhs, S_data, rhs_time);
667
668 return 0;
669 };
670
671 udata.post_stage = [&](amrex::Real time, N_Vector y_data,
672 void * /* user_data */) -> int {
673
674 T S_data;
675 unpack_vector(y_data, S_data);
676
677 BaseT::post_stage_action(S_data, time);
678
679 return 0;
680 };
681
682 udata.post_step = [&](amrex::Real time, N_Vector y_data,
683 void * /* user_data */) -> int {
684
685 T S_data;
686 unpack_vector(y_data, S_data);
687
688 BaseT::post_step_action(S_data, time);
689
690 return 0;
691 };
692
693 udata.post_fast_stage = [&](amrex::Real time, N_Vector y_data,
694 void * /* user_data */) -> int {
695
696 T S_data;
697 unpack_vector(y_data, S_data);
698
699 BaseT::post_fast_stage_action(S_data, time);
700
701 return 0;
702 };
703
704 udata.post_fast_step = [&](amrex::Real time, N_Vector y_data,
705 void * /* user_data */) -> int {
706
707 T S_data;
708 unpack_vector(y_data, S_data);
709
710 BaseT::post_fast_step_action(S_data, time);
711
712 return 0;
713 };
714
715 N_Vector y_data = copy_data(S_data); // ideally just wrap and ignore const
716
717 if (use_ark) {
718 SetupRK(time, y_data);
719 }
720 else if (use_mri)
721 {
722 SetupMRI(time, y_data);
723 }
724
725 N_VDestroy(y_data);
726 }
727
732 // Print integrator statistics
733 if (amrex::Verbose()) {
734 if (type == "EX-MRI" || type == "IM-MRI" || type == "IMEX-MRI") {
735 amrex::Print() << "Slow Time Integrator Stats\n";
737 MRIStepPrintAllStats(arkode_mem, stdout, SUN_OUTPUTFORMAT_TABLE);
738 }
739 amrex::Print() << "Fast Time Integrator Stats\n";
741 ARKStepPrintAllStats(arkode_fast_mem, stdout, SUN_OUTPUTFORMAT_TABLE);
742 }
743 } else {
744 amrex::Print() << "Time Integrator Stats\n";
746 ARKStepPrintAllStats(arkode_mem, stdout, SUN_OUTPUTFORMAT_TABLE);
747 }
748 }
749 }
750
751 // Clean up allocated memory
752 SUNLinSolFree(LS);
753 SUNLinSolFree(fast_LS);
754 SUNNonlinSolFree(NLS);
755 SUNNonlinSolFree(fast_NLS);
756 MRIStepInnerStepper_Free(&fast_stepper);
757 MRIStepFree(&arkode_fast_mem);
758 ARKStepFree(&arkode_mem);
759 }
760
770 amrex::Real advance (T& S_old, T& S_new, amrex::Real time, const amrex::Real dt) override
771 {
772 amrex::Real tout = time + dt;
773 amrex::Real tret;
774
775 N_Vector y_old = wrap_data(S_old);
776 N_Vector y_new = wrap_data(S_new);
777
778 if (use_ark) {
779 ARKStepReset(arkode_mem, time, y_old); // should probably resize
780 ARKStepSetFixedStep(arkode_mem, dt);
781 int flag = ARKStepEvolve(arkode_mem, tout, y_new, &tret, ARK_ONE_STEP);
782 AMREX_ALWAYS_ASSERT(flag >= 0);
783 }
784 else if (use_mri) {
785 MRIStepReset(arkode_mem, time, y_old); // should probably resize -- need to resize inner stepper
786 MRIStepSetFixedStep(arkode_mem, dt);
787 int flag = MRIStepEvolve(arkode_mem, tout, y_new, &tret, ARK_ONE_STEP);
788 AMREX_ALWAYS_ASSERT(flag >= 0);
789 } else {
790 Error("SUNDIALS integrator type not specified.");
791 }
792
793 N_VDestroy(y_old);
794 N_VDestroy(y_new);
795
796 return dt;
797 }
798
805 void evolve (T& S_out, const amrex::Real time_out) override
806 {
807 int flag = 0; // SUNDIALS return status
808 amrex::Real time_ret; // SUNDIALS return time
809
810 N_Vector y_out = wrap_data(S_out);
811
812 if (use_ark) {
814 ARKStepSetFixedStep(arkode_mem, BaseT::time_step);
815 }
816 flag = ARKStepEvolve(arkode_mem, time_out, y_out, &time_ret, ARK_NORMAL);
817 AMREX_ALWAYS_ASSERT(flag >= 0);
818 }
819 else if (use_mri) {
821 MRIStepSetFixedStep(arkode_mem, BaseT::time_step);
822 }
824 ARKStepSetFixedStep(arkode_fast_mem, BaseT::fast_time_step);
825 }
826 flag = MRIStepEvolve(arkode_mem, time_out, y_out, &time_ret, ARK_NORMAL);
827 AMREX_ALWAYS_ASSERT(flag >= 0);
828 } else {
829 Error("SUNDIALS integrator type not specified.");
830 }
831
832 N_VDestroy(y_out);
833 }
834
838 void time_interpolate (const T& /* S_new */, const T& /* S_old */, amrex::Real /* timestep_fraction */, T& /* data */) override {}
839
843 void map_data (std::function<void(T&)> /* Map */) override {}
844};
845
846}
847
848#endif
#define AMREX_ALWAYS_ASSERT(EX)
Definition AMReX_BLassert.H:50
amrex::ParmParse pp
Input file parser instance for the given namespace.
Definition AMReX_HypreIJIface.cpp:15
Long numPts() const noexcept
Returns the total number of cells contained in all boxes in the BoxArray.
Definition AMReX_BoxArray.cpp:394
int nGrow(int direction=0) const noexcept
Return the grow factor that defines the region of definition.
Definition AMReX_FabArrayBase.H:78
const DistributionMapping & DistributionMap() const noexcept
Return constant reference to associated DistributionMapping.
Definition AMReX_FabArrayBase.H:131
int nComp() const noexcept
Return number of variables (aka components) associated with each point.
Definition AMReX_FabArrayBase.H:83
const BoxArray & boxArray() const noexcept
Return a constant reference to the BoxArray that defines the valid region associated with this FabArr...
Definition AMReX_FabArrayBase.H:95
Definition AMReX_IntegratorBase.H:164
bool use_adaptive_fast_time_step
Flag to enable/disable adaptive time stepping at the fast time scale in multirate methods (bool)
Definition AMReX_IntegratorBase.H:246
amrex::Real fast_rel_tol
Relative tolerance for adaptive time stepping at the fast time scale (Real)
Definition AMReX_IntegratorBase.H:278
amrex::Real rel_tol
Relative tolerance for adaptive time stepping (Real)
Definition AMReX_IntegratorBase.H:267
std::function< void(T &, amrex::Real)> post_fast_stage_action
The post_stage_action function is called by the integrator on the computed stage just after it is com...
Definition AMReX_IntegratorBase.H:218
amrex::Real fast_abs_tol
Absolute tolerance for adaptive time stepping at the fast time scale (Real)
Definition AMReX_IntegratorBase.H:284
amrex::Real fast_time_step
Current integrator fast time scale time step size with multirate methods (Real)
Definition AMReX_IntegratorBase.H:252
std::function< void(T &rhs, T &state, const amrex::Real time)> RhsEx
RhsEx is the explicit right-hand-side function an ImEx integrator will use.
Definition AMReX_IntegratorBase.H:194
std::function< void(T &, amrex::Real)> post_step_action
The post_step_action function is called by the integrator on the computed state just after it is comp...
Definition AMReX_IntegratorBase.H:212
std::function< void(T &, amrex::Real)> post_fast_step_action
The post_step_action function is called by the integrator on the computed state just after it is comp...
Definition AMReX_IntegratorBase.H:224
std::function< void(T &rhs, T &state, const amrex::Real time)> RhsIm
RhsIm is the implicit right-hand-side function an ImEx integrator will use.
Definition AMReX_IntegratorBase.H:188
std::function< void(T &rhs, T &state, const amrex::Real time)> Rhs
Rhs is the right-hand-side function the integrator will use.
Definition AMReX_IntegratorBase.H:182
bool use_adaptive_time_step
Flag to enable/disable adaptive time stepping in single rate methods or at the slow time scale in mul...
Definition AMReX_IntegratorBase.H:230
std::function< void(T &, amrex::Real)> post_stage_action
The post_stage_action function is called by the integrator on the computed stage just after it is com...
Definition AMReX_IntegratorBase.H:206
amrex::Real time_step
Current integrator time step size (Real)
Definition AMReX_IntegratorBase.H:235
std::function< void(T &rhs, T &state, const amrex::Real time)> RhsFast
RhsFast is the fast timescale right-hand-side function a multirate integrator will use.
Definition AMReX_IntegratorBase.H:200
amrex::Real abs_tol
Absolute tolerance for adaptive time stepping (Real)
Definition AMReX_IntegratorBase.H:272
A collection (stored as an array) of FArrayBox objects.
Definition AMReX_MultiFab.H:40
static void Copy(MultiFab &dst, const MultiFab &src, int srccomp, int dstcomp, int numcomp, int nghost)
Copy from src to dst including nghost ghost cells. The two MultiFabs MUST have the same underlying Bo...
Definition AMReX_MultiFab.cpp:193
Parse Parameters From Command Line and Input Files.
Definition AMReX_ParmParse.H:348
int query(std::string_view name, bool &ref, int ival=FIRST) const
Same as querykth() but searches for the last occurrence of name.
Definition AMReX_ParmParse.cpp:1946
This class provides the user with a few print options.
Definition AMReX_Print.H:35
IntegratorBase implementation powered by SUNDIALS ARKStep/MRIStep.
Definition AMReX_SundialsIntegrator.H:95
void time_interpolate(const T &, const T &, amrex::Real, T &) override
Interpolate between SUNDIALS stages (not yet implemented for this integrator).
Definition AMReX_SundialsIntegrator.H:838
SundialsIntegrator()
Construct an uninitialized integrator; call initialize() before use.
Definition AMReX_SundialsIntegrator.H:577
void initialize(const T &S_data, const amrex::Real time=0.0)
Configure (or reconfigure) the SUNDIALS integrator for the provided state.
Definition AMReX_SundialsIntegrator.H:596
amrex::Real advance(T &S_old, T &S_new, amrex::Real time, const amrex::Real dt) override
Take a single time step of size dt starting from S_old.
Definition AMReX_SundialsIntegrator.H:770
void evolve(T &S_out, const amrex::Real time_out) override
Evolve the solution in S_out up to time_out using ARKStep/MRIStep.
Definition AMReX_SundialsIntegrator.H:805
virtual ~SundialsIntegrator()
Destroy the integrator, printing summary statistics when verbose.
Definition AMReX_SundialsIntegrator.H:731
SundialsIntegrator(const T &S_data, const amrex::Real time=0.0)
Construct and immediately configure the integrator with S_data at time time.
Definition AMReX_SundialsIntegrator.H:585
void map_data(std::function< void(T &)>) override
Apply a user-supplied mapping to every MultiFab in the integrator (unused placeholder).
Definition AMReX_SundialsIntegrator.H:843
This class is a thin wrapper around std::vector. Unlike vector, Vector::operator[] provides bound che...
Definition AMReX_Vector.H:28
Long size() const noexcept
Definition AMReX_Vector.H:53
amrex_real Real
Floating Point Type for Fields.
Definition AMReX_REAL.H:79
amrex_long Long
Definition AMReX_INT.H:30
bool IOProcessor() noexcept
Is this CPU the I/O Processor? To get the rank number, call IOProcessorNumber()
Definition AMReX_ParallelDescriptor.H:289
MPI_Comm CommunicatorSub() noexcept
sub-communicator for current frame
Definition AMReX_ParallelContext.H:70
static int fi(amrex::Real t, N_Vector y_data, N_Vector y_rhs, void *user_data)
Definition AMReX_SundialsIntegrator.H:52
static int fe(amrex::Real t, N_Vector y_data, N_Vector y_rhs, void *user_data)
Definition AMReX_SundialsIntegrator.H:57
static int post_fast_step(amrex::Real t, N_Vector y_data, void *user_data)
Definition AMReX_SundialsIntegrator.H:82
static int f(amrex::Real t, N_Vector y_data, N_Vector y_rhs, void *user_data)
Definition AMReX_SundialsIntegrator.H:47
static int post_step(amrex::Real t, N_Vector y_data, void *user_data)
Definition AMReX_SundialsIntegrator.H:72
static int ff(amrex::Real t, N_Vector y_data, N_Vector y_rhs, void *user_data)
Definition AMReX_SundialsIntegrator.H:62
static int post_stage(amrex::Real t, N_Vector y_data, void *user_data)
Definition AMReX_SundialsIntegrator.H:67
static int post_fast_stage(amrex::Real t, N_Vector y_data, void *user_data)
Definition AMReX_SundialsIntegrator.H:77
int MPI_Comm
Definition AMReX_ccse-mpi.H:51
N_Vector N_VMake_MultiFab(sunindextype length, amrex::MultiFab *v_mf, ::sundials::Context *sunctx)
Wrap an existing MultiFab mf as an N_Vector without copying.
Definition AMReX_NVector_MultiFab.cpp:105
amrex::MultiFab *& getMFptr(N_Vector v)
Access the MultiFab pointer stored inside v (non-const).
Definition AMReX_NVector_MultiFab.cpp:233
N_Vector N_VNew_MultiFab(sunindextype length, const amrex::BoxArray &ba, const amrex::DistributionMapping &dm, sunindextype nComp, sunindextype nGhost, ::sundials::Context *sunctx)
Allocate a MultiFab-backed N_Vector of length vec_length.
Definition AMReX_NVector_MultiFab.cpp:80
Definition AMReX_Amr.cpp:49
@ make_alias
Definition AMReX_MakeType.H:7
int nComp(FabArrayBase const &fa)
Definition AMReX_FabArrayBase.cpp:2851
DistributionMapping const & DistributionMap(FabArrayBase const &fa)
Definition AMReX_FabArrayBase.cpp:2866
void Error(const std::string &msg)
Print out message to cerr and exit via amrex::Abort().
Definition AMReX.cpp:234
int Verbose() noexcept
Definition AMReX.cpp:179
const int[]
Definition AMReX_BLProfiler.cpp:1664
BoxArray const & boxArray(FabArrayBase const &fa)
Definition AMReX_FabArrayBase.cpp:2861
User-supplied callbacks consumed by the AMReX/SUNDIALS bridge.
Definition AMReX_SundialsIntegrator.H:35
std::function< int(amrex::Real, N_Vector, N_Vector, void *)> fi
Implicit RHS for ImEx schemes.
Definition AMReX_SundialsIntegrator.H:37
std::function< int(amrex::Real, N_Vector, void *)> post_fast_stage
Hook for MRI fast stages.
Definition AMReX_SundialsIntegrator.H:42
std::function< int(amrex::Real, N_Vector, void *)> post_step
Hook invoked after each time step.
Definition AMReX_SundialsIntegrator.H:41
std::function< int(amrex::Real, N_Vector, N_Vector, void *)> f
ERK/DIRK RHS or MRI slow RHS.
Definition AMReX_SundialsIntegrator.H:36
std::function< int(amrex::Real, N_Vector, void *)> post_stage
Hook invoked after each stage.
Definition AMReX_SundialsIntegrator.H:40
std::function< int(amrex::Real, N_Vector, N_Vector, void *)> ff
MRI fast-scale RHS.
Definition AMReX_SundialsIntegrator.H:39
std::function< int(amrex::Real, N_Vector, void *)> post_fast_step
Hook for MRI fast steps.
Definition AMReX_SundialsIntegrator.H:43
std::function< int(amrex::Real, N_Vector, N_Vector, void *)> fe
Explicit RHS for ImEx schemes.
Definition AMReX_SundialsIntegrator.H:38