1#ifndef AMREX_SUNDIALS_INTEGRATOR_H
2#define AMREX_SUNDIALS_INTEGRATOR_H
6#include <AMReX_Config.H>
14#include <nvector/nvector_manyvector.h>
15#include <sunnonlinsol/sunnonlinsol_fixedpoint.h>
16#include <sunlinsol/sunlinsol_spgmr.h>
17#include <arkode/arkode_arkstep.h>
18#include <arkode/arkode_mristep.h>
46namespace SundialsUserFun {
47 static int f (
amrex::Real t, N_Vector y_data, N_Vector y_rhs,
void *user_data) {
49 return udata->
f(t, y_data, y_rhs, user_data);
52 static int fi (
amrex::Real t, N_Vector y_data, N_Vector y_rhs,
void *user_data) {
54 return udata->
fi(t, y_data, y_rhs, user_data);
57 static int fe (
amrex::Real t, N_Vector y_data, N_Vector y_rhs,
void *user_data) {
59 return udata->
fe(t, y_data, y_rhs, user_data);
62 static int ff (
amrex::Real t, N_Vector y_data, N_Vector y_rhs,
void *user_data) {
64 return udata->
ff(t, y_data, y_rhs, user_data);
69 return udata->
post_stage(t, y_data, user_data);
74 return udata->
post_step(t, y_data, user_data);
100 std::string type =
"ERK";
103 std::string method =
"DEFAULT";
104 std::string method_e =
"DEFAULT";
105 std::string method_i =
"DEFAULT";
108 std::string fast_type =
"ERK";
109 std::string fast_method =
"DEFAULT";
112 std::string nonlinear_solver =
"Newton";
113 int max_nonlinear_iters = 0;
115 std::string fast_nonlinear_solver =
"Newton";
116 int fast_max_nonlinear_iters = 0;
119 std::string linear_solver =
"GMRES";
120 int max_linear_iters = 0;
122 std::string fast_linear_solver =
"GMRES";
123 int fast_max_linear_iters = 0;
126 bool use_ark =
false;
127 bool use_mri =
false;
136 ::sundials::Context sunctx;
139 void *arkode_mem =
nullptr;
140 SUNLinearSolver LS =
nullptr;
141 SUNNonlinearSolver NLS =
nullptr;
144 void *arkode_fast_mem =
nullptr;
145 MRIStepInnerStepper fast_stepper =
nullptr;
146 SUNLinearSolver fast_LS =
nullptr;
147 SUNNonlinearSolver fast_NLS =
nullptr;
150 bool set_stop_time =
false;
156 void initialize_parameters ()
162 pp.
query(
"method_e", method_e);
163 pp.
query(
"method_i", method_i);
165 pp.
query(
"fast_type", fast_type);
166 pp.
query(
"fast_method", fast_method);
168 if (type ==
"ERK" || type ==
"DIRK" || type ==
"IMEX-RK") {
171 else if (type ==
"EX-MRI" || type ==
"IM-MRI" || type ==
"IMEX-MRI") {
175 std::string msg(
"Unknown method type: ");
180 pp.
query(
"nonlinear_solver", nonlinear_solver);
181 pp.
query(
"max_nonlinear_iters", max_nonlinear_iters);
183 pp.
query(
"fast_nonlinear_solver", fast_nonlinear_solver);
184 pp.
query(
"fast_max_nonlinear_iters", fast_max_nonlinear_iters);
186 pp.
query(
"linear_solver", linear_solver);
187 pp.
query(
"max_linear_iters", max_linear_iters);
189 pp.
query(
"fast_linear_solver", fast_linear_solver);
190 pp.
query(
"fast_max_linear_iters", fast_max_linear_iters);
192 set_stop_time =
pp.
query(
"stop_time", stop_time);
194 pp.
query(
"max_num_steps", max_num_steps);
207 if (method !=
"DEFAULT") {
208 flag = ARKStepSetTableName(arkode_mem,
"ARKODE_DIRK_NONE", method.c_str());
212 else if (type ==
"DIRK") {
216 if (method !=
"DEFAULT") {
217 flag = ARKStepSetTableName(arkode_mem, method.c_str(),
"ARKODE_ERK_NONE");
221 else if (type ==
"IMEX-RK") {
223 << method_e <<
"\n"; }
226 if (method_e !=
"DEFAULT" && method_i !=
"DEFAULT")
228 flag = ARKStepSetTableName(arkode_mem, method_i.c_str(), method_e.c_str());
234 flag = ARKStepSetUserData(arkode_mem, &udata);
248 if (type ==
"DIRK" || type ==
"IMEX-RK") {
250 amrex::Print() <<
"Nonlinear solver: " << nonlinear_solver <<
"\n";
251 amrex::Print() <<
"Max nonlinear iters: " << max_nonlinear_iters <<
"\n";
253 if (nonlinear_solver ==
"fixed-point") {
254 NLS = SUNNonlinSol_FixedPoint(y_data, 0, sunctx);
256 flag = ARKStepSetNonlinearSolver(arkode_mem, NLS);
259 flag = ARKStepSetMaxNonlinIters(arkode_mem, max_nonlinear_iters);
262 if (nonlinear_solver ==
"Newton") {
264 amrex::Print() <<
"Linear solver: " << linear_solver <<
"\n";
265 amrex::Print() <<
"Max linear iters: " << max_linear_iters <<
"\n";
267 LS = SUNLinSol_SPGMR(y_data, SUN_PREC_NONE, max_linear_iters, sunctx);
269 flag = ARKStepSetLinearSolver(arkode_mem, LS,
nullptr);
283 flag = ARKStepSetStopTime(arkode_mem, stop_time);
288 flag = ARKStepSetMaxNumSteps(arkode_mem, max_num_steps);
298 if (fast_type ==
"ERK") {
302 if (fast_method !=
"DEFAULT") {
303 flag = ARKStepSetTableName(arkode_fast_mem,
"ARKODE_DIRK_NONE", fast_method.c_str());
307 else if (fast_type ==
"DIRK") {
311 if (fast_method !=
"DEFAULT") {
312 flag = ARKStepSetTableName(arkode_fast_mem, fast_method.c_str(),
"ARKODE_ERK_NONE");
317 amrex::Print() <<
"Fast nonlinear solver: " << fast_nonlinear_solver <<
"\n";
318 amrex::Print() <<
"Fast max nonlinear iters: " << fast_max_nonlinear_iters <<
"\n";
320 if (fast_nonlinear_solver ==
"fixed-point") {
321 fast_NLS = SUNNonlinSol_FixedPoint(y_data, 0, sunctx);
323 flag = ARKStepSetNonlinearSolver(arkode_fast_mem, fast_NLS);
326 flag = ARKStepSetMaxNonlinIters(arkode_fast_mem, fast_max_nonlinear_iters);
329 if (fast_nonlinear_solver ==
"Newton") {
331 amrex::Print() <<
"Linear solver: " << fast_linear_solver <<
"\n";
332 amrex::Print() <<
"Max linear iters: " << fast_max_linear_iters <<
"\n";
334 fast_LS = SUNLinSol_SPGMR(y_data, SUN_PREC_NONE, fast_max_linear_iters, sunctx);
336 flag = ARKStepSetLinearSolver(arkode_fast_mem, fast_LS,
nullptr);
342 flag = ARKStepSetUserData(arkode_fast_mem, &udata);
362 flag = ARKStepSetMaxNumSteps(arkode_fast_mem, max_num_steps);
366 flag = ARKStepCreateMRIStepInnerStepper(arkode_fast_mem, &fast_stepper);
370 if (type ==
"EX-MRI") {
373 fast_stepper, sunctx);
376 else if (type ==
"IM-MRI") {
379 fast_stepper, sunctx);
382 else if (type ==
"IMEX-MRI") {
385 time, y_data, fast_stepper, sunctx);
390 if (method !=
"DEFAULT") {
391 MRIStepCoupling MRIC = MRIStepCoupling_LoadTableByName(method.c_str());
393 flag = MRIStepSetCoupling(arkode_mem, MRIC);
395 MRIStepCoupling_Free(MRIC);
399 flag = MRIStepSetUserData(arkode_mem, &udata);
413 if (type ==
"IM-MRI" || type ==
"IMEX-MRI") {
415 amrex::Print() <<
"Nonlinear solver: " << nonlinear_solver <<
"\n";
416 amrex::Print() <<
"Max nonlinear iters: " << max_nonlinear_iters <<
"\n";
418 if (nonlinear_solver ==
"fixed-point") {
419 NLS = SUNNonlinSol_FixedPoint(y_data, 0, sunctx);
421 flag = MRIStepSetNonlinearSolver(arkode_mem, NLS);
424 flag = MRIStepSetMaxNonlinIters(arkode_mem, max_nonlinear_iters);
427 if (nonlinear_solver ==
"Newton") {
429 amrex::Print() <<
"Linear solver: " << linear_solver <<
"\n";
430 amrex::Print() <<
"Max linear iters: " << max_linear_iters <<
"\n";
432 LS = SUNLinSol_SPGMR(y_data, SUN_PREC_NONE, max_linear_iters, sunctx);
434 flag = MRIStepSetLinearSolver(arkode_mem, LS,
nullptr);
448 flag = MRIStepSetStopTime(arkode_mem, stop_time);
453 flag = MRIStepSetMaxNumSteps(arkode_mem, max_num_steps);
464 const int num_vecs = N_VGetNumSubvectors_ManyVector(y_data);
465 S_data.resize(num_vecs);
467 for(
int i = 0; i < num_vecs; i++)
479 auto get_length = [&](
int index) -> sunindextype {
480 auto* p_mf = &S_data[index];
481 return p_mf->nComp() * (p_mf->boxArray()).numPts();
484 sunindextype NV_len = S_data.
size();
485 N_Vector* NV_array =
new N_Vector[NV_len];
487 for (
int i = 0; i < NV_len; ++i) {
489 &S_data[i], &sunctx);
492 N_Vector y_data = N_VNew_ManyVector(NV_len, NV_array, sunctx);
502 auto get_length = [&](
int index) -> sunindextype {
503 auto* p_mf = &S_data[index];
504 return p_mf->nComp() * (p_mf->boxArray()).numPts();
507 sunindextype NV_len = S_data.
size();
508 N_Vector* NV_array =
new N_Vector[NV_len];
510 for (
int i = 0; i < NV_len; ++i) {
526 N_Vector y_data = N_VNew_ManyVector(NV_len, NV_array, sunctx);
598 initialize_parameters();
600#if defined(SUNDIALS_VERSION_MAJOR) && (SUNDIALS_VERSION_MAJOR < 7)
602 sunctx = ::sundials::Context(&mpi_comm);
604 sunctx = ::sundials::Context(
nullptr);
608 sunctx = ::sundials::Context(mpi_comm);
610 sunctx = ::sundials::Context(SUN_COMM_NULL);
615 udata.
f = [&](
amrex::Real rhs_time, N_Vector y_data, N_Vector y_rhs,
619 unpack_vector(y_data, S_data);
622 unpack_vector(y_rhs, S_rhs);
629 udata.
fi = [&](
amrex::Real rhs_time, N_Vector y_data, N_Vector y_rhs,
633 unpack_vector(y_data, S_data);
636 unpack_vector(y_rhs, S_rhs);
643 udata.
fe = [&](
amrex::Real rhs_time, N_Vector y_data, N_Vector y_rhs,
647 unpack_vector(y_data, S_data);
650 unpack_vector(y_rhs, S_rhs);
657 udata.
ff = [&](
amrex::Real rhs_time, N_Vector y_data, N_Vector y_rhs,
661 unpack_vector(y_data, S_data);
664 unpack_vector(y_rhs, S_rhs);
675 unpack_vector(y_data, S_data);
686 unpack_vector(y_data, S_data);
697 unpack_vector(y_data, S_data);
708 unpack_vector(y_data, S_data);
715 N_Vector y_data = copy_data(S_data);
718 SetupRK(time, y_data);
722 SetupMRI(time, y_data);
734 if (type ==
"EX-MRI" || type ==
"IM-MRI" || type ==
"IMEX-MRI") {
737 MRIStepPrintAllStats(arkode_mem, stdout, SUN_OUTPUTFORMAT_TABLE);
741 ARKStepPrintAllStats(arkode_fast_mem, stdout, SUN_OUTPUTFORMAT_TABLE);
746 ARKStepPrintAllStats(arkode_mem, stdout, SUN_OUTPUTFORMAT_TABLE);
753 SUNLinSolFree(fast_LS);
754 SUNNonlinSolFree(NLS);
755 SUNNonlinSolFree(fast_NLS);
756 MRIStepInnerStepper_Free(&fast_stepper);
757 MRIStepFree(&arkode_fast_mem);
758 ARKStepFree(&arkode_mem);
775 N_Vector y_old = wrap_data(S_old);
776 N_Vector y_new = wrap_data(S_new);
779 ARKStepReset(arkode_mem, time, y_old);
780 ARKStepSetFixedStep(arkode_mem, dt);
781 int flag = ARKStepEvolve(arkode_mem, tout, y_new, &tret, ARK_ONE_STEP);
785 MRIStepReset(arkode_mem, time, y_old);
786 MRIStepSetFixedStep(arkode_mem, dt);
787 int flag = MRIStepEvolve(arkode_mem, tout, y_new, &tret, ARK_ONE_STEP);
790 Error(
"SUNDIALS integrator type not specified.");
810 N_Vector y_out = wrap_data(S_out);
816 flag = ARKStepEvolve(arkode_mem, time_out, y_out, &time_ret, ARK_NORMAL);
826 flag = MRIStepEvolve(arkode_mem, time_out, y_out, &time_ret, ARK_NORMAL);
829 Error(
"SUNDIALS integrator type not specified.");
843 void map_data (std::function<
void(T&)> )
override {}
#define AMREX_ALWAYS_ASSERT(EX)
Definition AMReX_BLassert.H:50
amrex::ParmParse pp
Input file parser instance for the given namespace.
Definition AMReX_HypreIJIface.cpp:15
Long numPts() const noexcept
Returns the total number of cells contained in all boxes in the BoxArray.
Definition AMReX_BoxArray.cpp:394
int nGrow(int direction=0) const noexcept
Return the grow factor that defines the region of definition.
Definition AMReX_FabArrayBase.H:78
const DistributionMapping & DistributionMap() const noexcept
Return constant reference to associated DistributionMapping.
Definition AMReX_FabArrayBase.H:131
int nComp() const noexcept
Return number of variables (aka components) associated with each point.
Definition AMReX_FabArrayBase.H:83
const BoxArray & boxArray() const noexcept
Return a constant reference to the BoxArray that defines the valid region associated with this FabArr...
Definition AMReX_FabArrayBase.H:95
Definition AMReX_IntegratorBase.H:164
bool use_adaptive_fast_time_step
Flag to enable/disable adaptive time stepping at the fast time scale in multirate methods (bool)
Definition AMReX_IntegratorBase.H:246
amrex::Real fast_rel_tol
Relative tolerance for adaptive time stepping at the fast time scale (Real)
Definition AMReX_IntegratorBase.H:278
amrex::Real rel_tol
Relative tolerance for adaptive time stepping (Real)
Definition AMReX_IntegratorBase.H:267
std::function< void(T &, amrex::Real)> post_fast_stage_action
The post_stage_action function is called by the integrator on the computed stage just after it is com...
Definition AMReX_IntegratorBase.H:218
amrex::Real fast_abs_tol
Absolute tolerance for adaptive time stepping at the fast time scale (Real)
Definition AMReX_IntegratorBase.H:284
amrex::Real fast_time_step
Current integrator fast time scale time step size with multirate methods (Real)
Definition AMReX_IntegratorBase.H:252
std::function< void(T &rhs, T &state, const amrex::Real time)> RhsEx
RhsEx is the explicit right-hand-side function an ImEx integrator will use.
Definition AMReX_IntegratorBase.H:194
std::function< void(T &, amrex::Real)> post_step_action
The post_step_action function is called by the integrator on the computed state just after it is comp...
Definition AMReX_IntegratorBase.H:212
std::function< void(T &, amrex::Real)> post_fast_step_action
The post_step_action function is called by the integrator on the computed state just after it is comp...
Definition AMReX_IntegratorBase.H:224
std::function< void(T &rhs, T &state, const amrex::Real time)> RhsIm
RhsIm is the implicit right-hand-side function an ImEx integrator will use.
Definition AMReX_IntegratorBase.H:188
std::function< void(T &rhs, T &state, const amrex::Real time)> Rhs
Rhs is the right-hand-side function the integrator will use.
Definition AMReX_IntegratorBase.H:182
bool use_adaptive_time_step
Flag to enable/disable adaptive time stepping in single rate methods or at the slow time scale in mul...
Definition AMReX_IntegratorBase.H:230
std::function< void(T &, amrex::Real)> post_stage_action
The post_stage_action function is called by the integrator on the computed stage just after it is com...
Definition AMReX_IntegratorBase.H:206
amrex::Real time_step
Current integrator time step size (Real)
Definition AMReX_IntegratorBase.H:235
std::function< void(T &rhs, T &state, const amrex::Real time)> RhsFast
RhsFast is the fast timescale right-hand-side function a multirate integrator will use.
Definition AMReX_IntegratorBase.H:200
amrex::Real abs_tol
Absolute tolerance for adaptive time stepping (Real)
Definition AMReX_IntegratorBase.H:272
A collection (stored as an array) of FArrayBox objects.
Definition AMReX_MultiFab.H:40
static void Copy(MultiFab &dst, const MultiFab &src, int srccomp, int dstcomp, int numcomp, int nghost)
Copy from src to dst including nghost ghost cells. The two MultiFabs MUST have the same underlying Bo...
Definition AMReX_MultiFab.cpp:193
Parse Parameters From Command Line and Input Files.
Definition AMReX_ParmParse.H:348
int query(std::string_view name, bool &ref, int ival=FIRST) const
Same as querykth() but searches for the last occurrence of name.
Definition AMReX_ParmParse.cpp:1946
This class provides the user with a few print options.
Definition AMReX_Print.H:35
IntegratorBase implementation powered by SUNDIALS ARKStep/MRIStep.
Definition AMReX_SundialsIntegrator.H:95
void time_interpolate(const T &, const T &, amrex::Real, T &) override
Interpolate between SUNDIALS stages (not yet implemented for this integrator).
Definition AMReX_SundialsIntegrator.H:838
SundialsIntegrator()
Construct an uninitialized integrator; call initialize() before use.
Definition AMReX_SundialsIntegrator.H:577
void initialize(const T &S_data, const amrex::Real time=0.0)
Configure (or reconfigure) the SUNDIALS integrator for the provided state.
Definition AMReX_SundialsIntegrator.H:596
amrex::Real advance(T &S_old, T &S_new, amrex::Real time, const amrex::Real dt) override
Take a single time step of size dt starting from S_old.
Definition AMReX_SundialsIntegrator.H:770
void evolve(T &S_out, const amrex::Real time_out) override
Evolve the solution in S_out up to time_out using ARKStep/MRIStep.
Definition AMReX_SundialsIntegrator.H:805
virtual ~SundialsIntegrator()
Destroy the integrator, printing summary statistics when verbose.
Definition AMReX_SundialsIntegrator.H:731
SundialsIntegrator(const T &S_data, const amrex::Real time=0.0)
Construct and immediately configure the integrator with S_data at time time.
Definition AMReX_SundialsIntegrator.H:585
void map_data(std::function< void(T &)>) override
Apply a user-supplied mapping to every MultiFab in the integrator (unused placeholder).
Definition AMReX_SundialsIntegrator.H:843
This class is a thin wrapper around std::vector. Unlike vector, Vector::operator[] provides bound che...
Definition AMReX_Vector.H:28
Long size() const noexcept
Definition AMReX_Vector.H:53
amrex_real Real
Floating Point Type for Fields.
Definition AMReX_REAL.H:79
amrex_long Long
Definition AMReX_INT.H:30
bool IOProcessor() noexcept
Is this CPU the I/O Processor? To get the rank number, call IOProcessorNumber()
Definition AMReX_ParallelDescriptor.H:289
MPI_Comm CommunicatorSub() noexcept
sub-communicator for current frame
Definition AMReX_ParallelContext.H:70
static int fi(amrex::Real t, N_Vector y_data, N_Vector y_rhs, void *user_data)
Definition AMReX_SundialsIntegrator.H:52
static int fe(amrex::Real t, N_Vector y_data, N_Vector y_rhs, void *user_data)
Definition AMReX_SundialsIntegrator.H:57
static int post_fast_step(amrex::Real t, N_Vector y_data, void *user_data)
Definition AMReX_SundialsIntegrator.H:82
static int f(amrex::Real t, N_Vector y_data, N_Vector y_rhs, void *user_data)
Definition AMReX_SundialsIntegrator.H:47
static int post_step(amrex::Real t, N_Vector y_data, void *user_data)
Definition AMReX_SundialsIntegrator.H:72
static int ff(amrex::Real t, N_Vector y_data, N_Vector y_rhs, void *user_data)
Definition AMReX_SundialsIntegrator.H:62
static int post_stage(amrex::Real t, N_Vector y_data, void *user_data)
Definition AMReX_SundialsIntegrator.H:67
static int post_fast_stage(amrex::Real t, N_Vector y_data, void *user_data)
Definition AMReX_SundialsIntegrator.H:77
int MPI_Comm
Definition AMReX_ccse-mpi.H:51
N_Vector N_VMake_MultiFab(sunindextype length, amrex::MultiFab *v_mf, ::sundials::Context *sunctx)
Wrap an existing MultiFab mf as an N_Vector without copying.
Definition AMReX_NVector_MultiFab.cpp:105
amrex::MultiFab *& getMFptr(N_Vector v)
Access the MultiFab pointer stored inside v (non-const).
Definition AMReX_NVector_MultiFab.cpp:233
N_Vector N_VNew_MultiFab(sunindextype length, const amrex::BoxArray &ba, const amrex::DistributionMapping &dm, sunindextype nComp, sunindextype nGhost, ::sundials::Context *sunctx)
Allocate a MultiFab-backed N_Vector of length vec_length.
Definition AMReX_NVector_MultiFab.cpp:80
Definition AMReX_Amr.cpp:49
@ make_alias
Definition AMReX_MakeType.H:7
int nComp(FabArrayBase const &fa)
Definition AMReX_FabArrayBase.cpp:2851
DistributionMapping const & DistributionMap(FabArrayBase const &fa)
Definition AMReX_FabArrayBase.cpp:2866
void Error(const std::string &msg)
Print out message to cerr and exit via amrex::Abort().
Definition AMReX.cpp:234
int Verbose() noexcept
Definition AMReX.cpp:179
const int[]
Definition AMReX_BLProfiler.cpp:1664
BoxArray const & boxArray(FabArrayBase const &fa)
Definition AMReX_FabArrayBase.cpp:2861
User-supplied callbacks consumed by the AMReX/SUNDIALS bridge.
Definition AMReX_SundialsIntegrator.H:35
std::function< int(amrex::Real, N_Vector, N_Vector, void *)> fi
Implicit RHS for ImEx schemes.
Definition AMReX_SundialsIntegrator.H:37
std::function< int(amrex::Real, N_Vector, void *)> post_fast_stage
Hook for MRI fast stages.
Definition AMReX_SundialsIntegrator.H:42
std::function< int(amrex::Real, N_Vector, void *)> post_step
Hook invoked after each time step.
Definition AMReX_SundialsIntegrator.H:41
std::function< int(amrex::Real, N_Vector, N_Vector, void *)> f
ERK/DIRK RHS or MRI slow RHS.
Definition AMReX_SundialsIntegrator.H:36
std::function< int(amrex::Real, N_Vector, void *)> post_stage
Hook invoked after each stage.
Definition AMReX_SundialsIntegrator.H:40
std::function< int(amrex::Real, N_Vector, N_Vector, void *)> ff
MRI fast-scale RHS.
Definition AMReX_SundialsIntegrator.H:39
std::function< int(amrex::Real, N_Vector, void *)> post_fast_step
Hook for MRI fast steps.
Definition AMReX_SundialsIntegrator.H:43
std::function< int(amrex::Real, N_Vector, N_Vector, void *)> fe
Explicit RHS for ImEx schemes.
Definition AMReX_SundialsIntegrator.H:38