17#ifdef MFEM_USE_SUNDIALS
27#include <sundials/sundials_config.h>
29#if !defined(SUNDIALS_VERSION_MAJOR) || (SUNDIALS_VERSION_MAJOR < 5)
30#error MFEM requires SUNDIALS version 5.0.0 or newer!
32#if defined(MFEM_USE_CUDA) && ((SUNDIALS_VERSION_MAJOR == 5) && (SUNDIALS_VERSION_MINOR < 4))
33#error MFEM requires SUNDIALS version 5.4.0 or newer when MFEM_USE_CUDA=TRUE!
35#if defined(MFEM_USE_HIP) && ((SUNDIALS_VERSION_MAJOR == 5) && (SUNDIALS_VERSION_MINOR < 7))
36#error MFEM requires SUNDIALS version 5.7.0 or newer when MFEM_USE_HIP=TRUE!
38#if defined(MFEM_USE_CUDA) && !defined(SUNDIALS_NVECTOR_CUDA)
39#error MFEM_USE_CUDA=TRUE requires SUNDIALS to be built with CUDA support
41#if defined(MFEM_USE_HIP) && !defined(SUNDIALS_NVECTOR_HIP)
42#error MFEM_USE_HIP=TRUE requires SUNDIALS to be built with HIP support
44#include <sundials/sundials_matrix.h>
45#include <sundials/sundials_linearsolver.h>
46#include <arkode/arkode_arkstep.h>
47#include <cvodes/cvodes.h>
48#include <kinsol/kinsol.h>
49#if defined(MFEM_USE_CUDA)
50#include <sunmemory/sunmemory_cuda.h>
51#elif defined(MFEM_USE_HIP)
52#include <sunmemory/sunmemory_hip.h>
57#if (SUNDIALS_VERSION_MAJOR < 6)
76#if defined(MFEM_USE_CUDA) || defined(MFEM_USE_HIP)
81class SundialsMemHelper
110 operator SUNMemoryHelper()
const {
return h; }
113 size_t memsize, SUNMemoryType mem_type
114#
if (SUNDIALS_VERSION_MAJOR >= 6)
120#
if (SUNDIALS_VERSION_MAJOR >= 6)
223 SundialsNVector(MPI_Comm comm,
double *data_,
int loc_size,
long glob_size);
237 inline N_Vector_ID
GetNVectorID(N_Vector x_)
const {
return N_VGetVectorID(x_); }
241 inline MPI_Comm
GetComm()
const {
return *
static_cast<MPI_Comm*
>(N_VGetCommunicator(
x)); }
248 void SetSize(
int s,
long glob_size = 0);
280 operator N_Vector()
const {
return x; }
294 using Vector::operator=;
311 static N_Vector
MakeNVector(MPI_Comm comm,
bool use_device);
314#if defined(MFEM_USE_CUDA) || defined(MFEM_USE_HIP)
393 static int RHS(realtype
t,
const N_Vector y, N_Vector ydot,
void *user_data);
396 static int LinSysSetup(realtype
t, N_Vector y, N_Vector fy, SUNMatrix
A,
397 booleantype jok, booleantype *jcur,
398 realtype gamma,
void *user_data, N_Vector tmp1,
399 N_Vector tmp2, N_Vector tmp3);
402 static int LinSysSolve(SUNLinearSolver LS, SUNMatrix
A, N_Vector x,
403 N_Vector
b, realtype tol);
406 static int root(realtype
t, N_Vector y, realtype *gout,
void *user_data);
463 virtual void Step(
Vector &x,
double &
t,
double &dt);
528 static int RHSQ(realtype
t,
const N_Vector y, N_Vector qdot,
void *user_data);
531 static int RHSB(realtype
t, N_Vector y,
532 N_Vector
yB, N_Vector yBdot,
void *user_dataB);
535 static int RHSQB(realtype
t, N_Vector y, N_Vector
yB,
536 N_Vector qBdot,
void *user_dataB);
539 static int ewt(N_Vector y, N_Vector w,
void *user_data);
589 virtual void Step(
Vector &x,
double &
t,
double &dt);
599 double reltolQ = 1e-3,
600 double abstolQ = 1e-8);
604 double abstolQB = 1e-8);
651 static int LinSysSetupB(realtype
t, N_Vector y, N_Vector
yB, N_Vector fyB,
653 booleantype jok, booleantype *jcur,
654 realtype gamma,
void *user_data, N_Vector tmp1,
655 N_Vector tmp2, N_Vector tmp3);
658 static int LinSysSolveB(SUNLinearSolver LS, SUNMatrix
A, N_Vector x,
659 N_Vector
b, realtype tol);
692 static int RHS1(realtype
t,
const N_Vector y, N_Vector ydot,
void *user_data);
693 static int RHS2(realtype
t,
const N_Vector y, N_Vector ydot,
void *user_data);
697 static int LinSysSetup(realtype
t, N_Vector y, N_Vector fy, SUNMatrix
A,
698 SUNMatrix
M, booleantype jok, booleantype *jcur,
699 realtype gamma,
void *user_data, N_Vector tmp1,
700 N_Vector tmp2, N_Vector tmp3);
703 static int LinSysSolve(SUNLinearSolver LS, SUNMatrix
A, N_Vector x,
704 N_Vector
b, realtype tol);
708 N_Vector tmp1, N_Vector tmp2, N_Vector tmp3);
711 static int MassSysSolve(SUNLinearSolver LS, SUNMatrix
M, N_Vector x,
712 N_Vector
b, realtype tol);
715 static int MassMult1(SUNMatrix
M, N_Vector x, N_Vector v);
718 static int MassMult2(N_Vector x, N_Vector v, realtype
t,
766 virtual void Step(
Vector &x,
double &
t,
double &dt);
860 static int Mult(
const N_Vector
u, N_Vector fu,
void *user_data);
864 booleantype *new_u,
void *user_data);
867 static int LinSysSetup(N_Vector
u, N_Vector fu, SUNMatrix J,
868 void *user_data, N_Vector tmp1, N_Vector tmp2);
871 static int LinSysSolve(SUNLinearSolver LS, SUNMatrix J, N_Vector
u,
872 N_Vector
b, realtype tol);
899 KINSolver(
int strategy,
bool oper_grad =
true);
909 KINSolver(MPI_Comm comm,
int strategy,
bool oper_grad =
true);
Interface to ARKode's ARKStep module – additive Runge-Kutta methods.
void SetMaxStep(double dt_max)
Set the maximum time step.
void PrintInfo() const
Print various ARKStep statistics.
Type rk_type
Runge-Kutta type.
void Init(TimeDependentOperator &f_)
Initialize ARKode: calls ARKStepCreate() to create the ARKStep memory and set some defaults.
ARKStepSolver(Type type=EXPLICIT)
Construct a serial wrapper to SUNDIALS' ARKode integrator.
void SetOrder(int order)
Chooses integration order for all explicit / implicit / IMEX methods.
static int MassMult2(N_Vector x, N_Vector v, realtype t, void *mtimes_data)
Compute the matrix-vector product at time t.
void SetStepMode(int itask)
Select the ARKode step mode: ARK_NORMAL (default) or ARK_ONE_STEP.
Type
Types of ARKODE solvers.
@ IMPLICIT
Implicit RK method.
@ IMEX
Implicit-explicit ARK method.
@ EXPLICIT
Explicit RK method.
void UseSundialsLinearSolver()
Attach a SUNDIALS GMRES linear solver to ARKode.
static int RHS2(realtype t, const N_Vector y, N_Vector ydot, void *user_data)
virtual void Step(Vector &x, double &t, double &dt)
Integrate the ODE with ARKode using the specified step mode.
void SetIRKTableNum(ARKODE_DIRKTableID table_id)
Choose a specific Butcher table for a diagonally implicit RK method.
void SetFixedStep(double dt)
Use a fixed time step size (disable temporal adaptivity).
static int LinSysSolve(SUNLinearSolver LS, SUNMatrix A, N_Vector x, N_Vector b, realtype tol)
Solve the linear system .
void SetERKTableNum(ARKODE_ERKTableID table_id)
Choose a specific Butcher table for an explicit RK method.
int step_mode
ARKStep step mode (ARK_NORMAL or ARK_ONE_STEP).
void UseMFEMMassLinearSolver(int tdep)
Attach mass matrix linear system setup, solve, and matrix-vector product methods from the TimeDepende...
bool use_implicit
True for implicit or imex integration.
static int LinSysSetup(realtype t, N_Vector y, N_Vector fy, SUNMatrix A, SUNMatrix M, booleantype jok, booleantype *jcur, realtype gamma, void *user_data, N_Vector tmp1, N_Vector tmp2, N_Vector tmp3)
Setup the linear system .
static int MassSysSolve(SUNLinearSolver LS, SUNMatrix M, N_Vector x, N_Vector b, realtype tol)
Solve the linear system .
void UseSundialsMassLinearSolver(int tdep)
Attach the SUNDIALS GMRES linear solver and the mass matrix matrix-vector product method from the Tim...
virtual ~ARKStepSolver()
Destroy the associated ARKode memory and SUNDIALS objects.
void SetIMEXTableNum(ARKODE_ERKTableID etable_id, ARKODE_DIRKTableID itable_id)
Choose a specific Butcher table for an IMEX RK method.
void UseMFEMLinearSolver()
Attach the linear system setup and solve methods from the TimeDependentOperator i....
static int MassMult1(SUNMatrix M, N_Vector x, N_Vector v)
Compute the matrix-vector product .
static int MassSysSetup(realtype t, SUNMatrix M, void *user_data, N_Vector tmp1, N_Vector tmp2, N_Vector tmp3)
Setup the linear system .
static int RHS1(realtype t, const N_Vector y, N_Vector ydot, void *user_data)
void SetSStolerances(double reltol, double abstol)
Set the scalar relative and scalar absolute tolerances.
void EvalQuadIntegrationB(double t, Vector &dG_dp)
Evaluate Quadrature solution.
void EvalQuadIntegration(double t, Vector &q)
Evaluate Quadrature.
long GetNumSteps()
Get Number of Steps for ForwardSolve.
static int RHSB(realtype t, N_Vector y, N_Vector yB, N_Vector yBdot, void *user_dataB)
Wrapper to compute the ODE RHS backward function.
static constexpr double default_abs_tolB
Default scalar backward absolute tolerance.
static constexpr double default_abs_tolQB
Default scalar backward absolute quadrature tolerance.
static int RHSQB(realtype t, N_Vector y, N_Vector yB, N_Vector qBdot, void *user_dataB)
Wrapper to compute the ODE RHS Backwards Quadrature function.
void SetMaxNStepsB(int mxstepsB)
Set the maximum number of backward steps.
static int LinSysSolveB(SUNLinearSolver LS, SUNMatrix A, N_Vector x, N_Vector b, realtype tol)
Solve the linear system A x = b.
void InitB(TimeDependentAdjointOperator &f_)
Initialize the adjoint problem.
SundialsNVector * q
Quadrature vector.
int indexB
backward problem index
void GetForwardSolution(double tB, mfem::Vector &yy)
Get Interpolated Forward solution y at backward integration time tB.
SUNLinearSolver LSB
Linear solver for A.
virtual void Step(Vector &x, double &t, double &dt)
void SetSVtolerancesB(double reltol, Vector abstol)
Tolerance specification functions for the adjoint problem.
void UseSundialsLinearSolverB()
Use built in SUNDIALS Newton solver.
void SetWFTolerances(EWTFunction func)
Set multiplicative error weights.
SundialsNVector * yy
State vector.
void Init(TimeDependentAdjointOperator &f_)
SUNMatrix AB
Linear system A = I - gamma J, M - gamma J, or J.
int ncheck
number of checkpoints used so far
SundialsNVector * qB
State vector.
static int LinSysSetupB(realtype t, N_Vector y, N_Vector yB, N_Vector fyB, SUNMatrix A, booleantype jok, booleantype *jcur, realtype gamma, void *user_data, N_Vector tmp1, N_Vector tmp2, N_Vector tmp3)
Setup the linear system A x = b.
static int ewt(N_Vector y, N_Vector w, void *user_data)
Error control function.
static int RHSQ(realtype t, const N_Vector y, N_Vector qdot, void *user_data)
Wrapper to compute the ODE RHS Quadrature function.
virtual void StepB(Vector &w, double &t, double &dt)
Solve one adjoint time step.
void InitQuadIntegrationB(mfem::Vector &qB0, double reltolQB=1e-3, double abstolQB=1e-8)
Initialize Quadrature Integration (Adjoint)
static constexpr double default_rel_tolB
Default scalar backward relative tolerance.
void InitAdjointSolve(int steps, int interpolation)
Initialize Adjoint.
SundialsNVector * yB
State vector.
void InitQuadIntegration(mfem::Vector &q0, double reltolQ=1e-3, double abstolQ=1e-8)
virtual ~CVODESSolver()
Destroy the associated CVODES memory and SUNDIALS objects.
void SetSStolerancesB(double reltol, double abstol)
Tolerance specification functions for the adjoint problem.
void UseMFEMLinearSolverB()
Set Linear Solver for the backward problem.
Interface to the CVODE library – linear multi-step methods.
void SetStepMode(int itask)
Select the CVODE step mode: CV_NORMAL (default) or CV_ONE_STEP.
void SetRootFinder(int components, RootFunction func)
Initialize Root Finder.
void SetSStolerances(double reltol, double abstol)
Set the scalar relative and scalar absolute tolerances.
virtual ~CVODESolver()
Destroy the associated CVODE memory and SUNDIALS objects.
void SetMaxNSteps(int steps)
Set the maximum number of time steps.
CVODESolver(int lmm)
Construct a serial wrapper to SUNDIALS' CVODE integrator.
long GetNumSteps()
Get the number of internal steps taken so far.
std::function< int(realtype t, Vector y, Vector gout, CVODESolver *)> RootFunction
Typedef for root finding functions.
static int LinSysSolve(SUNLinearSolver LS, SUNMatrix A, N_Vector x, N_Vector b, realtype tol)
Solve the linear system .
EWTFunction ewt_func
A class member to facilitate pointing to a user-specified error weight function.
void SetMaxStep(double dt_max)
Set the maximum time step.
void Init(TimeDependentOperator &f_)
Initialize CVODE: calls CVodeCreate() to create the CVODE memory and set some defaults.
int lmm_type
Linear multistep method type.
virtual void Step(Vector &x, double &t, double &dt)
Integrate the ODE with CVODE using the specified step mode.
void PrintInfo() const
Print various CVODE statistics.
void UseSundialsLinearSolver()
Attach SUNDIALS GMRES linear solver to CVODE.
void UseMFEMLinearSolver()
Attach the linear system setup and solve methods from the TimeDependentOperator i....
RootFunction root_func
A class member to facilitate pointing to a user-specified root function.
std::function< int(Vector y, Vector w, CVODESolver *)> EWTFunction
Typedef declaration for error weight functions.
int step_mode
CVODE step mode (CV_NORMAL or CV_ONE_STEP).
static int root(realtype t, N_Vector y, realtype *gout, void *user_data)
Prototype to define root finding for CVODE.
void SetSVtolerances(double reltol, Vector abstol)
Set the scalar relative and vector of absolute tolerances.
static int RHS(realtype t, const N_Vector y, N_Vector ydot, void *user_data)
Number of components in gout.
void SetMaxOrder(int max_order)
Set the maximum method order.
static int LinSysSetup(realtype t, N_Vector y, N_Vector fy, SUNMatrix A, booleantype jok, booleantype *jcur, realtype gamma, void *user_data, N_Vector tmp1, N_Vector tmp2, N_Vector tmp3)
Setup the linear system .
static MemoryType GetDeviceMemoryType()
Get the current Device MemoryType. This is the MemoryType used by most MFEM classes when allocating m...
Wrapper for hypre's parallel vector class.
int print_level
(DEPRECATED) Legacy print level definition, which is left for compatibility with custom iterative sol...
Interface to the KINSOL library – nonlinear solver methods.
SundialsNVector * f_scale
scaling vectors
KINSolver(int strategy, bool oper_grad=true)
Construct a serial wrapper to SUNDIALS' KINSOL nonlinear solver.
void SetJFNK(bool use_jfnk)
Set the Jacobian Free Newton Krylov flag. The default is false.
void SetJFNKSolver(Solver &solver)
virtual ~KINSolver()
Destroy the associated KINSOL memory.
virtual void SetPrintLevel(int print_lvl)
Set the print level for the KINSetPrintLevel function.
static int Mult(const N_Vector u, N_Vector fu, void *user_data)
Wrapper to compute the nonlinear residual .
static int GradientMult(N_Vector v, N_Vector Jv, N_Vector u, booleantype *new_u, void *user_data)
Wrapper to compute the Jacobian-vector product .
int global_strategy
KINSOL solution strategy.
virtual void SetOperator(const Operator &op)
Set the nonlinear Operator of the system and initialize KINSOL.
int maa
number of acceleration vectors
static int PrecSolve(N_Vector uu, N_Vector uscale, N_Vector fval, N_Vector fscale, N_Vector vv, void *user_data)
Solve the preconditioner equation .
static int LinSysSolve(SUNLinearSolver LS, SUNMatrix J, N_Vector u, N_Vector b, realtype tol)
Solve the linear system .
int maxlrs
Maximum linear solver restarts.
const Operator * jacobian
stores oper->GetGradient()
int maxli
Maximum linear iterations.
Vector wrk
Work vector needed for the JFNK PC.
virtual void SetPreconditioner(Solver &solver)
Equivalent to SetSolver(solver).
void SetScaledStepTol(double sstol)
Set KINSOL's scaled step tolerance.
virtual void SetSolver(Solver &solver)
Set the linear solver for inverting the Jacobian.
SundialsNVector * y_scale
void SetMaxSetupCalls(int max_calls)
Set maximum number of nonlinear iterations without a Jacobian update.
static int LinSysSetup(N_Vector u, N_Vector fu, SUNMatrix J, void *user_data, N_Vector tmp1, N_Vector tmp2)
Setup the linear system .
void SetMAA(int maa)
Set the number of acceleration vectors to use with KIN_FP or KIN_PICARD.
static int PrecSetup(N_Vector uu, N_Vector uscale, N_Vector fval, N_Vector fscale, void *user_data)
Setup the preconditioner.
void SetLSMaxRestarts(int m)
Set the maximum number of linear solver restarts.
void SetLSMaxIter(int m)
Set the maximum number of linear solver iterations.
bool use_oper_grad
use the Jv prod function
Newton's method for solving F(x)=b for a given operator F.
Abstract class for solving systems of ODEs: dx/dt = f(x,t)
SundialsMemHelper & operator=(const SundialsMemHelper &)=delete
Disable copy assignment.
SundialsMemHelper(const SundialsMemHelper &that_helper)=delete
Disable copy construction.
SundialsMemHelper()=default
Default constructor – object must be moved to.
static int SundialsMemHelper_Alloc(SUNMemoryHelper helper, SUNMemory *memptr, size_t memsize, SUNMemoryType mem_type #if(SUNDIALS_VERSION_MAJOR >=6), void *queue #endif)
static int SundialsMemHelper_Dealloc(SUNMemoryHelper helper, SUNMemory sunmem #if(SUNDIALS_VERSION_MAJOR >=6), void *queue #endif)
SundialsMemHelper(SUNContext context)
Vector interface for SUNDIALS N_Vectors.
N_Vector_ID GetNVectorID(N_Vector x_) const
Returns the N_Vector_ID for the N_Vector x_.
N_Vector StealNVector()
Changes the ownership of the the vector.
long GlobalSize() const
Returns the MPI global length for the internal N_Vector x.
MPI_Comm GetComm() const
Returns the MPI communicator for the internal N_Vector x.
void MakeRef(Vector &base, int offset, int s)
Reset the Vector to be a reference to a sub-vector of base.
void SetSize(int s, long glob_size=0)
Resize the vector to size s.
static bool UseManagedMemory()
static N_Vector MakeNVector(bool use_device)
Create a N_Vector.
int GetOwnership() const
Gets ownership of the internal N_Vector.
~SundialsNVector()
Calls SUNDIALS N_VDestroy function if the N_Vector is owned by 'this'.
void SetDataAndSize(double *d, int s, long glob_size=0)
Set the vector data and size.
void SetOwnership(int own)
Sets ownership of the internal N_Vector.
N_Vector_ID GetNVectorID() const
Returns the N_Vector_ID for the internal N_Vector.
void _SetNvecDataAndSize_(long glob_size=0)
Set data and length of internal N_Vector x from 'this'.
void _SetDataAndSize_()
Set data and length from the internal N_Vector x.
N_Vector x
The actual SUNDIALS object.
void MakeRef(Vector &base, int offset)
Reset the Vector to be a reference to a sub-vector of base without changing its current size.
SundialsNVector()
Creates an empty SundialsNVector.
Base class for interfacing with SUNDIALS packages.
static constexpr double default_abs_tol
Default scalar absolute tolerance.
SUNMatrix M
Mass matrix M.
int flag
Last flag returned from a call to SUNDIALS.
long saved_global_size
Global vector length on last initialization.
SundialsSolver()
Protected constructor: objects of this type should be constructed only as part of a derived class.
static constexpr double default_rel_tol
Default scalar relative tolerance.
void * GetMem() const
Access the SUNDIALS memory structure.
bool reinit
Flag to signal memory reinitialization is need.
void AllocateEmptyNVector(N_Vector &y)
SundialsNVector * Y
State vector.
void * sundials_mem
SUNDIALS mem structure.
void AllocateEmptyNVector(N_Vector &y, MPI_Comm comm)
SUNLinearSolver LSA
Linear solver for A.
SUNLinearSolver LSM
Linear solver for M.
int GetFlag() const
Returns the last flag returned by a call to a SUNDIALS function.
SUNNonlinearSolver NLS
Nonlinear solver.
Singleton class for SUNContext and SundialsMemHelper objects.
Sundials(Sundials &other)=delete
Disable copy construction.
static SUNContext & GetContext()
Provides access to the SUNContext object.
void operator=(const Sundials &other)=delete
Disable copy assignment.
static SundialsMemHelper & GetMemHelper()
Provides access to the SundialsMemHelper object.
Base abstract class for first order time dependent operators.
virtual const real_t * Read(bool on_dev=true) const
Shortcut for mfem::Read(vec.GetMemory(), vec.Size(), on_dev).
void MakeRef(Vector &base, int offset, int size)
Reset the Vector to be a reference to a sub-vector of base.
real_t u(const Vector &xvec)
constexpr ARKODE_ERKTableID ARKODE_FEHLBERG_13_7_8
constexpr ARKODE_ERKTableID ARKODE_ERK_NONE
constexpr ARKODE_DIRKTableID ARKODE_DIRK_NONE