1#ifndef AMREX_SUNDIALS_INTEGRATOR_H
2#define AMREX_SUNDIALS_INTEGRATOR_H
6#include <AMReX_Config.H>
14#include <nvector/nvector_manyvector.h>
15#include <sunnonlinsol/sunnonlinsol_fixedpoint.h>
16#include <sunlinsol/sunlinsol_spgmr.h>
17#include <arkode/arkode_arkstep.h>
18#include <arkode/arkode_mristep.h>
44namespace SundialsUserFun {
45 static int f (
amrex::Real t, N_Vector y_data, N_Vector y_rhs,
void *user_data) {
47 return udata->
f(t, y_data, y_rhs, user_data);
50 static int fi (
amrex::Real t, N_Vector y_data, N_Vector y_rhs,
void *user_data) {
52 return udata->
fi(t, y_data, y_rhs, user_data);
55 static int fe (
amrex::Real t, N_Vector y_data, N_Vector y_rhs,
void *user_data) {
57 return udata->
fe(t, y_data, y_rhs, user_data);
60 static int ff (
amrex::Real t, N_Vector y_data, N_Vector y_rhs,
void *user_data) {
62 return udata->
ff(t, y_data, y_rhs, user_data);
67 return udata->
post_stage(t, y_data, user_data);
72 return udata->
post_step(t, y_data, user_data);
93 std::string type =
"ERK";
96 std::string method =
"DEFAULT";
97 std::string method_e =
"DEFAULT";
98 std::string method_i =
"DEFAULT";
101 std::string fast_type =
"ERK";
102 std::string fast_method =
"DEFAULT";
105 std::string nonlinear_solver =
"Newton";
106 int max_nonlinear_iters = 0;
108 std::string fast_nonlinear_solver =
"Newton";
109 int fast_max_nonlinear_iters = 0;
112 std::string linear_solver =
"GMRES";
113 int max_linear_iters = 0;
115 std::string fast_linear_solver =
"GMRES";
116 int fast_max_linear_iters = 0;
119 bool use_ark =
false;
120 bool use_mri =
false;
129 ::sundials::Context sunctx;
132 void *arkode_mem =
nullptr;
133 SUNLinearSolver LS =
nullptr;
134 SUNNonlinearSolver NLS =
nullptr;
137 void *arkode_fast_mem =
nullptr;
138 MRIStepInnerStepper fast_stepper =
nullptr;
139 SUNLinearSolver fast_LS =
nullptr;
140 SUNNonlinearSolver fast_NLS =
nullptr;
143 bool set_stop_time =
false;
149 void initialize_parameters ()
155 pp.
query(
"method_e", method_e);
156 pp.
query(
"method_i", method_i);
158 pp.
query(
"fast_type", fast_type);
159 pp.
query(
"fast_method", fast_method);
161 if (type ==
"ERK" || type ==
"DIRK" || type ==
"IMEX-RK") {
164 else if (type ==
"EX-MRI" || type ==
"IM-MRI" || type ==
"IMEX-MRI") {
168 std::string msg(
"Unknown method type: ");
173 pp.
query(
"nonlinear_solver", nonlinear_solver);
174 pp.
query(
"max_nonlinear_iters", max_nonlinear_iters);
176 pp.
query(
"fast_nonlinear_solver", fast_nonlinear_solver);
177 pp.
query(
"fast_max_nonlinear_iters", fast_max_nonlinear_iters);
179 pp.
query(
"linear_solver", linear_solver);
180 pp.
query(
"max_linear_iters", max_linear_iters);
182 pp.
query(
"fast_linear_solver", fast_linear_solver);
183 pp.
query(
"fast_max_linear_iters", fast_max_linear_iters);
185 set_stop_time =
pp.
query(
"stop_time", stop_time);
187 pp.
query(
"max_num_steps", max_num_steps);
200 if (method !=
"DEFAULT") {
201 flag = ARKStepSetTableName(arkode_mem,
"ARKODE_DIRK_NONE", method.c_str());
205 else if (type ==
"DIRK") {
209 if (method !=
"DEFAULT") {
210 flag = ARKStepSetTableName(arkode_mem, method.c_str(),
"ARKODE_ERK_NONE");
214 else if (type ==
"IMEX-RK") {
216 << method_e <<
"\n"; }
219 if (method_e !=
"DEFAULT" && method_i !=
"DEFAULT")
221 flag = ARKStepSetTableName(arkode_mem, method_i.c_str(), method_e.c_str());
227 flag = ARKStepSetUserData(arkode_mem, &udata);
241 if (type ==
"DIRK" || type ==
"IMEX-RK") {
243 amrex::Print() <<
"Nonlinear solver: " << nonlinear_solver <<
"\n";
244 amrex::Print() <<
"Max nonlinear iters: " << max_nonlinear_iters <<
"\n";
246 if (nonlinear_solver ==
"fixed-point") {
247 NLS = SUNNonlinSol_FixedPoint(y_data, 0, sunctx);
249 flag = ARKStepSetNonlinearSolver(arkode_mem, NLS);
252 flag = ARKStepSetMaxNonlinIters(arkode_mem, max_nonlinear_iters);
255 if (nonlinear_solver ==
"Newton") {
257 amrex::Print() <<
"Linear solver: " << linear_solver <<
"\n";
258 amrex::Print() <<
"Max linear iters: " << max_linear_iters <<
"\n";
260 LS = SUNLinSol_SPGMR(y_data, SUN_PREC_NONE, max_linear_iters, sunctx);
262 flag = ARKStepSetLinearSolver(arkode_mem, LS,
nullptr);
276 flag = ARKStepSetStopTime(arkode_mem, stop_time);
281 ARKStepSetMaxNumSteps(arkode_mem, max_num_steps);
291 if (fast_type ==
"ERK") {
295 if (fast_method !=
"DEFAULT") {
296 flag = ARKStepSetTableName(arkode_fast_mem,
"ARKODE_DIRK_NONE", fast_method.c_str());
300 else if (fast_type ==
"DIRK") {
304 if (fast_method !=
"DEFAULT") {
305 flag = ARKStepSetTableName(arkode_fast_mem, fast_method.c_str(),
"ARKODE_ERK_NONE");
310 amrex::Print() <<
"Fast nonlinear solver: " << fast_nonlinear_solver <<
"\n";
311 amrex::Print() <<
"Fast max nonlinear iters: " << fast_max_nonlinear_iters <<
"\n";
313 if (fast_nonlinear_solver ==
"fixed-point") {
314 fast_NLS = SUNNonlinSol_FixedPoint(y_data, 0, sunctx);
316 flag = ARKStepSetNonlinearSolver(arkode_mem, fast_NLS);
319 flag = ARKStepSetMaxNonlinIters(arkode_mem, fast_max_nonlinear_iters);
322 if (fast_nonlinear_solver ==
"Newton") {
324 amrex::Print() <<
"Linear solver: " << fast_linear_solver <<
"\n";
325 amrex::Print() <<
"Max linear iters: " << fast_max_linear_iters <<
"\n";
327 fast_LS = SUNLinSol_SPGMR(y_data, SUN_PREC_NONE, fast_max_linear_iters, sunctx);
329 flag = ARKStepSetLinearSolver(arkode_mem, fast_LS,
nullptr);
335 flag = ARKStepSetUserData(arkode_fast_mem, &udata);
355 ARKStepSetMaxNumSteps(arkode_fast_mem, max_num_steps);
359 flag = ARKStepCreateMRIStepInnerStepper(arkode_fast_mem, &fast_stepper);
363 if (type ==
"EX-MRI") {
366 fast_stepper, sunctx);
369 else if (type ==
"IM-MRI") {
372 fast_stepper, sunctx);
375 else if (type ==
"IMEX-MRI") {
378 time, y_data, fast_stepper, sunctx);
383 if (method !=
"DEFAULT") {
384 MRIStepCoupling MRIC = MRIStepCoupling_LoadTableByName(method.c_str());
386 flag = MRIStepSetCoupling(arkode_mem, MRIC);
388 MRIStepCoupling_Free(MRIC);
392 flag = MRIStepSetUserData(arkode_mem, &udata);
406 if (type ==
"IM-MRI" || type ==
"IMEX-MRI") {
408 amrex::Print() <<
"Nonlinear solver: " << nonlinear_solver <<
"\n";
409 amrex::Print() <<
"Max nonlinear iters: " << max_nonlinear_iters <<
"\n";
411 if (nonlinear_solver ==
"fixed-point") {
412 NLS = SUNNonlinSol_FixedPoint(y_data, 0, sunctx);
414 flag = ARKStepSetNonlinearSolver(arkode_mem, NLS);
417 flag = ARKStepSetMaxNonlinIters(arkode_mem, max_nonlinear_iters);
420 if (nonlinear_solver ==
"Newton") {
422 amrex::Print() <<
"Linear solver: " << linear_solver <<
"\n";
423 amrex::Print() <<
"Max linear iters: " << max_linear_iters <<
"\n";
425 LS = SUNLinSol_SPGMR(y_data, SUN_PREC_NONE, max_linear_iters, sunctx);
427 flag = ARKStepSetLinearSolver(arkode_mem, LS,
nullptr);
441 flag = ARKStepSetStopTime(arkode_mem, stop_time);
446 ARKStepSetMaxNumSteps(arkode_mem, max_num_steps);
457 const int num_vecs = N_VGetNumSubvectors_ManyVector(y_data);
458 S_data.resize(num_vecs);
460 for(
int i = 0; i < num_vecs; i++)
472 auto get_length = [&](
int index) -> sunindextype {
473 auto* p_mf = &S_data[index];
474 return p_mf->nComp() * (p_mf->boxArray()).numPts();
477 sunindextype NV_len = S_data.
size();
478 N_Vector* NV_array =
new N_Vector[NV_len];
480 for (
int i = 0; i < NV_len; ++i) {
482 &S_data[i], &sunctx);
485 N_Vector y_data = N_VNew_ManyVector(NV_len, NV_array, sunctx);
495 auto get_length = [&](
int index) -> sunindextype {
496 auto* p_mf = &S_data[index];
497 return p_mf->nComp() * (p_mf->boxArray()).numPts();
500 sunindextype NV_len = S_data.
size();
501 N_Vector* NV_array =
new N_Vector[NV_len];
503 for (
int i = 0; i < NV_len; ++i) {
519 N_Vector y_data = N_VNew_ManyVector(NV_len, NV_array, sunctx);
576 initialize_parameters();
578#if defined(SUNDIALS_VERSION_MAJOR) && (SUNDIALS_VERSION_MAJOR < 7)
580 sunctx = ::sundials::Context(&mpi_comm);
582 sunctx = ::sundials::Context(
nullptr);
586 sunctx = ::sundials::Context(mpi_comm);
588 sunctx = ::sundials::Context(SUN_COMM_NULL);
593 udata.
f = [&](
amrex::Real rhs_time, N_Vector y_data, N_Vector y_rhs,
597 unpack_vector(y_data, S_data);
600 unpack_vector(y_rhs, S_rhs);
607 udata.
fi = [&](
amrex::Real rhs_time, N_Vector y_data, N_Vector y_rhs,
611 unpack_vector(y_data, S_data);
614 unpack_vector(y_rhs, S_rhs);
621 udata.
fe = [&](
amrex::Real rhs_time, N_Vector y_data, N_Vector y_rhs,
625 unpack_vector(y_data, S_data);
628 unpack_vector(y_rhs, S_rhs);
635 udata.
ff = [&](
amrex::Real rhs_time, N_Vector y_data, N_Vector y_rhs,
639 unpack_vector(y_data, S_data);
642 unpack_vector(y_rhs, S_rhs);
653 unpack_vector(y_data, S_data);
664 unpack_vector(y_data, S_data);
675 unpack_vector(y_data, S_data);
686 unpack_vector(y_data, S_data);
693 N_Vector y_data = copy_data(S_data);
696 SetupRK(time, y_data);
700 SetupMRI(time, y_data);
709 if (type ==
"EX-MRI" || type ==
"IM-MRI" || type ==
"IMEX-MRI") {
712 MRIStepPrintAllStats(arkode_mem, stdout, SUN_OUTPUTFORMAT_TABLE);
716 ARKStepPrintAllStats(arkode_fast_mem, stdout, SUN_OUTPUTFORMAT_TABLE);
721 ARKStepPrintAllStats(arkode_mem, stdout, SUN_OUTPUTFORMAT_TABLE);
728 SUNLinSolFree(fast_LS);
729 SUNNonlinSolFree(NLS);
730 SUNNonlinSolFree(fast_NLS);
731 MRIStepInnerStepper_Free(&fast_stepper);
732 MRIStepFree(&arkode_fast_mem);
733 ARKStepFree(&arkode_mem);
741 N_Vector y_old = wrap_data(S_old);
742 N_Vector y_new = wrap_data(S_new);
745 ARKStepReset(arkode_mem, time, y_old);
746 ARKStepSetFixedStep(arkode_mem, dt);
747 int flag = ARKStepEvolve(arkode_mem, tout, y_new, &tret, ARK_ONE_STEP);
751 MRIStepReset(arkode_mem, time, y_old);
752 MRIStepSetFixedStep(arkode_mem, dt);
753 int flag = MRIStepEvolve(arkode_mem, tout, y_new, &tret, ARK_ONE_STEP);
756 Error(
"SUNDIALS integrator type not specified.");
770 N_Vector y_out = wrap_data(S_out);
776 flag = ARKStepEvolve(arkode_mem, time_out, y_out, &time_ret, ARK_NORMAL);
786 flag = MRIStepEvolve(arkode_mem, time_out, y_out, &time_ret, ARK_NORMAL);
789 Error(
"SUNDIALS integrator type not specified.");
797 void map_data (std::function<
void(T&)> )
override {}
#define AMREX_ALWAYS_ASSERT(EX)
Definition AMReX_BLassert.H:50
amrex::ParmParse pp
Input file parser instance for the given namespace.
Definition AMReX_HypreIJIface.cpp:15
Long numPts() const noexcept
Returns the total number of cells contained in all boxes in the BoxArray.
Definition AMReX_BoxArray.cpp:394
int nGrow(int direction=0) const noexcept
Return the grow factor that defines the region of definition.
Definition AMReX_FabArrayBase.H:78
const DistributionMapping & DistributionMap() const noexcept
Return constant reference to associated DistributionMapping.
Definition AMReX_FabArrayBase.H:131
int nComp() const noexcept
Return number of variables (aka components) associated with each point.
Definition AMReX_FabArrayBase.H:83
const BoxArray & boxArray() const noexcept
Return a constant reference to the BoxArray that defines the valid region associated with this FabArr...
Definition AMReX_FabArrayBase.H:95
Definition AMReX_IntegratorBase.H:164
bool use_adaptive_fast_time_step
Flag to enable/disable adaptive time stepping at the fast time scale in multirate methods (bool)
Definition AMReX_IntegratorBase.H:246
amrex::Real fast_rel_tol
Relative tolerance for adaptive time stepping at the fast time scale (Real)
Definition AMReX_IntegratorBase.H:278
amrex::Real rel_tol
Relative tolerance for adaptive time stepping (Real)
Definition AMReX_IntegratorBase.H:267
std::function< void(T &, amrex::Real)> post_fast_stage_action
The post_stage_action function is called by the integrator on the computed stage just after it is com...
Definition AMReX_IntegratorBase.H:218
amrex::Real fast_abs_tol
Absolute tolerance for adaptive time stepping at the fast time scale (Real)
Definition AMReX_IntegratorBase.H:284
amrex::Real fast_time_step
Current integrator fast time scale time step size with multirate methods (Real)
Definition AMReX_IntegratorBase.H:252
std::function< void(T &rhs, T &state, const amrex::Real time)> RhsEx
RhsEx is the explicit right-hand-side function an ImEx integrator will use.
Definition AMReX_IntegratorBase.H:194
std::function< void(T &, amrex::Real)> post_step_action
The post_step_action function is called by the integrator on the computed state just after it is comp...
Definition AMReX_IntegratorBase.H:212
std::function< void(T &, amrex::Real)> post_fast_step_action
The post_step_action function is called by the integrator on the computed state just after it is comp...
Definition AMReX_IntegratorBase.H:224
std::function< void(T &rhs, T &state, const amrex::Real time)> RhsIm
RhsIm is the implicit right-hand-side function an ImEx integrator will use.
Definition AMReX_IntegratorBase.H:188
std::function< void(T &rhs, T &state, const amrex::Real time)> Rhs
Rhs is the right-hand-side function the integrator will use.
Definition AMReX_IntegratorBase.H:182
bool use_adaptive_time_step
Flag to enable/disable adaptive time stepping in single rate methods or at the slow time scale in mul...
Definition AMReX_IntegratorBase.H:230
std::function< void(T &, amrex::Real)> post_stage_action
The post_stage_action function is called by the integrator on the computed stage just after it is com...
Definition AMReX_IntegratorBase.H:206
amrex::Real time_step
Current integrator time step size (Real)
Definition AMReX_IntegratorBase.H:235
std::function< void(T &rhs, T &state, const amrex::Real time)> RhsFast
RhsFast is the fast timescale right-hand-side function a multirate integrator will use.
Definition AMReX_IntegratorBase.H:200
amrex::Real abs_tol
Absolute tolerance for adaptive time stepping (Real)
Definition AMReX_IntegratorBase.H:272
A collection (stored as an array) of FArrayBox objects.
Definition AMReX_MultiFab.H:40
static void Copy(MultiFab &dst, const MultiFab &src, int srccomp, int dstcomp, int numcomp, int nghost)
Copy from src to dst including nghost ghost cells. The two MultiFabs MUST have the same underlying Bo...
Definition AMReX_MultiFab.cpp:193
Parse Parameters From Command Line and Input Files.
Definition AMReX_ParmParse.H:346
int query(std::string_view name, bool &ref, int ival=FIRST) const
Same as querykth() but searches for the last occurrence of name.
Definition AMReX_ParmParse.cpp:1450
This class provides the user with a few print options.
Definition AMReX_Print.H:35
Definition AMReX_SundialsIntegrator.H:88
void time_interpolate(const T &, const T &, amrex::Real, T &) override
Definition AMReX_SundialsIntegrator.H:795
SundialsIntegrator()
Definition AMReX_SundialsIntegrator.H:567
void initialize(const T &S_data, const amrex::Real time=0.0)
Definition AMReX_SundialsIntegrator.H:574
amrex::Real advance(T &S_old, T &S_new, amrex::Real time, const amrex::Real dt) override
Take a single time step from (time, S_old) to (time + dt, S_new) with the given step size.
Definition AMReX_SundialsIntegrator.H:736
void evolve(T &S_out, const amrex::Real time_out) override
Evolve the current (internal) integrator state to time_out.
Definition AMReX_SundialsIntegrator.H:765
virtual ~SundialsIntegrator()
Definition AMReX_SundialsIntegrator.H:706
SundialsIntegrator(const T &S_data, const amrex::Real time=0.0)
Definition AMReX_SundialsIntegrator.H:569
void map_data(std::function< void(T &)>) override
Definition AMReX_SundialsIntegrator.H:797
This class is a thin wrapper around std::vector. Unlike vector, Vector::operator[] provides bound che...
Definition AMReX_Vector.H:28
Long size() const noexcept
Definition AMReX_Vector.H:53
amrex_real Real
Floating Point Type for Fields.
Definition AMReX_REAL.H:79
amrex_long Long
Definition AMReX_INT.H:30
bool IOProcessor() noexcept
Is this CPU the I/O Processor? To get the rank number, call IOProcessorNumber()
Definition AMReX_ParallelDescriptor.H:289
MPI_Comm CommunicatorSub() noexcept
sub-communicator for current frame
Definition AMReX_ParallelContext.H:70
static int fi(amrex::Real t, N_Vector y_data, N_Vector y_rhs, void *user_data)
Definition AMReX_SundialsIntegrator.H:50
static int fe(amrex::Real t, N_Vector y_data, N_Vector y_rhs, void *user_data)
Definition AMReX_SundialsIntegrator.H:55
static int post_fast_step(amrex::Real t, N_Vector y_data, void *user_data)
Definition AMReX_SundialsIntegrator.H:80
static int f(amrex::Real t, N_Vector y_data, N_Vector y_rhs, void *user_data)
Definition AMReX_SundialsIntegrator.H:45
static int post_step(amrex::Real t, N_Vector y_data, void *user_data)
Definition AMReX_SundialsIntegrator.H:70
static int ff(amrex::Real t, N_Vector y_data, N_Vector y_rhs, void *user_data)
Definition AMReX_SundialsIntegrator.H:60
static int post_stage(amrex::Real t, N_Vector y_data, void *user_data)
Definition AMReX_SundialsIntegrator.H:65
static int post_fast_stage(amrex::Real t, N_Vector y_data, void *user_data)
Definition AMReX_SundialsIntegrator.H:75
int MPI_Comm
Definition AMReX_ccse-mpi.H:51
N_Vector N_VMake_MultiFab(sunindextype length, amrex::MultiFab *v_mf, ::sundials::Context *sunctx)
Definition AMReX_NVector_MultiFab.cpp:103
amrex::MultiFab *& getMFptr(N_Vector v)
Definition AMReX_NVector_MultiFab.cpp:228
N_Vector N_VNew_MultiFab(sunindextype length, const amrex::BoxArray &ba, const amrex::DistributionMapping &dm, sunindextype nComp, sunindextype nGhost, ::sundials::Context *sunctx)
Definition AMReX_NVector_MultiFab.cpp:78
Definition AMReX_Amr.cpp:49
@ make_alias
Definition AMReX_MakeType.H:7
int nComp(FabArrayBase const &fa)
Definition AMReX_FabArrayBase.cpp:2854
DistributionMapping const & DistributionMap(FabArrayBase const &fa)
Definition AMReX_FabArrayBase.cpp:2869
void Error(const std::string &msg)
Print out message to cerr and exit via amrex::Abort().
Definition AMReX.cpp:224
int Verbose() noexcept
Definition AMReX.cpp:169
const int[]
Definition AMReX_BLProfiler.cpp:1664
BoxArray const & boxArray(FabArrayBase const &fa)
Definition AMReX_FabArrayBase.cpp:2864
Definition AMReX_SundialsIntegrator.H:22
std::function< int(amrex::Real, N_Vector, N_Vector, void *)> fi
Definition AMReX_SundialsIntegrator.H:29
std::function< int(amrex::Real, N_Vector, void *)> post_fast_stage
Definition AMReX_SundialsIntegrator.H:40
std::function< int(amrex::Real, N_Vector, void *)> post_step
Definition AMReX_SundialsIntegrator.H:37
std::function< int(amrex::Real, N_Vector, N_Vector, void *)> f
Definition AMReX_SundialsIntegrator.H:25
std::function< int(amrex::Real, N_Vector, void *)> post_stage
Definition AMReX_SundialsIntegrator.H:36
std::function< int(amrex::Real, N_Vector, N_Vector, void *)> ff
Definition AMReX_SundialsIntegrator.H:33
std::function< int(amrex::Real, N_Vector, void *)> post_fast_step
Definition AMReX_SundialsIntegrator.H:41
std::function< int(amrex::Real, N_Vector, N_Vector, void *)> fe
Definition AMReX_SundialsIntegrator.H:30