17#ifdef MFEM_USE_SUNDIALS
27#include <sundials/sundials_config.h>
29#if !defined(SUNDIALS_VERSION_MAJOR) || (SUNDIALS_VERSION_MAJOR < 5)
30#error MFEM requires SUNDIALS version 5.0.0 or newer!
32#if defined(MFEM_USE_CUDA) && ((SUNDIALS_VERSION_MAJOR == 5) && (SUNDIALS_VERSION_MINOR < 4))
33#error MFEM requires SUNDIALS version 5.4.0 or newer when MFEM_USE_CUDA=TRUE!
35#if defined(MFEM_USE_HIP) && ((SUNDIALS_VERSION_MAJOR == 5) && (SUNDIALS_VERSION_MINOR < 7))
36#error MFEM requires SUNDIALS version 5.7.0 or newer when MFEM_USE_HIP=TRUE!
38#if defined(MFEM_USE_CUDA) && !defined(SUNDIALS_NVECTOR_CUDA)
39#error MFEM_USE_CUDA=TRUE requires SUNDIALS to be built with CUDA support
41#if defined(MFEM_USE_HIP) && !defined(SUNDIALS_NVECTOR_HIP)
42#error MFEM_USE_HIP=TRUE requires SUNDIALS to be built with HIP support
44#include <sundials/sundials_matrix.h>
45#include <sundials/sundials_linearsolver.h>
46#include <arkode/arkode_arkstep.h>
47#include <cvodes/cvodes.h>
48#include <kinsol/kinsol.h>
49#if defined(MFEM_USE_CUDA)
50#include <sunmemory/sunmemory_cuda.h>
51#elif defined(MFEM_USE_HIP)
52#include <sunmemory/sunmemory_hip.h>
57#define MFEM_SUNDIALS_VERSION \
58 (SUNDIALS_VERSION_MAJOR*10000 + SUNDIALS_VERSION_MINOR*100 + \
59 SUNDIALS_VERSION_PATCH)
61#if (SUNDIALS_VERSION_MAJOR < 6)
91#if (SUNDIALS_VERSION_MAJOR < 7)
103#if defined(MFEM_USE_CUDA) || defined(MFEM_USE_HIP)
108class SundialsMemHelper
137 operator SUNMemoryHelper()
const {
return h; }
140 size_t memsize, SUNMemoryType mem_type
141#
if (SUNDIALS_VERSION_MAJOR >= 6)
147#
if (SUNDIALS_VERSION_MAJOR >= 6)
250 SundialsNVector(MPI_Comm comm,
double *data_,
int loc_size,
long glob_size);
264 inline N_Vector_ID
GetNVectorID(N_Vector x_)
const {
return N_VGetVectorID(x_); }
270#if SUNDIALS_VERSION_MAJOR < 7
271 return *
static_cast<MPI_Comm*
>(N_VGetCommunicator(
x));
273 return N_VGetCommunicator(
x);
282 void SetSize(
int s,
long glob_size = 0);
314 operator N_Vector()
const {
return x; }
328 using Vector::operator=;
345 static N_Vector
MakeNVector(MPI_Comm comm,
bool use_device);
348#if defined(MFEM_USE_CUDA) || defined(MFEM_USE_HIP)
434 N_Vector tmp2, N_Vector tmp3);
437 static int LinSysSolve(SUNLinearSolver LS, SUNMatrix
A, N_Vector x,
500 void Step(
Vector &x,
double &t,
double &dt)
override;
570 N_Vector
yB, N_Vector yBdot,
void *user_dataB);
574 N_Vector qBdot,
void *user_dataB);
577 static int ewt(N_Vector y, N_Vector w,
void *user_data);
627 void Step(
Vector &x,
double &t,
double &dt)
override;
630 virtual void StepB(
Vector &w,
double &t,
double &dt);
637 double reltolQ = 1e-3,
638 double abstolQ = 1e-8);
642 double abstolQB = 1e-8);
693 N_Vector tmp2, N_Vector tmp3);
696 static int LinSysSolveB(SUNLinearSolver LS, SUNMatrix
A, N_Vector x,
740 N_Vector tmp2, N_Vector tmp3);
743 static int LinSysSolve(SUNLinearSolver LS, SUNMatrix
A, N_Vector x,
748 N_Vector tmp1, N_Vector tmp2, N_Vector tmp3);
751 static int MassSysSolve(SUNLinearSolver LS, SUNMatrix
M, N_Vector x,
755 static int MassMult1(SUNMatrix
M, N_Vector x, N_Vector v);
904 static int Mult(
const N_Vector
u, N_Vector fu,
void *user_data);
911 static int LinSysSetup(N_Vector
u, N_Vector fu, SUNMatrix J,
912 void *user_data, N_Vector tmp1, N_Vector tmp2);
915 static int LinSysSolve(SUNLinearSolver LS, SUNMatrix J, N_Vector
u,
943 KINSolver(
int strategy,
bool oper_grad =
true);
953 KINSolver(MPI_Comm comm,
int strategy,
bool oper_grad =
true);
997 double damping = 1.0);
Interface to ARKode's ARKStep module – additive Runge-Kutta methods.
static int RHS2(sunrealtype t, const N_Vector y, N_Vector ydot, void *user_data)
void SetMaxStep(double dt_max)
Set the maximum time step.
void PrintInfo() const
Print various ARKStep statistics.
Type rk_type
Runge-Kutta type.
ARKStepSolver(Type type=EXPLICIT)
Construct a serial wrapper to SUNDIALS' ARKode integrator.
void SetOrder(int order)
Chooses integration order for all explicit / implicit / IMEX methods.
void SetStepMode(int itask)
Select the ARKode step mode: ARK_NORMAL (default) or ARK_ONE_STEP.
Type
Types of ARKODE solvers.
@ IMPLICIT
Implicit RK method.
@ IMEX
Implicit-explicit ARK method.
@ EXPLICIT
Explicit RK method.
static int LinSysSolve(SUNLinearSolver LS, SUNMatrix A, N_Vector x, N_Vector b, sunrealtype tol)
Solve the linear system .
static int LinSysSetup(sunrealtype t, N_Vector y, N_Vector fy, SUNMatrix A, SUNMatrix M, sunbooleantype jok, sunbooleantype *jcur, sunrealtype gamma, void *user_data, N_Vector tmp1, N_Vector tmp2, N_Vector tmp3)
Setup the linear system .
void UseSundialsLinearSolver()
Attach a SUNDIALS GMRES linear solver to ARKode.
void Init(TimeDependentOperator &f_) override
Initialize ARKode: calls ARKStepCreate() to create the ARKStep memory and set some defaults.
static int MassMult2(N_Vector x, N_Vector v, sunrealtype t, void *mtimes_data)
Compute the matrix-vector product at time t.
void SetIRKTableNum(ARKODE_DIRKTableID table_id)
Choose a specific Butcher table for a diagonally implicit RK method.
void SetFixedStep(double dt)
Use a fixed time step size (disable temporal adaptivity).
void SetERKTableNum(ARKODE_ERKTableID table_id)
Choose a specific Butcher table for an explicit RK method.
int step_mode
ARKStep step mode (ARK_NORMAL or ARK_ONE_STEP).
static int RHS1(sunrealtype t, const N_Vector y, N_Vector ydot, void *user_data)
void UseMFEMMassLinearSolver(int tdep)
Attach mass matrix linear system setup, solve, and matrix-vector product methods from the TimeDepende...
bool use_implicit
True for implicit or imex integration.
void UseSundialsMassLinearSolver(int tdep)
Attach the SUNDIALS GMRES linear solver and the mass matrix matrix-vector product method from the Tim...
virtual ~ARKStepSolver()
Destroy the associated ARKode memory and SUNDIALS objects.
void SetIMEXTableNum(ARKODE_ERKTableID etable_id, ARKODE_DIRKTableID itable_id)
Choose a specific Butcher table for an IMEX RK method.
void UseMFEMLinearSolver()
Attach the linear system setup and solve methods from the TimeDependentOperator i....
static int MassMult1(SUNMatrix M, N_Vector x, N_Vector v)
Compute the matrix-vector product .
static int MassSysSolve(SUNLinearSolver LS, SUNMatrix M, N_Vector x, N_Vector b, sunrealtype tol)
Solve the linear system .
static int MassSysSetup(sunrealtype t, SUNMatrix M, void *user_data, N_Vector tmp1, N_Vector tmp2, N_Vector tmp3)
Setup the linear system .
void Step(Vector &x, real_t &t, real_t &dt) override
Integrate the ODE with ARKode using the specified step mode.
void SetSStolerances(double reltol, double abstol)
Set the scalar relative and scalar absolute tolerances.
static int RHSB(sunrealtype t, N_Vector y, N_Vector yB, N_Vector yBdot, void *user_dataB)
Wrapper to compute the ODE RHS backward function.
void EvalQuadIntegrationB(double t, Vector &dG_dp)
Evaluate Quadrature solution.
void EvalQuadIntegration(double t, Vector &q)
Evaluate Quadrature.
static int LinSysSetupB(sunrealtype t, N_Vector y, N_Vector yB, N_Vector fyB, SUNMatrix A, sunbooleantype jok, sunbooleantype *jcur, sunrealtype gamma, void *user_data, N_Vector tmp1, N_Vector tmp2, N_Vector tmp3)
Setup the linear system A x = b.
static int LinSysSolveB(SUNLinearSolver LS, SUNMatrix A, N_Vector x, N_Vector b, sunrealtype tol)
Solve the linear system A x = b.
long GetNumSteps()
Get Number of Steps for ForwardSolve.
static constexpr double default_abs_tolB
Default scalar backward absolute tolerance.
static constexpr double default_abs_tolQB
Default scalar backward absolute quadrature tolerance.
void SetMaxNStepsB(int mxstepsB)
Set the maximum number of backward steps.
static int RHSQ(sunrealtype t, const N_Vector y, N_Vector qdot, void *user_data)
Wrapper to compute the ODE RHS Quadrature function.
void Step(Vector &x, double &t, double &dt) override
void InitB(TimeDependentAdjointOperator &f_)
Initialize the adjoint problem.
static int RHSQB(sunrealtype t, N_Vector y, N_Vector yB, N_Vector qBdot, void *user_dataB)
Wrapper to compute the ODE RHS Backwards Quadrature function.
SundialsNVector * q
Quadrature vector.
int indexB
backward problem index
void GetForwardSolution(double tB, mfem::Vector &yy)
Get Interpolated Forward solution y at backward integration time tB.
SUNLinearSolver LSB
Linear solver for A.
void SetSVtolerancesB(double reltol, Vector abstol)
Tolerance specification functions for the adjoint problem.
void UseSundialsLinearSolverB()
Use built in SUNDIALS Newton solver.
void SetWFTolerances(EWTFunction func)
Set multiplicative error weights.
SundialsNVector * yy
State vector.
void Init(TimeDependentAdjointOperator &f_)
SUNMatrix AB
Linear system A = I - gamma J, M - gamma J, or J.
int ncheck
number of checkpoints used so far
SundialsNVector * qB
State vector.
static int ewt(N_Vector y, N_Vector w, void *user_data)
Error control function.
virtual void StepB(Vector &w, double &t, double &dt)
Solve one adjoint time step.
void InitQuadIntegrationB(mfem::Vector &qB0, double reltolQB=1e-3, double abstolQB=1e-8)
Initialize Quadrature Integration (Adjoint)
static constexpr double default_rel_tolB
Default scalar backward relative tolerance.
void InitAdjointSolve(int steps, int interpolation)
Initialize Adjoint.
SundialsNVector * yB
State vector.
void InitQuadIntegration(mfem::Vector &q0, double reltolQ=1e-3, double abstolQ=1e-8)
virtual ~CVODESSolver()
Destroy the associated CVODES memory and SUNDIALS objects.
void SetSStolerancesB(double reltol, double abstol)
Tolerance specification functions for the adjoint problem.
void UseMFEMLinearSolverB()
Set Linear Solver for the backward problem.
Interface to the CVODE library – linear multi-step methods.
void SetStepMode(int itask)
Select the CVODE step mode: CV_NORMAL (default) or CV_ONE_STEP.
void SetRootFinder(int components, RootFunction func)
Initialize Root Finder.
std::function< int(sunrealtype t, Vector y, Vector gout, CVODESolver *)> RootFunction
Typedef for root finding functions.
void SetSStolerances(double reltol, double abstol)
Set the scalar relative and scalar absolute tolerances.
void Init(TimeDependentOperator &f_) override
Initialize CVODE: calls CVodeCreate() to create the CVODE memory and set some defaults.
static int LinSysSolve(SUNLinearSolver LS, SUNMatrix A, N_Vector x, N_Vector b, sunrealtype tol)
Solve the linear system .
virtual ~CVODESolver()
Destroy the associated CVODE memory and SUNDIALS objects.
static int root(sunrealtype t, N_Vector y, sunrealtype *gout, void *user_data)
Prototype to define root finding for CVODE.
void SetMaxNSteps(int steps)
Set the maximum number of time steps.
CVODESolver(int lmm)
Construct a serial wrapper to SUNDIALS' CVODE integrator.
long GetNumSteps()
Get the number of internal steps taken so far.
void Step(Vector &x, double &t, double &dt) override
Integrate the ODE with CVODE using the specified step mode.
EWTFunction ewt_func
A class member to facilitate pointing to a user-specified error weight function.
static int RHS(sunrealtype t, const N_Vector y, N_Vector ydot, void *user_data)
Number of components in gout.
void SetMaxStep(double dt_max)
Set the maximum time step.
int lmm_type
Linear multistep method type.
void PrintInfo() const
Print various CVODE statistics.
void UseSundialsLinearSolver()
Attach SUNDIALS GMRES linear solver to CVODE.
void UseMFEMLinearSolver()
Attach the linear system setup and solve methods from the TimeDependentOperator i....
RootFunction root_func
A class member to facilitate pointing to a user-specified root function.
std::function< int(Vector y, Vector w, CVODESolver *)> EWTFunction
Typedef declaration for error weight functions.
int step_mode
CVODE step mode (CV_NORMAL or CV_ONE_STEP).
static int LinSysSetup(sunrealtype t, N_Vector y, N_Vector fy, SUNMatrix A, sunbooleantype jok, sunbooleantype *jcur, sunrealtype gamma, void *user_data, N_Vector tmp1, N_Vector tmp2, N_Vector tmp3)
Setup the linear system .
void SetSVtolerances(double reltol, Vector abstol)
Set the scalar relative and vector of absolute tolerances.
void SetMaxOrder(int max_order)
Set the maximum method order.
static MemoryType GetDeviceMemoryType()
Get the current Device MemoryType. This is the MemoryType used by most MFEM classes when allocating m...
Wrapper for hypre's parallel vector class.
int print_level
(DEPRECATED) Legacy print level definition, which is left for compatibility with custom iterative sol...
Interface to the KINSOL library – nonlinear solver methods.
SundialsNVector * f_scale
scaling vectors
KINSolver(int strategy, bool oper_grad=true)
Construct a serial wrapper to SUNDIALS' KINSOL nonlinear solver.
double aa_damping
Anderson Acceleration damping.
void SetJFNK(bool use_jfnk)
Set the Jacobian Free Newton Krylov flag. The default is false.
void SetJFNKSolver(Solver &solver)
virtual ~KINSolver()
Destroy the associated KINSOL memory.
static int Mult(const N_Vector u, N_Vector fu, void *user_data)
Wrapper to compute the nonlinear residual .
static int LinSysSolve(SUNLinearSolver LS, SUNMatrix J, N_Vector u, N_Vector b, sunrealtype tol)
Solve the linear system .
int aa_delay
Anderson Acceleration delay.
int global_strategy
KINSOL solution strategy.
static int PrecSolve(N_Vector uu, N_Vector uscale, N_Vector fval, N_Vector fscale, N_Vector vv, void *user_data)
Solve the preconditioner equation .
int maxlrs
Maximum linear solver restarts.
const Operator * jacobian
stores oper->GetGradient()
int maxli
Maximum linear iterations.
Vector wrk
Work vector needed for the JFNK PC.
void SetScaledStepTol(double sstol)
Set KINSOL's scaled step tolerance.
void SetSolver(Solver &solver) override
Set the linear solver for inverting the Jacobian.
SundialsNVector * y_scale
void SetOperator(const Operator &op) override
Set the nonlinear Operator of the system and initialize KINSOL.
void SetMaxSetupCalls(int max_calls)
Set maximum number of nonlinear iterations without a Jacobian update.
static int LinSysSetup(N_Vector u, N_Vector fu, SUNMatrix J, void *user_data, N_Vector tmp1, N_Vector tmp2)
Setup the linear system .
int aa_orth
Anderson Acceleration orthogonalization routine.
static int PrecSetup(N_Vector uu, N_Vector uscale, N_Vector fval, N_Vector fscale, void *user_data)
Setup the preconditioner.
void SetDamping(double damping)
void SetLSMaxRestarts(int m)
Set the maximum number of linear solver restarts.
double fp_damping
Fixed Point or Picard damping parameter.
void EnableAndersonAcc(int n, int orth=KIN_ORTH_MGS, int delay=0, double damping=1.0)
Enable Anderson Acceleration for KIN_FP or KIN_PICARD.
void SetPrintLevel(int print_lvl) override
Set the print level for the KINSetPrintLevel function.
void SetPreconditioner(Solver &solver) override
Equivalent to SetSolver(solver).
void SetLSMaxIter(int m)
Set the maximum number of linear solver iterations.
bool use_oper_grad
use the Jv prod function
int aa_n
number of acceleration vectors
static int GradientMult(N_Vector v, N_Vector Jv, N_Vector u, sunbooleantype *new_u, void *user_data)
Wrapper to compute the Jacobian-vector product .
Newton's method for solving F(x)=b for a given operator F.
Abstract class for solving systems of ODEs: dx/dt = f(x,t)
SundialsMemHelper & operator=(const SundialsMemHelper &)=delete
Disable copy assignment.
SundialsMemHelper(const SundialsMemHelper &that_helper)=delete
Disable copy construction.
SundialsMemHelper()=default
Default constructor – object must be moved to.
static int SundialsMemHelper_Alloc(SUNMemoryHelper helper, SUNMemory *memptr, size_t memsize, SUNMemoryType mem_type #if(SUNDIALS_VERSION_MAJOR >=6), void *queue #endif)
static int SundialsMemHelper_Dealloc(SUNMemoryHelper helper, SUNMemory sunmem #if(SUNDIALS_VERSION_MAJOR >=6), void *queue #endif)
SundialsMemHelper(SUNContext context)
Vector interface for SUNDIALS N_Vectors.
N_Vector_ID GetNVectorID(N_Vector x_) const
Returns the N_Vector_ID for the N_Vector x_.
N_Vector StealNVector()
Changes the ownership of the vector.
long GlobalSize() const
Returns the MPI global length for the internal N_Vector x.
MPI_Comm GetComm() const
Returns the MPI communicator for the internal N_Vector x.
void MakeRef(Vector &base, int offset, int s)
Reset the Vector to be a reference to a sub-vector of base.
void SetSize(int s, long glob_size=0)
Resize the vector to size s.
static bool UseManagedMemory()
static N_Vector MakeNVector(bool use_device)
Create a N_Vector.
int GetOwnership() const
Gets ownership of the internal N_Vector.
~SundialsNVector()
Calls SUNDIALS N_VDestroy function if the N_Vector is owned by 'this'.
void SetDataAndSize(double *d, int s, long glob_size=0)
Set the vector data and size.
void SetOwnership(int own)
Sets ownership of the internal N_Vector.
N_Vector_ID GetNVectorID() const
Returns the N_Vector_ID for the internal N_Vector.
void _SetNvecDataAndSize_(long glob_size=0)
Set data and length of internal N_Vector x from 'this'.
void _SetDataAndSize_()
Set data and length from the internal N_Vector x.
N_Vector x
The actual SUNDIALS object.
void MakeRef(Vector &base, int offset)
Reset the Vector to be a reference to a sub-vector of base without changing its current size.
SundialsNVector()
Creates an empty SundialsNVector.
Base class for interfacing with SUNDIALS packages.
static constexpr double default_abs_tol
Default scalar absolute tolerance.
SUNMatrix M
Mass matrix M.
int flag
Last flag returned from a call to SUNDIALS.
long saved_global_size
Global vector length on last initialization.
SundialsSolver()
Protected constructor: objects of this type should be constructed only as part of a derived class.
static constexpr double default_rel_tol
Default scalar relative tolerance.
void * GetMem() const
Access the SUNDIALS memory structure.
bool reinit
Flag to signal memory reinitialization is need.
void AllocateEmptyNVector(N_Vector &y)
SundialsNVector * Y
State vector.
void * sundials_mem
SUNDIALS mem structure.
void AllocateEmptyNVector(N_Vector &y, MPI_Comm comm)
SUNLinearSolver LSA
Linear solver for A.
SUNLinearSolver LSM
Linear solver for M.
int GetFlag() const
Returns the last flag returned by a call to a SUNDIALS function.
SUNNonlinearSolver NLS
Nonlinear solver.
Singleton class for SUNContext and SundialsMemHelper objects.
Sundials(Sundials &other)=delete
Disable copy construction.
static SUNContext & GetContext()
Provides access to the SUNContext object.
void operator=(const Sundials &other)=delete
Disable copy assignment.
static SundialsMemHelper & GetMemHelper()
Provides access to the SundialsMemHelper object.
Base abstract class for first order time dependent operators.
virtual const real_t * Read(bool on_dev=true) const
Shortcut for mfem::Read(vec.GetMemory(), vec.Size(), on_dev).
void MakeRef(Vector &base, int offset, int size)
Reset the Vector to be a reference to a sub-vector of base.
real_t u(const Vector &xvec)
realtype sunrealtype
'sunrealtype' was first introduced in v6.0.0
constexpr ARKODE_ERKTableID ARKODE_FEHLBERG_13_7_8
booleantype sunbooleantype
'sunbooleantype' was first introduced in v6.0.0
constexpr ARKODE_ERKTableID ARKODE_ERK_NONE
constexpr ARKODE_DIRKTableID ARKODE_DIRK_NONE