MFEM v4.8.0
Finite element discretization library
Loading...
Searching...
No Matches
sundials.cpp
Go to the documentation of this file.
1// Copyright (c) 2010-2025, Lawrence Livermore National Security, LLC. Produced
2// at the Lawrence Livermore National Laboratory. All Rights reserved. See files
3// LICENSE and NOTICE for details. LLNL-CODE-806117.
4//
5// This file is part of the MFEM library. For more information and source code
6// availability visit https://mfem.org.
7//
8// MFEM is free software; you can redistribute it and/or modify it under the
9// terms of the BSD-3 license. We welcome feedback and contributions, see file
10// CONTRIBUTING.md for details.
11
12#include "sundials.hpp"
13
14#ifdef MFEM_USE_SUNDIALS
15
16#include "solvers.hpp"
17#ifdef MFEM_USE_MPI
18#include "hypre.hpp"
19#endif
20
21// SUNDIALS vectors
22#include <nvector/nvector_serial.h>
23#if defined(MFEM_USE_CUDA)
24#include <nvector/nvector_cuda.h>
25#elif defined(MFEM_USE_HIP)
26#include <nvector/nvector_hip.h>
27#endif
28#ifdef MFEM_USE_MPI
29#include <nvector/nvector_mpiplusx.h>
30#include <nvector/nvector_parallel.h>
31#endif
32
33// SUNDIALS linear solvers
34#include <sunlinsol/sunlinsol_spgmr.h>
35#include <sunlinsol/sunlinsol_spfgmr.h>
36
37// Access SUNDIALS object's content pointer
38#define GET_CONTENT(X) ( X->content )
39
40#if defined(MFEM_USE_CUDA)
41#define SUN_Hip_OR_Cuda(X) X##_Cuda
42#define SUN_HIP_OR_CUDA(X) X##_CUDA
43#elif defined(MFEM_USE_HIP)
44#define SUN_Hip_OR_Cuda(X) X##_Hip
45#define SUN_HIP_OR_CUDA(X) X##_HIP
46#endif
47
48using namespace std;
49
50#if (SUNDIALS_VERSION_MAJOR < 6)
51
52/// (DEPRECATED) Wrapper function for backwards compatibility with SUNDIALS
53/// version < 6
54MFEM_DEPRECATED N_Vector N_VNewEmpty_Serial(sunindextype vec_length, SUNContext)
55{
56 return N_VNewEmpty_Serial(vec_length);
57}
58
59/// (DEPRECATED) Wrapper function for backwards compatibility with SUNDIALS
60/// version < 6
61MFEM_DEPRECATED SUNMatrix SUNMatNewEmpty(SUNContext)
62{
63 return SUNMatNewEmpty();
64}
65
66/// (DEPRECATED) Wrapper function for backwards compatibility with SUNDIALS
67/// version < 6
68MFEM_DEPRECATED SUNLinearSolver SUNLinSolNewEmpty(SUNContext)
69{
70 return SUNLinSolNewEmpty();
71}
72
73/// (DEPRECATED) Wrapper function for backwards compatibility with SUNDIALS
74/// version < 6
75MFEM_DEPRECATED SUNLinearSolver SUNLinSol_SPGMR(N_Vector y, int pretype,
76 int maxl, SUNContext)
77{
78 return SUNLinSol_SPGMR(y, pretype, maxl);
79}
80
81/// (DEPRECATED) Wrapper function for backwards compatibility with SUNDIALS
82/// version < 6
83MFEM_DEPRECATED SUNLinearSolver SUNLinSol_SPFGMR(N_Vector y, int pretype,
84 int maxl, SUNContext)
85{
86 return SUNLinSol_SPFGMR(y, pretype, maxl);
87}
88
89/// (DEPRECATED) Wrapper function for backwards compatibility with SUNDIALS
90/// version < 6
91MFEM_DEPRECATED void* CVodeCreate(int lmm, SUNContext)
92{
93 return CVodeCreate(lmm);
94}
95
96/// (DEPRECATED) Wrapper function for backwards compatibility with SUNDIALS
97/// version < 6
98MFEM_DEPRECATED void* ARKStepCreate(ARKRhsFn fe, ARKRhsFn fi, sunrealtype t0,
99 N_Vector y0, SUNContext)
100{
101 return ARKStepCreate(fe, fi, t0, y0);
102}
103
104/// (DEPRECATED) Wrapper function for backwards compatibility with SUNDIALS
105/// version < 6
106MFEM_DEPRECATED void* KINCreate(SUNContext)
107{
108 return KINCreate();
109}
110
111#ifdef MFEM_USE_MPI
112
113/// (DEPRECATED) Wrapper function for backwards compatibility with SUNDIALS
114/// version < 6
115MFEM_DEPRECATED N_Vector N_VNewEmpty_Parallel(MPI_Comm comm,
116 sunindextype local_length,
117 sunindextype global_length,
119{
120 return N_VNewEmpty_Parallel(comm, local_length, global_length);
121}
122
123#endif // MFEM_USE_MPI
124
125#if defined(MFEM_USE_CUDA) || defined(MFEM_USE_HIP)
126
127/// (DEPRECATED) Wrapper function for backwards compatibility with SUNDIALS
128/// version < 6
129MFEM_DEPRECATED N_Vector SUN_Hip_OR_Cuda(N_VNewWithMemHelp)(sunindextype length,
130 sunbooleantype use_managed_mem,
131 SUNMemoryHelper helper,
133{
134 return SUN_Hip_OR_Cuda(N_VNewWithMemHelp)(length, use_managed_mem, helper);
135}
136
137/// (DEPRECATED) Wrapper function for backwards compatibility with SUNDIALS
138/// version < 6
139MFEM_DEPRECATED SUNMemoryHelper SUNMemoryHelper_NewEmpty(SUNContext)
140{
142}
143
144#endif // MFEM_USE_CUDA || MFEM_USE_HIP
145
146#if defined(MFEM_USE_MPI) && (defined(MFEM_USE_CUDA) || defined(MFEM_USE_HIP))
147
148/// (DEPRECATED) Wrapper function for backwards compatibility with SUNDIALS
149/// version < 6
150MFEM_DEPRECATED N_Vector N_VMake_MPIPlusX(MPI_Comm comm, N_Vector local_vector,
152{
153 return N_VMake_MPIPlusX(comm, local_vector);
154}
155
156#endif // MFEM_USE_MPI && (MFEM_USE_CUDA || MFEM_USE_HIP)
157
158#endif // SUNDIALS_VERSION_MAJOR < 6
159
160#if MFEM_SUNDIALS_VERSION < 70100
161#define MFEM_ARKode(FUNC) ARKStep##FUNC
162#else
163#define MFEM_ARKode(FUNC) ARKode##FUNC
164#endif
165
166// Macro STR(): expand the argument and add double quotes
167#define STR1(s) #s
168#define STR(s) STR1(s)
169
170
171namespace mfem
172{
173
175{
176 Sundials::Instance();
177}
178
179Sundials &Sundials::Instance()
180{
181 static Sundials sundials;
182 return sundials;
183}
184
186{
187 return Sundials::Instance().context;
188}
189
191{
192 return Sundials::Instance().memHelper;
193}
194
195#if (SUNDIALS_VERSION_MAJOR >= 6)
196
197Sundials::Sundials()
198{
199#ifdef MFEM_USE_MPI
200 int mpi_initialized = 0;
201 MPI_Initialized(&mpi_initialized);
202 MPI_Comm communicator = mpi_initialized ? MPI_COMM_WORLD : MPI_COMM_NULL;
203#if SUNDIALS_VERSION_MAJOR < 7
204 int return_val = SUNContext_Create((void*) &communicator, &context);
205#else
206 int return_val = SUNContext_Create(communicator, &context);
207#endif
208#else // #ifdef MFEM_USE_MPI
209#if SUNDIALS_VERSION_MAJOR < 7
210 int return_val = SUNContext_Create(nullptr, &context);
211#else
212 int return_val = SUNContext_Create((SUNComm)(0), &context);
213#endif
214#endif // #ifdef MFEM_USE_MPI
215 MFEM_VERIFY(return_val == 0, "Call to SUNContext_Create failed");
216 SundialsMemHelper actual_helper(context);
217 memHelper = std::move(actual_helper);
218}
219
220Sundials::~Sundials()
221{
222 SUNContext_Free(&context);
223}
224
225#else // SUNDIALS_VERSION_MAJOR >= 6
226
227Sundials::Sundials()
228{
229 // Do nothing
230}
231
232Sundials::~Sundials()
233{
234 // Do nothing
235}
236
237#endif // SUNDIALS_VERSION_MAJOR >= 6
238
239#if defined(MFEM_USE_CUDA) || defined(MFEM_USE_HIP)
241{
242 /* Allocate helper */
243 h = SUNMemoryHelper_NewEmpty(context);
244
245 /* Set the ops */
246 h->ops->alloc = SundialsMemHelper_Alloc;
247 h->ops->dealloc = SundialsMemHelper_Dealloc;
248 h->ops->copy = SUN_Hip_OR_Cuda(SUNMemoryHelper_Copy);
249 h->ops->copyasync = SUN_Hip_OR_Cuda(SUNMemoryHelper_CopyAsync);
250}
251
253{
254 this->h = that_helper.h;
255 that_helper.h = nullptr;
256}
257
259{
260 this->h = rhs.h;
261 rhs.h = nullptr;
262 return *this;
263}
264
266 SUNMemory* memptr, size_t memsize,
267 SUNMemoryType mem_type
268#if (SUNDIALS_VERSION_MAJOR >= 6)
269 , void*
270#endif
271 )
272{
273#if (SUNDIALS_VERSION_MAJOR < 7)
274 SUNMemory sunmem = SUNMemoryNewEmpty();
275#else
276 SUNMemory sunmem = SUNMemoryNewEmpty(helper->sunctx);
277#endif
278
279 sunmem->ptr = NULL;
280 sunmem->own = SUNTRUE;
281
282 // memsize is the number of bytes to allocate, so we use Memory<char>
283 if (mem_type == SUNMEMTYPE_HOST)
284 {
286 mem.SetHostPtrOwner(false);
287 sunmem->ptr = mfem::HostReadWrite(mem, memsize);
288 sunmem->type = SUNMEMTYPE_HOST;
289 mem.Delete();
290 }
291 else if (mem_type == SUNMEMTYPE_DEVICE || mem_type == SUNMEMTYPE_UVM)
292 {
294 mem.SetDevicePtrOwner(false);
295 sunmem->ptr = mfem::ReadWrite(mem, memsize);
296 sunmem->type = mem_type;
297 mem.Delete();
298 }
299 else
300 {
301 free(sunmem);
302 return -1;
303 }
304
305 *memptr = sunmem;
306 return 0;
307}
308
310 SUNMemory sunmem
311#if (SUNDIALS_VERSION_MAJOR >= 6)
312 , void*
313#endif
314 )
315{
316 if (sunmem->ptr && sunmem->own && !mm.IsKnown(sunmem->ptr))
317 {
318 if (sunmem->type == SUNMEMTYPE_HOST)
319 {
320 Memory<char> mem(static_cast<char*>(sunmem->ptr), 1,
322 mem.Delete();
323 }
324 else if (sunmem->type == SUNMEMTYPE_DEVICE || sunmem->type == SUNMEMTYPE_UVM)
325 {
326 Memory<char> mem(static_cast<char*>(sunmem->ptr), 1,
328 mem.Delete();
329 }
330 else
331 {
332 MFEM_ABORT("Invalid SUNMEMTYPE");
333 return -1;
334 }
335 }
336 free(sunmem);
337 return 0;
338}
339
340#endif // MFEM_USE_CUDA || MFEM_USE_HIP
341
342
343// ---------------------------------------------------------------------------
344// SUNDIALS N_Vector interface functions
345// ---------------------------------------------------------------------------
346
348{
349#ifdef MFEM_USE_MPI
350 N_Vector local_x = MPIPlusX() ? N_VGetLocalVector_MPIPlusX(x) : x;
351#else
352 N_Vector local_x = x;
353#endif
354 N_Vector_ID id = N_VGetVectorID(local_x);
355
356 // Set the N_Vector data and length from the Vector data and size.
357 switch (id)
358 {
359 case SUNDIALS_NVEC_SERIAL:
360 {
361 MFEM_ASSERT(NV_OWN_DATA_S(local_x) == SUNFALSE, "invalid serial N_Vector");
362 NV_DATA_S(local_x) = HostReadWrite();
363 NV_LENGTH_S(local_x) = size;
364 break;
365 }
366#if defined(MFEM_USE_CUDA) || defined(MFEM_USE_HIP)
367 case SUN_HIP_OR_CUDA(SUNDIALS_NVEC):
368 {
369 SUN_Hip_OR_Cuda(N_VSetHostArrayPointer)(HostReadWrite(), local_x);
370 SUN_Hip_OR_Cuda(N_VSetDeviceArrayPointer)(ReadWrite(), local_x);
371 static_cast<SUN_Hip_OR_Cuda(N_VectorContent)>(GET_CONTENT(
372 local_x))->length = size;
373 break;
374 }
375#endif
376#ifdef MFEM_USE_MPI
377 case SUNDIALS_NVEC_PARALLEL:
378 {
379 MFEM_ASSERT(NV_OWN_DATA_P(x) == SUNFALSE, "invalid parallel N_Vector");
380 NV_DATA_P(x) = HostReadWrite();
381 NV_LOCLENGTH_P(x) = size;
382 if (glob_size == 0)
383 {
384 glob_size = GlobalSize();
385
386 if (glob_size == 0 && glob_size != size)
387 {
388 long local_size = size;
389 MPI_Allreduce(&local_size, &glob_size, 1, MPI_LONG,
390 MPI_SUM, GetComm());
391 }
392 }
393 NV_GLOBLENGTH_P(x) = glob_size;
394 break;
395 }
396#endif
397 default:
398 MFEM_ABORT("N_Vector type " << id << " is not supported");
399 }
400
401#ifdef MFEM_USE_MPI
402 if (MPIPlusX())
403 {
404 if (glob_size == 0)
405 {
406 glob_size = GlobalSize();
407
408 if (glob_size == 0 && glob_size != size)
409 {
410 long local_size = size;
411 MPI_Allreduce(&local_size, &glob_size, 1, MPI_LONG,
412 MPI_SUM, GetComm());
413 }
414 }
415 static_cast<N_VectorContent_MPIManyVector>(GET_CONTENT(x))->global_length =
416 glob_size;
417 }
418#endif
419}
420
422{
423#ifdef MFEM_USE_MPI
424 N_Vector local_x = MPIPlusX() ? N_VGetLocalVector_MPIPlusX(x) : x;
425#else
426 N_Vector local_x = x;
427#endif
428 N_Vector_ID id = N_VGetVectorID(local_x);
429
430 // The SUNDIALS NVector owns the data if it created it.
431 switch (id)
432 {
433 case SUNDIALS_NVEC_SERIAL:
434 {
435 const bool known = mm.IsKnown(NV_DATA_S(local_x));
436 size = NV_LENGTH_S(local_x);
437 data.Wrap(NV_DATA_S(local_x), size, false);
438 if (known) { data.ClearOwnerFlags(); }
439 break;
440 }
441#if defined(MFEM_USE_CUDA) || defined(MFEM_USE_HIP)
442 case SUN_HIP_OR_CUDA(SUNDIALS_NVEC):
443 {
444 double *h_ptr = SUN_Hip_OR_Cuda(N_VGetHostArrayPointer)(local_x);
445 double *d_ptr = SUN_Hip_OR_Cuda(N_VGetDeviceArrayPointer)(local_x);
446 const bool known = mm.IsKnown(h_ptr);
447 size = SUN_Hip_OR_Cuda(N_VGetLength)(local_x);
448 data.Wrap(h_ptr, d_ptr, size, Device::GetHostMemoryType(), false, false, true);
449 if (known) { data.ClearOwnerFlags(); }
450 UseDevice(true);
451 break;
452 }
453#endif
454#ifdef MFEM_USE_MPI
455 case SUNDIALS_NVEC_PARALLEL:
456 {
457 const bool known = mm.IsKnown(NV_DATA_P(x));
458 size = NV_LENGTH_S(x);
459 data.Wrap(NV_DATA_P(x), NV_LOCLENGTH_P(x), false);
460 if (known) { data.ClearOwnerFlags(); }
461 break;
462 }
463#endif
464 default:
465 MFEM_ABORT("N_Vector type " << id << " is not supported");
466 }
467}
468
470 : Vector()
471{
472 // MFEM creates and owns the data,
473 // and provides it to the SUNDIALS NVector.
476 own_NVector = 1;
477}
478
479SundialsNVector::SundialsNVector(double *data_, int size_)
480 : Vector(data_, size_)
481{
484 own_NVector = 1;
486}
487
489 : x(nv)
490{
492 own_NVector = 0;
493}
494
495#ifdef MFEM_USE_MPI
497 : Vector()
498{
500 x = MakeNVector(comm, UseDevice());
501 own_NVector = 1;
502}
503
504SundialsNVector::SundialsNVector(MPI_Comm comm, int loc_size, long glob_size)
505 : Vector(loc_size)
506{
508 x = MakeNVector(comm, UseDevice());
509 own_NVector = 1;
510 _SetNvecDataAndSize_(glob_size);
511}
512
513SundialsNVector::SundialsNVector(MPI_Comm comm, double *data_, int loc_size,
514 long glob_size)
515 : Vector(data_, loc_size)
516{
518 x = MakeNVector(comm, UseDevice());
519 own_NVector = 1;
520 _SetNvecDataAndSize_(glob_size);
521}
522
524 : SundialsNVector(vec.GetComm(), vec.GetData(), vec.Size(), vec.GlobalSize())
525{}
526#endif
527
529{
530 if (own_NVector)
531 {
532#ifdef MFEM_USE_MPI
533 if (MPIPlusX())
534 {
535 N_VDestroy(N_VGetLocalVector_MPIPlusX(x));
536 }
537#endif
538 N_VDestroy(x);
539 }
540}
541
542void SundialsNVector::SetSize(int s, long glob_size)
543{
545 _SetNvecDataAndSize_(glob_size);
546}
547
549{
552}
553
554void SundialsNVector::SetDataAndSize(double *d, int s, long glob_size)
555{
557 _SetNvecDataAndSize_(glob_size);
558}
559
560N_Vector SundialsNVector::MakeNVector(bool use_device)
561{
562 N_Vector x;
563#if defined(MFEM_USE_CUDA) || defined(MFEM_USE_HIP)
564 if (use_device)
565 {
566 x = SUN_Hip_OR_Cuda(N_VNewWithMemHelp)(0, UseManagedMemory(),
569 }
570 else
571 {
573 }
574#else
576#endif
577
578 MFEM_VERIFY(x, "Error in SundialsNVector::MakeNVector.");
579
580 return x;
581}
582
583#ifdef MFEM_USE_MPI
584N_Vector SundialsNVector::MakeNVector(MPI_Comm comm, bool use_device)
585{
586 N_Vector x;
587
588 if (comm == MPI_COMM_NULL)
589 {
590 x = MakeNVector(use_device);
591 }
592 else
593 {
594#if defined(MFEM_USE_CUDA) || defined(MFEM_USE_HIP)
595 if (use_device)
596 {
597 x = N_VMake_MPIPlusX(comm, SUN_Hip_OR_Cuda(N_VNewWithMemHelp)(0,
602 }
603 else
604 {
606 }
607#else
609#endif // MFEM_USE_CUDA || MFEM_USE_HIP
610 }
611
612 MFEM_VERIFY(x, "Error in SundialsNVector::MakeNVector.");
613
614 return x;
615}
616#endif // MFEM_USE_MPI
617
618
619// ---------------------------------------------------------------------------
620// SUNMatrix interface functions
621// ---------------------------------------------------------------------------
622
623// Return the matrix ID
624static SUNMatrix_ID MatGetID(SUNMatrix)
625{
626 return (SUNMATRIX_CUSTOM);
627}
628
629static void MatDestroy(SUNMatrix A)
630{
631 if (A->content) { A->content = NULL; }
632 if (A->ops) { free(A->ops); A->ops = NULL; }
633 free(A); A = NULL;
634 return;
635}
636
637// ---------------------------------------------------------------------------
638// SUNLinearSolver interface functions
639// ---------------------------------------------------------------------------
640
641// Return the linear solver type
642static SUNLinearSolver_Type LSGetType(SUNLinearSolver)
643{
644 return (SUNLINEARSOLVER_MATRIX_ITERATIVE);
645}
646
647static int LSFree(SUNLinearSolver LS)
648{
649 if (LS->content) { LS->content = NULL; }
650 if (LS->ops) { free(LS->ops); LS->ops = NULL; }
651 free(LS); LS = NULL;
652 return (0);
653}
654
655// ---------------------------------------------------------------------------
656// CVODE interface
657// ---------------------------------------------------------------------------
658int CVODESolver::RHS(sunrealtype t, const N_Vector y, N_Vector ydot,
659 void *user_data)
660{
661 // At this point the up-to-date data for N_Vector y and ydot is on the device.
662 const SundialsNVector mfem_y(y);
663 SundialsNVector mfem_ydot(ydot);
664
665 CVODESolver *self = static_cast<CVODESolver*>(user_data);
666
667 // Compute y' = f(t, y)
668 self->f->SetTime(t);
669 self->f->Mult(mfem_y, mfem_ydot);
670
671 // Return success
672 return (0);
673}
674
676 void *user_data)
677{
678 CVODESolver *self = static_cast<CVODESolver*>(user_data);
679
680 if (!self->root_func) { return CV_RTFUNC_FAIL; }
681
682 SundialsNVector mfem_y(y);
683 SundialsNVector mfem_gout(gout, self->root_components);
684
685 return self->root_func(t, mfem_y, mfem_gout, self);
686}
687
688void CVODESolver::SetRootFinder(int components, RootFunction func)
689{
690 root_func = func;
691
692 flag = CVodeRootInit(sundials_mem, components, root);
693 MFEM_VERIFY(flag == CV_SUCCESS, "error in SetRootFinder()");
694}
695
696int CVODESolver::LinSysSetup(sunrealtype t, N_Vector y, N_Vector fy,
697 SUNMatrix A, sunbooleantype jok,
698 sunbooleantype *jcur, sunrealtype gamma,
699 void*, N_Vector, N_Vector, N_Vector)
700{
701 // Get data from N_Vectors
702 const SundialsNVector mfem_y(y);
703 const SundialsNVector mfem_fy(fy);
704 CVODESolver *self = static_cast<CVODESolver*>(GET_CONTENT(A));
705
706 // Compute the linear system
707 self->f->SetTime(t);
708 return (self->f->SUNImplicitSetup(mfem_y, mfem_fy, jok, jcur, gamma));
709}
710
711int CVODESolver::LinSysSolve(SUNLinearSolver LS, SUNMatrix, N_Vector x,
712 N_Vector b, sunrealtype tol)
713{
714 SundialsNVector mfem_x(x);
715 const SundialsNVector mfem_b(b);
716 CVODESolver *self = static_cast<CVODESolver*>(GET_CONTENT(LS));
717 // Solve the linear system
718 return (self->f->SUNImplicitSolve(mfem_b, mfem_x, tol));
719}
720
722 : lmm_type(lmm), step_mode(CV_NORMAL)
723{
724 Y = new SundialsNVector();
725}
726
727#ifdef MFEM_USE_MPI
728CVODESolver::CVODESolver(MPI_Comm comm, int lmm)
729 : lmm_type(lmm), step_mode(CV_NORMAL)
730{
731 Y = new SundialsNVector(comm);
732}
733#endif
734
736{
737 // Initialize the base class
738 ODESolver::Init(f_);
739
740 // Get the vector length
741 long local_size = f_.Height();
742
743#ifdef MFEM_USE_MPI
744 long global_size = 0;
745 if (Parallel())
746 {
747 MPI_Allreduce(&local_size, &global_size, 1, MPI_LONG, MPI_SUM,
748 Y->GetComm());
749 }
750#endif
751
752 // Get current time
753 double t = f_.GetTime();
754
755 if (sundials_mem)
756 {
757 // Check if the problem size has changed since the last Init() call
758 int resize = 0;
759 if (!Parallel())
760 {
761 resize = (Y->Size() != local_size);
762 }
763 else
764 {
765#ifdef MFEM_USE_MPI
766 int l_resize = (Y->Size() != local_size) ||
767 (saved_global_size != global_size);
768 MPI_Allreduce(&l_resize, &resize, 1, MPI_INT, MPI_LOR,
769 Y->GetComm());
770#endif
771 }
772
773 // Free existing solver memory and re-create with new vector size
774 if (resize)
775 {
776 CVodeFree(&sundials_mem);
777 sundials_mem = NULL;
778 }
779 }
780
781 if (!sundials_mem)
782 {
783 // Temporarily set N_Vector wrapper data to create CVODE. The correct
784 // initial condition will be set using CVodeReInit() when Step() is
785 // called.
786
787 if (!Parallel())
788 {
789 Y->SetSize(local_size);
790 }
791#ifdef MFEM_USE_MPI
792 else
793 {
794 Y->SetSize(local_size, global_size);
795 saved_global_size = global_size;
796 }
797#endif
798
799 // Create CVODE
801 MFEM_VERIFY(sundials_mem, "error in CVodeCreate()");
802
803 // Initialize CVODE
804 flag = CVodeInit(sundials_mem, CVODESolver::RHS, t, *Y);
805 MFEM_VERIFY(flag == CV_SUCCESS, "error in CVodeInit()");
806
807 // Attach the CVODESolver as user-defined data
808 flag = CVodeSetUserData(sundials_mem, this);
809 MFEM_VERIFY(flag == CV_SUCCESS, "error in CVodeSetUserData()");
810
811 // Set default tolerances
812 flag = CVodeSStolerances(sundials_mem, default_rel_tol, default_abs_tol);
813 MFEM_VERIFY(flag == CV_SUCCESS, "error in CVodeSetSStolerances()");
814
815 // Attach MFEM linear solver by default
817 }
818
819 // Set the reinit flag to call CVodeReInit() in the next Step() call.
820 reinit = true;
821}
822
823void CVODESolver::Step(Vector &x, double &t, double &dt)
824{
825 Y->MakeRef(x, 0, x.Size());
826 MFEM_VERIFY(Y->Size() == x.Size(), "size mismatch");
827
828 // Reinitialize CVODE memory if needed
829 if (reinit)
830 {
831 flag = CVodeReInit(sundials_mem, t, *Y);
832 MFEM_VERIFY(flag == CV_SUCCESS, "error in CVodeReInit()");
833 // reset flag
834 reinit = false;
835 }
836
837 // Integrate the system
838 double tout = t + dt;
839 flag = CVode(sundials_mem, tout, *Y, &t, step_mode);
840 MFEM_VERIFY(flag >= 0, "error in CVode()");
841
842 // Make sure host is up to date
843 Y->HostRead();
844
845 // Return the last incremental step size
846 flag = CVodeGetLastStep(sundials_mem, &dt);
847 MFEM_VERIFY(flag == CV_SUCCESS, "error in CVodeGetLastStep()");
848}
849
851{
852 // Free any existing matrix and linear solver
853 if (A != NULL) { SUNMatDestroy(A); A = NULL; }
854 if (LSA != NULL) { SUNLinSolFree(LSA); LSA = NULL; }
855
856 // Wrap linear solver as SUNLinearSolver and SUNMatrix
858 MFEM_VERIFY(LSA, "error in SUNLinSolNewEmpty()");
859
860 LSA->content = this;
861 LSA->ops->gettype = LSGetType;
862 LSA->ops->solve = CVODESolver::LinSysSolve;
863 LSA->ops->free = LSFree;
864
866 MFEM_VERIFY(A, "error in SUNMatNewEmpty()");
867
868 A->content = this;
869 A->ops->getid = MatGetID;
870 A->ops->destroy = MatDestroy;
871
872 // Attach the linear solver and matrix
873 flag = CVodeSetLinearSolver(sundials_mem, LSA, A);
874 MFEM_VERIFY(flag == CV_SUCCESS, "error in CVodeSetLinearSolver()");
875
876 // Set the linear system evaluation function
877 flag = CVodeSetLinSysFn(sundials_mem, CVODESolver::LinSysSetup);
878 MFEM_VERIFY(flag == CV_SUCCESS, "error in CVodeSetLinSysFn()");
879}
880
882{
883 // Free any existing matrix and linear solver
884 if (A != NULL) { SUNMatDestroy(A); A = NULL; }
885 if (LSA != NULL) { SUNLinSolFree(LSA); LSA = NULL; }
886
887 // Create linear solver
889 MFEM_VERIFY(LSA, "error in SUNLinSol_SPGMR()");
890
891 // Attach linear solver
892 flag = CVodeSetLinearSolver(sundials_mem, LSA, NULL);
893 MFEM_VERIFY(flag == CV_SUCCESS, "error in CVodeSetLinearSolver()");
894}
895
897{
898 step_mode = itask;
899}
900
901void CVODESolver::SetSStolerances(double reltol, double abstol)
902{
903 flag = CVodeSStolerances(sundials_mem, reltol, abstol);
904 MFEM_VERIFY(flag == CV_SUCCESS, "error in CVodeSStolerances()");
905}
906
907void CVODESolver::SetSVtolerances(double reltol, Vector abstol)
908{
909 MFEM_VERIFY(abstol.Size() == f->Height(),
910 "abs tolerance is not the same size.");
911
912 SundialsNVector mfem_abstol;
913 mfem_abstol.MakeRef(abstol, 0, abstol.Size());
914
915 flag = CVodeSVtolerances(sundials_mem, reltol, mfem_abstol);
916 MFEM_VERIFY(flag == CV_SUCCESS, "error in CVodeSVtolerances()");
917}
918
919void CVODESolver::SetMaxStep(double dt_max)
920{
921 flag = CVodeSetMaxStep(sundials_mem, dt_max);
922 MFEM_VERIFY(flag == CV_SUCCESS, "error in CVodeSetMaxStep()");
923}
924
926{
927 flag = CVodeSetMaxNumSteps(sundials_mem, mxsteps);
928 MFEM_VERIFY(flag == CV_SUCCESS, "error in CVodeSetMaxNumSteps()");
929}
930
932{
933 long nsteps;
934 flag = CVodeGetNumSteps(sundials_mem, &nsteps);
935 MFEM_VERIFY(flag == CV_SUCCESS, "error in CVodeGetNumSteps()");
936 return nsteps;
937}
938
939void CVODESolver::SetMaxOrder(int max_order)
940{
941 flag = CVodeSetMaxOrd(sundials_mem, max_order);
942 MFEM_VERIFY(flag == CV_SUCCESS, "error in CVodeSetMaxOrd()");
943}
944
946{
947 long int nsteps, nfevals, nlinsetups, netfails;
948 int qlast, qcur;
949 double hinused, hlast, hcur, tcur;
950 long int nniters, nncfails;
951
952 // Get integrator stats
953 flag = CVodeGetIntegratorStats(sundials_mem,
954 &nsteps,
955 &nfevals,
956 &nlinsetups,
957 &netfails,
958 &qlast,
959 &qcur,
960 &hinused,
961 &hlast,
962 &hcur,
963 &tcur);
964 MFEM_VERIFY(flag == CV_SUCCESS, "error in CVodeGetIntegratorStats()");
965
966 // Get nonlinear solver stats
967 flag = CVodeGetNonlinSolvStats(sundials_mem,
968 &nniters,
969 &nncfails);
970 MFEM_VERIFY(flag == CV_SUCCESS, "error in CVodeGetNonlinSolvStats()");
971
972 mfem::out <<
973 "CVODE:\n"
974 "num steps: " << nsteps << "\n"
975 "num rhs evals: " << nfevals << "\n"
976 "num lin setups: " << nlinsetups << "\n"
977 "num nonlin sol iters: " << nniters << "\n"
978 "num nonlin conv fail: " << nncfails << "\n"
979 "num error test fails: " << netfails << "\n"
980 "last order: " << qlast << "\n"
981 "current order: " << qcur << "\n"
982 "initial dt: " << hinused << "\n"
983 "last dt: " << hlast << "\n"
984 "current dt: " << hcur << "\n"
985 "current t: " << tcur << "\n" << endl;
986
987 return;
988}
989
991{
992 delete Y;
993 SUNMatDestroy(A);
994 SUNLinSolFree(LSA);
995 SUNNonlinSolFree(NLS);
996 CVodeFree(&sundials_mem);
997}
998
999// ---------------------------------------------------------------------------
1000// CVODESSolver interface
1001// ---------------------------------------------------------------------------
1002
1004 CVODESolver(lmm),
1005 ncheck(0),
1006 indexB(0),
1007 AB(nullptr),
1008 LSB(nullptr)
1009{
1010 q = new SundialsNVector();
1011 qB = new SundialsNVector();
1012 yB = new SundialsNVector();
1013 yy = new SundialsNVector();
1014}
1015
1016#ifdef MFEM_USE_MPI
1017CVODESSolver::CVODESSolver(MPI_Comm comm, int lmm) :
1018 CVODESolver(comm, lmm),
1019 ncheck(0),
1020 indexB(0),
1021 AB(nullptr),
1022 LSB(nullptr)
1023{
1024 q = new SundialsNVector(comm);
1025 qB = new SundialsNVector(comm);
1026 yB = new SundialsNVector(comm);
1027 yy = new SundialsNVector(comm);
1028}
1029#endif
1030
1032{
1033 MFEM_VERIFY(t <= f->GetTime(), "t > current forward solver time");
1034
1035 flag = CVodeGetQuad(sundials_mem, &t, *q);
1036 MFEM_VERIFY(flag == CV_SUCCESS, "error in CVodeGetQuad()");
1037
1038 Q.Set(1., *q);
1039}
1040
1042{
1043 MFEM_VERIFY(t <= f->GetTime(), "t > current forward solver time");
1044
1045 flag = CVodeGetQuadB(sundials_mem, indexB, &t, *qB);
1046 MFEM_VERIFY(flag == CV_SUCCESS, "error in CVodeGetQuadB()");
1047
1048 dG_dp.Set(-1., *qB);
1049}
1050
1052{
1053 yy->MakeRef(yyy, 0, yyy.Size());
1054
1055 flag = CVodeGetAdjY(sundials_mem, tB, *yy);
1056 MFEM_VERIFY(flag >= 0, "error in CVodeGetAdjY()");
1057}
1058
1059// Implemented to enforce type checking for TimeDependentAdjointOperator
1064
1066{
1067 long local_size = f_.GetAdjointHeight();
1068
1069 // Get current time
1070 double tB = f_.GetTime();
1071
1072 yB->SetSize(local_size);
1073
1074 // Create the solver memory
1075 flag = CVodeCreateB(sundials_mem, CV_BDF, &indexB);
1076 MFEM_VERIFY(flag == CV_SUCCESS, "error in CVodeCreateB()");
1077
1078 // Initialize
1079 flag = CVodeInitB(sundials_mem, indexB, RHSB, tB, *yB);
1080 MFEM_VERIFY(flag == CV_SUCCESS, "error in CVodeInit()");
1081
1082 // Attach the CVODESSolver as user-defined data
1083 flag = CVodeSetUserDataB(sundials_mem, indexB, this);
1084 MFEM_VERIFY(flag == CV_SUCCESS, "error in CVodeSetUserDataB()");
1085
1086 // Set default tolerances
1087 flag = CVodeSStolerancesB(sundials_mem, indexB, default_rel_tolB,
1089 MFEM_VERIFY(flag == CV_SUCCESS, "error in CVodeSetSStolerancesB()");
1090
1091 // Attach MFEM linear solver by default
1093
1094 // Set the reinit flag to call CVodeReInit() in the next Step() call.
1095 reinit = true;
1096}
1097
1098void CVODESSolver::InitAdjointSolve(int steps, int interpolation)
1099{
1100 flag = CVodeAdjInit(sundials_mem, steps, interpolation);
1101 MFEM_VERIFY(flag == CV_SUCCESS, "Error in CVodeAdjInit");
1102}
1103
1105{
1106 flag = CVodeSetMaxNumStepsB(sundials_mem, indexB, mxstepsB);
1107 MFEM_VERIFY(flag == CV_SUCCESS, "Error in CVodeSetMaxNumStepsB()");
1108}
1109
1111 double abstolQ)
1112{
1113 q->MakeRef(q0, 0, q0.Size());
1114
1115 flag = CVodeQuadInit(sundials_mem, RHSQ, *q);
1116 MFEM_VERIFY(flag == CV_SUCCESS, "Error in CVodeQuadInit()");
1117
1118 flag = CVodeSetQuadErrCon(sundials_mem, SUNTRUE);
1119 MFEM_VERIFY(flag == CV_SUCCESS, "Error in CVodeSetQuadErrCon");
1120
1121 flag = CVodeQuadSStolerances(sundials_mem, reltolQ, abstolQ);
1122 MFEM_VERIFY(flag == CV_SUCCESS, "Error in CVodeQuadSStolerances");
1123}
1124
1126 double abstolQB)
1127{
1128 qB->MakeRef(qB0, 0, qB0.Size());
1129
1130 flag = CVodeQuadInitB(sundials_mem, indexB, RHSQB, *qB);
1131 MFEM_VERIFY(flag == CV_SUCCESS, "Error in CVodeQuadInitB()");
1132
1133 flag = CVodeSetQuadErrConB(sundials_mem, indexB, SUNTRUE);
1134 MFEM_VERIFY(flag == CV_SUCCESS, "Error in CVodeSetQuadErrConB");
1135
1136 flag = CVodeQuadSStolerancesB(sundials_mem, indexB, reltolQB, abstolQB);
1137 MFEM_VERIFY(flag == CV_SUCCESS, "Error in CVodeQuadSStolerancesB");
1138}
1139
1141{
1142 // Free any existing linear solver
1143 if (AB != NULL) { SUNMatDestroy(AB); AB = NULL; }
1144 if (LSB != NULL) { SUNLinSolFree(LSB); LSB = NULL; }
1145
1146 // Wrap linear solver as SUNLinearSolver and SUNMatrix
1148 MFEM_VERIFY(LSB, "error in SUNLinSolNewEmpty()");
1149
1150 LSB->content = this;
1151 LSB->ops->gettype = LSGetType;
1152 LSB->ops->solve = CVODESSolver::LinSysSolveB; // JW change
1153 LSB->ops->free = LSFree;
1154
1156 MFEM_VERIFY(AB, "error in SUNMatNewEmpty()");
1157
1158 AB->content = this;
1159 AB->ops->getid = MatGetID;
1160 AB->ops->destroy = MatDestroy;
1161
1162 // Attach the linear solver and matrix
1163 flag = CVodeSetLinearSolverB(sundials_mem, indexB, LSB, AB);
1164 MFEM_VERIFY(flag == CV_SUCCESS, "error in CVodeSetLinearSolverB()");
1165
1166 // Set the linear system evaluation function
1167 flag = CVodeSetLinSysFnB(sundials_mem, indexB,
1168 CVODESSolver::LinSysSetupB); // JW change
1169 MFEM_VERIFY(flag == CV_SUCCESS, "error in CVodeSetLinSysFn()");
1170}
1171
1173{
1174 // Free any existing matrix and linear solver
1175 if (AB != NULL) { SUNMatDestroy(AB); AB = NULL; }
1176 if (LSB != NULL) { SUNLinSolFree(LSB); LSB = NULL; }
1177
1178 // Set default linear solver (Newton is the default Nonlinear Solver)
1180 MFEM_VERIFY(LSB, "error in SUNLinSol_SPGMR()");
1181
1182 /* Attach the matrix and linear solver */
1183 flag = CVodeSetLinearSolverB(sundials_mem, indexB, LSB, NULL);
1184 MFEM_VERIFY(flag == CV_SUCCESS, "error in CVodeSetLinearSolverB()");
1185}
1186
1187int CVODESSolver::LinSysSetupB(sunrealtype t, N_Vector y, N_Vector yB,
1188 N_Vector fyB, SUNMatrix AB,
1189 sunbooleantype jokB, sunbooleantype *jcurB,
1190 sunrealtype gammaB, void *user_data,
1191 N_Vector tmp1, N_Vector tmp2, N_Vector tmp3)
1192{
1193 // Get data from N_Vectors
1194 const SundialsNVector mfem_y(y);
1195 const SundialsNVector mfem_yB(yB);
1196 SundialsNVector mfem_fyB(fyB);
1197 CVODESSolver *self = static_cast<CVODESSolver*>(GET_CONTENT(AB));
1199 (self->f);
1200 f->SetTime(t);
1201 // Compute the linear system
1202 return (f->SUNImplicitSetupB(t, mfem_y, mfem_yB, mfem_fyB, jokB, jcurB,
1203 gammaB));
1204}
1205
1206int CVODESSolver::LinSysSolveB(SUNLinearSolver LS, SUNMatrix AB, N_Vector yB,
1207 N_Vector Rb, sunrealtype tol)
1208{
1209 SundialsNVector mfem_yB(yB);
1210 const SundialsNVector mfem_Rb(Rb);
1211 CVODESSolver *self = static_cast<CVODESSolver*>(GET_CONTENT(LS));
1213 (self->f);
1214 // Solve the linear system
1215 int ret = f->SUNImplicitSolveB(mfem_yB, mfem_Rb, tol);
1216 return (ret);
1217}
1218
1219void CVODESSolver::SetSStolerancesB(double reltol, double abstol)
1220{
1221 flag = CVodeSStolerancesB(sundials_mem, indexB, reltol, abstol);
1222 MFEM_VERIFY(flag == CV_SUCCESS, "error in CVodeSStolerancesB()");
1223}
1224
1225void CVODESSolver::SetSVtolerancesB(double reltol, Vector abstol)
1226{
1227 MFEM_VERIFY(abstol.Size() == f->Height(),
1228 "abs tolerance is not the same size.");
1229
1230 SundialsNVector mfem_abstol;
1231 mfem_abstol.MakeRef(abstol, 0, abstol.Size());
1232
1233 flag = CVodeSVtolerancesB(sundials_mem, indexB, reltol, mfem_abstol);
1234 MFEM_VERIFY(flag == CV_SUCCESS, "error in CVodeSVtolerancesB()");
1235}
1236
1238{
1239 ewt_func = func;
1240 CVodeWFtolerances(sundials_mem, ewt);
1241}
1242
1243// CVODESSolver static functions
1244
1245int CVODESSolver::RHSQ(sunrealtype t, const N_Vector y, N_Vector qdot,
1246 void *user_data)
1247{
1248 CVODESSolver *self = static_cast<CVODESSolver*>(user_data);
1249 const SundialsNVector mfem_y(y);
1250 SundialsNVector mfem_qdot(qdot);
1252 (self->f);
1253 f->SetTime(t);
1254 f->QuadratureIntegration(mfem_y, mfem_qdot);
1255 return 0;
1256}
1257
1258int CVODESSolver::RHSQB(sunrealtype t, N_Vector y, N_Vector yB, N_Vector qBdot,
1259 void *user_dataB)
1260{
1261 CVODESSolver *self = static_cast<CVODESSolver*>(user_dataB);
1262 SundialsNVector mfem_y(y);
1263 SundialsNVector mfem_yB(yB);
1264 SundialsNVector mfem_qBdot(qBdot);
1266 (self->f);
1267 f->SetTime(t);
1268 f->QuadratureSensitivityMult(mfem_y, mfem_yB, mfem_qBdot);
1269 return 0;
1270}
1271
1272int CVODESSolver::RHSB(sunrealtype t, N_Vector y, N_Vector yB, N_Vector yBdot,
1273 void *user_dataB)
1274{
1275 CVODESSolver *self = static_cast<CVODESSolver*>(user_dataB);
1276 SundialsNVector mfem_y(y);
1277 SundialsNVector mfem_yB(yB);
1278 SundialsNVector mfem_yBdot(yBdot);
1279
1280 mfem_yBdot = 0.;
1282 (self->f);
1283 f->SetTime(t);
1284 f->AdjointRateMult(mfem_y, mfem_yB, mfem_yBdot);
1285 return 0;
1286}
1287
1288int CVODESSolver::ewt(N_Vector y, N_Vector w, void *user_data)
1289{
1290 CVODESSolver *self = static_cast<CVODESSolver*>(user_data);
1291
1292 SundialsNVector mfem_y(y);
1293 SundialsNVector mfem_w(w);
1294
1295 return self->ewt_func(mfem_y, mfem_w, self);
1296}
1297
1298// Pretty much a copy of CVODESolver::Step except we use CVodeF instead of CVode
1299void CVODESSolver::Step(Vector &x, double &t, double &dt)
1300{
1301 Y->MakeRef(x, 0, x.Size());
1302 MFEM_VERIFY(Y->Size() == x.Size(), "size mismatch");
1303
1304 // Reinitialize CVODE memory if needed, initializes the N_Vector y with x
1305 if (reinit)
1306 {
1307 flag = CVodeReInit(sundials_mem, t, *Y);
1308 MFEM_VERIFY(flag == CV_SUCCESS, "error in CVodeReInit()");
1309
1310 // reset flag
1311 reinit = false;
1312 }
1313
1314 // Integrate the system
1315 double tout = t + dt;
1316 flag = CVodeF(sundials_mem, tout, *Y, &t, step_mode, &ncheck);
1317 MFEM_VERIFY(flag >= 0, "error in CVodeF()");
1318
1319 // Make sure host is up to date
1320 Y->HostRead();
1321
1322 // Return the last incremental step size
1323 flag = CVodeGetLastStep(sundials_mem, &dt);
1324 MFEM_VERIFY(flag == CV_SUCCESS, "error in CVodeGetLastStep()");
1325}
1326
1327void CVODESSolver::StepB(Vector &xB, double &tB, double &dtB)
1328{
1329 yB->MakeRef(xB, 0, xB.Size());
1330 MFEM_VERIFY(yB->Size() == xB.Size(), "");
1331
1332 // Reinitialize CVODE memory if needed
1333 if (reinit)
1334 {
1335 flag = CVodeReInitB(sundials_mem, indexB, tB, *yB);
1336 MFEM_VERIFY(flag == CV_SUCCESS, "error in CVodeReInit()");
1337
1338 // reset flag
1339 reinit = false;
1340 }
1341
1342 // Integrate the system
1343 double tout = tB - dtB;
1344 flag = CVodeB(sundials_mem, tout, step_mode);
1345 MFEM_VERIFY(flag >= 0, "error in CVodeB()");
1346
1347 // Call CVodeGetB to get yB of the backward ODE problem.
1348 flag = CVodeGetB(sundials_mem, indexB, &tB, *yB);
1349 MFEM_VERIFY(flag >= 0, "error in CVodeGetB()");
1350
1351 // Make sure host is up to date
1352 yB->HostRead();
1353}
1354
1356{
1357 delete yB;
1358 delete yy;
1359 delete qB;
1360 delete q;
1361 SUNMatDestroy(AB);
1362 SUNLinSolFree(LSB);
1363}
1364
1365
1366// ---------------------------------------------------------------------------
1367// ARKStep interface
1368// ---------------------------------------------------------------------------
1369
1370int ARKStepSolver::RHS1(sunrealtype t, const N_Vector y, N_Vector result,
1371 void *user_data)
1372{
1373 // Get data from N_Vectors
1374 const SundialsNVector mfem_y(y);
1375 SundialsNVector mfem_result(result);
1376 ARKStepSolver *self = static_cast<ARKStepSolver*>(user_data);
1377
1378 // Compute either f(t, y) in one of
1379 // 1. y' = f(t, y)
1380 // 2. M y' = f(t, y)
1381 // or fe(t, y) in one of
1382 // 1. y' = fe(t, y) + fi(t, y)
1383 // 2. M y' = fe(t, y) + fi(t, y)
1384 self->f->SetTime(t);
1385 if (self->rk_type == IMEX)
1386 {
1388 }
1389 if (self->f->isExplicit()) // ODE is in form 1
1390 {
1391 self->f->Mult(mfem_y, mfem_result);
1392 }
1393 else // ODE is in form 2
1394 {
1395 self->f->ExplicitMult(mfem_y, mfem_result);
1396 }
1397
1398 // Return success
1399 return (0);
1400}
1401
1402int ARKStepSolver::RHS2(sunrealtype t, const N_Vector y, N_Vector result,
1403 void *user_data)
1404{
1405 // Get data from N_Vectors
1406 const SundialsNVector mfem_y(y);
1407 SundialsNVector mfem_result(result);
1408 ARKStepSolver *self = static_cast<ARKStepSolver*>(user_data);
1409
1410 // Compute fi(t, y) in one of
1411 // 1. y' = fe(t, y) + fi(t, y) (ODE is expressed in EXPLICIT form)
1412 // 2. M y' = fe(t, y) + fi(y, t) (ODE is expressed in IMPLICIT form)
1413 self->f->SetTime(t);
1415 if (self->f->isExplicit())
1416 {
1417 self->f->Mult(mfem_y, mfem_result);
1418 }
1419 else
1420 {
1421 self->f->ExplicitMult(mfem_y, mfem_result);
1422 }
1423
1424 // Return success
1425 return (0);
1426}
1427
1428int ARKStepSolver::LinSysSetup(sunrealtype t, N_Vector y, N_Vector fy,
1429 SUNMatrix A, SUNMatrix, sunbooleantype jok,
1430 sunbooleantype *jcur, sunrealtype gamma,
1431 void*, N_Vector, N_Vector, N_Vector)
1432{
1433 // Get data from N_Vectors
1434 const SundialsNVector mfem_y(y);
1435 const SundialsNVector mfem_fy(fy);
1436 ARKStepSolver *self = static_cast<ARKStepSolver*>(GET_CONTENT(A));
1437
1438 // Compute the linear system
1439 self->f->SetTime(t);
1440 if (self->rk_type == IMEX)
1441 {
1443 }
1444 return (self->f->SUNImplicitSetup(mfem_y, mfem_fy, jok, jcur, gamma));
1445}
1446
1447int ARKStepSolver::LinSysSolve(SUNLinearSolver LS, SUNMatrix, N_Vector x,
1448 N_Vector b, sunrealtype tol)
1449{
1450 SundialsNVector mfem_x(x);
1451 const SundialsNVector mfem_b(b);
1452 ARKStepSolver *self = static_cast<ARKStepSolver*>(GET_CONTENT(LS));
1453
1454 // Solve the linear system
1455 if (self->rk_type == IMEX)
1456 {
1458 }
1459 return (self->f->SUNImplicitSolve(mfem_b, mfem_x, tol));
1460}
1461
1463 void*, N_Vector, N_Vector, N_Vector)
1464{
1465 ARKStepSolver *self = static_cast<ARKStepSolver*>(GET_CONTENT(M));
1466
1467 // Compute the mass matrix system
1468 self->f->SetTime(t);
1469 return (self->f->SUNMassSetup());
1470}
1471
1472int ARKStepSolver::MassSysSolve(SUNLinearSolver LS, SUNMatrix, N_Vector x,
1473 N_Vector b, sunrealtype tol)
1474{
1475 SundialsNVector mfem_x(x);
1476 const SundialsNVector mfem_b(b);
1477 ARKStepSolver *self = static_cast<ARKStepSolver*>(GET_CONTENT(LS));
1478
1479 // Solve the mass matrix system
1480 return (self->f->SUNMassSolve(mfem_b, mfem_x, tol));
1481}
1482
1483int ARKStepSolver::MassMult1(SUNMatrix M, N_Vector x, N_Vector v)
1484{
1485 const SundialsNVector mfem_x(x);
1486 SundialsNVector mfem_v(v);
1487 ARKStepSolver *self = static_cast<ARKStepSolver*>(GET_CONTENT(M));
1488
1489 // Compute the mass matrix-vector product
1490 return (self->f->SUNMassMult(mfem_x, mfem_v));
1491}
1492
1493int ARKStepSolver::MassMult2(N_Vector x, N_Vector v, sunrealtype t,
1494 void* mtimes_data)
1495{
1496 const SundialsNVector mfem_x(x);
1497 SundialsNVector mfem_v(v);
1498 ARKStepSolver *self = static_cast<ARKStepSolver*>(mtimes_data);
1499
1500 // Compute the mass matrix-vector product
1501 self->f->SetTime(t);
1502 return (self->f->SUNMassMult(mfem_x, mfem_v));
1503}
1504
1506 : rk_type(type), step_mode(ARK_NORMAL),
1507 use_implicit(type == IMPLICIT || type == IMEX)
1508{
1509 Y = new SundialsNVector();
1510}
1511
1512#ifdef MFEM_USE_MPI
1514 : rk_type(type), step_mode(ARK_NORMAL),
1515 use_implicit(type == IMPLICIT || type == IMEX)
1516{
1517 Y = new SundialsNVector(comm);
1518}
1519#endif
1520
1522{
1523 // Initialize the base class
1524 ODESolver::Init(f_);
1525
1526 // Get the vector length
1527 long local_size = f_.Height();
1528#ifdef MFEM_USE_MPI
1529 long global_size;
1530#endif
1531
1532 if (Parallel())
1533 {
1534#ifdef MFEM_USE_MPI
1535 MPI_Allreduce(&local_size, &global_size, 1, MPI_LONG, MPI_SUM,
1536 Y->GetComm());
1537#endif
1538 }
1539
1540 // Get current time
1541 double t = f_.GetTime();
1542
1543 if (sundials_mem)
1544 {
1545 // Check if the problem size has changed since the last Init() call
1546 int resize = 0;
1547 if (!Parallel())
1548 {
1549 resize = (Y->Size() != local_size);
1550 }
1551 else
1552 {
1553#ifdef MFEM_USE_MPI
1554 int l_resize = (Y->Size() != local_size) ||
1555 (saved_global_size != global_size);
1556 MPI_Allreduce(&l_resize, &resize, 1, MPI_INT, MPI_LOR,
1557 Y->GetComm());
1558#endif
1559 }
1560
1561 // Free existing solver memory and re-create with new vector size
1562 if (resize)
1563 {
1564 MFEM_ARKode(Free)(&sundials_mem);
1565 sundials_mem = NULL;
1566 }
1567 }
1568
1569 if (!sundials_mem)
1570 {
1571 if (!Parallel())
1572 {
1573 Y->SetSize(local_size);
1574 }
1575#ifdef MFEM_USE_MPI
1576 else
1577 {
1578 Y->SetSize(local_size, global_size);
1579 saved_global_size = global_size;
1580 }
1581#endif
1582
1583 // Create ARKStep memory
1584 if (rk_type == IMPLICIT)
1585 {
1588 }
1589 else if (rk_type == EXPLICIT)
1590 {
1593 }
1594 else
1595 {
1597 t, *Y, Sundials::GetContext());
1598 }
1599 MFEM_VERIFY(sundials_mem, "error in ARKStepCreate()");
1600
1601 // Attach the ARKStepSolver as user-defined data
1602 flag = MFEM_ARKode(SetUserData)(sundials_mem, this);
1603 MFEM_VERIFY(flag == ARK_SUCCESS,
1604 "error in " STR(MFEM_ARKode(SetUserData)) "()");
1605
1606 // Set default tolerances
1607 flag = MFEM_ARKode(SStolerances)(sundials_mem, default_rel_tol,
1609 MFEM_VERIFY(flag == ARK_SUCCESS,
1610 "error in " STR(MFEM_ARKode(SStolerances)) "()");
1611
1612 // If implicit, attach MFEM linear solver by default
1614 }
1615
1616 // Set the reinit flag to call ARKStepReInit() in the next Step() call.
1617 reinit = true;
1618}
1619
1621{
1622 Y->MakeRef(x, 0, x.Size());
1623 MFEM_VERIFY(Y->Size() == x.Size(), "size mismatch");
1624
1625 // Reinitialize ARKStep memory if needed
1626 if (reinit)
1627 {
1628 if (rk_type == IMPLICIT)
1629 {
1630 flag = ARKStepReInit(sundials_mem, NULL, ARKStepSolver::RHS1, t, *Y);
1631 }
1632 else if (rk_type == EXPLICIT)
1633 {
1634 flag = ARKStepReInit(sundials_mem, ARKStepSolver::RHS1, NULL, t, *Y);
1635 }
1636 else
1637 {
1638 flag = ARKStepReInit(sundials_mem,
1640 }
1641 MFEM_VERIFY(flag == ARK_SUCCESS, "error in ARKStepReInit()");
1642
1643 // reset flag
1644 reinit = false;
1645 }
1646
1647 // Integrate the system
1648 double tout = t + dt;
1649 flag = MFEM_ARKode(Evolve)(sundials_mem, tout, *Y, &t, step_mode);
1650 MFEM_VERIFY(flag >= 0, "error in " STR(MFEM_ARKode(Evolve)) "()");
1651
1652 // Make sure host is up to date
1653 Y->HostRead();
1654
1655 // Return the last incremental step size
1656 flag = MFEM_ARKode(GetLastStep)(sundials_mem, &dt);
1657 MFEM_VERIFY(flag == ARK_SUCCESS,
1658 "error in " STR(MFEM_ARKode(GetLastStep)) "()");
1659}
1660
1662{
1663 // Free any existing matrix and linear solver
1664 if (A != NULL) { SUNMatDestroy(A); A = NULL; }
1665 if (LSA != NULL) { SUNLinSolFree(LSA); LSA = NULL; }
1666
1667 // Wrap linear solver as SUNLinearSolver and SUNMatrix
1669 MFEM_VERIFY(LSA, "error in SUNLinSolNewEmpty()");
1670
1671 LSA->content = this;
1672 LSA->ops->gettype = LSGetType;
1673 LSA->ops->solve = ARKStepSolver::LinSysSolve;
1674 LSA->ops->free = LSFree;
1675
1677 MFEM_VERIFY(A, "error in SUNMatNewEmpty()");
1678
1679 A->content = this;
1680 A->ops->getid = MatGetID;
1681 A->ops->destroy = MatDestroy;
1682
1683 // Attach the linear solver and matrix
1684 flag = MFEM_ARKode(SetLinearSolver)(sundials_mem, LSA, A);
1685 MFEM_VERIFY(flag == ARK_SUCCESS,
1686 "error in " STR(MFEM_ARKode(SetLinearSolver)) "()");
1687
1688 // Set the linear system evaluation function
1689 flag = MFEM_ARKode(SetLinSysFn)(sundials_mem, ARKStepSolver::LinSysSetup);
1690 MFEM_VERIFY(flag == ARK_SUCCESS,
1691 "error in " STR(MFEM_ARKode(SetLinSysFn)) "()");
1692}
1693
1695{
1696 // Free any existing matrix and linear solver
1697 if (A != NULL) { SUNMatDestroy(A); A = NULL; }
1698 if (LSA != NULL) { SUNLinSolFree(LSA); LSA = NULL; }
1699
1700 // Create linear solver
1702 MFEM_VERIFY(LSA, "error in SUNLinSol_SPGMR()");
1703
1704 // Attach linear solver
1705 flag = MFEM_ARKode(SetLinearSolver)(sundials_mem, LSA, NULL);
1706 MFEM_VERIFY(flag == ARK_SUCCESS,
1707 "error in " STR(MFEM_ARKode(SetLinearSolver)) "()");
1708}
1709
1711{
1712 // Free any existing matrix and linear solver
1713 if (M != NULL) { SUNMatDestroy(M); M = NULL; }
1714 if (LSM != NULL) { SUNLinSolFree(LSM); LSM = NULL; }
1715
1716 // Wrap linear solver as SUNLinearSolver and SUNMatrix
1718 MFEM_VERIFY(LSM, "error in SUNLinSolNewEmpty()");
1719
1720 LSM->content = this;
1721 LSM->ops->gettype = LSGetType;
1722 LSM->ops->solve = ARKStepSolver::MassSysSolve;
1723 LSM->ops->free = LSFree;
1724
1726 MFEM_VERIFY(M, "error in SUNMatNewEmpty()");
1727
1728 M->content = this;
1729 M->ops->getid = MatGetID;
1730 M->ops->matvec = ARKStepSolver::MassMult1;
1731 M->ops->destroy = MatDestroy;
1732
1733 // Attach the linear solver and matrix
1734 flag = MFEM_ARKode(SetMassLinearSolver)(sundials_mem, LSM, M, tdep);
1735 MFEM_VERIFY(flag == ARK_SUCCESS,
1736 "error in " STR(MFEM_ARKode(SetMassLinearSolver)) "()");
1737
1738 // Set the linear system function
1739 flag = MFEM_ARKode(SetMassFn)(sundials_mem, ARKStepSolver::MassSysSetup);
1740 MFEM_VERIFY(flag == ARK_SUCCESS,
1741 "error in " STR(MFEM_ARKode(SetMassFn)) "()");
1742
1743 // Check that the ODE is not expressed in EXPLICIT form
1744 MFEM_VERIFY(!f->isExplicit(), "ODE operator is expressed in EXPLICIT form")
1745}
1746
1748{
1749 // Free any existing matrix and linear solver
1750 if (M != NULL) { SUNMatDestroy(A); M = NULL; }
1751 if (LSM != NULL) { SUNLinSolFree(LSM); LSM = NULL; }
1752
1753 // Create linear solver
1755 MFEM_VERIFY(LSM, "error in SUNLinSol_SPGMR()");
1756
1757 // Attach linear solver
1758 flag = MFEM_ARKode(SetMassLinearSolver)(sundials_mem, LSM, NULL, tdep);
1759 MFEM_VERIFY(flag == ARK_SUCCESS,
1760 "error in " STR(MFEM_ARKode(SetMassLinearSolver)) "()");
1761
1762 // Attach matrix multiplication function
1763 flag = MFEM_ARKode(SetMassTimes)(sundials_mem, NULL,
1765 MFEM_VERIFY(flag == ARK_SUCCESS,
1766 "error in " STR(MFEM_ARKode(SetMassTimes)) "()");
1767
1768 // Check that the ODE is not expressed in EXPLICIT form
1769 MFEM_VERIFY(!f->isExplicit(), "ODE operator is expressed in EXPLICIT form")
1770}
1771
1773{
1774 step_mode = itask;
1775}
1776
1777void ARKStepSolver::SetSStolerances(double reltol, double abstol)
1778{
1779 flag = MFEM_ARKode(SStolerances)(sundials_mem, reltol, abstol);
1780 MFEM_VERIFY(flag == ARK_SUCCESS,
1781 "error in " STR(MFEM_ARKode(SStolerances)) "()");
1782}
1783
1784void ARKStepSolver::SetMaxStep(double dt_max)
1785{
1786 flag = MFEM_ARKode(SetMaxStep)(sundials_mem, dt_max);
1787 MFEM_VERIFY(flag == ARK_SUCCESS,
1788 "error in " STR(MFEM_ARKode(SetMaxStep)) "()");
1789}
1790
1792{
1793 flag = MFEM_ARKode(SetOrder)(sundials_mem, order);
1794 MFEM_VERIFY(flag == ARK_SUCCESS,
1795 "error in " STR(MFEM_ARKode(SetOrder)) "()");
1796}
1797
1799{
1800 flag = ARKStepSetTableNum(sundials_mem, ARKODE_DIRK_NONE, table_id);
1801 MFEM_VERIFY(flag == ARK_SUCCESS, "error in ARKStepSetTableNum()");
1802}
1803
1805{
1806 flag = ARKStepSetTableNum(sundials_mem, table_id, ARKODE_ERK_NONE);
1807 MFEM_VERIFY(flag == ARK_SUCCESS, "error in ARKStepSetTableNum()");
1808}
1809
1811 ARKODE_DIRKTableID itable_id)
1812{
1813 flag = ARKStepSetTableNum(sundials_mem, itable_id, etable_id);
1814 MFEM_VERIFY(flag == ARK_SUCCESS, "error in ARKStepSetTableNum()");
1815}
1816
1818{
1819 flag = MFEM_ARKode(SetFixedStep)(sundials_mem, dt);
1820 MFEM_VERIFY(flag == ARK_SUCCESS,
1821 "error in " STR(MFEM_ARKode(SetFixedStep)) "()");
1822}
1823
1825{
1826 long int nsteps, expsteps, accsteps, step_attempts;
1827 long int nfe_evals, nfi_evals;
1828 long int nlinsetups, netfails;
1829 double hinused, hlast, hcur, tcur;
1830 long int nniters, nncfails;
1831
1832 // Get integrator stats
1833 flag = ARKStepGetTimestepperStats(sundials_mem,
1834 &expsteps,
1835 &accsteps,
1836 &step_attempts,
1837 &nfe_evals,
1838 &nfi_evals,
1839 &nlinsetups,
1840 &netfails);
1841 MFEM_VERIFY(flag == ARK_SUCCESS, "error in ARKStepGetTimestepperStats()");
1842
1843 flag = MFEM_ARKode(GetStepStats)(sundials_mem,
1844 &nsteps,
1845 &hinused,
1846 &hlast,
1847 &hcur,
1848 &tcur);
1849
1850 // Get nonlinear solver stats
1851 flag = MFEM_ARKode(GetNonlinSolvStats)(sundials_mem,
1852 &nniters,
1853 &nncfails);
1854 MFEM_VERIFY(flag == ARK_SUCCESS,
1855 "error in " STR(MFEM_ARKode(GetNonlinSolvStats)) "()");
1856
1857 mfem::out <<
1858 "ARKStep:\n"
1859 "num steps: " << nsteps << "\n"
1860 "num exp rhs evals: " << nfe_evals << "\n"
1861 "num imp rhs evals: " << nfi_evals << "\n"
1862 "num lin setups: " << nlinsetups << "\n"
1863 "num nonlin sol iters: " << nniters << "\n"
1864 "num nonlin conv fail: " << nncfails << "\n"
1865 "num steps attempted: " << step_attempts << "\n"
1866 "num acc limited steps: " << accsteps << "\n"
1867 "num exp limited stepfails: " << expsteps << "\n"
1868 "num error test fails: " << netfails << "\n"
1869 "initial dt: " << hinused << "\n"
1870 "last dt: " << hlast << "\n"
1871 "current dt: " << hcur << "\n"
1872 "current t: " << tcur << "\n" << endl;
1873
1874 return;
1875}
1876
1878{
1879 delete Y;
1880 SUNMatDestroy(A);
1881 SUNLinSolFree(LSA);
1882 SUNNonlinSolFree(NLS);
1883 MFEM_ARKode(Free)(&sundials_mem);
1884}
1885
1886// ---------------------------------------------------------------------------
1887// KINSOL interface
1888// ---------------------------------------------------------------------------
1889
1890// Wrapper for evaluating the nonlinear residual F(u) = 0
1891int KINSolver::Mult(const N_Vector u, N_Vector fu, void *user_data)
1892{
1893 const SundialsNVector mfem_u(u);
1894 SundialsNVector mfem_fu(fu);
1895 KINSolver *self = static_cast<KINSolver*>(user_data);
1896
1897 // Compute the non-linear action F(u).
1898 self->oper->Mult(mfem_u, mfem_fu);
1899
1900 // Return success
1901 return 0;
1902}
1903
1904// Wrapper for computing Jacobian-vector products
1905int KINSolver::GradientMult(N_Vector v, N_Vector Jv, N_Vector u,
1906 sunbooleantype *new_u, void *user_data)
1907{
1908 const SundialsNVector mfem_v(v);
1909 SundialsNVector mfem_Jv(Jv);
1910 KINSolver *self = static_cast<KINSolver*>(user_data);
1911
1912 // Update Jacobian information if needed
1913 if (*new_u)
1914 {
1915 const SundialsNVector mfem_u(u);
1916 self->jacobian = &self->oper->GetGradient(mfem_u);
1917 *new_u = SUNFALSE;
1918 }
1919
1920 // Compute the Jacobian-vector product
1921 self->jacobian->Mult(mfem_v, mfem_Jv);
1922
1923 // Return success
1924 return 0;
1925}
1926
1927// Wrapper for evaluating linear systems J u = b
1928int KINSolver::LinSysSetup(N_Vector u, N_Vector, SUNMatrix J,
1929 void *, N_Vector, N_Vector )
1930{
1931 const SundialsNVector mfem_u(u);
1932 KINSolver *self = static_cast<KINSolver*>(GET_CONTENT(J));
1933
1934 // Update the Jacobian
1935 self->jacobian = &self->oper->GetGradient(mfem_u);
1936
1937 // Set the Jacobian solve operator
1938 self->prec->SetOperator(*self->jacobian);
1939
1940 // Return success
1941 return (0);
1942}
1943
1944// Wrapper for solving linear systems J u = b
1945int KINSolver::LinSysSolve(SUNLinearSolver LS, SUNMatrix, N_Vector u,
1946 N_Vector b, sunrealtype)
1947{
1948 SundialsNVector mfem_u(u), mfem_b(b);
1949 KINSolver *self = static_cast<KINSolver*>(GET_CONTENT(LS));
1950
1951 // Solve for u = [J(u)]^{-1} b, maybe approximately.
1952 self->prec->Mult(mfem_b, mfem_u);
1953
1954 // Return success
1955 return (0);
1956}
1957
1958int KINSolver::PrecSetup(N_Vector uu,
1959 N_Vector uscale,
1960 N_Vector fval,
1961 N_Vector fscale,
1962 void *user_data)
1963{
1964 SundialsNVector mfem_u(uu);
1965 KINSolver *self = static_cast<KINSolver *>(user_data);
1966
1967 // Update the Jacobian
1968 self->jacobian = &self->oper->GetGradient(mfem_u);
1969
1970 // Set the Jacobian solve operator
1971 self->prec->SetOperator(*self->jacobian);
1972
1973 return 0;
1974}
1975
1976int KINSolver::PrecSolve(N_Vector uu,
1977 N_Vector uscale,
1978 N_Vector fval,
1979 N_Vector fscale,
1980 N_Vector vv,
1981 void *user_data)
1982{
1983 KINSolver *self = static_cast<KINSolver *>(user_data);
1984 SundialsNVector mfem_v(vv);
1985
1986 self->wrk = 0.0;
1987
1988 // Solve for u = P^{-1} v
1989 self->prec->Mult(mfem_v, self->wrk);
1990
1991 mfem_v = self->wrk;
1992
1993 return 0;
1994}
1995
1996KINSolver::KINSolver(int strategy, bool oper_grad)
1997 : global_strategy(strategy), use_oper_grad(oper_grad), y_scale(NULL),
1998 f_scale(NULL), jacobian(NULL)
1999{
2000 Y = new SundialsNVector();
2001 y_scale = new SundialsNVector();
2002 f_scale = new SundialsNVector();
2003
2004 // Default abs_tol and print_level
2005#if MFEM_SUNDIALS_VERSION < 70000
2006 abs_tol = pow(UNIT_ROUNDOFF, 1.0/3.0);
2007#else
2008 abs_tol = pow(SUN_UNIT_ROUNDOFF, 1.0/3.0);
2009#endif
2010 print_level = 0;
2011}
2012
2013#ifdef MFEM_USE_MPI
2014KINSolver::KINSolver(MPI_Comm comm, int strategy, bool oper_grad)
2015 : global_strategy(strategy), use_oper_grad(oper_grad), y_scale(NULL),
2016 f_scale(NULL), jacobian(NULL)
2017{
2018 Y = new SundialsNVector(comm);
2019 y_scale = new SundialsNVector(comm);
2020 f_scale = new SundialsNVector(comm);
2021
2022 // Default abs_tol and print_level
2023#if MFEM_SUNDIALS_VERSION < 70000
2024 abs_tol = pow(UNIT_ROUNDOFF, 1.0/3.0);
2025#else
2026 abs_tol = pow(SUN_UNIT_ROUNDOFF, 1.0/3.0);
2027#endif
2028 print_level = 0;
2029}
2030#endif
2031
2032
2034{
2035 // Initialize the base class
2037 jacobian = NULL;
2038
2039 // Get the vector length
2040 long local_size = height;
2041#ifdef MFEM_USE_MPI
2042 long global_size;
2043#endif
2044
2045 if (Parallel())
2046 {
2047#ifdef MFEM_USE_MPI
2048 MPI_Allreduce(&local_size, &global_size, 1, MPI_LONG, MPI_SUM,
2049 Y->GetComm());
2050#endif
2051 }
2052
2053 if (sundials_mem)
2054 {
2055 // Check if the problem size has changed since the last SetOperator call
2056 int resize = 0;
2057 if (!Parallel())
2058 {
2059 resize = (Y->Size() != local_size);
2060 }
2061 else
2062 {
2063#ifdef MFEM_USE_MPI
2064 int l_resize = (Y->Size() != local_size) ||
2065 (saved_global_size != global_size);
2066 MPI_Allreduce(&l_resize, &resize, 1, MPI_INT, MPI_LOR,
2067 Y->GetComm());
2068#endif
2069 }
2070
2071 // Free existing solver memory and re-create with new vector size
2072 if (resize)
2073 {
2074 KINFree(&sundials_mem);
2075 sundials_mem = NULL;
2076 }
2077 }
2078
2079 if (!sundials_mem)
2080 {
2081 if (!Parallel())
2082 {
2083 Y->SetSize(local_size);
2084 }
2085#ifdef MFEM_USE_MPI
2086 else
2087 {
2088 Y->SetSize(local_size, global_size);
2089 y_scale->SetSize(local_size, global_size);
2090 f_scale->SetSize(local_size, global_size);
2091 saved_global_size = global_size;
2092 }
2093#endif
2094
2095 // Create the solver memory
2097 MFEM_VERIFY(sundials_mem, "Error in KINCreate().");
2098
2099 // Enable Anderson Acceleration
2100 if (aa_n > 0)
2101 {
2102 flag = KINSetMAA(sundials_mem, aa_n);
2103 MFEM_ASSERT(flag == KIN_SUCCESS, "error in KINSetMAA()");
2104
2105 flag = KINSetDelayAA(sundials_mem, aa_delay);
2106 MFEM_ASSERT(flag == KIN_SUCCESS, "error in KINSetDelayAA()");
2107
2108 flag = KINSetDampingAA(sundials_mem, aa_damping);
2109 MFEM_ASSERT(flag == KIN_SUCCESS, "error in KINSetDampingAA()");
2110
2111#if SUNDIALS_VERSION_MAJOR >= 6
2112 flag = KINSetOrthAA(sundials_mem, aa_orth);
2113 MFEM_ASSERT(flag == KIN_SUCCESS, "error in KINSetOrthAA()");
2114#endif
2115 }
2116
2117 // Initialize KINSOL
2118 flag = KINInit(sundials_mem, KINSolver::Mult, *Y);
2119 MFEM_VERIFY(flag == KIN_SUCCESS, "error in KINInit()");
2120
2121 // Attach the KINSolver as user-defined data
2122 flag = KINSetUserData(sundials_mem, this);
2123 MFEM_ASSERT(flag == KIN_SUCCESS, "error in KINSetUserData()");
2124
2125 flag = KINSetDamping(sundials_mem, fp_damping);
2126 MFEM_ASSERT(flag == KIN_SUCCESS, "error in KINSetDamping()");
2127
2128 // Set the linear solver
2129 if (prec || jfnk)
2130 {
2132 }
2133 else
2134 {
2135 // Free any existing linear solver
2136 if (A != NULL) { SUNMatDestroy(A); A = NULL; }
2137 if (LSA != NULL) { SUNLinSolFree(LSA); LSA = NULL; }
2138
2140 MFEM_VERIFY(LSA, "error in SUNLinSol_SPGMR()");
2141
2142 flag = KINSetLinearSolver(sundials_mem, LSA, NULL);
2143 MFEM_ASSERT(flag == KIN_SUCCESS, "error in KINSetLinearSolver()");
2144
2145 // Set Jacobian-vector product function
2146 if (use_oper_grad)
2147 {
2148 flag = KINSetJacTimesVecFn(sundials_mem, KINSolver::GradientMult);
2149 MFEM_ASSERT(flag == KIN_SUCCESS, "error in KINSetJacTimesVecFn()");
2150 }
2151 }
2152 }
2153}
2154
2156{
2157 if (jfnk)
2158 {
2159 SetJFNKSolver(solver);
2160 }
2161 else
2162 {
2163 // Store the solver
2164 prec = &solver;
2165
2166 // Free any existing linear solver
2167 if (A != NULL) { SUNMatDestroy(A); A = NULL; }
2168 if (LSA != NULL) { SUNLinSolFree(LSA); LSA = NULL; }
2169
2170 // Wrap KINSolver as SUNLinearSolver and SUNMatrix
2172 MFEM_VERIFY(LSA, "error in SUNLinSolNewEmpty()");
2173
2174 LSA->content = this;
2175 LSA->ops->gettype = LSGetType;
2176 LSA->ops->solve = KINSolver::LinSysSolve;
2177 LSA->ops->free = LSFree;
2178
2180 MFEM_VERIFY(A, "error in SUNMatNewEmpty()");
2181
2182 A->content = this;
2183 A->ops->getid = MatGetID;
2184 A->ops->destroy = MatDestroy;
2185
2186 // Attach the linear solver and matrix
2187 flag = KINSetLinearSolver(sundials_mem, LSA, A);
2188 MFEM_VERIFY(flag == KIN_SUCCESS, "error in KINSetLinearSolver()");
2189
2190 // Set the Jacobian evaluation function
2192 MFEM_VERIFY(flag == KIN_SUCCESS, "error in KINSetJacFn()");
2193 }
2194}
2195
2197{
2198 // Store the solver
2199 prec = &solver;
2200
2202
2203 // Free any existing linear solver
2204 if (A != NULL) { SUNMatDestroy(A); A = NULL; }
2205 if (LSA != NULL) { SUNLinSolFree(LSA); LSA = NULL; }
2206
2207 // Setup FGMRES
2210 MFEM_VERIFY(LSA, "error in SUNLinSol_SPFGMR()");
2211
2212 flag = SUNLinSol_SPFGMRSetMaxRestarts(LSA, maxlrs);
2213 MFEM_VERIFY(flag == SUN_SUCCESS, "error in SUNLinSol_SPFGMR()");
2214
2215 flag = KINSetLinearSolver(sundials_mem, LSA, NULL);
2216 MFEM_VERIFY(flag == KIN_SUCCESS, "error in KINSetLinearSolver()");
2217
2218 if (prec)
2219 {
2220 flag = KINSetPreconditioner(sundials_mem,
2223 MFEM_VERIFY(flag == KIN_SUCCESS, "error in KINSetPreconditioner()");
2224 }
2225}
2226
2228{
2229 flag = KINSetScaledStepTol(sundials_mem, sstol);
2230 MFEM_ASSERT(flag == KIN_SUCCESS, "error in KINSetScaledStepTol()");
2231}
2232
2234{
2235 flag = KINSetMaxSetupCalls(sundials_mem, max_calls);
2236 MFEM_ASSERT(flag == KIN_SUCCESS, "error in KINSetMaxSetupCalls()");
2237}
2238
2239void KINSolver::EnableAndersonAcc(int n, int orth, int delay, double damping)
2240{
2241 if (sundials_mem != nullptr)
2242 {
2243 if (aa_n < n)
2244 {
2245 MFEM_ABORT("Subsequent calls to EnableAndersonAcc() must set"
2246 " the subspace size to less or equal to the initially requested size."
2247 " If SetOperator() has already been called, the subspace size can't be"
2248 " increased.");
2249 }
2250
2251 flag = KINSetMAA(sundials_mem, n);
2252 MFEM_ASSERT(flag == KIN_SUCCESS, "error in KINSetMAA()");
2253
2254 flag = KINSetDelayAA(sundials_mem, delay);
2255 MFEM_ASSERT(flag == KIN_SUCCESS, "error in KINSetDelayAA()");
2256
2257 flag = KINSetDampingAA(sundials_mem, damping);
2258 MFEM_ASSERT(flag == KIN_SUCCESS, "error in KINSetDampingAA()");
2259
2260#if SUNDIALS_VERSION_MAJOR >= 6
2261 flag = KINSetOrthAA(sundials_mem, orth);
2262 MFEM_ASSERT(flag == KIN_SUCCESS, "error in KINSetOrthAA()");
2263#else
2264 if (orth != KIN_ORTH_MGS)
2265 {
2266 MFEM_WARNING("SUNDIALS < v6 does not support setting the Anderson"
2267 " acceleration orthogonalization routine!");
2268 }
2269#endif
2270 }
2271
2272 aa_n = n;
2273 aa_delay = delay;
2274 aa_damping = damping;
2275 aa_orth = orth;
2276}
2277
2278void KINSolver::SetDamping(double damping)
2279{
2280 fp_damping = damping;
2281 if (sundials_mem)
2282 {
2283 flag = KINSetDamping(sundials_mem, fp_damping);
2284 MFEM_ASSERT(flag == KIN_SUCCESS, "error in KINSetDamping()");
2285 }
2286}
2287
2289{
2290 MFEM_ABORT("this method is not supported! Use SetPrintLevel(int) instead.");
2291}
2292
2293// Compute the scaling vectors and solve nonlinear system
2294void KINSolver::Mult(const Vector&, Vector &x) const
2295{
2296 // residual norm tolerance
2297 double tol;
2298
2299 // Uses c = 1, corresponding to x_scale.
2300 c = 1.0;
2301
2302 if (!iterative_mode) { x = 0.0; }
2303
2304 // For relative tolerance, r = 1 / |residual(x)|, corresponding to fx_scale.
2305 if (rel_tol > 0.0)
2306 {
2307
2308 oper->Mult(x, r);
2309
2310 // Note that KINSOL uses infinity norms.
2311 double norm = r.Normlinf();
2312#ifdef MFEM_USE_MPI
2313 if (Parallel())
2314 {
2315 double lnorm = norm;
2316 MPI_Allreduce(&lnorm, &norm, 1, MPITypeMap<real_t>::mpi_type, MPI_MAX,
2317 Y->GetComm());
2318 }
2319#endif
2320 if (abs_tol > rel_tol * norm)
2321 {
2322 r = 1.0;
2323 tol = abs_tol;
2324 }
2325 else
2326 {
2327 r = 1.0 / norm;
2328 tol = rel_tol;
2329 }
2330 }
2331 else
2332 {
2333 r = 1.0;
2334 tol = abs_tol;
2335 }
2336
2337 // Set the residual norm tolerance
2338 flag = KINSetFuncNormTol(sundials_mem, tol);
2339 MFEM_ASSERT(flag == KIN_SUCCESS, "error in KINSetFuncNormTol()");
2340
2341 // Solve the nonlinear system by calling the other Mult method
2342 KINSolver::Mult(x, c, r);
2343}
2344
2345// Solve the nonlinear system using the provided scaling vectors
2347 const Vector &x_scale, const Vector &fx_scale) const
2348{
2349 flag = KINSetNumMaxIters(sundials_mem, max_iter);
2350 MFEM_ASSERT(flag == KIN_SUCCESS, "KINSetNumMaxIters() failed!");
2351
2352 Y->MakeRef(x, 0, x.Size());
2353 y_scale->MakeRef(const_cast<Vector&>(x_scale), 0, x_scale.Size());
2354 f_scale->MakeRef(const_cast<Vector&>(fx_scale), 0, fx_scale.Size());
2355
2356 int rank = -1;
2357 if (!Parallel())
2358 {
2359 rank = 0;
2360 }
2361 else
2362 {
2363#ifdef MFEM_USE_MPI
2364 MPI_Comm_rank(Y->GetComm(), &rank);
2365#endif
2366 }
2367
2368 if (rank == 0)
2369 {
2370#if MFEM_SUNDIALS_VERSION < 70000
2371 flag = KINSetPrintLevel(sundials_mem, print_level);
2372 MFEM_VERIFY(flag == KIN_SUCCESS, "KINSetPrintLevel() failed!");
2373#endif
2374 // NOTE: there is no KINSetPrintLevel in SUNDIALS v7!
2375
2376#ifdef SUNDIALS_BUILD_WITH_MONITORING
2377 if (jfnk && print_level)
2378 {
2379 flag = SUNLinSolSetInfoFile_SPFGMR(LSA, stdout);
2380 MFEM_VERIFY(flag == SUN_SUCCESS,
2381 "error in SUNLinSolSetInfoFile_SPFGMR()");
2382
2383 flag = SUNLinSolSetPrintLevel_SPFGMR(LSA, 1);
2384 MFEM_VERIFY(flag == SUN_SUCCESS,
2385 "error in SUNLinSolSetPrintLevel_SPFGMR()");
2386 }
2387#endif
2388 }
2389
2390 if (!iterative_mode) { x = 0.0; }
2391
2392 // Solve the nonlinear system
2394 converged = (flag >= 0);
2395
2396 // Make sure host is up to date
2397 Y->HostRead();
2398
2399 // Get number of nonlinear iterations
2400 long int tmp_nni;
2401 flag = KINGetNumNonlinSolvIters(sundials_mem, &tmp_nni);
2402 MFEM_ASSERT(flag == KIN_SUCCESS, "error in KINGetNumNonlinSolvIters()");
2403 final_iter = (int) tmp_nni;
2404
2405 // Get the residual norm
2406 flag = KINGetFuncNorm(sundials_mem, &final_norm);
2407 MFEM_ASSERT(flag == KIN_SUCCESS, "error in KINGetFuncNorm()");
2408}
2409
2411{
2412 delete Y;
2413 delete y_scale;
2414 delete f_scale;
2415 SUNMatDestroy(A);
2416 SUNLinSolFree(LSA);
2417 KINFree(&sundials_mem);
2418}
2419
2420} // namespace mfem
2421
2422#endif // MFEM_USE_SUNDIALS
Interface to ARKode's ARKStep module – additive Runge-Kutta methods.
Definition sundials.hpp:711
static int RHS2(sunrealtype t, const N_Vector y, N_Vector ydot, void *user_data)
void SetMaxStep(double dt_max)
Set the maximum time step.
void PrintInfo() const
Print various ARKStep statistics.
Type rk_type
Runge-Kutta type.
Definition sundials.hpp:722
ARKStepSolver(Type type=EXPLICIT)
Construct a serial wrapper to SUNDIALS' ARKode integrator.
void SetOrder(int order)
Chooses integration order for all explicit / implicit / IMEX methods.
void SetStepMode(int itask)
Select the ARKode step mode: ARK_NORMAL (default) or ARK_ONE_STEP.
Type
Types of ARKODE solvers.
Definition sundials.hpp:715
@ IMPLICIT
Implicit RK method.
Definition sundials.hpp:717
@ IMEX
Implicit-explicit ARK method.
Definition sundials.hpp:718
@ EXPLICIT
Explicit RK method.
Definition sundials.hpp:716
static int LinSysSolve(SUNLinearSolver LS, SUNMatrix A, N_Vector x, N_Vector b, sunrealtype tol)
Solve the linear system .
static int LinSysSetup(sunrealtype t, N_Vector y, N_Vector fy, SUNMatrix A, SUNMatrix M, sunbooleantype jok, sunbooleantype *jcur, sunrealtype gamma, void *user_data, N_Vector tmp1, N_Vector tmp2, N_Vector tmp3)
Setup the linear system .
void UseSundialsLinearSolver()
Attach a SUNDIALS GMRES linear solver to ARKode.
void Init(TimeDependentOperator &f_) override
Initialize ARKode: calls ARKStepCreate() to create the ARKStep memory and set some defaults.
static int MassMult2(N_Vector x, N_Vector v, sunrealtype t, void *mtimes_data)
Compute the matrix-vector product at time t.
void SetIRKTableNum(ARKODE_DIRKTableID table_id)
Choose a specific Butcher table for a diagonally implicit RK method.
void SetFixedStep(double dt)
Use a fixed time step size (disable temporal adaptivity).
void SetERKTableNum(ARKODE_ERKTableID table_id)
Choose a specific Butcher table for an explicit RK method.
int step_mode
ARKStep step mode (ARK_NORMAL or ARK_ONE_STEP).
Definition sundials.hpp:723
static int RHS1(sunrealtype t, const N_Vector y, N_Vector ydot, void *user_data)
void UseMFEMMassLinearSolver(int tdep)
Attach mass matrix linear system setup, solve, and matrix-vector product methods from the TimeDepende...
bool use_implicit
True for implicit or imex integration.
Definition sundials.hpp:724
void UseSundialsMassLinearSolver(int tdep)
Attach the SUNDIALS GMRES linear solver and the mass matrix matrix-vector product method from the Tim...
virtual ~ARKStepSolver()
Destroy the associated ARKode memory and SUNDIALS objects.
void SetIMEXTableNum(ARKODE_ERKTableID etable_id, ARKODE_DIRKTableID itable_id)
Choose a specific Butcher table for an IMEX RK method.
void UseMFEMLinearSolver()
Attach the linear system setup and solve methods from the TimeDependentOperator i....
static int MassMult1(SUNMatrix M, N_Vector x, N_Vector v)
Compute the matrix-vector product .
static int MassSysSolve(SUNLinearSolver LS, SUNMatrix M, N_Vector x, N_Vector b, sunrealtype tol)
Solve the linear system .
static int MassSysSetup(sunrealtype t, SUNMatrix M, void *user_data, N_Vector tmp1, N_Vector tmp2, N_Vector tmp3)
Setup the linear system .
void Step(Vector &x, real_t &t, real_t &dt) override
Integrate the ODE with ARKode using the specified step mode.
void SetSStolerances(double reltol, double abstol)
Set the scalar relative and scalar absolute tolerances.
static int RHSB(sunrealtype t, N_Vector y, N_Vector yB, N_Vector yBdot, void *user_dataB)
Wrapper to compute the ODE RHS backward function.
void EvalQuadIntegrationB(double t, Vector &dG_dp)
Evaluate Quadrature solution.
void EvalQuadIntegration(double t, Vector &q)
Evaluate Quadrature.
static int LinSysSetupB(sunrealtype t, N_Vector y, N_Vector yB, N_Vector fyB, SUNMatrix A, sunbooleantype jok, sunbooleantype *jcur, sunrealtype gamma, void *user_data, N_Vector tmp1, N_Vector tmp2, N_Vector tmp3)
Setup the linear system A x = b.
static int LinSysSolveB(SUNLinearSolver LS, SUNMatrix A, N_Vector x, N_Vector b, sunrealtype tol)
Solve the linear system A x = b.
static constexpr double default_abs_tolB
Default scalar backward absolute tolerance.
Definition sundials.hpp:589
void SetMaxNStepsB(int mxstepsB)
Set the maximum number of backward steps.
static int RHSQ(sunrealtype t, const N_Vector y, N_Vector qdot, void *user_data)
Wrapper to compute the ODE RHS Quadrature function.
void Step(Vector &x, double &t, double &dt) override
void InitB(TimeDependentAdjointOperator &f_)
Initialize the adjoint problem.
static int RHSQB(sunrealtype t, N_Vector y, N_Vector yB, N_Vector qBdot, void *user_dataB)
Wrapper to compute the ODE RHS Backwards Quadrature function.
SundialsNVector * q
Quadrature vector.
Definition sundials.hpp:581
int indexB
backward problem index
Definition sundials.hpp:562
void GetForwardSolution(double tB, mfem::Vector &yy)
Get Interpolated Forward solution y at backward integration time tB.
SUNLinearSolver LSB
Linear solver for A.
Definition sundials.hpp:580
void SetSVtolerancesB(double reltol, Vector abstol)
Tolerance specification functions for the adjoint problem.
void UseSundialsLinearSolverB()
Use built in SUNDIALS Newton solver.
void SetWFTolerances(EWTFunction func)
Set multiplicative error weights.
SundialsNVector * yy
State vector.
Definition sundials.hpp:583
void Init(TimeDependentAdjointOperator &f_)
SUNMatrix AB
Linear system A = I - gamma J, M - gamma J, or J.
Definition sundials.hpp:579
int ncheck
number of checkpoints used so far
Definition sundials.hpp:561
SundialsNVector * qB
State vector.
Definition sundials.hpp:584
static int ewt(N_Vector y, N_Vector w, void *user_data)
Error control function.
virtual void StepB(Vector &w, double &t, double &dt)
Solve one adjoint time step.
void InitQuadIntegrationB(mfem::Vector &qB0, double reltolQB=1e-3, double abstolQB=1e-8)
Initialize Quadrature Integration (Adjoint)
static constexpr double default_rel_tolB
Default scalar backward relative tolerance.
Definition sundials.hpp:587
void InitAdjointSolve(int steps, int interpolation)
Initialize Adjoint.
SundialsNVector * yB
State vector.
Definition sundials.hpp:582
void InitQuadIntegration(mfem::Vector &q0, double reltolQ=1e-3, double abstolQ=1e-8)
virtual ~CVODESSolver()
Destroy the associated CVODES memory and SUNDIALS objects.
void SetSStolerancesB(double reltol, double abstol)
Tolerance specification functions for the adjoint problem.
void UseMFEMLinearSolverB()
Set Linear Solver for the backward problem.
Interface to the CVODE library – linear multi-step methods.
Definition sundials.hpp:420
void SetStepMode(int itask)
Select the CVODE step mode: CV_NORMAL (default) or CV_ONE_STEP.
Definition sundials.cpp:896
void SetRootFinder(int components, RootFunction func)
Initialize Root Finder.
Definition sundials.cpp:688
std::function< int(sunrealtype t, Vector y, Vector gout, CVODESolver *)> RootFunction
Typedef for root finding functions.
Definition sundials.hpp:446
void SetSStolerances(double reltol, double abstol)
Set the scalar relative and scalar absolute tolerances.
Definition sundials.cpp:901
void Init(TimeDependentOperator &f_) override
Initialize CVODE: calls CVodeCreate() to create the CVODE memory and set some defaults.
Definition sundials.cpp:735
static int LinSysSolve(SUNLinearSolver LS, SUNMatrix A, N_Vector x, N_Vector b, sunrealtype tol)
Solve the linear system .
Definition sundials.cpp:711
virtual ~CVODESolver()
Destroy the associated CVODE memory and SUNDIALS objects.
Definition sundials.cpp:990
static int root(sunrealtype t, N_Vector y, sunrealtype *gout, void *user_data)
Prototype to define root finding for CVODE.
Definition sundials.cpp:675
void SetMaxNSteps(int steps)
Set the maximum number of time steps.
Definition sundials.cpp:925
CVODESolver(int lmm)
Construct a serial wrapper to SUNDIALS' CVODE integrator.
Definition sundials.cpp:721
long GetNumSteps()
Get the number of internal steps taken so far.
Definition sundials.cpp:931
void Step(Vector &x, double &t, double &dt) override
Integrate the ODE with CVODE using the specified step mode.
Definition sundials.cpp:823
EWTFunction ewt_func
A class member to facilitate pointing to a user-specified error weight function.
Definition sundials.hpp:456
static int RHS(sunrealtype t, const N_Vector y, N_Vector ydot, void *user_data)
Number of components in gout.
Definition sundials.cpp:658
void SetMaxStep(double dt_max)
Set the maximum time step.
Definition sundials.cpp:919
int lmm_type
Linear multistep method type.
Definition sundials.hpp:422
void PrintInfo() const
Print various CVODE statistics.
Definition sundials.cpp:945
void UseSundialsLinearSolver()
Attach SUNDIALS GMRES linear solver to CVODE.
Definition sundials.cpp:881
void UseMFEMLinearSolver()
Attach the linear system setup and solve methods from the TimeDependentOperator i....
Definition sundials.cpp:850
RootFunction root_func
A class member to facilitate pointing to a user-specified root function.
Definition sundials.hpp:449
std::function< int(Vector y, Vector w, CVODESolver *)> EWTFunction
Typedef declaration for error weight functions.
Definition sundials.hpp:452
int step_mode
CVODE step mode (CV_NORMAL or CV_ONE_STEP).
Definition sundials.hpp:423
static int LinSysSetup(sunrealtype t, N_Vector y, N_Vector fy, SUNMatrix A, sunbooleantype jok, sunbooleantype *jcur, sunrealtype gamma, void *user_data, N_Vector tmp1, N_Vector tmp2, N_Vector tmp3)
Setup the linear system .
Definition sundials.cpp:696
void SetSVtolerances(double reltol, Vector abstol)
Set the scalar relative and vector of absolute tolerances.
Definition sundials.cpp:907
void SetMaxOrder(int max_order)
Set the maximum method order.
Definition sundials.cpp:939
static MemoryType GetHostMemoryType()
Get the current Host MemoryType. This is the MemoryType used by most MFEM classes when allocating mem...
Definition device.hpp:265
static bool IsAvailable()
Return true if an actual device (e.g. GPU) has been configured.
Definition device.hpp:244
static MemoryType GetDeviceMemoryType()
Get the current Device MemoryType. This is the MemoryType used by most MFEM classes when allocating m...
Definition device.hpp:274
Wrapper for hypre's parallel vector class.
Definition hypre.hpp:219
real_t abs_tol
Absolute tolerance.
Definition solvers.hpp:177
real_t rel_tol
Relative tolerance.
Definition solvers.hpp:174
const Operator * oper
Definition solvers.hpp:140
int print_level
(DEPRECATED) Legacy print level definition, which is left for compatibility with custom iterative sol...
Definition solvers.hpp:150
int max_iter
Limit for the number of iterations the solver is allowed to do.
Definition solvers.hpp:171
Interface to the KINSOL library – nonlinear solver methods.
Definition sundials.hpp:887
SundialsNVector * f_scale
scaling vectors
Definition sundials.hpp:891
KINSolver(int strategy, bool oper_grad=true)
Construct a serial wrapper to SUNDIALS' KINSOL nonlinear solver.
double aa_damping
Anderson Acceleration damping.
Definition sundials.hpp:895
void SetJFNKSolver(Solver &solver)
virtual ~KINSolver()
Destroy the associated KINSOL memory.
static int Mult(const N_Vector u, N_Vector fu, void *user_data)
Wrapper to compute the nonlinear residual .
static int LinSysSolve(SUNLinearSolver LS, SUNMatrix J, N_Vector u, N_Vector b, sunrealtype tol)
Solve the linear system .
int aa_delay
Anderson Acceleration delay.
Definition sundials.hpp:894
int global_strategy
KINSOL solution strategy.
Definition sundials.hpp:889
bool jfnk
enable JFNK
Definition sundials.hpp:898
static int PrecSolve(N_Vector uu, N_Vector uscale, N_Vector fval, N_Vector fscale, N_Vector vv, void *user_data)
Solve the preconditioner equation .
int maxlrs
Maximum linear solver restarts.
Definition sundials.hpp:901
const Operator * jacobian
stores oper->GetGradient()
Definition sundials.hpp:892
int maxli
Maximum linear iterations.
Definition sundials.hpp:900
Vector wrk
Work vector needed for the JFNK PC.
Definition sundials.hpp:899
void SetScaledStepTol(double sstol)
Set KINSOL's scaled step tolerance.
void SetSolver(Solver &solver) override
Set the linear solver for inverting the Jacobian.
SundialsNVector * y_scale
Definition sundials.hpp:891
void SetOperator(const Operator &op) override
Set the nonlinear Operator of the system and initialize KINSOL.
void SetMaxSetupCalls(int max_calls)
Set maximum number of nonlinear iterations without a Jacobian update.
static int LinSysSetup(N_Vector u, N_Vector fu, SUNMatrix J, void *user_data, N_Vector tmp1, N_Vector tmp2)
Setup the linear system .
int aa_orth
Anderson Acceleration orthogonalization routine.
Definition sundials.hpp:896
static int PrecSetup(N_Vector uu, N_Vector uscale, N_Vector fval, N_Vector fscale, void *user_data)
Setup the preconditioner.
void SetDamping(double damping)
double fp_damping
Fixed Point or Picard damping parameter.
Definition sundials.hpp:897
void EnableAndersonAcc(int n, int orth=KIN_ORTH_MGS, int delay=0, double damping=1.0)
Enable Anderson Acceleration for KIN_FP or KIN_PICARD.
void SetPrintLevel(int print_lvl) override
Set the print level for the KINSetPrintLevel function.
bool use_oper_grad
use the Jv prod function
Definition sundials.hpp:890
int aa_n
number of acceleration vectors
Definition sundials.hpp:893
static int GradientMult(N_Vector v, N_Vector Jv, N_Vector u, sunbooleantype *new_u, void *user_data)
Wrapper to compute the Jacobian-vector product .
bool IsKnown(const void *h_ptr)
Return true if the pointer is known by the memory manager.
Class used by MFEM to store pointers to host and/or device memory.
void SetHostPtrOwner(bool own) const
Set/clear the ownership flag for the host pointer. Ownership indicates whether the pointer will be de...
void SetDevicePtrOwner(bool own) const
Set/clear the ownership flag for the device pointer. Ownership indicates whether the pointer will be ...
void Wrap(T *ptr, int size, bool own)
Wrap an externally allocated host pointer, ptr with the current host memory type returned by MemoryMa...
void Delete()
Delete the owned pointers and reset the Memory object.
void ClearOwnerFlags() const
Clear the ownership flags for the host and device pointers, as well as any internal data allocated by...
void SetOperator(const Operator &op) override
Also calls SetOperator for the preconditioner if there is one.
Definition solvers.cpp:1898
TimeDependentOperator * f
Pointer to the associated TimeDependentOperator.
Definition ode.hpp:113
virtual void Init(TimeDependentOperator &f_)
Associate a TimeDependentOperator with the ODE solver.
Definition ode.cpp:161
Abstract operator.
Definition operator.hpp:25
int Height() const
Get the height (size of output) of the Operator. Synonym with NumRows().
Definition operator.hpp:66
int height
Dimension of the output / number of rows in the matrix.
Definition operator.hpp:27
virtual void Mult(const Vector &x, Vector &y) const =0
Operator application: y=A(x).
virtual Operator & GetGradient(const Vector &x) const
Evaluate the gradient operator at the point x. The default behavior in class Operator is to generate ...
Definition operator.hpp:122
Base class for solvers.
Definition operator.hpp:780
bool iterative_mode
If true, use the second argument of Mult() as an initial guess.
Definition operator.hpp:783
virtual void SetOperator(const Operator &op)=0
Set/update the solver for the given operator.
SundialsMemHelper & operator=(const SundialsMemHelper &)=delete
Disable copy assignment.
SundialsMemHelper()=default
Default constructor – object must be moved to.
static int SundialsMemHelper_Alloc(SUNMemoryHelper helper, SUNMemory *memptr, size_t memsize, SUNMemoryType mem_type #if(SUNDIALS_VERSION_MAJOR >=6), void *queue #endif)
Definition sundials.cpp:265
static int SundialsMemHelper_Dealloc(SUNMemoryHelper helper, SUNMemory sunmem #if(SUNDIALS_VERSION_MAJOR >=6), void *queue #endif)
Definition sundials.cpp:309
Vector interface for SUNDIALS N_Vectors.
Definition sundials.hpp:212
long GlobalSize() const
Returns the MPI global length for the internal N_Vector x.
Definition sundials.hpp:278
MPI_Comm GetComm() const
Returns the MPI communicator for the internal N_Vector x.
Definition sundials.hpp:268
void MakeRef(Vector &base, int offset, int s)
Reset the Vector to be a reference to a sub-vector of base.
Definition sundials.hpp:295
void SetSize(int s, long glob_size=0)
Resize the vector to size s.
Definition sundials.cpp:542
static bool UseManagedMemory()
Definition sundials.hpp:349
static N_Vector MakeNVector(bool use_device)
Create a N_Vector.
Definition sundials.cpp:560
~SundialsNVector()
Calls SUNDIALS N_VDestroy function if the N_Vector is owned by 'this'.
Definition sundials.cpp:528
void SetDataAndSize(double *d, int s, long glob_size=0)
Set the vector data and size.
Definition sundials.cpp:554
void _SetNvecDataAndSize_(long glob_size=0)
Set data and length of internal N_Vector x from 'this'.
Definition sundials.cpp:347
void _SetDataAndSize_()
Set data and length from the internal N_Vector x.
Definition sundials.cpp:421
N_Vector x
The actual SUNDIALS object.
Definition sundials.hpp:217
void SetData(double *d)
Definition sundials.cpp:548
bool MPIPlusX() const
Definition sundials.hpp:331
SundialsNVector()
Creates an empty SundialsNVector.
Definition sundials.cpp:469
static constexpr double default_abs_tol
Default scalar absolute tolerance.
Definition sundials.hpp:389
SUNMatrix M
Mass matrix M.
Definition sundials.hpp:374
int flag
Last flag returned from a call to SUNDIALS.
Definition sundials.hpp:367
long saved_global_size
Global vector length on last initialization.
Definition sundials.hpp:369
static constexpr double default_rel_tol
Default scalar relative tolerance.
Definition sundials.hpp:387
bool reinit
Flag to signal memory reinitialization is need.
Definition sundials.hpp:368
SundialsNVector * Y
State vector.
Definition sundials.hpp:371
void * sundials_mem
SUNDIALS mem structure.
Definition sundials.hpp:366
SUNLinearSolver LSA
Linear solver for A.
Definition sundials.hpp:375
SUNLinearSolver LSM
Linear solver for M.
Definition sundials.hpp:376
SUNNonlinearSolver NLS
Nonlinear solver.
Definition sundials.hpp:377
bool Parallel() const
Definition sundials.hpp:380
Singleton class for SUNContext and SundialsMemHelper objects.
Definition sundials.hpp:176
static SUNContext & GetContext()
Provides access to the SUNContext object.
Definition sundials.cpp:185
static SundialsMemHelper & GetMemHelper()
Provides access to the SundialsMemHelper object.
Definition sundials.cpp:190
static void Init()
Definition sundials.cpp:174
int GetAdjointHeight()
Returns the size of the adjoint problem state space.
Definition operator.hpp:718
Base abstract class for first order time dependent operators.
Definition operator.hpp:332
bool isExplicit() const
True if type is EXPLICIT.
Definition operator.hpp:397
virtual int SUNImplicitSolve(const Vector &r, Vector &dk, real_t tol)
Solve the ODE linear system A dk = r , where A and r are defined by the method SUNImplicitSetup().
Definition operator.cpp:327
virtual int SUNMassMult(const Vector &x, Vector &v)
Compute the mass matrix-vector product v = M x, where M is defined by the method SUNMassSetup().
Definition operator.cpp:345
virtual int SUNMassSetup()
Setup the mass matrix in the ODE system .
Definition operator.cpp:333
void Mult(const Vector &u, Vector &k) const override
Perform the action of the operator (u,t) -> k(u,t) where t is the current time set by SetTime() and k...
Definition operator.cpp:293
virtual int SUNMassSolve(const Vector &b, Vector &x, real_t tol)
Solve the mass matrix linear system M x = b, where M is defined by the method SUNMassSetup().
Definition operator.cpp:339
virtual void ExplicitMult(const Vector &u, Vector &v) const
Perform the action of the explicit part of the operator, G: v = G(u, t) where t is the current time.
Definition operator.cpp:282
virtual void SetEvalMode(const EvalMode new_eval_mode)
Set the evaluation mode of the time-dependent operator.
Definition operator.hpp:417
virtual void SetTime(const real_t t_)
Set the current time.
Definition operator.hpp:394
virtual int SUNImplicitSetup(const Vector &y, const Vector &v, int jok, int *jcur, real_t gamma)
Setup a linear system as needed by some SUNDIALS ODE solvers to perform a similar action to ImplicitS...
Definition operator.cpp:319
virtual real_t GetTime() const
Read the currently set time.
Definition operator.hpp:391
Vector data type.
Definition vector.hpp:82
virtual const real_t * HostRead() const
Shortcut for mfem::Read(vec.GetMemory(), vec.Size(), false).
Definition vector.hpp:498
virtual real_t * ReadWrite(bool on_dev=true)
Shortcut for mfem::ReadWrite(vec.GetMemory(), vec.Size(), on_dev).
Definition vector.hpp:510
void SetDataAndSize(real_t *d, int s)
Set the Vector data and size.
Definition vector.hpp:183
real_t Normlinf() const
Returns the l_infinity norm of the vector.
Definition vector.cpp:972
Memory< real_t > data
Definition vector.hpp:85
Vector & Set(const real_t a, const Vector &x)
(*this) = a * x
Definition vector.cpp:337
virtual bool UseDevice() const
Return the device flag of the Memory object used by the Vector.
Definition vector.hpp:147
int Size() const
Returns the size of the vector.
Definition vector.hpp:226
void SetSize(int s)
Resize the vector to size s.
Definition vector.hpp:558
void SetData(real_t *d)
Definition vector.hpp:176
virtual real_t * HostReadWrite()
Shortcut for mfem::ReadWrite(vec.GetMemory(), vec.Size(), false).
Definition vector.hpp:514
real_t b
Definition lissajous.cpp:42
T * HostReadWrite(Memory< T > &mem, int size)
Shortcut to ReadWrite(Memory<T> &mem, int size, false)
Definition device.hpp:382
real_t u(const Vector &xvec)
Definition lor_mms.hpp:22
OutStream out(std::cout)
Global stream used by the library for standard output. Initially it uses the same std::streambuf as s...
Definition globals.hpp:66
T * ReadWrite(Memory< T > &mem, int size, bool on_dev=true)
Get a pointer for read+write access to mem with the mfem::Device's DeviceMemoryClass,...
Definition device.hpp:375
MemoryManager mm
The (single) global memory manager object.
float real_t
Definition config.hpp:43
Settings for the output behavior of the IterativeSolver.
Definition solvers.hpp:98
Helper struct to convert a C++ type to an MPI type.
MFEM_DEPRECATED SUNLinearSolver SUNLinSolNewEmpty(SUNContext)
Definition sundials.cpp:68
MFEM_DEPRECATED N_Vector N_VMake_MPIPlusX(MPI_Comm comm, N_Vector local_vector, SUNContext)
Definition sundials.cpp:150
MFEM_DEPRECATED void * KINCreate(SUNContext)
Definition sundials.cpp:106
MFEM_DEPRECATED SUNLinearSolver SUNLinSol_SPFGMR(N_Vector y, int pretype, int maxl, SUNContext)
Definition sundials.cpp:83
MFEM_DEPRECATED void * CVodeCreate(int lmm, SUNContext)
Definition sundials.cpp:91
MFEM_DEPRECATED N_Vector N_VNewEmpty_Parallel(MPI_Comm comm, sunindextype local_length, sunindextype global_length, SUNContext)
Definition sundials.cpp:115
MFEM_DEPRECATED SUNMatrix SUNMatNewEmpty(SUNContext)
Definition sundials.cpp:61
MFEM_DEPRECATED SUNMemoryHelper SUNMemoryHelper_NewEmpty(SUNContext)
Definition sundials.cpp:139
MFEM_DEPRECATED N_Vector SUN_Hip_OR_Cuda N_VNewWithMemHelp(sunindextype length, sunbooleantype use_managed_mem, SUNMemoryHelper helper, SUNContext)
Definition sundials.cpp:129
MFEM_DEPRECATED void * ARKStepCreate(ARKRhsFn fe, ARKRhsFn fi, sunrealtype t0, N_Vector y0, SUNContext)
Definition sundials.cpp:98
MFEM_DEPRECATED N_Vector N_VNewEmpty_Serial(sunindextype vec_length, SUNContext)
Definition sundials.cpp:54
MFEM_DEPRECATED SUNLinearSolver SUNLinSol_SPGMR(N_Vector y, int pretype, int maxl, SUNContext)
Definition sundials.cpp:75
int ARKODE_DIRKTableID
Definition sundials.hpp:66
realtype sunrealtype
'sunrealtype' was first introduced in v6.0.0
Definition sundials.hpp:76
int ARKODE_ERKTableID
Definition sundials.hpp:65
@ SUN_PREC_NONE
Definition sundials.hpp:81
@ SUN_PREC_RIGHT
Definition sundials.hpp:81
booleantype sunbooleantype
'sunbooleantype' was first introduced in v6.0.0
Definition sundials.hpp:78
constexpr ARKODE_ERKTableID ARKODE_ERK_NONE
Definition sundials.hpp:67
constexpr ARKODE_DIRKTableID ARKODE_DIRK_NONE
Definition sundials.hpp:68
@ SUN_SUCCESS
Definition sundials.hpp:95
void * SUNContext
Definition sundials.hpp:73