14#ifdef MFEM_USE_SUNDIALS
22#include <nvector/nvector_serial.h>
23#if defined(MFEM_USE_CUDA)
24#include <nvector/nvector_cuda.h>
25#elif defined(MFEM_USE_HIP)
26#include <nvector/nvector_hip.h>
29#include <nvector/nvector_mpiplusx.h>
30#include <nvector/nvector_parallel.h>
34#include <sunlinsol/sunlinsol_spgmr.h>
35#include <sunlinsol/sunlinsol_spfgmr.h>
38#define GET_CONTENT(X) ( X->content )
40#if defined(MFEM_USE_CUDA)
41#define SUN_Hip_OR_Cuda(X) X##_Cuda
42#define SUN_HIP_OR_CUDA(X) X##_CUDA
43#elif defined(MFEM_USE_HIP)
44#define SUN_Hip_OR_Cuda(X) X##_Hip
45#define SUN_HIP_OR_CUDA(X) X##_HIP
50#if (SUNDIALS_VERSION_MAJOR < 6)
98MFEM_DEPRECATED
void*
ARKStepCreate(ARKRhsFn fe, ARKRhsFn fi, realtype t0,
116 sunindextype local_length,
117 sunindextype global_length,
125#if defined(MFEM_USE_CUDA) || defined(MFEM_USE_HIP)
130 booleantype use_managed_mem,
131 SUNMemoryHelper helper,
146#if defined(MFEM_USE_MPI) && (defined(MFEM_USE_CUDA) || defined(MFEM_USE_HIP))
166 Sundials::Instance();
177 return Sundials::Instance().context;
182 return Sundials::Instance().memHelper;
185#if (SUNDIALS_VERSION_MAJOR >= 6)
190 MPI_Comm communicator = MPI_COMM_WORLD;
191 int return_val = SUNContext_Create((
void*) &communicator, &context);
193 int return_val = SUNContext_Create(
nullptr, &context);
195 MFEM_VERIFY(return_val == 0,
"Call to SUNContext_Create failed");
197 memHelper = std::move(actual_helper);
202 SUNContext_Free(&context);
219#if defined(MFEM_USE_CUDA) || defined(MFEM_USE_HIP)
228 h->ops->copy = SUN_Hip_OR_Cuda(SUNMemoryHelper_Copy);
229 h->ops->copyasync = SUN_Hip_OR_Cuda(SUNMemoryHelper_CopyAsync);
234 this->h = that_helper.h;
235 that_helper.h =
nullptr;
246 SUNMemory* memptr,
size_t memsize,
247 SUNMemoryType mem_type
248#
if (SUNDIALS_VERSION_MAJOR >= 6)
253 SUNMemory sunmem = SUNMemoryNewEmpty();
256 sunmem->own = SUNTRUE;
259 if (mem_type == SUNMEMTYPE_HOST)
264 sunmem->type = SUNMEMTYPE_HOST;
267 else if (mem_type == SUNMEMTYPE_DEVICE || mem_type == SUNMEMTYPE_UVM)
272 sunmem->type = mem_type;
287#
if (SUNDIALS_VERSION_MAJOR >= 6)
292 if (sunmem->ptr && sunmem->own && !
mm.
IsKnown(sunmem->ptr))
294 if (sunmem->type == SUNMEMTYPE_HOST)
300 else if (sunmem->type == SUNMEMTYPE_DEVICE || sunmem->type == SUNMEMTYPE_UVM)
308 MFEM_ABORT(
"Invalid SUNMEMTYPE");
326 N_Vector local_x =
MPIPlusX() ? N_VGetLocalVector_MPIPlusX(
x) :
x;
328 N_Vector local_x =
x;
330 N_Vector_ID
id = N_VGetVectorID(local_x);
335 case SUNDIALS_NVEC_SERIAL:
337 MFEM_ASSERT(NV_OWN_DATA_S(local_x) == SUNFALSE,
"invalid serial N_Vector");
339 NV_LENGTH_S(local_x) =
size;
342#if defined(MFEM_USE_CUDA) || defined(MFEM_USE_HIP)
343 case SUN_HIP_OR_CUDA(SUNDIALS_NVEC):
345 SUN_Hip_OR_Cuda(N_VSetHostArrayPointer)(
HostReadWrite(), local_x);
346 SUN_Hip_OR_Cuda(N_VSetDeviceArrayPointer)(
ReadWrite(), local_x);
347 static_cast<SUN_Hip_OR_Cuda(N_VectorContent)
>(GET_CONTENT(
348 local_x))->length =
size;
353 case SUNDIALS_NVEC_PARALLEL:
355 MFEM_ASSERT(NV_OWN_DATA_P(
x) == SUNFALSE,
"invalid parallel N_Vector");
357 NV_LOCLENGTH_P(
x) =
size;
362 if (glob_size == 0 && glob_size !=
size)
364 long local_size =
size;
365 MPI_Allreduce(&local_size, &glob_size, 1, MPI_LONG,
369 NV_GLOBLENGTH_P(
x) = glob_size;
374 MFEM_ABORT(
"N_Vector type " <<
id <<
" is not supported");
384 if (glob_size == 0 && glob_size !=
size)
386 long local_size =
size;
387 MPI_Allreduce(&local_size, &glob_size, 1, MPI_LONG,
391 static_cast<N_VectorContent_MPIManyVector
>(GET_CONTENT(
x))->global_length =
400 N_Vector local_x =
MPIPlusX() ? N_VGetLocalVector_MPIPlusX(
x) :
x;
402 N_Vector local_x =
x;
404 N_Vector_ID
id = N_VGetVectorID(local_x);
409 case SUNDIALS_NVEC_SERIAL:
411 const bool known =
mm.
IsKnown(NV_DATA_S(local_x));
412 size = NV_LENGTH_S(local_x);
417#if defined(MFEM_USE_CUDA) || defined(MFEM_USE_HIP)
418 case SUN_HIP_OR_CUDA(SUNDIALS_NVEC):
420 double *h_ptr = SUN_Hip_OR_Cuda(N_VGetHostArrayPointer)(local_x);
421 double *d_ptr = SUN_Hip_OR_Cuda(N_VGetDeviceArrayPointer)(local_x);
423 size = SUN_Hip_OR_Cuda(N_VGetLength)(local_x);
431 case SUNDIALS_NVEC_PARALLEL:
434 size = NV_LENGTH_S(
x);
435 data.
Wrap(NV_DATA_P(
x), NV_LOCLENGTH_P(
x),
false);
441 MFEM_ABORT(
"N_Vector type " <<
id <<
" is not supported");
500 :
SundialsNVector(vec.GetComm(), vec.GetData(), vec.Size(), vec.GlobalSize())
511 N_VDestroy(N_VGetLocalVector_MPIPlusX(
x));
539#if defined(MFEM_USE_CUDA) || defined(MFEM_USE_HIP)
554 MFEM_VERIFY(
x,
"Error in SundialsNVector::MakeNVector.");
564 if (comm == MPI_COMM_NULL)
570#if defined(MFEM_USE_CUDA) || defined(MFEM_USE_HIP)
588 MFEM_VERIFY(
x,
"Error in SundialsNVector::MakeNVector.");
600static SUNMatrix_ID MatGetID(SUNMatrix)
602 return (SUNMATRIX_CUSTOM);
605static void MatDestroy(SUNMatrix A)
607 if (A->content) { A->content = NULL; }
608 if (A->ops) { free(A->ops); A->ops = NULL; }
618static SUNLinearSolver_Type LSGetType(SUNLinearSolver)
620 return (SUNLINEARSOLVER_MATRIX_ITERATIVE);
623static int LSFree(SUNLinearSolver LS)
625 if (LS->content) { LS->content = NULL; }
626 if (LS->ops) { free(LS->ops); LS->ops = NULL; }
645 self->
f->
Mult(mfem_y, mfem_ydot);
655 if (!self->
root_func) {
return CV_RTFUNC_FAIL; }
660 return self->
root_func(
t, mfem_y, mfem_gout, self);
668 MFEM_VERIFY(
flag == CV_SUCCESS,
"error in SetRootFinder()");
672 booleantype jok, booleantype *jcur, realtype gamma,
673 void*, N_Vector, N_Vector, N_Vector)
686 N_Vector
b, realtype tol)
696 : lmm_type(lmm), step_mode(CV_NORMAL)
703 : lmm_type(lmm), step_mode(CV_NORMAL)
715 long local_size = f_.
Height();
718 long global_size = 0;
721 MPI_Allreduce(&local_size, &global_size, 1, MPI_LONG, MPI_SUM,
735 resize = (
Y->
Size() != local_size);
740 int l_resize = (
Y->
Size() != local_size) ||
742 MPI_Allreduce(&l_resize, &resize, 1, MPI_INT, MPI_LOR,
768 Y->
SetSize(local_size, global_size);
779 MFEM_VERIFY(
flag == CV_SUCCESS,
"error in CVodeInit()");
783 MFEM_VERIFY(
flag == CV_SUCCESS,
"error in CVodeSetUserData()");
787 MFEM_VERIFY(
flag == CV_SUCCESS,
"error in CVodeSetSStolerances()");
800 MFEM_VERIFY(
Y->
Size() == x.
Size(),
"size mismatch");
806 MFEM_VERIFY(
flag == CV_SUCCESS,
"error in CVodeReInit()");
812 double tout =
t + dt;
814 MFEM_VERIFY(
flag >= 0,
"error in CVode()");
821 MFEM_VERIFY(
flag == CV_SUCCESS,
"error in CVodeGetLastStep()");
827 if (
A != NULL) { SUNMatDestroy(
A);
A = NULL; }
828 if (
LSA != NULL) { SUNLinSolFree(
LSA);
LSA = NULL; }
832 MFEM_VERIFY(
LSA,
"error in SUNLinSolNewEmpty()");
835 LSA->ops->gettype = LSGetType;
837 LSA->ops->free = LSFree;
840 MFEM_VERIFY(
A,
"error in SUNMatNewEmpty()");
843 A->ops->getid = MatGetID;
844 A->ops->destroy = MatDestroy;
848 MFEM_VERIFY(
flag == CV_SUCCESS,
"error in CVodeSetLinearSolver()");
852 MFEM_VERIFY(
flag == CV_SUCCESS,
"error in CVodeSetLinSysFn()");
858 if (
A != NULL) { SUNMatDestroy(
A);
A = NULL; }
859 if (
LSA != NULL) { SUNLinSolFree(
LSA);
LSA = NULL; }
863 MFEM_VERIFY(
LSA,
"error in SUNLinSol_SPGMR()");
867 MFEM_VERIFY(
flag == CV_SUCCESS,
"error in CVodeSetLinearSolver()");
878 MFEM_VERIFY(
flag == CV_SUCCESS,
"error in CVodeSStolerances()");
884 "abs tolerance is not the same size.");
890 MFEM_VERIFY(
flag == CV_SUCCESS,
"error in CVodeSVtolerances()");
896 MFEM_VERIFY(
flag == CV_SUCCESS,
"error in CVodeSetMaxStep()");
902 MFEM_VERIFY(
flag == CV_SUCCESS,
"error in CVodeSetMaxNumSteps()");
909 MFEM_VERIFY(
flag == CV_SUCCESS,
"error in CVodeGetNumSteps()");
916 MFEM_VERIFY(
flag == CV_SUCCESS,
"error in CVodeSetMaxOrd()");
921 long int nsteps, nfevals, nlinsetups, netfails;
923 double hinused, hlast, hcur, tcur;
924 long int nniters, nncfails;
938 MFEM_VERIFY(
flag == CV_SUCCESS,
"error in CVodeGetIntegratorStats()");
944 MFEM_VERIFY(
flag == CV_SUCCESS,
"error in CVodeGetNonlinSolvStats()");
948 "num steps: " << nsteps <<
"\n"
949 "num rhs evals: " << nfevals <<
"\n"
950 "num lin setups: " << nlinsetups <<
"\n"
951 "num nonlin sol iters: " << nniters <<
"\n"
952 "num nonlin conv fail: " << nncfails <<
"\n"
953 "num error test fails: " << netfails <<
"\n"
954 "last order: " << qlast <<
"\n"
955 "current order: " << qcur <<
"\n"
956 "initial dt: " << hinused <<
"\n"
957 "last dt: " << hlast <<
"\n"
958 "current dt: " << hcur <<
"\n"
959 "current t: " << tcur <<
"\n" << endl;
969 SUNNonlinSolFree(
NLS);
1007 MFEM_VERIFY(
t <= f->GetTime(),
"t > current forward solver time");
1010 MFEM_VERIFY(
flag == CV_SUCCESS,
"error in CVodeGetQuad()");
1017 MFEM_VERIFY(
t <= f->GetTime(),
"t > current forward solver time");
1020 MFEM_VERIFY(
flag == CV_SUCCESS,
"error in CVodeGetQuadB()");
1022 dG_dp.
Set(-1., *
qB);
1030 MFEM_VERIFY(
flag >= 0,
"error in CVodeGetAdjY()");
1050 MFEM_VERIFY(
flag == CV_SUCCESS,
"error in CVodeCreateB()");
1054 MFEM_VERIFY(
flag == CV_SUCCESS,
"error in CVodeInit()");
1058 MFEM_VERIFY(
flag == CV_SUCCESS,
"error in CVodeSetUserDataB()");
1063 MFEM_VERIFY(
flag == CV_SUCCESS,
"error in CVodeSetSStolerancesB()");
1075 MFEM_VERIFY(
flag == CV_SUCCESS,
"Error in CVodeAdjInit");
1081 MFEM_VERIFY(
flag == CV_SUCCESS,
"Error in CVodeSetMaxNumStepsB()");
1090 MFEM_VERIFY(
flag == CV_SUCCESS,
"Error in CVodeQuadInit()");
1093 MFEM_VERIFY(
flag == CV_SUCCESS,
"Error in CVodeSetQuadErrCon");
1096 MFEM_VERIFY(
flag == CV_SUCCESS,
"Error in CVodeQuadSStolerances");
1105 MFEM_VERIFY(
flag == CV_SUCCESS,
"Error in CVodeQuadInitB()");
1108 MFEM_VERIFY(
flag == CV_SUCCESS,
"Error in CVodeSetQuadErrConB");
1111 MFEM_VERIFY(
flag == CV_SUCCESS,
"Error in CVodeQuadSStolerancesB");
1117 if (
AB != NULL) { SUNMatDestroy(
AB);
AB = NULL; }
1118 if (
LSB != NULL) { SUNLinSolFree(
LSB);
LSB = NULL; }
1122 MFEM_VERIFY(
LSB,
"error in SUNLinSolNewEmpty()");
1124 LSB->content =
this;
1125 LSB->ops->gettype = LSGetType;
1127 LSB->ops->free = LSFree;
1130 MFEM_VERIFY(
AB,
"error in SUNMatNewEmpty()");
1133 AB->ops->getid = MatGetID;
1134 AB->ops->destroy = MatDestroy;
1138 MFEM_VERIFY(
flag == CV_SUCCESS,
"error in CVodeSetLinearSolverB()");
1143 MFEM_VERIFY(
flag == CV_SUCCESS,
"error in CVodeSetLinSysFn()");
1149 if (
AB != NULL) { SUNMatDestroy(
AB);
AB = NULL; }
1150 if (
LSB != NULL) { SUNLinSolFree(
LSB);
LSB = NULL; }
1154 MFEM_VERIFY(
LSB,
"error in SUNLinSol_SPGMR()");
1158 MFEM_VERIFY(
flag == CV_SUCCESS,
"error in CVodeSetLinearSolverB()");
1162 N_Vector fyB, SUNMatrix AB,
1163 booleantype jokB, booleantype *jcurB,
1164 realtype gammaB,
void *user_data, N_Vector tmp1,
1165 N_Vector tmp2, N_Vector tmp3)
1176 return (
f->SUNImplicitSetupB(
t, mfem_y, mfem_yB, mfem_fyB, jokB, jcurB,
1181 N_Vector Rb, realtype tol)
1189 int ret =
f->SUNImplicitSolveB(mfem_yB, mfem_Rb, tol);
1196 MFEM_VERIFY(
flag == CV_SUCCESS,
"error in CVodeSStolerancesB()");
1202 "abs tolerance is not the same size.");
1208 MFEM_VERIFY(
flag == CV_SUCCESS,
"error in CVodeSVtolerancesB()");
1228 f->QuadratureIntegration(mfem_y, mfem_qdot);
1242 f->QuadratureSensitivityMult(mfem_y, mfem_yB, mfem_qBdot);
1258 f->AdjointRateMult(mfem_y, mfem_yB, mfem_yBdot);
1269 return self->
ewt_func(mfem_y, mfem_w, self);
1276 MFEM_VERIFY(
Y->
Size() == x.
Size(),
"size mismatch");
1282 MFEM_VERIFY(
flag == CV_SUCCESS,
"error in CVodeReInit()");
1289 double tout =
t + dt;
1291 MFEM_VERIFY(
flag >= 0,
"error in CVodeF()");
1298 MFEM_VERIFY(
flag == CV_SUCCESS,
"error in CVodeGetLastStep()");
1310 MFEM_VERIFY(
flag == CV_SUCCESS,
"error in CVodeReInit()");
1317 double tout = tB - dtB;
1319 MFEM_VERIFY(
flag >= 0,
"error in CVodeB()");
1323 MFEM_VERIFY(
flag >= 0,
"error in CVodeGetB()");
1358 self->
f->
Mult(mfem_y, mfem_ydot);
1375 self->
f->
Mult(mfem_y, mfem_ydot);
1382 SUNMatrix, booleantype jok, booleantype *jcur,
1384 void*, N_Vector, N_Vector, N_Vector)
1401 N_Vector
b, realtype tol)
1416 void*, N_Vector, N_Vector, N_Vector)
1426 N_Vector
b, realtype tol)
1459 : rk_type(type), step_mode(ARK_NORMAL),
1460 use_implicit(type == IMPLICIT || type == IMEX)
1467 : rk_type(type), step_mode(ARK_NORMAL),
1468 use_implicit(type == IMPLICIT || type == IMEX)
1480 long local_size = f_.
Height();
1488 MPI_Allreduce(&local_size, &global_size, 1, MPI_LONG, MPI_SUM,
1502 resize = (
Y->
Size() != local_size);
1507 int l_resize = (
Y->
Size() != local_size) ||
1509 MPI_Allreduce(&l_resize, &resize, 1, MPI_INT, MPI_LOR,
1531 Y->
SetSize(local_size, global_size);
1556 MFEM_VERIFY(
flag == ARK_SUCCESS,
"error in ARKStepSetUserData()");
1560 MFEM_VERIFY(
flag == ARK_SUCCESS,
"error in ARKStepSetSStolerances()");
1573 MFEM_VERIFY(
Y->
Size() == x.
Size(),
"size mismatch");
1591 MFEM_VERIFY(
flag == ARK_SUCCESS,
"error in ARKStepReInit()");
1598 double tout =
t + dt;
1600 MFEM_VERIFY(
flag >= 0,
"error in ARKStepEvolve()");
1607 MFEM_VERIFY(
flag == ARK_SUCCESS,
"error in ARKStepGetLastStep()");
1613 if (
A != NULL) { SUNMatDestroy(
A);
A = NULL; }
1614 if (
LSA != NULL) { SUNLinSolFree(
LSA);
LSA = NULL; }
1618 MFEM_VERIFY(
LSA,
"error in SUNLinSolNewEmpty()");
1620 LSA->content =
this;
1621 LSA->ops->gettype = LSGetType;
1623 LSA->ops->free = LSFree;
1626 MFEM_VERIFY(
A,
"error in SUNMatNewEmpty()");
1629 A->ops->getid = MatGetID;
1630 A->ops->destroy = MatDestroy;
1634 MFEM_VERIFY(
flag == ARK_SUCCESS,
"error in ARKStepSetLinearSolver()");
1638 MFEM_VERIFY(
flag == ARK_SUCCESS,
"error in ARKStepSetLinSysFn()");
1644 if (
A != NULL) { SUNMatDestroy(
A);
A = NULL; }
1645 if (
LSA != NULL) { SUNLinSolFree(
LSA);
LSA = NULL; }
1649 MFEM_VERIFY(
LSA,
"error in SUNLinSol_SPGMR()");
1653 MFEM_VERIFY(
flag == ARK_SUCCESS,
"error in ARKStepSetLinearSolver()");
1659 if (
M != NULL) { SUNMatDestroy(
M);
M = NULL; }
1660 if (
LSM != NULL) { SUNLinSolFree(
LSM);
LSM = NULL; }
1664 MFEM_VERIFY(
LSM,
"error in SUNLinSolNewEmpty()");
1666 LSM->content =
this;
1667 LSM->ops->gettype = LSGetType;
1669 LSA->ops->free = LSFree;
1672 MFEM_VERIFY(
M,
"error in SUNMatNewEmpty()");
1675 M->ops->getid = MatGetID;
1677 M->ops->destroy = MatDestroy;
1681 MFEM_VERIFY(
flag == ARK_SUCCESS,
"error in ARKStepSetLinearSolver()");
1685 MFEM_VERIFY(
flag == ARK_SUCCESS,
"error in ARKStepSetMassFn()");
1691 if (
M != NULL) { SUNMatDestroy(
A);
M = NULL; }
1692 if (
LSM != NULL) { SUNLinSolFree(
LSM);
LSM = NULL; }
1696 MFEM_VERIFY(
LSM,
"error in SUNLinSol_SPGMR()");
1700 MFEM_VERIFY(
flag == ARK_SUCCESS,
"error in ARKStepSetMassLinearSolver()");
1705 MFEM_VERIFY(
flag == ARK_SUCCESS,
"error in ARKStepSetMassTimes()");
1716 MFEM_VERIFY(
flag == ARK_SUCCESS,
"error in ARKStepSStolerances()");
1722 MFEM_VERIFY(
flag == ARK_SUCCESS,
"error in ARKStepSetMaxStep()");
1728 MFEM_VERIFY(
flag == ARK_SUCCESS,
"error in ARKStepSetOrder()");
1734 MFEM_VERIFY(
flag == ARK_SUCCESS,
"error in ARKStepSetTableNum()");
1740 MFEM_VERIFY(
flag == ARK_SUCCESS,
"error in ARKStepSetTableNum()");
1747 MFEM_VERIFY(
flag == ARK_SUCCESS,
"error in ARKStepSetTableNum()");
1753 MFEM_VERIFY(
flag == ARK_SUCCESS,
"error in ARKStepSetFixedStep()");
1758 long int nsteps, expsteps, accsteps, step_attempts;
1759 long int nfe_evals, nfi_evals;
1760 long int nlinsetups, netfails;
1761 double hinused, hlast, hcur, tcur;
1762 long int nniters, nncfails;
1773 MFEM_VERIFY(
flag == ARK_SUCCESS,
"error in ARKStepGetTimestepperStats()");
1786 MFEM_VERIFY(
flag == ARK_SUCCESS,
"error in ARKStepGetNonlinSolvStats()");
1790 "num steps: " << nsteps <<
"\n"
1791 "num exp rhs evals: " << nfe_evals <<
"\n"
1792 "num imp rhs evals: " << nfi_evals <<
"\n"
1793 "num lin setups: " << nlinsetups <<
"\n"
1794 "num nonlin sol iters: " << nniters <<
"\n"
1795 "num nonlin conv fail: " << nncfails <<
"\n"
1796 "num steps attempted: " << step_attempts <<
"\n"
1797 "num acc limited steps: " << accsteps <<
"\n"
1798 "num exp limited stepfails: " << expsteps <<
"\n"
1799 "num error test fails: " << netfails <<
"\n"
1800 "initial dt: " << hinused <<
"\n"
1801 "last dt: " << hlast <<
"\n"
1802 "current dt: " << hcur <<
"\n"
1803 "current t: " << tcur <<
"\n" << endl;
1813 SUNNonlinSolFree(
NLS);
1837 booleantype *new_u,
void *user_data)
1860 void *, N_Vector, N_Vector )
1877 N_Vector
b, realtype)
1928 : global_strategy(strategy), use_oper_grad(oper_grad), y_scale(NULL),
1929 f_scale(NULL), jacobian(NULL), maa(0)
1936 abs_tol = pow(UNIT_ROUNDOFF, 1.0/3.0);
1942 : global_strategy(strategy), use_oper_grad(oper_grad), y_scale(NULL),
1943 f_scale(NULL), jacobian(NULL), maa(0)
1950 abs_tol = pow(UNIT_ROUNDOFF, 1.0/3.0);
1963 long local_size =
height;
1971 MPI_Allreduce(&local_size, &global_size, 1, MPI_LONG, MPI_SUM,
1982 resize = (
Y->
Size() != local_size);
1987 int l_resize = (
Y->
Size() != local_size) ||
1989 MPI_Allreduce(&l_resize, &resize, 1, MPI_INT, MPI_LOR,
2011 Y->
SetSize(local_size, global_size);
2026 MFEM_ASSERT(
flag == KIN_SUCCESS,
"error in KINSetMAA()");
2031 MFEM_VERIFY(
flag == KIN_SUCCESS,
"error in KINInit()");
2035 MFEM_ASSERT(
flag == KIN_SUCCESS,
"error in KINSetUserData()");
2045 if (
A != NULL) { SUNMatDestroy(
A);
A = NULL; }
2046 if (
LSA != NULL) { SUNLinSolFree(
LSA);
LSA = NULL; }
2049 MFEM_VERIFY(
LSA,
"error in SUNLinSol_SPGMR()");
2052 MFEM_ASSERT(
flag == KIN_SUCCESS,
"error in KINSetLinearSolver()");
2058 MFEM_ASSERT(
flag == KIN_SUCCESS,
"error in KINSetJacTimesVecFn()");
2076 if (
A != NULL) { SUNMatDestroy(
A);
A = NULL; }
2077 if (
LSA != NULL) { SUNLinSolFree(
LSA);
LSA = NULL; }
2081 MFEM_VERIFY(
LSA,
"error in SUNLinSolNewEmpty()");
2083 LSA->content =
this;
2084 LSA->ops->gettype = LSGetType;
2086 LSA->ops->free = LSFree;
2089 MFEM_VERIFY(
A,
"error in SUNMatNewEmpty()");
2092 A->ops->getid = MatGetID;
2093 A->ops->destroy = MatDestroy;
2097 MFEM_VERIFY(
flag == KIN_SUCCESS,
"error in KINSetLinearSolver()");
2101 MFEM_VERIFY(
flag == KIN_SUCCESS,
"error in KINSetJacFn()");
2113 if (
A != NULL) { SUNMatDestroy(
A);
A = NULL; }
2114 if (
LSA != NULL) { SUNLinSolFree(
LSA);
LSA = NULL; }
2119 MFEM_VERIFY(
LSA,
"error in SUNLinSol_SPFGMR()");
2122 MFEM_VERIFY(
flag == SUNLS_SUCCESS,
"error in SUNLinSol_SPFGMR()");
2125 MFEM_VERIFY(
flag == KIN_SUCCESS,
"error in KINSetLinearSolver()");
2132 MFEM_VERIFY(
flag == KIN_SUCCESS,
"error in KINSetPreconditioner()");
2139 MFEM_ASSERT(
flag == KIN_SUCCESS,
"error in KINSetScaledStepTol()");
2145 MFEM_ASSERT(
flag == KIN_SUCCESS,
"error in KINSetMaxSetupCalls()");
2156 MFEM_ASSERT(
flag == KIN_SUCCESS,
"error in KINSetMAA()");
2162 MFEM_ABORT(
"this method is not supported! Use SetPrintLevel(int) instead.");
2187 double lnorm = norm;
2211 MFEM_ASSERT(
flag == KIN_SUCCESS,
"error in KINSetFuncNormTol()");
2222 MFEM_ASSERT(
flag == KIN_SUCCESS,
"KINSetNumMaxIters() failed!");
2236 MPI_Comm_rank(
Y->
GetComm(), &rank);
2243 MFEM_VERIFY(
flag == KIN_SUCCESS,
"KINSetPrintLevel() failed!");
2245#ifdef SUNDIALS_BUILD_WITH_MONITORING
2248 flag = SUNLinSolSetInfoFile_SPFGMR(
LSA, stdout);
2249 MFEM_VERIFY(
flag == SUNLS_SUCCESS,
2250 "error in SUNLinSolSetInfoFile_SPFGMR()");
2252 flag = SUNLinSolSetPrintLevel_SPFGMR(
LSA, 1);
2253 MFEM_VERIFY(
flag == SUNLS_SUCCESS,
2254 "error in SUNLinSolSetPrintLevel_SPFGMR()");
2271 MFEM_ASSERT(
flag == KIN_SUCCESS,
"error in KINGetNumNonlinSolvIters()");
2276 MFEM_ASSERT(
flag == KIN_SUCCESS,
"error in KINGetFuncNorm()");
Interface to ARKode's ARKStep module – additive Runge-Kutta methods.
void SetMaxStep(double dt_max)
Set the maximum time step.
void PrintInfo() const
Print various ARKStep statistics.
Type rk_type
Runge-Kutta type.
void Init(TimeDependentOperator &f_)
Initialize ARKode: calls ARKStepCreate() to create the ARKStep memory and set some defaults.
ARKStepSolver(Type type=EXPLICIT)
Construct a serial wrapper to SUNDIALS' ARKode integrator.
void SetOrder(int order)
Chooses integration order for all explicit / implicit / IMEX methods.
static int MassMult2(N_Vector x, N_Vector v, realtype t, void *mtimes_data)
Compute the matrix-vector product at time t.
void SetStepMode(int itask)
Select the ARKode step mode: ARK_NORMAL (default) or ARK_ONE_STEP.
Type
Types of ARKODE solvers.
@ IMPLICIT
Implicit RK method.
@ IMEX
Implicit-explicit ARK method.
@ EXPLICIT
Explicit RK method.
void UseSundialsLinearSolver()
Attach a SUNDIALS GMRES linear solver to ARKode.
static int RHS2(realtype t, const N_Vector y, N_Vector ydot, void *user_data)
virtual void Step(Vector &x, double &t, double &dt)
Integrate the ODE with ARKode using the specified step mode.
void SetIRKTableNum(ARKODE_DIRKTableID table_id)
Choose a specific Butcher table for a diagonally implicit RK method.
void SetFixedStep(double dt)
Use a fixed time step size (disable temporal adaptivity).
static int LinSysSolve(SUNLinearSolver LS, SUNMatrix A, N_Vector x, N_Vector b, realtype tol)
Solve the linear system .
void SetERKTableNum(ARKODE_ERKTableID table_id)
Choose a specific Butcher table for an explicit RK method.
int step_mode
ARKStep step mode (ARK_NORMAL or ARK_ONE_STEP).
void UseMFEMMassLinearSolver(int tdep)
Attach mass matrix linear system setup, solve, and matrix-vector product methods from the TimeDepende...
bool use_implicit
True for implicit or imex integration.
static int LinSysSetup(realtype t, N_Vector y, N_Vector fy, SUNMatrix A, SUNMatrix M, booleantype jok, booleantype *jcur, realtype gamma, void *user_data, N_Vector tmp1, N_Vector tmp2, N_Vector tmp3)
Setup the linear system .
static int MassSysSolve(SUNLinearSolver LS, SUNMatrix M, N_Vector x, N_Vector b, realtype tol)
Solve the linear system .
void UseSundialsMassLinearSolver(int tdep)
Attach the SUNDIALS GMRES linear solver and the mass matrix matrix-vector product method from the Tim...
virtual ~ARKStepSolver()
Destroy the associated ARKode memory and SUNDIALS objects.
void SetIMEXTableNum(ARKODE_ERKTableID etable_id, ARKODE_DIRKTableID itable_id)
Choose a specific Butcher table for an IMEX RK method.
void UseMFEMLinearSolver()
Attach the linear system setup and solve methods from the TimeDependentOperator i....
static int MassMult1(SUNMatrix M, N_Vector x, N_Vector v)
Compute the matrix-vector product .
static int MassSysSetup(realtype t, SUNMatrix M, void *user_data, N_Vector tmp1, N_Vector tmp2, N_Vector tmp3)
Setup the linear system .
static int RHS1(realtype t, const N_Vector y, N_Vector ydot, void *user_data)
void SetSStolerances(double reltol, double abstol)
Set the scalar relative and scalar absolute tolerances.
void EvalQuadIntegrationB(double t, Vector &dG_dp)
Evaluate Quadrature solution.
void EvalQuadIntegration(double t, Vector &q)
Evaluate Quadrature.
static int RHSB(realtype t, N_Vector y, N_Vector yB, N_Vector yBdot, void *user_dataB)
Wrapper to compute the ODE RHS backward function.
static constexpr double default_abs_tolB
Default scalar backward absolute tolerance.
static int RHSQB(realtype t, N_Vector y, N_Vector yB, N_Vector qBdot, void *user_dataB)
Wrapper to compute the ODE RHS Backwards Quadrature function.
void SetMaxNStepsB(int mxstepsB)
Set the maximum number of backward steps.
static int LinSysSolveB(SUNLinearSolver LS, SUNMatrix A, N_Vector x, N_Vector b, realtype tol)
Solve the linear system A x = b.
void InitB(TimeDependentAdjointOperator &f_)
Initialize the adjoint problem.
SundialsNVector * q
Quadrature vector.
int indexB
backward problem index
void GetForwardSolution(double tB, mfem::Vector &yy)
Get Interpolated Forward solution y at backward integration time tB.
SUNLinearSolver LSB
Linear solver for A.
virtual void Step(Vector &x, double &t, double &dt)
void SetSVtolerancesB(double reltol, Vector abstol)
Tolerance specification functions for the adjoint problem.
void UseSundialsLinearSolverB()
Use built in SUNDIALS Newton solver.
void SetWFTolerances(EWTFunction func)
Set multiplicative error weights.
SundialsNVector * yy
State vector.
void Init(TimeDependentAdjointOperator &f_)
SUNMatrix AB
Linear system A = I - gamma J, M - gamma J, or J.
int ncheck
number of checkpoints used so far
SundialsNVector * qB
State vector.
static int LinSysSetupB(realtype t, N_Vector y, N_Vector yB, N_Vector fyB, SUNMatrix A, booleantype jok, booleantype *jcur, realtype gamma, void *user_data, N_Vector tmp1, N_Vector tmp2, N_Vector tmp3)
Setup the linear system A x = b.
static int ewt(N_Vector y, N_Vector w, void *user_data)
Error control function.
static int RHSQ(realtype t, const N_Vector y, N_Vector qdot, void *user_data)
Wrapper to compute the ODE RHS Quadrature function.
virtual void StepB(Vector &w, double &t, double &dt)
Solve one adjoint time step.
void InitQuadIntegrationB(mfem::Vector &qB0, double reltolQB=1e-3, double abstolQB=1e-8)
Initialize Quadrature Integration (Adjoint)
static constexpr double default_rel_tolB
Default scalar backward relative tolerance.
void InitAdjointSolve(int steps, int interpolation)
Initialize Adjoint.
SundialsNVector * yB
State vector.
void InitQuadIntegration(mfem::Vector &q0, double reltolQ=1e-3, double abstolQ=1e-8)
virtual ~CVODESSolver()
Destroy the associated CVODES memory and SUNDIALS objects.
void SetSStolerancesB(double reltol, double abstol)
Tolerance specification functions for the adjoint problem.
void UseMFEMLinearSolverB()
Set Linear Solver for the backward problem.
Interface to the CVODE library – linear multi-step methods.
void SetStepMode(int itask)
Select the CVODE step mode: CV_NORMAL (default) or CV_ONE_STEP.
void SetRootFinder(int components, RootFunction func)
Initialize Root Finder.
void SetSStolerances(double reltol, double abstol)
Set the scalar relative and scalar absolute tolerances.
virtual ~CVODESolver()
Destroy the associated CVODE memory and SUNDIALS objects.
void SetMaxNSteps(int steps)
Set the maximum number of time steps.
CVODESolver(int lmm)
Construct a serial wrapper to SUNDIALS' CVODE integrator.
long GetNumSteps()
Get the number of internal steps taken so far.
std::function< int(realtype t, Vector y, Vector gout, CVODESolver *)> RootFunction
Typedef for root finding functions.
static int LinSysSolve(SUNLinearSolver LS, SUNMatrix A, N_Vector x, N_Vector b, realtype tol)
Solve the linear system .
EWTFunction ewt_func
A class member to facilitate pointing to a user-specified error weight function.
void SetMaxStep(double dt_max)
Set the maximum time step.
void Init(TimeDependentOperator &f_)
Initialize CVODE: calls CVodeCreate() to create the CVODE memory and set some defaults.
int lmm_type
Linear multistep method type.
virtual void Step(Vector &x, double &t, double &dt)
Integrate the ODE with CVODE using the specified step mode.
void PrintInfo() const
Print various CVODE statistics.
void UseSundialsLinearSolver()
Attach SUNDIALS GMRES linear solver to CVODE.
void UseMFEMLinearSolver()
Attach the linear system setup and solve methods from the TimeDependentOperator i....
RootFunction root_func
A class member to facilitate pointing to a user-specified root function.
std::function< int(Vector y, Vector w, CVODESolver *)> EWTFunction
Typedef declaration for error weight functions.
int step_mode
CVODE step mode (CV_NORMAL or CV_ONE_STEP).
static int root(realtype t, N_Vector y, realtype *gout, void *user_data)
Prototype to define root finding for CVODE.
void SetSVtolerances(double reltol, Vector abstol)
Set the scalar relative and vector of absolute tolerances.
static int RHS(realtype t, const N_Vector y, N_Vector ydot, void *user_data)
Number of components in gout.
void SetMaxOrder(int max_order)
Set the maximum method order.
static int LinSysSetup(realtype t, N_Vector y, N_Vector fy, SUNMatrix A, booleantype jok, booleantype *jcur, realtype gamma, void *user_data, N_Vector tmp1, N_Vector tmp2, N_Vector tmp3)
Setup the linear system .
static MemoryType GetHostMemoryType()
Get the current Host MemoryType. This is the MemoryType used by most MFEM classes when allocating mem...
static bool IsAvailable()
Return true if an actual device (e.g. GPU) has been configured.
static MemoryType GetDeviceMemoryType()
Get the current Device MemoryType. This is the MemoryType used by most MFEM classes when allocating m...
Wrapper for hypre's parallel vector class.
real_t abs_tol
Absolute tolerance.
real_t rel_tol
Relative tolerance.
int print_level
(DEPRECATED) Legacy print level definition, which is left for compatibility with custom iterative sol...
int max_iter
Limit for the number of iterations the solver is allowed to do.
Interface to the KINSOL library – nonlinear solver methods.
SundialsNVector * f_scale
scaling vectors
KINSolver(int strategy, bool oper_grad=true)
Construct a serial wrapper to SUNDIALS' KINSOL nonlinear solver.
void SetJFNKSolver(Solver &solver)
virtual ~KINSolver()
Destroy the associated KINSOL memory.
virtual void SetPrintLevel(int print_lvl)
Set the print level for the KINSetPrintLevel function.
static int Mult(const N_Vector u, N_Vector fu, void *user_data)
Wrapper to compute the nonlinear residual .
static int GradientMult(N_Vector v, N_Vector Jv, N_Vector u, booleantype *new_u, void *user_data)
Wrapper to compute the Jacobian-vector product .
int global_strategy
KINSOL solution strategy.
virtual void SetOperator(const Operator &op)
Set the nonlinear Operator of the system and initialize KINSOL.
int maa
number of acceleration vectors
static int PrecSolve(N_Vector uu, N_Vector uscale, N_Vector fval, N_Vector fscale, N_Vector vv, void *user_data)
Solve the preconditioner equation .
static int LinSysSolve(SUNLinearSolver LS, SUNMatrix J, N_Vector u, N_Vector b, realtype tol)
Solve the linear system .
int maxlrs
Maximum linear solver restarts.
const Operator * jacobian
stores oper->GetGradient()
int maxli
Maximum linear iterations.
Vector wrk
Work vector needed for the JFNK PC.
void SetScaledStepTol(double sstol)
Set KINSOL's scaled step tolerance.
virtual void SetSolver(Solver &solver)
Set the linear solver for inverting the Jacobian.
SundialsNVector * y_scale
void SetMaxSetupCalls(int max_calls)
Set maximum number of nonlinear iterations without a Jacobian update.
static int LinSysSetup(N_Vector u, N_Vector fu, SUNMatrix J, void *user_data, N_Vector tmp1, N_Vector tmp2)
Setup the linear system .
void SetMAA(int maa)
Set the number of acceleration vectors to use with KIN_FP or KIN_PICARD.
static int PrecSetup(N_Vector uu, N_Vector uscale, N_Vector fval, N_Vector fscale, void *user_data)
Setup the preconditioner.
bool use_oper_grad
use the Jv prod function
bool IsKnown(const void *h_ptr)
Return true if the pointer is known by the memory manager.
Class used by MFEM to store pointers to host and/or device memory.
void SetHostPtrOwner(bool own) const
Set/clear the ownership flag for the host pointer. Ownership indicates whether the pointer will be de...
void SetDevicePtrOwner(bool own) const
Set/clear the ownership flag for the device pointer. Ownership indicates whether the pointer will be ...
void Wrap(T *ptr, int size, bool own)
Wrap an externally allocated host pointer, ptr with the current host memory type returned by MemoryMa...
void Delete()
Delete the owned pointers and reset the Memory object.
void ClearOwnerFlags() const
Clear the ownership flags for the host and device pointers, as well as any internal data allocated by...
virtual void SetOperator(const Operator &op)
Also calls SetOperator for the preconditioner if there is one.
TimeDependentOperator * f
Pointer to the associated TimeDependentOperator.
virtual void Init(TimeDependentOperator &f_)
Associate a TimeDependentOperator with the ODE solver.
int Height() const
Get the height (size of output) of the Operator. Synonym with NumRows().
int height
Dimension of the output / number of rows in the matrix.
virtual void Mult(const Vector &x, Vector &y) const =0
Operator application: y=A(x).
virtual Operator & GetGradient(const Vector &x) const
Evaluate the gradient operator at the point x. The default behavior in class Operator is to generate ...
bool iterative_mode
If true, use the second argument of Mult() as an initial guess.
virtual void SetOperator(const Operator &op)=0
Set/update the solver for the given operator.
SundialsMemHelper & operator=(const SundialsMemHelper &)=delete
Disable copy assignment.
SundialsMemHelper()=default
Default constructor – object must be moved to.
static int SundialsMemHelper_Alloc(SUNMemoryHelper helper, SUNMemory *memptr, size_t memsize, SUNMemoryType mem_type #if(SUNDIALS_VERSION_MAJOR >=6), void *queue #endif)
static int SundialsMemHelper_Dealloc(SUNMemoryHelper helper, SUNMemory sunmem #if(SUNDIALS_VERSION_MAJOR >=6), void *queue #endif)
Vector interface for SUNDIALS N_Vectors.
long GlobalSize() const
Returns the MPI global length for the internal N_Vector x.
MPI_Comm GetComm() const
Returns the MPI communicator for the internal N_Vector x.
void MakeRef(Vector &base, int offset, int s)
Reset the Vector to be a reference to a sub-vector of base.
void SetSize(int s, long glob_size=0)
Resize the vector to size s.
static bool UseManagedMemory()
static N_Vector MakeNVector(bool use_device)
Create a N_Vector.
~SundialsNVector()
Calls SUNDIALS N_VDestroy function if the N_Vector is owned by 'this'.
void SetDataAndSize(double *d, int s, long glob_size=0)
Set the vector data and size.
void _SetNvecDataAndSize_(long glob_size=0)
Set data and length of internal N_Vector x from 'this'.
void _SetDataAndSize_()
Set data and length from the internal N_Vector x.
N_Vector x
The actual SUNDIALS object.
SundialsNVector()
Creates an empty SundialsNVector.
static constexpr double default_abs_tol
Default scalar absolute tolerance.
SUNMatrix M
Mass matrix M.
int flag
Last flag returned from a call to SUNDIALS.
long saved_global_size
Global vector length on last initialization.
static constexpr double default_rel_tol
Default scalar relative tolerance.
bool reinit
Flag to signal memory reinitialization is need.
SundialsNVector * Y
State vector.
void * sundials_mem
SUNDIALS mem structure.
SUNLinearSolver LSA
Linear solver for A.
SUNLinearSolver LSM
Linear solver for M.
SUNNonlinearSolver NLS
Nonlinear solver.
Singleton class for SUNContext and SundialsMemHelper objects.
static SUNContext & GetContext()
Provides access to the SUNContext object.
static SundialsMemHelper & GetMemHelper()
Provides access to the SundialsMemHelper object.
int GetAdjointHeight()
Returns the size of the adjoint problem state space.
Base abstract class for first order time dependent operators.
virtual int SUNImplicitSolve(const Vector &b, Vector &x, real_t tol)
Solve the ODE linear system as setup by the method SUNImplicitSetup().
virtual int SUNMassMult(const Vector &x, Vector &v)
Compute the mass matrix-vector product .
virtual int SUNMassSetup()
Setup the mass matrix in the ODE system .
virtual void Mult(const Vector &x, Vector &y) const
Perform the action of the operator: y = k = f(x, t), where k solves the algebraic equation F(x,...
virtual int SUNMassSolve(const Vector &b, Vector &x, real_t tol)
Solve the mass matrix linear system as setup by the method SUNMassSetup().
virtual void SetEvalMode(const EvalMode new_eval_mode)
Set the evaluation mode of the time-dependent operator.
virtual void SetTime(const real_t t_)
Set the current time.
virtual int SUNImplicitSetup(const Vector &x, const Vector &fx, int jok, int *jcur, real_t gamma)
Setup the ODE linear system or , where .
virtual real_t GetTime() const
Read the currently set time.
virtual const real_t * HostRead() const
Shortcut for mfem::Read(vec.GetMemory(), vec.Size(), false).
virtual real_t * ReadWrite(bool on_dev=true)
Shortcut for mfem::ReadWrite(vec.GetMemory(), vec.Size(), on_dev).
void SetDataAndSize(real_t *d, int s)
Set the Vector data and size.
real_t Normlinf() const
Returns the l_infinity norm of the vector.
Vector & Set(const real_t a, const Vector &x)
(*this) = a * x
virtual bool UseDevice() const
Return the device flag of the Memory object used by the Vector.
int Size() const
Returns the size of the vector.
void SetSize(int s)
Resize the vector to size s.
virtual real_t * HostReadWrite()
Shortcut for mfem::ReadWrite(vec.GetMemory(), vec.Size(), false).
T * HostReadWrite(Memory< T > &mem, int size)
Shortcut to ReadWrite(Memory<T> &mem, int size, false)
real_t u(const Vector &xvec)
OutStream out(std::cout)
Global stream used by the library for standard output. Initially it uses the same std::streambuf as s...
T * ReadWrite(Memory< T > &mem, int size, bool on_dev=true)
Get a pointer for read+write access to mem with the mfem::Device's DeviceMemoryClass,...
MemoryManager mm
The (single) global memory manager object.
Settings for the output behavior of the IterativeSolver.
Helper struct to convert a C++ type to an MPI type.
MFEM_DEPRECATED SUNLinearSolver SUNLinSolNewEmpty(SUNContext)
MFEM_DEPRECATED N_Vector N_VMake_MPIPlusX(MPI_Comm comm, N_Vector local_vector, SUNContext)
MFEM_DEPRECATED void * KINCreate(SUNContext)
MFEM_DEPRECATED SUNLinearSolver SUNLinSol_SPFGMR(N_Vector y, int pretype, int maxl, SUNContext)
MFEM_DEPRECATED void * CVodeCreate(int lmm, SUNContext)
MFEM_DEPRECATED N_Vector N_VNewEmpty_Parallel(MPI_Comm comm, sunindextype local_length, sunindextype global_length, SUNContext)
MFEM_DEPRECATED SUNMatrix SUNMatNewEmpty(SUNContext)
MFEM_DEPRECATED SUNMemoryHelper SUNMemoryHelper_NewEmpty(SUNContext)
MFEM_DEPRECATED void * ARKStepCreate(ARKRhsFn fe, ARKRhsFn fi, realtype t0, N_Vector y0, SUNContext)
MFEM_DEPRECATED N_Vector SUN_Hip_OR_Cuda N_VNewWithMemHelp(sunindextype length, booleantype use_managed_mem, SUNMemoryHelper helper, SUNContext)
MFEM_DEPRECATED N_Vector N_VNewEmpty_Serial(sunindextype vec_length, SUNContext)
MFEM_DEPRECATED SUNLinearSolver SUNLinSol_SPGMR(N_Vector y, int pretype, int maxl, SUNContext)
constexpr ARKODE_ERKTableID ARKODE_ERK_NONE
constexpr ARKODE_DIRKTableID ARKODE_DIRK_NONE