MFEM v4.9.0
Finite element discretization library
Loading...
Searching...
No Matches
sundials.cpp
Go to the documentation of this file.
1// Copyright (c) 2010-2025, Lawrence Livermore National Security, LLC. Produced
2// at the Lawrence Livermore National Laboratory. All Rights reserved. See files
3// LICENSE and NOTICE for details. LLNL-CODE-806117.
4//
5// This file is part of the MFEM library. For more information and source code
6// availability visit https://mfem.org.
7//
8// MFEM is free software; you can redistribute it and/or modify it under the
9// terms of the BSD-3 license. We welcome feedback and contributions, see file
10// CONTRIBUTING.md for details.
11
12#include "sundials.hpp"
13
14#ifdef MFEM_USE_SUNDIALS
15
16#include "solvers.hpp"
17#ifdef MFEM_USE_MPI
18#include "hypre.hpp"
19#endif
20
21// SUNDIALS vectors
22#include <nvector/nvector_serial.h>
23#if defined(MFEM_USE_CUDA)
24#include <nvector/nvector_cuda.h>
25#elif defined(MFEM_USE_HIP)
26#include <nvector/nvector_hip.h>
27#endif
28#ifdef MFEM_USE_MPI
29#include <nvector/nvector_mpiplusx.h>
30#include <nvector/nvector_parallel.h>
31#endif
32
33// SUNDIALS linear solvers
34#include <sunlinsol/sunlinsol_spgmr.h>
35#include <sunlinsol/sunlinsol_spfgmr.h>
36
37// Access SUNDIALS object's content pointer
38#define GET_CONTENT(X) ( X->content )
39
40#if defined(MFEM_USE_CUDA)
41#define SUN_Hip_OR_Cuda(X) X##_Cuda
42#define SUN_HIP_OR_CUDA(X) X##_CUDA
43#elif defined(MFEM_USE_HIP)
44#define SUN_Hip_OR_Cuda(X) X##_Hip
45#define SUN_HIP_OR_CUDA(X) X##_HIP
46#endif
47
48using namespace std;
49
50#if (SUNDIALS_VERSION_MAJOR < 6)
51
52/// (DEPRECATED) Wrapper function for backwards compatibility with SUNDIALS
53/// version < 6
54MFEM_DEPRECATED N_Vector N_VNewEmpty_Serial(sunindextype vec_length, SUNContext)
55{
56 return N_VNewEmpty_Serial(vec_length);
57}
58
59/// (DEPRECATED) Wrapper function for backwards compatibility with SUNDIALS
60/// version < 6
61MFEM_DEPRECATED SUNMatrix SUNMatNewEmpty(SUNContext)
62{
63 return SUNMatNewEmpty();
64}
65
66/// (DEPRECATED) Wrapper function for backwards compatibility with SUNDIALS
67/// version < 6
68MFEM_DEPRECATED SUNLinearSolver SUNLinSolNewEmpty(SUNContext)
69{
70 return SUNLinSolNewEmpty();
71}
72
73/// (DEPRECATED) Wrapper function for backwards compatibility with SUNDIALS
74/// version < 6
75MFEM_DEPRECATED SUNLinearSolver SUNLinSol_SPGMR(N_Vector y, int pretype,
76 int maxl, SUNContext)
77{
78 return SUNLinSol_SPGMR(y, pretype, maxl);
79}
80
81/// (DEPRECATED) Wrapper function for backwards compatibility with SUNDIALS
82/// version < 6
83MFEM_DEPRECATED SUNLinearSolver SUNLinSol_SPFGMR(N_Vector y, int pretype,
84 int maxl, SUNContext)
85{
86 return SUNLinSol_SPFGMR(y, pretype, maxl);
87}
88
89/// (DEPRECATED) Wrapper function for backwards compatibility with SUNDIALS
90/// version < 6
91MFEM_DEPRECATED void* CVodeCreate(int lmm, SUNContext)
92{
93 return CVodeCreate(lmm);
94}
95
96/// (DEPRECATED) Wrapper function for backwards compatibility with SUNDIALS
97/// version < 6
98MFEM_DEPRECATED void* ARKStepCreate(ARKRhsFn fe, ARKRhsFn fi, sunrealtype t0,
99 N_Vector y0, SUNContext)
100{
101 return ARKStepCreate(fe, fi, t0, y0);
102}
103
104/// (DEPRECATED) Wrapper function for backwards compatibility with SUNDIALS
105/// version < 6
106MFEM_DEPRECATED void* KINCreate(SUNContext)
107{
108 return KINCreate();
109}
110
111#ifdef MFEM_USE_MPI
112
113/// (DEPRECATED) Wrapper function for backwards compatibility with SUNDIALS
114/// version < 6
115MFEM_DEPRECATED N_Vector N_VNewEmpty_Parallel(MPI_Comm comm,
116 sunindextype local_length,
117 sunindextype global_length,
119{
120 return N_VNewEmpty_Parallel(comm, local_length, global_length);
121}
122
123#endif // MFEM_USE_MPI
124
125#if defined(MFEM_USE_CUDA) || defined(MFEM_USE_HIP)
126
127/// (DEPRECATED) Wrapper function for backwards compatibility with SUNDIALS
128/// version < 6
129MFEM_DEPRECATED N_Vector SUN_Hip_OR_Cuda(N_VNewWithMemHelp)(sunindextype length,
130 sunbooleantype use_managed_mem,
131 SUNMemoryHelper helper,
133{
134 return SUN_Hip_OR_Cuda(N_VNewWithMemHelp)(length, use_managed_mem, helper);
135}
136
137/// (DEPRECATED) Wrapper function for backwards compatibility with SUNDIALS
138/// version < 6
139MFEM_DEPRECATED SUNMemoryHelper SUNMemoryHelper_NewEmpty(SUNContext)
140{
142}
143
144#endif // MFEM_USE_CUDA || MFEM_USE_HIP
145
146#if defined(MFEM_USE_MPI) && (defined(MFEM_USE_CUDA) || defined(MFEM_USE_HIP))
147
148/// (DEPRECATED) Wrapper function for backwards compatibility with SUNDIALS
149/// version < 6
150MFEM_DEPRECATED N_Vector N_VMake_MPIPlusX(MPI_Comm comm, N_Vector local_vector,
152{
153 return N_VMake_MPIPlusX(comm, local_vector);
154}
155
156#endif // MFEM_USE_MPI && (MFEM_USE_CUDA || MFEM_USE_HIP)
157
158#endif // SUNDIALS_VERSION_MAJOR < 6
159
160#if MFEM_SUNDIALS_VERSION < 70100
161#define MFEM_ARKode(FUNC) ARKStep##FUNC
162#else
163#define MFEM_ARKode(FUNC) ARKode##FUNC
164#endif
165
166// Macro STR(): expand the argument and add double quotes
167#define STR1(s) #s
168#define STR(s) STR1(s)
169
170
171namespace mfem
172{
173
175{
176 Sundials::Instance();
177}
178
179Sundials &Sundials::Instance()
180{
181 static Sundials sundials;
182 return sundials;
183}
184
186{
187 return Sundials::Instance().context;
188}
189
191{
192 return Sundials::Instance().memHelper;
193}
194
195#if (SUNDIALS_VERSION_MAJOR >= 6)
196
197Sundials::Sundials()
198{
199#ifdef MFEM_USE_MPI
200 int mpi_initialized = 0;
201 MPI_Initialized(&mpi_initialized);
202 MPI_Comm communicator = mpi_initialized ? MPI_COMM_WORLD : MPI_COMM_NULL;
203#if SUNDIALS_VERSION_MAJOR < 7
204 int return_val = SUNContext_Create((void*) &communicator, &context);
205#else
206 int return_val = SUNContext_Create(communicator, &context);
207#endif
208#else // #ifdef MFEM_USE_MPI
209#if SUNDIALS_VERSION_MAJOR < 7
210 int return_val = SUNContext_Create(nullptr, &context);
211#else
212 int return_val = SUNContext_Create((SUNComm)(0), &context);
213#endif
214#endif // #ifdef MFEM_USE_MPI
215 MFEM_VERIFY(return_val == 0, "Call to SUNContext_Create failed");
216 SundialsMemHelper actual_helper(context);
217 memHelper = std::move(actual_helper);
218 isInitialized = true;
219}
220
221Sundials::~Sundials()
222{
224}
225
227{
228 Sundials& sundials = Instance();
229 if (sundials.isInitialized)
230 {
231 SUNContext context = GetContext();
232 SUNContext_Free(&context);
233 sundials.isInitialized = false;
234 }
235}
236
237#else // SUNDIALS_VERSION_MAJOR >= 6
238
239Sundials::Sundials()
240{
241 // Do nothing
242}
243
244Sundials::~Sundials()
245{
246 // Do nothing
247}
248
250{
251 // Do nothing
252}
253
254#endif // SUNDIALS_VERSION_MAJOR >= 6
255
256#if defined(MFEM_USE_CUDA) || defined(MFEM_USE_HIP)
258{
259 /* Allocate helper */
260 h = SUNMemoryHelper_NewEmpty(context);
261
262 /* Set the ops */
263 h->ops->alloc = SundialsMemHelper_Alloc;
264 h->ops->dealloc = SundialsMemHelper_Dealloc;
265 h->ops->copy = SUN_Hip_OR_Cuda(SUNMemoryHelper_Copy);
266 h->ops->copyasync = SUN_Hip_OR_Cuda(SUNMemoryHelper_CopyAsync);
267}
268
270{
271 this->h = that_helper.h;
272 that_helper.h = nullptr;
273}
274
276{
277 this->h = rhs.h;
278 rhs.h = nullptr;
279 return *this;
280}
281
283 SUNMemory* memptr, size_t memsize,
284 SUNMemoryType mem_type
285#if (SUNDIALS_VERSION_MAJOR >= 6)
286 , void*
287#endif
288 )
289{
290#if (SUNDIALS_VERSION_MAJOR < 7)
291 SUNMemory sunmem = SUNMemoryNewEmpty();
292#else
293 SUNMemory sunmem = SUNMemoryNewEmpty(helper->sunctx);
294#endif
295
296 sunmem->ptr = NULL;
297 sunmem->own = SUNTRUE;
298
299 // memsize is the number of bytes to allocate, so we use Memory<char>
300 if (mem_type == SUNMEMTYPE_HOST)
301 {
303 mem.SetHostPtrOwner(false);
304 sunmem->ptr = mfem::HostReadWrite(mem, memsize);
305 sunmem->type = SUNMEMTYPE_HOST;
306 mem.Delete();
307 }
308 else if (mem_type == SUNMEMTYPE_DEVICE || mem_type == SUNMEMTYPE_UVM)
309 {
311 mem.SetDevicePtrOwner(false);
312 sunmem->ptr = mfem::ReadWrite(mem, memsize);
313 sunmem->type = mem_type;
314 mem.Delete();
315 }
316 else
317 {
318 free(sunmem);
319 return -1;
320 }
321
322 *memptr = sunmem;
323 return 0;
324}
325
327 SUNMemory sunmem
328#if (SUNDIALS_VERSION_MAJOR >= 6)
329 , void*
330#endif
331 )
332{
333 if (sunmem->ptr && sunmem->own && !mm.IsKnown(sunmem->ptr))
334 {
335 if (sunmem->type == SUNMEMTYPE_HOST)
336 {
337 Memory<char> mem(static_cast<char*>(sunmem->ptr), 1,
339 mem.Delete();
340 }
341 else if (sunmem->type == SUNMEMTYPE_DEVICE || sunmem->type == SUNMEMTYPE_UVM)
342 {
343 Memory<char> mem(static_cast<char*>(sunmem->ptr), 1,
345 mem.Delete();
346 }
347 else
348 {
349 MFEM_ABORT("Invalid SUNMEMTYPE");
350 return -1;
351 }
352 }
353 free(sunmem);
354 return 0;
355}
356
357#endif // MFEM_USE_CUDA || MFEM_USE_HIP
358
359
360// ---------------------------------------------------------------------------
361// SUNDIALS N_Vector interface functions
362// ---------------------------------------------------------------------------
363
365{
366#ifdef MFEM_USE_MPI
367 N_Vector local_x = MPIPlusX() ? N_VGetLocalVector_MPIPlusX(x) : x;
368#else
369 N_Vector local_x = x;
370#endif
371 N_Vector_ID id = N_VGetVectorID(local_x);
372
373 // Set the N_Vector data and length from the Vector data and size.
374 switch (id)
375 {
376 case SUNDIALS_NVEC_SERIAL:
377 {
378 MFEM_ASSERT(NV_OWN_DATA_S(local_x) == SUNFALSE, "invalid serial N_Vector");
379 NV_DATA_S(local_x) = HostReadWrite();
380 NV_LENGTH_S(local_x) = size;
381 break;
382 }
383#if defined(MFEM_USE_CUDA) || defined(MFEM_USE_HIP)
384 case SUN_HIP_OR_CUDA(SUNDIALS_NVEC):
385 {
386 SUN_Hip_OR_Cuda(N_VSetHostArrayPointer)(HostReadWrite(), local_x);
387 SUN_Hip_OR_Cuda(N_VSetDeviceArrayPointer)(ReadWrite(), local_x);
388 static_cast<SUN_Hip_OR_Cuda(N_VectorContent)>(GET_CONTENT(
389 local_x))->length = size;
390 break;
391 }
392#endif
393#ifdef MFEM_USE_MPI
394 case SUNDIALS_NVEC_PARALLEL:
395 {
396 MFEM_ASSERT(NV_OWN_DATA_P(x) == SUNFALSE, "invalid parallel N_Vector");
397 NV_DATA_P(x) = HostReadWrite();
398 NV_LOCLENGTH_P(x) = size;
399 if (glob_size == 0)
400 {
401 glob_size = GlobalSize();
402
403 if (glob_size == 0 && glob_size != size)
404 {
405 long local_size = size;
406 MPI_Allreduce(&local_size, &glob_size, 1, MPI_LONG,
407 MPI_SUM, GetComm());
408 }
409 }
410 NV_GLOBLENGTH_P(x) = glob_size;
411 break;
412 }
413#endif
414 default:
415 MFEM_ABORT("N_Vector type " << id << " is not supported");
416 }
417
418#ifdef MFEM_USE_MPI
419 if (MPIPlusX())
420 {
421 if (glob_size == 0)
422 {
423 glob_size = GlobalSize();
424
425 if (glob_size == 0 && glob_size != size)
426 {
427 long local_size = size;
428 MPI_Allreduce(&local_size, &glob_size, 1, MPI_LONG,
429 MPI_SUM, GetComm());
430 }
431 }
432 static_cast<N_VectorContent_MPIManyVector>(GET_CONTENT(x))->global_length =
433 glob_size;
434 }
435#endif
436}
437
439{
440#ifdef MFEM_USE_MPI
441 N_Vector local_x = MPIPlusX() ? N_VGetLocalVector_MPIPlusX(x) : x;
442#else
443 N_Vector local_x = x;
444#endif
445 N_Vector_ID id = N_VGetVectorID(local_x);
446
447 // The SUNDIALS NVector owns the data if it created it.
448 switch (id)
449 {
450 case SUNDIALS_NVEC_SERIAL:
451 {
452 const bool known = mm.IsKnown(NV_DATA_S(local_x));
453 size = NV_LENGTH_S(local_x);
454 data.Wrap(NV_DATA_S(local_x), size, false);
455 if (known) { data.ClearOwnerFlags(); }
456 break;
457 }
458#if defined(MFEM_USE_CUDA) || defined(MFEM_USE_HIP)
459 case SUN_HIP_OR_CUDA(SUNDIALS_NVEC):
460 {
461 double *h_ptr = SUN_Hip_OR_Cuda(N_VGetHostArrayPointer)(local_x);
462 double *d_ptr = SUN_Hip_OR_Cuda(N_VGetDeviceArrayPointer)(local_x);
463 const bool known = mm.IsKnown(h_ptr);
464 size = SUN_Hip_OR_Cuda(N_VGetLength)(local_x);
465 data.Wrap(h_ptr, d_ptr, size, Device::GetHostMemoryType(), false, false, true);
466 if (known) { data.ClearOwnerFlags(); }
467 UseDevice(true);
468 break;
469 }
470#endif
471#ifdef MFEM_USE_MPI
472 case SUNDIALS_NVEC_PARALLEL:
473 {
474 const bool known = mm.IsKnown(NV_DATA_P(x));
475 size = NV_LENGTH_S(x);
476 data.Wrap(NV_DATA_P(x), NV_LOCLENGTH_P(x), false);
477 if (known) { data.ClearOwnerFlags(); }
478 break;
479 }
480#endif
481 default:
482 MFEM_ABORT("N_Vector type " << id << " is not supported");
483 }
484}
485
487 : Vector()
488{
489 // MFEM creates and owns the data,
490 // and provides it to the SUNDIALS NVector.
493 own_NVector = 1;
494}
495
496SundialsNVector::SundialsNVector(double *data_, int size_)
497 : Vector(data_, size_)
498{
501 own_NVector = 1;
503}
504
506 : x(nv)
507{
509 own_NVector = 0;
510}
511
512#ifdef MFEM_USE_MPI
514 : Vector()
515{
517 x = MakeNVector(comm, UseDevice());
518 own_NVector = 1;
519}
520
521SundialsNVector::SundialsNVector(MPI_Comm comm, int loc_size, long glob_size)
522 : Vector(loc_size)
523{
525 x = MakeNVector(comm, UseDevice());
526 own_NVector = 1;
527 _SetNvecDataAndSize_(glob_size);
528}
529
530SundialsNVector::SundialsNVector(MPI_Comm comm, double *data_, int loc_size,
531 long glob_size)
532 : Vector(data_, loc_size)
533{
535 x = MakeNVector(comm, UseDevice());
536 own_NVector = 1;
537 _SetNvecDataAndSize_(glob_size);
538}
539
541 : SundialsNVector(vec.GetComm(), vec.GetData(), vec.Size(), vec.GlobalSize())
542{}
543#endif
544
546{
547 if (own_NVector)
548 {
549#ifdef MFEM_USE_MPI
550 if (MPIPlusX())
551 {
552 N_VDestroy(N_VGetLocalVector_MPIPlusX(x));
553 }
554#endif
555 N_VDestroy(x);
556 }
557}
558
559void SundialsNVector::SetSize(int s, long glob_size)
560{
562 _SetNvecDataAndSize_(glob_size);
563}
564
566{
569}
570
571void SundialsNVector::SetDataAndSize(double *d, int s, long glob_size)
572{
574 _SetNvecDataAndSize_(glob_size);
575}
576
577N_Vector SundialsNVector::MakeNVector(bool use_device)
578{
579 N_Vector x;
580#if defined(MFEM_USE_CUDA) || defined(MFEM_USE_HIP)
581 if (use_device)
582 {
583 x = SUN_Hip_OR_Cuda(N_VNewWithMemHelp)(0, UseManagedMemory(),
586 }
587 else
588 {
590 }
591#else
593#endif
594
595 MFEM_VERIFY(x, "Error in SundialsNVector::MakeNVector.");
596
597 return x;
598}
599
600#ifdef MFEM_USE_MPI
601N_Vector SundialsNVector::MakeNVector(MPI_Comm comm, bool use_device)
602{
603 N_Vector x;
604
605 if (comm == MPI_COMM_NULL)
606 {
607 x = MakeNVector(use_device);
608 }
609 else
610 {
611#if defined(MFEM_USE_CUDA) || defined(MFEM_USE_HIP)
612 if (use_device)
613 {
614 x = N_VMake_MPIPlusX(comm, SUN_Hip_OR_Cuda(N_VNewWithMemHelp)(0,
619 }
620 else
621 {
623 }
624#else
626#endif // MFEM_USE_CUDA || MFEM_USE_HIP
627 }
628
629 MFEM_VERIFY(x, "Error in SundialsNVector::MakeNVector.");
630
631 return x;
632}
633#endif // MFEM_USE_MPI
634
635
636// ---------------------------------------------------------------------------
637// SUNMatrix interface functions
638// ---------------------------------------------------------------------------
639
640// Return the matrix ID
641static SUNMatrix_ID MatGetID(SUNMatrix)
642{
643 return (SUNMATRIX_CUSTOM);
644}
645
646static void MatDestroy(SUNMatrix A)
647{
648 if (A->content) { A->content = NULL; }
649 if (A->ops) { free(A->ops); A->ops = NULL; }
650 free(A); A = NULL;
651 return;
652}
653
654// ---------------------------------------------------------------------------
655// SUNLinearSolver interface functions
656// ---------------------------------------------------------------------------
657
658// Return the linear solver type
659static SUNLinearSolver_Type LSGetType(SUNLinearSolver)
660{
661 return (SUNLINEARSOLVER_MATRIX_ITERATIVE);
662}
663
664static int LSFree(SUNLinearSolver LS)
665{
666 if (LS->content) { LS->content = NULL; }
667 if (LS->ops) { free(LS->ops); LS->ops = NULL; }
668 free(LS); LS = NULL;
669 return (0);
670}
671
672// ---------------------------------------------------------------------------
673// CVODE interface
674// ---------------------------------------------------------------------------
675int CVODESolver::RHS(sunrealtype t, const N_Vector y, N_Vector ydot,
676 void *user_data)
677{
678 // At this point the up-to-date data for N_Vector y and ydot is on the device.
679 const SundialsNVector mfem_y(y);
680 SundialsNVector mfem_ydot(ydot);
681
682 CVODESolver *self = static_cast<CVODESolver*>(user_data);
683
684 // Compute y' = f(t, y)
685 self->f->SetTime(t);
686 self->f->Mult(mfem_y, mfem_ydot);
687
688 // Return success
689 return (0);
690}
691
693 void *user_data)
694{
695 CVODESolver *self = static_cast<CVODESolver*>(user_data);
696
697 if (!self->root_func) { return CV_RTFUNC_FAIL; }
698
699 SundialsNVector mfem_y(y);
700 SundialsNVector mfem_gout(gout, self->root_components);
701
702 return self->root_func(t, mfem_y, mfem_gout, self);
703}
704
705void CVODESolver::SetRootFinder(int components, RootFunction func)
706{
707 root_func = func;
708
709 flag = CVodeRootInit(sundials_mem, components, root);
710 MFEM_VERIFY(flag == CV_SUCCESS, "error in SetRootFinder()");
711}
712
713int CVODESolver::LinSysSetup(sunrealtype t, N_Vector y, N_Vector fy,
714 SUNMatrix A, sunbooleantype jok,
715 sunbooleantype *jcur, sunrealtype gamma,
716 void*, N_Vector, N_Vector, N_Vector)
717{
718 // Get data from N_Vectors
719 const SundialsNVector mfem_y(y);
720 const SundialsNVector mfem_fy(fy);
721 CVODESolver *self = static_cast<CVODESolver*>(GET_CONTENT(A));
722
723 // Compute the linear system
724 self->f->SetTime(t);
725 return (self->f->SUNImplicitSetup(mfem_y, mfem_fy, jok, jcur, gamma));
726}
727
728int CVODESolver::LinSysSolve(SUNLinearSolver LS, SUNMatrix, N_Vector x,
729 N_Vector b, sunrealtype tol)
730{
731 SundialsNVector mfem_x(x);
732 const SundialsNVector mfem_b(b);
733 CVODESolver *self = static_cast<CVODESolver*>(GET_CONTENT(LS));
734 // Solve the linear system
735 return (self->f->SUNImplicitSolve(mfem_b, mfem_x, tol));
736}
737
739 : lmm_type(lmm), step_mode(CV_NORMAL)
740{
741 Y = new SundialsNVector();
742}
743
744#ifdef MFEM_USE_MPI
745CVODESolver::CVODESolver(MPI_Comm comm, int lmm)
746 : lmm_type(lmm), step_mode(CV_NORMAL)
747{
748 Y = new SundialsNVector(comm);
749}
750#endif
751
753{
754 // Initialize the base class
755 ODESolver::Init(f_);
756
757 // Get the vector length
758 long local_size = f_.Height();
759
760#ifdef MFEM_USE_MPI
761 long global_size = 0;
762 if (Parallel())
763 {
764 MPI_Allreduce(&local_size, &global_size, 1, MPI_LONG, MPI_SUM,
765 Y->GetComm());
766 }
767#endif
768
769 // Get current time
770 double t = f_.GetTime();
771
772 if (sundials_mem)
773 {
774 // Check if the problem size has changed since the last Init() call
775 int resize = 0;
776 if (!Parallel())
777 {
778 resize = (Y->Size() != local_size);
779 }
780 else
781 {
782#ifdef MFEM_USE_MPI
783 int l_resize = (Y->Size() != local_size) ||
784 (saved_global_size != global_size);
785 MPI_Allreduce(&l_resize, &resize, 1, MPI_INT, MPI_LOR,
786 Y->GetComm());
787#endif
788 }
789
790 // Free existing solver memory and re-create with new vector size
791 if (resize)
792 {
793 CVodeFree(&sundials_mem);
794 sundials_mem = NULL;
795 }
796 }
797
798 if (!sundials_mem)
799 {
800 // Temporarily set N_Vector wrapper data to create CVODE. The correct
801 // initial condition will be set using CVodeReInit() when Step() is
802 // called.
803
804 if (!Parallel())
805 {
806 Y->SetSize(local_size);
807 }
808#ifdef MFEM_USE_MPI
809 else
810 {
811 Y->SetSize(local_size, global_size);
812 saved_global_size = global_size;
813 }
814#endif
815
816 // Create CVODE
818 MFEM_VERIFY(sundials_mem, "error in CVodeCreate()");
819
820 // Initialize CVODE
821 flag = CVodeInit(sundials_mem, CVODESolver::RHS, t, *Y);
822 MFEM_VERIFY(flag == CV_SUCCESS, "error in CVodeInit()");
823
824 // Attach the CVODESolver as user-defined data
825 flag = CVodeSetUserData(sundials_mem, this);
826 MFEM_VERIFY(flag == CV_SUCCESS, "error in CVodeSetUserData()");
827
828 // Set default tolerances
829 flag = CVodeSStolerances(sundials_mem, default_rel_tol, default_abs_tol);
830 MFEM_VERIFY(flag == CV_SUCCESS, "error in CVodeSetSStolerances()");
831
832 // Attach MFEM linear solver by default
834 }
835
836 // Set the reinit flag to call CVodeReInit() in the next Step() call.
837 reinit = true;
838}
839
840void CVODESolver::Step(Vector &x, double &t, double &dt)
841{
842 Y->MakeRef(x, 0, x.Size());
843 MFEM_VERIFY(Y->Size() == x.Size(), "size mismatch");
844
845 // Reinitialize CVODE memory if needed
846 if (reinit)
847 {
848 flag = CVodeReInit(sundials_mem, t, *Y);
849 MFEM_VERIFY(flag == CV_SUCCESS, "error in CVodeReInit()");
850 // reset flag
851 reinit = false;
852 }
853
854 // Integrate the system
855 double tout = t + dt;
856 flag = CVode(sundials_mem, tout, *Y, &t, step_mode);
857 MFEM_VERIFY(flag >= 0, "error in CVode()");
858
859 // Make sure host is up to date
860 Y->HostRead();
861
862 // Return the last incremental step size
863 flag = CVodeGetLastStep(sundials_mem, &dt);
864 MFEM_VERIFY(flag == CV_SUCCESS, "error in CVodeGetLastStep()");
865}
866
868{
869 // Free any existing matrix and linear solver
870 if (A != NULL) { SUNMatDestroy(A); A = NULL; }
871 if (LSA != NULL) { SUNLinSolFree(LSA); LSA = NULL; }
872
873 // Wrap linear solver as SUNLinearSolver and SUNMatrix
875 MFEM_VERIFY(LSA, "error in SUNLinSolNewEmpty()");
876
877 LSA->content = this;
878 LSA->ops->gettype = LSGetType;
879 LSA->ops->solve = CVODESolver::LinSysSolve;
880 LSA->ops->free = LSFree;
881
883 MFEM_VERIFY(A, "error in SUNMatNewEmpty()");
884
885 A->content = this;
886 A->ops->getid = MatGetID;
887 A->ops->destroy = MatDestroy;
888
889 // Attach the linear solver and matrix
890 flag = CVodeSetLinearSolver(sundials_mem, LSA, A);
891 MFEM_VERIFY(flag == CV_SUCCESS, "error in CVodeSetLinearSolver()");
892
893 // Set the linear system evaluation function
894 flag = CVodeSetLinSysFn(sundials_mem, CVODESolver::LinSysSetup);
895 MFEM_VERIFY(flag == CV_SUCCESS, "error in CVodeSetLinSysFn()");
896}
897
899{
900 // Free any existing matrix and linear solver
901 if (A != NULL) { SUNMatDestroy(A); A = NULL; }
902 if (LSA != NULL) { SUNLinSolFree(LSA); LSA = NULL; }
903
904 // Create linear solver
906 MFEM_VERIFY(LSA, "error in SUNLinSol_SPGMR()");
907
908 // Attach linear solver
909 flag = CVodeSetLinearSolver(sundials_mem, LSA, NULL);
910 MFEM_VERIFY(flag == CV_SUCCESS, "error in CVodeSetLinearSolver()");
911}
912
914{
915 step_mode = itask;
916}
917
918void CVODESolver::SetSStolerances(double reltol, double abstol)
919{
920 flag = CVodeSStolerances(sundials_mem, reltol, abstol);
921 MFEM_VERIFY(flag == CV_SUCCESS, "error in CVodeSStolerances()");
922}
923
924void CVODESolver::SetSVtolerances(double reltol, Vector abstol)
925{
926 MFEM_VERIFY(abstol.Size() == f->Height(),
927 "abs tolerance is not the same size.");
928
929 SundialsNVector mfem_abstol;
930 mfem_abstol.MakeRef(abstol, 0, abstol.Size());
931
932 flag = CVodeSVtolerances(sundials_mem, reltol, mfem_abstol);
933 MFEM_VERIFY(flag == CV_SUCCESS, "error in CVodeSVtolerances()");
934}
935
936void CVODESolver::SetMaxStep(double dt_max)
937{
938 flag = CVodeSetMaxStep(sundials_mem, dt_max);
939 MFEM_VERIFY(flag == CV_SUCCESS, "error in CVodeSetMaxStep()");
940}
941
943{
944 flag = CVodeSetMaxNumSteps(sundials_mem, mxsteps);
945 MFEM_VERIFY(flag == CV_SUCCESS, "error in CVodeSetMaxNumSteps()");
946}
947
949{
950 long nsteps;
951 flag = CVodeGetNumSteps(sundials_mem, &nsteps);
952 MFEM_VERIFY(flag == CV_SUCCESS, "error in CVodeGetNumSteps()");
953 return nsteps;
954}
955
956void CVODESolver::SetMaxOrder(int max_order)
957{
958 flag = CVodeSetMaxOrd(sundials_mem, max_order);
959 MFEM_VERIFY(flag == CV_SUCCESS, "error in CVodeSetMaxOrd()");
960}
961
963{
964 long int nsteps, nfevals, nlinsetups, netfails;
965 int qlast, qcur;
966 double hinused, hlast, hcur, tcur;
967 long int nniters, nncfails;
968
969 // Get integrator stats
970 flag = CVodeGetIntegratorStats(sundials_mem,
971 &nsteps,
972 &nfevals,
973 &nlinsetups,
974 &netfails,
975 &qlast,
976 &qcur,
977 &hinused,
978 &hlast,
979 &hcur,
980 &tcur);
981 MFEM_VERIFY(flag == CV_SUCCESS, "error in CVodeGetIntegratorStats()");
982
983 // Get nonlinear solver stats
984 flag = CVodeGetNonlinSolvStats(sundials_mem,
985 &nniters,
986 &nncfails);
987 MFEM_VERIFY(flag == CV_SUCCESS, "error in CVodeGetNonlinSolvStats()");
988
989 mfem::out <<
990 "CVODE:\n"
991 "num steps: " << nsteps << "\n"
992 "num rhs evals: " << nfevals << "\n"
993 "num lin setups: " << nlinsetups << "\n"
994 "num nonlin sol iters: " << nniters << "\n"
995 "num nonlin conv fail: " << nncfails << "\n"
996 "num error test fails: " << netfails << "\n"
997 "last order: " << qlast << "\n"
998 "current order: " << qcur << "\n"
999 "initial dt: " << hinused << "\n"
1000 "last dt: " << hlast << "\n"
1001 "current dt: " << hcur << "\n"
1002 "current t: " << tcur << "\n" << endl;
1003
1004 return;
1005}
1006
1008{
1009 delete Y;
1010 SUNMatDestroy(A);
1011 SUNLinSolFree(LSA);
1012 SUNNonlinSolFree(NLS);
1013 CVodeFree(&sundials_mem);
1014}
1015
1016// ---------------------------------------------------------------------------
1017// CVODESSolver interface
1018// ---------------------------------------------------------------------------
1019
1021 CVODESolver(lmm),
1022 ncheck(0),
1023 indexB(0),
1024 AB(nullptr),
1025 LSB(nullptr)
1026{
1027 q = new SundialsNVector();
1028 qB = new SundialsNVector();
1029 yB = new SundialsNVector();
1030 yy = new SundialsNVector();
1031}
1032
1033#ifdef MFEM_USE_MPI
1034CVODESSolver::CVODESSolver(MPI_Comm comm, int lmm) :
1035 CVODESolver(comm, lmm),
1036 ncheck(0),
1037 indexB(0),
1038 AB(nullptr),
1039 LSB(nullptr)
1040{
1041 q = new SundialsNVector(comm);
1042 qB = new SundialsNVector(comm);
1043 yB = new SundialsNVector(comm);
1044 yy = new SundialsNVector(comm);
1045}
1046#endif
1047
1049{
1050 MFEM_VERIFY(t <= f->GetTime(), "t > current forward solver time");
1051
1052 flag = CVodeGetQuad(sundials_mem, &t, *q);
1053 MFEM_VERIFY(flag == CV_SUCCESS, "error in CVodeGetQuad()");
1054
1055 Q.Set(1., *q);
1056}
1057
1059{
1060 MFEM_VERIFY(t <= f->GetTime(), "t > current forward solver time");
1061
1062 flag = CVodeGetQuadB(sundials_mem, indexB, &t, *qB);
1063 MFEM_VERIFY(flag == CV_SUCCESS, "error in CVodeGetQuadB()");
1064
1065 dG_dp.Set(-1., *qB);
1066}
1067
1069{
1070 yy->MakeRef(yyy, 0, yyy.Size());
1071
1072 flag = CVodeGetAdjY(sundials_mem, tB, *yy);
1073 MFEM_VERIFY(flag >= 0, "error in CVodeGetAdjY()");
1074}
1075
1076// Implemented to enforce type checking for TimeDependentAdjointOperator
1081
1083{
1084 long local_size = f_.GetAdjointHeight();
1085
1086 // Get current time
1087 double tB = f_.GetTime();
1088
1089 yB->SetSize(local_size);
1090
1091 // Create the solver memory
1092 flag = CVodeCreateB(sundials_mem, CV_BDF, &indexB);
1093 MFEM_VERIFY(flag == CV_SUCCESS, "error in CVodeCreateB()");
1094
1095 // Initialize
1096 flag = CVodeInitB(sundials_mem, indexB, RHSB, tB, *yB);
1097 MFEM_VERIFY(flag == CV_SUCCESS, "error in CVodeInit()");
1098
1099 // Attach the CVODESSolver as user-defined data
1100 flag = CVodeSetUserDataB(sundials_mem, indexB, this);
1101 MFEM_VERIFY(flag == CV_SUCCESS, "error in CVodeSetUserDataB()");
1102
1103 // Set default tolerances
1104 flag = CVodeSStolerancesB(sundials_mem, indexB, default_rel_tolB,
1106 MFEM_VERIFY(flag == CV_SUCCESS, "error in CVodeSetSStolerancesB()");
1107
1108 // Attach MFEM linear solver by default
1110
1111 // Set the reinit flag to call CVodeReInit() in the next Step() call.
1112 reinit = true;
1113}
1114
1115void CVODESSolver::InitAdjointSolve(int steps, int interpolation)
1116{
1117 flag = CVodeAdjInit(sundials_mem, steps, interpolation);
1118 MFEM_VERIFY(flag == CV_SUCCESS, "Error in CVodeAdjInit");
1119}
1120
1122{
1123 flag = CVodeSetMaxNumStepsB(sundials_mem, indexB, mxstepsB);
1124 MFEM_VERIFY(flag == CV_SUCCESS, "Error in CVodeSetMaxNumStepsB()");
1125}
1126
1128 double abstolQ)
1129{
1130 q->MakeRef(q0, 0, q0.Size());
1131
1132 flag = CVodeQuadInit(sundials_mem, RHSQ, *q);
1133 MFEM_VERIFY(flag == CV_SUCCESS, "Error in CVodeQuadInit()");
1134
1135 flag = CVodeSetQuadErrCon(sundials_mem, SUNTRUE);
1136 MFEM_VERIFY(flag == CV_SUCCESS, "Error in CVodeSetQuadErrCon");
1137
1138 flag = CVodeQuadSStolerances(sundials_mem, reltolQ, abstolQ);
1139 MFEM_VERIFY(flag == CV_SUCCESS, "Error in CVodeQuadSStolerances");
1140}
1141
1143 double abstolQB)
1144{
1145 qB->MakeRef(qB0, 0, qB0.Size());
1146
1147 flag = CVodeQuadInitB(sundials_mem, indexB, RHSQB, *qB);
1148 MFEM_VERIFY(flag == CV_SUCCESS, "Error in CVodeQuadInitB()");
1149
1150 flag = CVodeSetQuadErrConB(sundials_mem, indexB, SUNTRUE);
1151 MFEM_VERIFY(flag == CV_SUCCESS, "Error in CVodeSetQuadErrConB");
1152
1153 flag = CVodeQuadSStolerancesB(sundials_mem, indexB, reltolQB, abstolQB);
1154 MFEM_VERIFY(flag == CV_SUCCESS, "Error in CVodeQuadSStolerancesB");
1155}
1156
1158{
1159 // Free any existing linear solver
1160 if (AB != NULL) { SUNMatDestroy(AB); AB = NULL; }
1161 if (LSB != NULL) { SUNLinSolFree(LSB); LSB = NULL; }
1162
1163 // Wrap linear solver as SUNLinearSolver and SUNMatrix
1165 MFEM_VERIFY(LSB, "error in SUNLinSolNewEmpty()");
1166
1167 LSB->content = this;
1168 LSB->ops->gettype = LSGetType;
1169 LSB->ops->solve = CVODESSolver::LinSysSolveB; // JW change
1170 LSB->ops->free = LSFree;
1171
1173 MFEM_VERIFY(AB, "error in SUNMatNewEmpty()");
1174
1175 AB->content = this;
1176 AB->ops->getid = MatGetID;
1177 AB->ops->destroy = MatDestroy;
1178
1179 // Attach the linear solver and matrix
1180 flag = CVodeSetLinearSolverB(sundials_mem, indexB, LSB, AB);
1181 MFEM_VERIFY(flag == CV_SUCCESS, "error in CVodeSetLinearSolverB()");
1182
1183 // Set the linear system evaluation function
1184 flag = CVodeSetLinSysFnB(sundials_mem, indexB,
1185 CVODESSolver::LinSysSetupB); // JW change
1186 MFEM_VERIFY(flag == CV_SUCCESS, "error in CVodeSetLinSysFn()");
1187}
1188
1190{
1191 // Free any existing matrix and linear solver
1192 if (AB != NULL) { SUNMatDestroy(AB); AB = NULL; }
1193 if (LSB != NULL) { SUNLinSolFree(LSB); LSB = NULL; }
1194
1195 // Set default linear solver (Newton is the default Nonlinear Solver)
1197 MFEM_VERIFY(LSB, "error in SUNLinSol_SPGMR()");
1198
1199 /* Attach the matrix and linear solver */
1200 flag = CVodeSetLinearSolverB(sundials_mem, indexB, LSB, NULL);
1201 MFEM_VERIFY(flag == CV_SUCCESS, "error in CVodeSetLinearSolverB()");
1202}
1203
1204int CVODESSolver::LinSysSetupB(sunrealtype t, N_Vector y, N_Vector yB,
1205 N_Vector fyB, SUNMatrix AB,
1206 sunbooleantype jokB, sunbooleantype *jcurB,
1207 sunrealtype gammaB, void *user_data,
1208 N_Vector tmp1, N_Vector tmp2, N_Vector tmp3)
1209{
1210 // Get data from N_Vectors
1211 const SundialsNVector mfem_y(y);
1212 const SundialsNVector mfem_yB(yB);
1213 SundialsNVector mfem_fyB(fyB);
1214 CVODESSolver *self = static_cast<CVODESSolver*>(GET_CONTENT(AB));
1216 (self->f);
1217 f->SetTime(t);
1218 // Compute the linear system
1219 return (f->SUNImplicitSetupB(t, mfem_y, mfem_yB, mfem_fyB, jokB, jcurB,
1220 gammaB));
1221}
1222
1223int CVODESSolver::LinSysSolveB(SUNLinearSolver LS, SUNMatrix AB, N_Vector yB,
1224 N_Vector Rb, sunrealtype tol)
1225{
1226 SundialsNVector mfem_yB(yB);
1227 const SundialsNVector mfem_Rb(Rb);
1228 CVODESSolver *self = static_cast<CVODESSolver*>(GET_CONTENT(LS));
1230 (self->f);
1231 // Solve the linear system
1232 int ret = f->SUNImplicitSolveB(mfem_yB, mfem_Rb, tol);
1233 return (ret);
1234}
1235
1236void CVODESSolver::SetSStolerancesB(double reltol, double abstol)
1237{
1238 flag = CVodeSStolerancesB(sundials_mem, indexB, reltol, abstol);
1239 MFEM_VERIFY(flag == CV_SUCCESS, "error in CVodeSStolerancesB()");
1240}
1241
1242void CVODESSolver::SetSVtolerancesB(double reltol, Vector abstol)
1243{
1244 MFEM_VERIFY(abstol.Size() == f->Height(),
1245 "abs tolerance is not the same size.");
1246
1247 SundialsNVector mfem_abstol;
1248 mfem_abstol.MakeRef(abstol, 0, abstol.Size());
1249
1250 flag = CVodeSVtolerancesB(sundials_mem, indexB, reltol, mfem_abstol);
1251 MFEM_VERIFY(flag == CV_SUCCESS, "error in CVodeSVtolerancesB()");
1252}
1253
1255{
1256 ewt_func = func;
1257 CVodeWFtolerances(sundials_mem, ewt);
1258}
1259
1260// CVODESSolver static functions
1261
1262int CVODESSolver::RHSQ(sunrealtype t, const N_Vector y, N_Vector qdot,
1263 void *user_data)
1264{
1265 CVODESSolver *self = static_cast<CVODESSolver*>(user_data);
1266 const SundialsNVector mfem_y(y);
1267 SundialsNVector mfem_qdot(qdot);
1269 (self->f);
1270 f->SetTime(t);
1271 f->QuadratureIntegration(mfem_y, mfem_qdot);
1272 return 0;
1273}
1274
1275int CVODESSolver::RHSQB(sunrealtype t, N_Vector y, N_Vector yB, N_Vector qBdot,
1276 void *user_dataB)
1277{
1278 CVODESSolver *self = static_cast<CVODESSolver*>(user_dataB);
1279 SundialsNVector mfem_y(y);
1280 SundialsNVector mfem_yB(yB);
1281 SundialsNVector mfem_qBdot(qBdot);
1283 (self->f);
1284 f->SetTime(t);
1285 f->QuadratureSensitivityMult(mfem_y, mfem_yB, mfem_qBdot);
1286 return 0;
1287}
1288
1289int CVODESSolver::RHSB(sunrealtype t, N_Vector y, N_Vector yB, N_Vector yBdot,
1290 void *user_dataB)
1291{
1292 CVODESSolver *self = static_cast<CVODESSolver*>(user_dataB);
1293 SundialsNVector mfem_y(y);
1294 SundialsNVector mfem_yB(yB);
1295 SundialsNVector mfem_yBdot(yBdot);
1296
1297 mfem_yBdot = 0.;
1299 (self->f);
1300 f->SetTime(t);
1301 f->AdjointRateMult(mfem_y, mfem_yB, mfem_yBdot);
1302 return 0;
1303}
1304
1305int CVODESSolver::ewt(N_Vector y, N_Vector w, void *user_data)
1306{
1307 CVODESSolver *self = static_cast<CVODESSolver*>(user_data);
1308
1309 SundialsNVector mfem_y(y);
1310 SundialsNVector mfem_w(w);
1311
1312 return self->ewt_func(mfem_y, mfem_w, self);
1313}
1314
1315// Pretty much a copy of CVODESolver::Step except we use CVodeF instead of CVode
1316void CVODESSolver::Step(Vector &x, double &t, double &dt)
1317{
1318 Y->MakeRef(x, 0, x.Size());
1319 MFEM_VERIFY(Y->Size() == x.Size(), "size mismatch");
1320
1321 // Reinitialize CVODE memory if needed, initializes the N_Vector y with x
1322 if (reinit)
1323 {
1324 flag = CVodeReInit(sundials_mem, t, *Y);
1325 MFEM_VERIFY(flag == CV_SUCCESS, "error in CVodeReInit()");
1326
1327 // reset flag
1328 reinit = false;
1329 }
1330
1331 // Integrate the system
1332 double tout = t + dt;
1333 flag = CVodeF(sundials_mem, tout, *Y, &t, step_mode, &ncheck);
1334 MFEM_VERIFY(flag >= 0, "error in CVodeF()");
1335
1336 // Make sure host is up to date
1337 Y->HostRead();
1338
1339 // Return the last incremental step size
1340 flag = CVodeGetLastStep(sundials_mem, &dt);
1341 MFEM_VERIFY(flag == CV_SUCCESS, "error in CVodeGetLastStep()");
1342}
1343
1344void CVODESSolver::StepB(Vector &xB, double &tB, double &dtB)
1345{
1346 yB->MakeRef(xB, 0, xB.Size());
1347 MFEM_VERIFY(yB->Size() == xB.Size(), "");
1348
1349 // Reinitialize CVODE memory if needed
1350 if (reinit)
1351 {
1352 flag = CVodeReInitB(sundials_mem, indexB, tB, *yB);
1353 MFEM_VERIFY(flag == CV_SUCCESS, "error in CVodeReInit()");
1354
1355 // reset flag
1356 reinit = false;
1357 }
1358
1359 // Integrate the system
1360 double tout = tB - dtB;
1361 flag = CVodeB(sundials_mem, tout, step_mode);
1362 MFEM_VERIFY(flag >= 0, "error in CVodeB()");
1363
1364 // Call CVodeGetB to get yB of the backward ODE problem.
1365 flag = CVodeGetB(sundials_mem, indexB, &tB, *yB);
1366 MFEM_VERIFY(flag >= 0, "error in CVodeGetB()");
1367
1368 // Make sure host is up to date
1369 yB->HostRead();
1370}
1371
1373{
1374 delete yB;
1375 delete yy;
1376 delete qB;
1377 delete q;
1378 SUNMatDestroy(AB);
1379 SUNLinSolFree(LSB);
1380}
1381
1382
1383// ---------------------------------------------------------------------------
1384// ARKStep interface
1385// ---------------------------------------------------------------------------
1386
1387int ARKStepSolver::RHS1(sunrealtype t, const N_Vector y, N_Vector result,
1388 void *user_data)
1389{
1390 // Get data from N_Vectors
1391 const SundialsNVector mfem_y(y);
1392 SundialsNVector mfem_result(result);
1393 ARKStepSolver *self = static_cast<ARKStepSolver*>(user_data);
1394
1395 // Compute either f(t, y) in one of
1396 // 1. y' = f(t, y)
1397 // 2. M y' = f(t, y)
1398 // or fe(t, y) in one of
1399 // 1. y' = fe(t, y) + fi(t, y)
1400 // 2. M y' = fe(t, y) + fi(t, y)
1401 self->f->SetTime(t);
1402 if (self->rk_type == IMEX)
1403 {
1405 }
1406 if (self->f->isExplicit()) // ODE is in form 1
1407 {
1408 self->f->Mult(mfem_y, mfem_result);
1409 }
1410 else // ODE is in form 2
1411 {
1412 self->f->ExplicitMult(mfem_y, mfem_result);
1413 }
1414
1415 // Return success
1416 return (0);
1417}
1418
1419int ARKStepSolver::RHS2(sunrealtype t, const N_Vector y, N_Vector result,
1420 void *user_data)
1421{
1422 // Get data from N_Vectors
1423 const SundialsNVector mfem_y(y);
1424 SundialsNVector mfem_result(result);
1425 ARKStepSolver *self = static_cast<ARKStepSolver*>(user_data);
1426
1427 // Compute fi(t, y) in one of
1428 // 1. y' = fe(t, y) + fi(t, y) (ODE is expressed in EXPLICIT form)
1429 // 2. M y' = fe(t, y) + fi(y, t) (ODE is expressed in IMPLICIT form)
1430 self->f->SetTime(t);
1432 if (self->f->isExplicit())
1433 {
1434 self->f->Mult(mfem_y, mfem_result);
1435 }
1436 else
1437 {
1438 self->f->ExplicitMult(mfem_y, mfem_result);
1439 }
1440
1441 // Return success
1442 return (0);
1443}
1444
1445int ARKStepSolver::LinSysSetup(sunrealtype t, N_Vector y, N_Vector fy,
1446 SUNMatrix A, SUNMatrix, sunbooleantype jok,
1447 sunbooleantype *jcur, sunrealtype gamma,
1448 void*, N_Vector, N_Vector, N_Vector)
1449{
1450 // Get data from N_Vectors
1451 const SundialsNVector mfem_y(y);
1452 const SundialsNVector mfem_fy(fy);
1453 ARKStepSolver *self = static_cast<ARKStepSolver*>(GET_CONTENT(A));
1454
1455 // Compute the linear system
1456 self->f->SetTime(t);
1457 if (self->rk_type == IMEX)
1458 {
1460 }
1461 return (self->f->SUNImplicitSetup(mfem_y, mfem_fy, jok, jcur, gamma));
1462}
1463
1464int ARKStepSolver::LinSysSolve(SUNLinearSolver LS, SUNMatrix, N_Vector x,
1465 N_Vector b, sunrealtype tol)
1466{
1467 SundialsNVector mfem_x(x);
1468 const SundialsNVector mfem_b(b);
1469 ARKStepSolver *self = static_cast<ARKStepSolver*>(GET_CONTENT(LS));
1470
1471 // Solve the linear system
1472 if (self->rk_type == IMEX)
1473 {
1475 }
1476 return (self->f->SUNImplicitSolve(mfem_b, mfem_x, tol));
1477}
1478
1480 void*, N_Vector, N_Vector, N_Vector)
1481{
1482 ARKStepSolver *self = static_cast<ARKStepSolver*>(GET_CONTENT(M));
1483
1484 // Compute the mass matrix system
1485 self->f->SetTime(t);
1486 return (self->f->SUNMassSetup());
1487}
1488
1489int ARKStepSolver::MassSysSolve(SUNLinearSolver LS, SUNMatrix, N_Vector x,
1490 N_Vector b, sunrealtype tol)
1491{
1492 SundialsNVector mfem_x(x);
1493 const SundialsNVector mfem_b(b);
1494 ARKStepSolver *self = static_cast<ARKStepSolver*>(GET_CONTENT(LS));
1495
1496 // Solve the mass matrix system
1497 return (self->f->SUNMassSolve(mfem_b, mfem_x, tol));
1498}
1499
1500int ARKStepSolver::MassMult1(SUNMatrix M, N_Vector x, N_Vector v)
1501{
1502 const SundialsNVector mfem_x(x);
1503 SundialsNVector mfem_v(v);
1504 ARKStepSolver *self = static_cast<ARKStepSolver*>(GET_CONTENT(M));
1505
1506 // Compute the mass matrix-vector product
1507 return (self->f->SUNMassMult(mfem_x, mfem_v));
1508}
1509
1510int ARKStepSolver::MassMult2(N_Vector x, N_Vector v, sunrealtype t,
1511 void* mtimes_data)
1512{
1513 const SundialsNVector mfem_x(x);
1514 SundialsNVector mfem_v(v);
1515 ARKStepSolver *self = static_cast<ARKStepSolver*>(mtimes_data);
1516
1517 // Compute the mass matrix-vector product
1518 self->f->SetTime(t);
1519 return (self->f->SUNMassMult(mfem_x, mfem_v));
1520}
1521
1523 : rk_type(type), step_mode(ARK_NORMAL),
1524 use_implicit(type == IMPLICIT || type == IMEX)
1525{
1526 Y = new SundialsNVector();
1527}
1528
1529#ifdef MFEM_USE_MPI
1531 : rk_type(type), step_mode(ARK_NORMAL),
1532 use_implicit(type == IMPLICIT || type == IMEX)
1533{
1534 Y = new SundialsNVector(comm);
1535}
1536#endif
1537
1539{
1540 // Initialize the base class
1541 ODESolver::Init(f_);
1542
1543 // Get the vector length
1544 long local_size = f_.Height();
1545#ifdef MFEM_USE_MPI
1546 long global_size;
1547#endif
1548
1549 if (Parallel())
1550 {
1551#ifdef MFEM_USE_MPI
1552 MPI_Allreduce(&local_size, &global_size, 1, MPI_LONG, MPI_SUM,
1553 Y->GetComm());
1554#endif
1555 }
1556
1557 // Get current time
1558 double t = f_.GetTime();
1559
1560 if (sundials_mem)
1561 {
1562 // Check if the problem size has changed since the last Init() call
1563 int resize = 0;
1564 if (!Parallel())
1565 {
1566 resize = (Y->Size() != local_size);
1567 }
1568 else
1569 {
1570#ifdef MFEM_USE_MPI
1571 int l_resize = (Y->Size() != local_size) ||
1572 (saved_global_size != global_size);
1573 MPI_Allreduce(&l_resize, &resize, 1, MPI_INT, MPI_LOR,
1574 Y->GetComm());
1575#endif
1576 }
1577
1578 // Free existing solver memory and re-create with new vector size
1579 if (resize)
1580 {
1581 MFEM_ARKode(Free)(&sundials_mem);
1582 sundials_mem = NULL;
1583 }
1584 }
1585
1586 if (!sundials_mem)
1587 {
1588 if (!Parallel())
1589 {
1590 Y->SetSize(local_size);
1591 }
1592#ifdef MFEM_USE_MPI
1593 else
1594 {
1595 Y->SetSize(local_size, global_size);
1596 saved_global_size = global_size;
1597 }
1598#endif
1599
1600 // Create ARKStep memory
1601 if (rk_type == IMPLICIT)
1602 {
1605 }
1606 else if (rk_type == EXPLICIT)
1607 {
1610 }
1611 else
1612 {
1614 t, *Y, Sundials::GetContext());
1615 }
1616 MFEM_VERIFY(sundials_mem, "error in ARKStepCreate()");
1617
1618 // Attach the ARKStepSolver as user-defined data
1619 flag = MFEM_ARKode(SetUserData)(sundials_mem, this);
1620 MFEM_VERIFY(flag == ARK_SUCCESS,
1621 "error in " STR(MFEM_ARKode(SetUserData)) "()");
1622
1623 // Set default tolerances
1624 flag = MFEM_ARKode(SStolerances)(sundials_mem, default_rel_tol,
1626 MFEM_VERIFY(flag == ARK_SUCCESS,
1627 "error in " STR(MFEM_ARKode(SStolerances)) "()");
1628
1629 // If implicit, attach MFEM linear solver by default
1631 }
1632
1633 // Set the reinit flag to call ARKStepReInit() in the next Step() call.
1634 reinit = true;
1635}
1636
1638{
1639 Y->MakeRef(x, 0, x.Size());
1640 MFEM_VERIFY(Y->Size() == x.Size(), "size mismatch");
1641
1642 // Reinitialize ARKStep memory if needed
1643 if (reinit)
1644 {
1645 if (rk_type == IMPLICIT)
1646 {
1647 flag = ARKStepReInit(sundials_mem, NULL, ARKStepSolver::RHS1, t, *Y);
1648 }
1649 else if (rk_type == EXPLICIT)
1650 {
1651 flag = ARKStepReInit(sundials_mem, ARKStepSolver::RHS1, NULL, t, *Y);
1652 }
1653 else
1654 {
1655 flag = ARKStepReInit(sundials_mem,
1657 }
1658 MFEM_VERIFY(flag == ARK_SUCCESS, "error in ARKStepReInit()");
1659
1660 // reset flag
1661 reinit = false;
1662 }
1663
1664 // Integrate the system
1665 double tout = t + dt;
1666 flag = MFEM_ARKode(Evolve)(sundials_mem, tout, *Y, &t, step_mode);
1667 MFEM_VERIFY(flag >= 0, "error in " STR(MFEM_ARKode(Evolve)) "()");
1668
1669 // Make sure host is up to date
1670 Y->HostRead();
1671
1672 // Return the last incremental step size
1673 flag = MFEM_ARKode(GetLastStep)(sundials_mem, &dt);
1674 MFEM_VERIFY(flag == ARK_SUCCESS,
1675 "error in " STR(MFEM_ARKode(GetLastStep)) "()");
1676}
1677
1679{
1680 // Free any existing matrix and linear solver
1681 if (A != NULL) { SUNMatDestroy(A); A = NULL; }
1682 if (LSA != NULL) { SUNLinSolFree(LSA); LSA = NULL; }
1683
1684 // Wrap linear solver as SUNLinearSolver and SUNMatrix
1686 MFEM_VERIFY(LSA, "error in SUNLinSolNewEmpty()");
1687
1688 LSA->content = this;
1689 LSA->ops->gettype = LSGetType;
1690 LSA->ops->solve = ARKStepSolver::LinSysSolve;
1691 LSA->ops->free = LSFree;
1692
1694 MFEM_VERIFY(A, "error in SUNMatNewEmpty()");
1695
1696 A->content = this;
1697 A->ops->getid = MatGetID;
1698 A->ops->destroy = MatDestroy;
1699
1700 // Attach the linear solver and matrix
1701 flag = MFEM_ARKode(SetLinearSolver)(sundials_mem, LSA, A);
1702 MFEM_VERIFY(flag == ARK_SUCCESS,
1703 "error in " STR(MFEM_ARKode(SetLinearSolver)) "()");
1704
1705 // Set the linear system evaluation function
1706 flag = MFEM_ARKode(SetLinSysFn)(sundials_mem, ARKStepSolver::LinSysSetup);
1707 MFEM_VERIFY(flag == ARK_SUCCESS,
1708 "error in " STR(MFEM_ARKode(SetLinSysFn)) "()");
1709}
1710
1712{
1713 // Free any existing matrix and linear solver
1714 if (A != NULL) { SUNMatDestroy(A); A = NULL; }
1715 if (LSA != NULL) { SUNLinSolFree(LSA); LSA = NULL; }
1716
1717 // Create linear solver
1719 MFEM_VERIFY(LSA, "error in SUNLinSol_SPGMR()");
1720
1721 // Attach linear solver
1722 flag = MFEM_ARKode(SetLinearSolver)(sundials_mem, LSA, NULL);
1723 MFEM_VERIFY(flag == ARK_SUCCESS,
1724 "error in " STR(MFEM_ARKode(SetLinearSolver)) "()");
1725}
1726
1728{
1729 // Free any existing matrix and linear solver
1730 if (M != NULL) { SUNMatDestroy(M); M = NULL; }
1731 if (LSM != NULL) { SUNLinSolFree(LSM); LSM = NULL; }
1732
1733 // Wrap linear solver as SUNLinearSolver and SUNMatrix
1735 MFEM_VERIFY(LSM, "error in SUNLinSolNewEmpty()");
1736
1737 LSM->content = this;
1738 LSM->ops->gettype = LSGetType;
1739 LSM->ops->solve = ARKStepSolver::MassSysSolve;
1740 LSM->ops->free = LSFree;
1741
1743 MFEM_VERIFY(M, "error in SUNMatNewEmpty()");
1744
1745 M->content = this;
1746 M->ops->getid = MatGetID;
1747 M->ops->matvec = ARKStepSolver::MassMult1;
1748 M->ops->destroy = MatDestroy;
1749
1750 // Attach the linear solver and matrix
1751 flag = MFEM_ARKode(SetMassLinearSolver)(sundials_mem, LSM, M, tdep);
1752 MFEM_VERIFY(flag == ARK_SUCCESS,
1753 "error in " STR(MFEM_ARKode(SetMassLinearSolver)) "()");
1754
1755 // Set the linear system function
1756 flag = MFEM_ARKode(SetMassFn)(sundials_mem, ARKStepSolver::MassSysSetup);
1757 MFEM_VERIFY(flag == ARK_SUCCESS,
1758 "error in " STR(MFEM_ARKode(SetMassFn)) "()");
1759
1760 // Check that the ODE is not expressed in EXPLICIT form
1761 MFEM_VERIFY(!f->isExplicit(), "ODE operator is expressed in EXPLICIT form")
1762}
1763
1765{
1766 // Free any existing matrix and linear solver
1767 if (M != NULL) { SUNMatDestroy(A); M = NULL; }
1768 if (LSM != NULL) { SUNLinSolFree(LSM); LSM = NULL; }
1769
1770 // Create linear solver
1772 MFEM_VERIFY(LSM, "error in SUNLinSol_SPGMR()");
1773
1774 // Attach linear solver
1775 flag = MFEM_ARKode(SetMassLinearSolver)(sundials_mem, LSM, NULL, tdep);
1776 MFEM_VERIFY(flag == ARK_SUCCESS,
1777 "error in " STR(MFEM_ARKode(SetMassLinearSolver)) "()");
1778
1779 // Attach matrix multiplication function
1780 flag = MFEM_ARKode(SetMassTimes)(sundials_mem, NULL,
1782 MFEM_VERIFY(flag == ARK_SUCCESS,
1783 "error in " STR(MFEM_ARKode(SetMassTimes)) "()");
1784
1785 // Check that the ODE is not expressed in EXPLICIT form
1786 MFEM_VERIFY(!f->isExplicit(), "ODE operator is expressed in EXPLICIT form")
1787}
1788
1790{
1791 step_mode = itask;
1792}
1793
1794void ARKStepSolver::SetSStolerances(double reltol, double abstol)
1795{
1796 flag = MFEM_ARKode(SStolerances)(sundials_mem, reltol, abstol);
1797 MFEM_VERIFY(flag == ARK_SUCCESS,
1798 "error in " STR(MFEM_ARKode(SStolerances)) "()");
1799}
1800
1801void ARKStepSolver::SetMaxStep(double dt_max)
1802{
1803 flag = MFEM_ARKode(SetMaxStep)(sundials_mem, dt_max);
1804 MFEM_VERIFY(flag == ARK_SUCCESS,
1805 "error in " STR(MFEM_ARKode(SetMaxStep)) "()");
1806}
1807
1809{
1810 flag = MFEM_ARKode(SetOrder)(sundials_mem, order);
1811 MFEM_VERIFY(flag == ARK_SUCCESS,
1812 "error in " STR(MFEM_ARKode(SetOrder)) "()");
1813}
1814
1816{
1817 flag = ARKStepSetTableNum(sundials_mem, ARKODE_DIRK_NONE, table_id);
1818 MFEM_VERIFY(flag == ARK_SUCCESS, "error in ARKStepSetTableNum()");
1819}
1820
1822{
1823 flag = ARKStepSetTableNum(sundials_mem, table_id, ARKODE_ERK_NONE);
1824 MFEM_VERIFY(flag == ARK_SUCCESS, "error in ARKStepSetTableNum()");
1825}
1826
1828 ARKODE_DIRKTableID itable_id)
1829{
1830 flag = ARKStepSetTableNum(sundials_mem, itable_id, etable_id);
1831 MFEM_VERIFY(flag == ARK_SUCCESS, "error in ARKStepSetTableNum()");
1832}
1833
1835{
1836 flag = MFEM_ARKode(SetFixedStep)(sundials_mem, dt);
1837 MFEM_VERIFY(flag == ARK_SUCCESS,
1838 "error in " STR(MFEM_ARKode(SetFixedStep)) "()");
1839}
1840
1842{
1843 long int nsteps, expsteps, accsteps, step_attempts;
1844 long int nfe_evals, nfi_evals;
1845 long int nlinsetups, netfails;
1846 double hinused, hlast, hcur, tcur;
1847 long int nniters, nncfails;
1848
1849 // Get integrator stats
1850 flag = ARKStepGetTimestepperStats(sundials_mem,
1851 &expsteps,
1852 &accsteps,
1853 &step_attempts,
1854 &nfe_evals,
1855 &nfi_evals,
1856 &nlinsetups,
1857 &netfails);
1858 MFEM_VERIFY(flag == ARK_SUCCESS, "error in ARKStepGetTimestepperStats()");
1859
1860 flag = MFEM_ARKode(GetStepStats)(sundials_mem,
1861 &nsteps,
1862 &hinused,
1863 &hlast,
1864 &hcur,
1865 &tcur);
1866
1867 // Get nonlinear solver stats
1868 flag = MFEM_ARKode(GetNonlinSolvStats)(sundials_mem,
1869 &nniters,
1870 &nncfails);
1871 MFEM_VERIFY(flag == ARK_SUCCESS,
1872 "error in " STR(MFEM_ARKode(GetNonlinSolvStats)) "()");
1873
1874 mfem::out <<
1875 "ARKStep:\n"
1876 "num steps: " << nsteps << "\n"
1877 "num exp rhs evals: " << nfe_evals << "\n"
1878 "num imp rhs evals: " << nfi_evals << "\n"
1879 "num lin setups: " << nlinsetups << "\n"
1880 "num nonlin sol iters: " << nniters << "\n"
1881 "num nonlin conv fail: " << nncfails << "\n"
1882 "num steps attempted: " << step_attempts << "\n"
1883 "num acc limited steps: " << accsteps << "\n"
1884 "num exp limited stepfails: " << expsteps << "\n"
1885 "num error test fails: " << netfails << "\n"
1886 "initial dt: " << hinused << "\n"
1887 "last dt: " << hlast << "\n"
1888 "current dt: " << hcur << "\n"
1889 "current t: " << tcur << "\n" << endl;
1890
1891 return;
1892}
1893
1895{
1896 delete Y;
1897 SUNMatDestroy(A);
1898 SUNLinSolFree(LSA);
1899 SUNNonlinSolFree(NLS);
1900 MFEM_ARKode(Free)(&sundials_mem);
1901}
1902
1903// ---------------------------------------------------------------------------
1904// KINSOL interface
1905// ---------------------------------------------------------------------------
1906
1907// Wrapper for evaluating the nonlinear residual F(u) = 0
1908int KINSolver::Mult(const N_Vector u, N_Vector fu, void *user_data)
1909{
1910 const SundialsNVector mfem_u(u);
1911 SundialsNVector mfem_fu(fu);
1912 KINSolver *self = static_cast<KINSolver*>(user_data);
1913
1914 // Compute the non-linear action F(u).
1915 self->oper->Mult(mfem_u, mfem_fu);
1916
1917 // Return success
1918 return 0;
1919}
1920
1921// Wrapper for computing Jacobian-vector products
1922int KINSolver::GradientMult(N_Vector v, N_Vector Jv, N_Vector u,
1923 sunbooleantype *new_u, void *user_data)
1924{
1925 const SundialsNVector mfem_v(v);
1926 SundialsNVector mfem_Jv(Jv);
1927 KINSolver *self = static_cast<KINSolver*>(user_data);
1928
1929 // Update Jacobian information if needed
1930 if (*new_u)
1931 {
1932 const SundialsNVector mfem_u(u);
1933 self->jacobian = &self->oper->GetGradient(mfem_u);
1934 *new_u = SUNFALSE;
1935 }
1936
1937 // Compute the Jacobian-vector product
1938 self->jacobian->Mult(mfem_v, mfem_Jv);
1939
1940 // Return success
1941 return 0;
1942}
1943
1944// Wrapper for evaluating linear systems J u = b
1945int KINSolver::LinSysSetup(N_Vector u, N_Vector, SUNMatrix J,
1946 void *, N_Vector, N_Vector )
1947{
1948 const SundialsNVector mfem_u(u);
1949 KINSolver *self = static_cast<KINSolver*>(GET_CONTENT(J));
1950
1951 // Update the Jacobian
1952 self->jacobian = &self->oper->GetGradient(mfem_u);
1953
1954 // Set the Jacobian solve operator
1955 self->prec->SetOperator(*self->jacobian);
1956
1957 // Return success
1958 return (0);
1959}
1960
1961// Wrapper for solving linear systems J u = b
1962int KINSolver::LinSysSolve(SUNLinearSolver LS, SUNMatrix, N_Vector u,
1963 N_Vector b, sunrealtype)
1964{
1965 SundialsNVector mfem_u(u), mfem_b(b);
1966 KINSolver *self = static_cast<KINSolver*>(GET_CONTENT(LS));
1967
1968 // Solve for u = [J(u)]^{-1} b, maybe approximately.
1969 self->prec->Mult(mfem_b, mfem_u);
1970
1971 // Return success
1972 return (0);
1973}
1974
1975int KINSolver::PrecSetup(N_Vector uu,
1976 N_Vector uscale,
1977 N_Vector fval,
1978 N_Vector fscale,
1979 void *user_data)
1980{
1981 SundialsNVector mfem_u(uu);
1982 KINSolver *self = static_cast<KINSolver *>(user_data);
1983
1984 // Update the Jacobian
1985 self->jacobian = &self->oper->GetGradient(mfem_u);
1986
1987 // Set the Jacobian solve operator
1988 self->prec->SetOperator(*self->jacobian);
1989
1990 return 0;
1991}
1992
1993int KINSolver::PrecSolve(N_Vector uu,
1994 N_Vector uscale,
1995 N_Vector fval,
1996 N_Vector fscale,
1997 N_Vector vv,
1998 void *user_data)
1999{
2000 KINSolver *self = static_cast<KINSolver *>(user_data);
2001 SundialsNVector mfem_v(vv);
2002
2003 self->wrk = 0.0;
2004
2005 // Solve for u = P^{-1} v
2006 self->prec->Mult(mfem_v, self->wrk);
2007
2008 mfem_v = self->wrk;
2009
2010 return 0;
2011}
2012
2013KINSolver::KINSolver(int strategy, bool oper_grad)
2014 : global_strategy(strategy), use_oper_grad(oper_grad), y_scale(NULL),
2015 f_scale(NULL), jacobian(NULL)
2016{
2017 Y = new SundialsNVector();
2018 y_scale = new SundialsNVector();
2019 f_scale = new SundialsNVector();
2020
2021 // Default abs_tol and print_level
2022#if MFEM_SUNDIALS_VERSION < 70000
2023 abs_tol = pow(UNIT_ROUNDOFF, 1.0/3.0);
2024#else
2025 abs_tol = pow(SUN_UNIT_ROUNDOFF, 1.0/3.0);
2026#endif
2027 print_level = 0;
2028}
2029
2030#ifdef MFEM_USE_MPI
2031KINSolver::KINSolver(MPI_Comm comm, int strategy, bool oper_grad)
2032 : global_strategy(strategy), use_oper_grad(oper_grad), y_scale(NULL),
2033 f_scale(NULL), jacobian(NULL)
2034{
2035 Y = new SundialsNVector(comm);
2036 y_scale = new SundialsNVector(comm);
2037 f_scale = new SundialsNVector(comm);
2038
2039 // Default abs_tol and print_level
2040#if MFEM_SUNDIALS_VERSION < 70000
2041 abs_tol = pow(UNIT_ROUNDOFF, 1.0/3.0);
2042#else
2043 abs_tol = pow(SUN_UNIT_ROUNDOFF, 1.0/3.0);
2044#endif
2045 print_level = 0;
2046}
2047#endif
2048
2049
2051{
2052 // Initialize the base class
2054 jacobian = NULL;
2055
2056 // Get the vector length
2057 long local_size = height;
2058#ifdef MFEM_USE_MPI
2059 long global_size;
2060#endif
2061
2062 if (Parallel())
2063 {
2064#ifdef MFEM_USE_MPI
2065 MPI_Allreduce(&local_size, &global_size, 1, MPI_LONG, MPI_SUM,
2066 Y->GetComm());
2067#endif
2068 }
2069
2070 if (sundials_mem)
2071 {
2072 // Check if the problem size has changed since the last SetOperator call
2073 int resize = 0;
2074 if (!Parallel())
2075 {
2076 resize = (Y->Size() != local_size);
2077 }
2078 else
2079 {
2080#ifdef MFEM_USE_MPI
2081 int l_resize = (Y->Size() != local_size) ||
2082 (saved_global_size != global_size);
2083 MPI_Allreduce(&l_resize, &resize, 1, MPI_INT, MPI_LOR,
2084 Y->GetComm());
2085#endif
2086 }
2087
2088 // Free existing solver memory and re-create with new vector size
2089 if (resize)
2090 {
2091 KINFree(&sundials_mem);
2092 sundials_mem = NULL;
2093 }
2094 }
2095
2096 if (!sundials_mem)
2097 {
2098 if (!Parallel())
2099 {
2100 Y->SetSize(local_size);
2101 }
2102#ifdef MFEM_USE_MPI
2103 else
2104 {
2105 Y->SetSize(local_size, global_size);
2106 y_scale->SetSize(local_size, global_size);
2107 f_scale->SetSize(local_size, global_size);
2108 saved_global_size = global_size;
2109 }
2110#endif
2111
2112 // Create the solver memory
2114 MFEM_VERIFY(sundials_mem, "Error in KINCreate().");
2115
2116 // Enable Anderson Acceleration
2117 if (aa_n > 0)
2118 {
2119 flag = KINSetMAA(sundials_mem, aa_n);
2120 MFEM_ASSERT(flag == KIN_SUCCESS, "error in KINSetMAA()");
2121
2122 flag = KINSetDelayAA(sundials_mem, aa_delay);
2123 MFEM_ASSERT(flag == KIN_SUCCESS, "error in KINSetDelayAA()");
2124
2125 flag = KINSetDampingAA(sundials_mem, aa_damping);
2126 MFEM_ASSERT(flag == KIN_SUCCESS, "error in KINSetDampingAA()");
2127
2128#if SUNDIALS_VERSION_MAJOR >= 6
2129 flag = KINSetOrthAA(sundials_mem, aa_orth);
2130 MFEM_ASSERT(flag == KIN_SUCCESS, "error in KINSetOrthAA()");
2131#endif
2132 }
2133
2134 // Initialize KINSOL
2135 flag = KINInit(sundials_mem, KINSolver::Mult, *Y);
2136 MFEM_VERIFY(flag == KIN_SUCCESS, "error in KINInit()");
2137
2138 // Attach the KINSolver as user-defined data
2139 flag = KINSetUserData(sundials_mem, this);
2140 MFEM_ASSERT(flag == KIN_SUCCESS, "error in KINSetUserData()");
2141
2142 flag = KINSetDamping(sundials_mem, fp_damping);
2143 MFEM_ASSERT(flag == KIN_SUCCESS, "error in KINSetDamping()");
2144
2145 // Set the linear solver
2146 if (prec || jfnk)
2147 {
2149 }
2150 else
2151 {
2152 // Free any existing linear solver
2153 if (A != NULL) { SUNMatDestroy(A); A = NULL; }
2154 if (LSA != NULL) { SUNLinSolFree(LSA); LSA = NULL; }
2155
2157 MFEM_VERIFY(LSA, "error in SUNLinSol_SPGMR()");
2158
2159 flag = KINSetLinearSolver(sundials_mem, LSA, NULL);
2160 MFEM_ASSERT(flag == KIN_SUCCESS, "error in KINSetLinearSolver()");
2161
2162 // Set Jacobian-vector product function
2163 if (use_oper_grad)
2164 {
2165 flag = KINSetJacTimesVecFn(sundials_mem, KINSolver::GradientMult);
2166 MFEM_ASSERT(flag == KIN_SUCCESS, "error in KINSetJacTimesVecFn()");
2167 }
2168 }
2169 }
2170}
2171
2173{
2174 if (jfnk)
2175 {
2176 SetJFNKSolver(solver);
2177 }
2178 else
2179 {
2180 // Store the solver
2181 prec = &solver;
2182
2183 // Free any existing linear solver
2184 if (A != NULL) { SUNMatDestroy(A); A = NULL; }
2185 if (LSA != NULL) { SUNLinSolFree(LSA); LSA = NULL; }
2186
2187 // Wrap KINSolver as SUNLinearSolver and SUNMatrix
2189 MFEM_VERIFY(LSA, "error in SUNLinSolNewEmpty()");
2190
2191 LSA->content = this;
2192 LSA->ops->gettype = LSGetType;
2193 LSA->ops->solve = KINSolver::LinSysSolve;
2194 LSA->ops->free = LSFree;
2195
2197 MFEM_VERIFY(A, "error in SUNMatNewEmpty()");
2198
2199 A->content = this;
2200 A->ops->getid = MatGetID;
2201 A->ops->destroy = MatDestroy;
2202
2203 // Attach the linear solver and matrix
2204 flag = KINSetLinearSolver(sundials_mem, LSA, A);
2205 MFEM_VERIFY(flag == KIN_SUCCESS, "error in KINSetLinearSolver()");
2206
2207 // Set the Jacobian evaluation function
2209 MFEM_VERIFY(flag == KIN_SUCCESS, "error in KINSetJacFn()");
2210 }
2211}
2212
2214{
2215 // Store the solver
2216 prec = &solver;
2217
2219
2220 // Free any existing linear solver
2221 if (A != NULL) { SUNMatDestroy(A); A = NULL; }
2222 if (LSA != NULL) { SUNLinSolFree(LSA); LSA = NULL; }
2223
2224 // Setup FGMRES
2227 MFEM_VERIFY(LSA, "error in SUNLinSol_SPFGMR()");
2228
2229 flag = SUNLinSol_SPFGMRSetMaxRestarts(LSA, maxlrs);
2230 MFEM_VERIFY(flag == SUN_SUCCESS, "error in SUNLinSol_SPFGMR()");
2231
2232 flag = KINSetLinearSolver(sundials_mem, LSA, NULL);
2233 MFEM_VERIFY(flag == KIN_SUCCESS, "error in KINSetLinearSolver()");
2234
2235 if (prec)
2236 {
2237 flag = KINSetPreconditioner(sundials_mem,
2240 MFEM_VERIFY(flag == KIN_SUCCESS, "error in KINSetPreconditioner()");
2241 }
2242}
2243
2245{
2246 flag = KINSetScaledStepTol(sundials_mem, sstol);
2247 MFEM_ASSERT(flag == KIN_SUCCESS, "error in KINSetScaledStepTol()");
2248}
2249
2251{
2252 flag = KINSetMaxSetupCalls(sundials_mem, max_calls);
2253 MFEM_ASSERT(flag == KIN_SUCCESS, "error in KINSetMaxSetupCalls()");
2254}
2255
2256void KINSolver::EnableAndersonAcc(int n, int orth, int delay, double damping)
2257{
2258 if (sundials_mem != nullptr)
2259 {
2260 if (aa_n < n)
2261 {
2262 MFEM_ABORT("Subsequent calls to EnableAndersonAcc() must set"
2263 " the subspace size to less or equal to the initially requested size."
2264 " If SetOperator() has already been called, the subspace size can't be"
2265 " increased.");
2266 }
2267
2268 flag = KINSetMAA(sundials_mem, n);
2269 MFEM_ASSERT(flag == KIN_SUCCESS, "error in KINSetMAA()");
2270
2271 flag = KINSetDelayAA(sundials_mem, delay);
2272 MFEM_ASSERT(flag == KIN_SUCCESS, "error in KINSetDelayAA()");
2273
2274 flag = KINSetDampingAA(sundials_mem, damping);
2275 MFEM_ASSERT(flag == KIN_SUCCESS, "error in KINSetDampingAA()");
2276
2277#if SUNDIALS_VERSION_MAJOR >= 6
2278 flag = KINSetOrthAA(sundials_mem, orth);
2279 MFEM_ASSERT(flag == KIN_SUCCESS, "error in KINSetOrthAA()");
2280#else
2281 if (orth != KIN_ORTH_MGS)
2282 {
2283 MFEM_WARNING("SUNDIALS < v6 does not support setting the Anderson"
2284 " acceleration orthogonalization routine!");
2285 }
2286#endif
2287 }
2288
2289 aa_n = n;
2290 aa_delay = delay;
2291 aa_damping = damping;
2292 aa_orth = orth;
2293}
2294
2295void KINSolver::SetDamping(double damping)
2296{
2297 fp_damping = damping;
2298 if (sundials_mem)
2299 {
2300 flag = KINSetDamping(sundials_mem, fp_damping);
2301 MFEM_ASSERT(flag == KIN_SUCCESS, "error in KINSetDamping()");
2302 }
2303}
2304
2306{
2307 MFEM_ABORT("this method is not supported! Use SetPrintLevel(int) instead.");
2308}
2309
2310// Compute the scaling vectors and solve nonlinear system
2311void KINSolver::Mult(const Vector&, Vector &x) const
2312{
2313 // residual norm tolerance
2314 double tol;
2315
2316 // Uses c = 1, corresponding to x_scale.
2317 c = 1.0;
2318
2319 if (!iterative_mode) { x = 0.0; }
2320
2321 // For relative tolerance, r = 1 / |residual(x)|, corresponding to fx_scale.
2322 if (rel_tol > 0.0)
2323 {
2324
2325 oper->Mult(x, r);
2326
2327 // Note that KINSOL uses infinity norms.
2328 double norm = r.Normlinf();
2329#ifdef MFEM_USE_MPI
2330 if (Parallel())
2331 {
2332 double lnorm = norm;
2333 MPI_Allreduce(&lnorm, &norm, 1, MPITypeMap<real_t>::mpi_type, MPI_MAX,
2334 Y->GetComm());
2335 }
2336#endif
2337 if (abs_tol > rel_tol * norm)
2338 {
2339 r = 1.0;
2340 tol = abs_tol;
2341 }
2342 else
2343 {
2344 r = 1.0 / norm;
2345 tol = rel_tol;
2346 }
2347 }
2348 else
2349 {
2350 r = 1.0;
2351 tol = abs_tol;
2352 }
2353
2354 // Set the residual norm tolerance
2355 flag = KINSetFuncNormTol(sundials_mem, tol);
2356 MFEM_ASSERT(flag == KIN_SUCCESS, "error in KINSetFuncNormTol()");
2357
2358 // Solve the nonlinear system by calling the other Mult method
2359 KINSolver::Mult(x, c, r);
2360}
2361
2362// Solve the nonlinear system using the provided scaling vectors
2364 const Vector &x_scale, const Vector &fx_scale) const
2365{
2366 flag = KINSetNumMaxIters(sundials_mem, max_iter);
2367 MFEM_ASSERT(flag == KIN_SUCCESS, "KINSetNumMaxIters() failed!");
2368
2369 Y->MakeRef(x, 0, x.Size());
2370 y_scale->MakeRef(const_cast<Vector&>(x_scale), 0, x_scale.Size());
2371 f_scale->MakeRef(const_cast<Vector&>(fx_scale), 0, fx_scale.Size());
2372
2373 int rank = -1;
2374 if (!Parallel())
2375 {
2376 rank = 0;
2377 }
2378 else
2379 {
2380#ifdef MFEM_USE_MPI
2381 MPI_Comm_rank(Y->GetComm(), &rank);
2382#endif
2383 }
2384
2385 if (rank == 0)
2386 {
2387#if MFEM_SUNDIALS_VERSION < 70000
2388 flag = KINSetPrintLevel(sundials_mem, print_level);
2389 MFEM_VERIFY(flag == KIN_SUCCESS, "KINSetPrintLevel() failed!");
2390#endif
2391 // NOTE: there is no KINSetPrintLevel in SUNDIALS v7!
2392
2393#ifdef SUNDIALS_BUILD_WITH_MONITORING
2394 if (jfnk && print_level)
2395 {
2396 flag = SUNLinSolSetInfoFile_SPFGMR(LSA, stdout);
2397 MFEM_VERIFY(flag == SUN_SUCCESS,
2398 "error in SUNLinSolSetInfoFile_SPFGMR()");
2399
2400 flag = SUNLinSolSetPrintLevel_SPFGMR(LSA, 1);
2401 MFEM_VERIFY(flag == SUN_SUCCESS,
2402 "error in SUNLinSolSetPrintLevel_SPFGMR()");
2403 }
2404#endif
2405 }
2406
2407 if (!iterative_mode) { x = 0.0; }
2408
2409 // Solve the nonlinear system
2411 converged = (flag >= 0);
2412
2413 // Make sure host is up to date
2414 Y->HostRead();
2415
2416 // Get number of nonlinear iterations
2417 long int tmp_nni;
2418 flag = KINGetNumNonlinSolvIters(sundials_mem, &tmp_nni);
2419 MFEM_ASSERT(flag == KIN_SUCCESS, "error in KINGetNumNonlinSolvIters()");
2420 final_iter = (int) tmp_nni;
2421
2422 // Get the residual norm
2423 flag = KINGetFuncNorm(sundials_mem, &final_norm);
2424 MFEM_ASSERT(flag == KIN_SUCCESS, "error in KINGetFuncNorm()");
2425}
2426
2428{
2429 delete Y;
2430 delete y_scale;
2431 delete f_scale;
2432 SUNMatDestroy(A);
2433 SUNLinSolFree(LSA);
2434 KINFree(&sundials_mem);
2435}
2436
2437} // namespace mfem
2438
2439#endif // MFEM_USE_SUNDIALS
Interface to ARKode's ARKStep module – additive Runge-Kutta methods.
Definition sundials.hpp:721
static int RHS2(sunrealtype t, const N_Vector y, N_Vector ydot, void *user_data)
void SetMaxStep(double dt_max)
Set the maximum time step.
void PrintInfo() const
Print various ARKStep statistics.
Type rk_type
Runge-Kutta type.
Definition sundials.hpp:732
ARKStepSolver(Type type=EXPLICIT)
Construct a serial wrapper to SUNDIALS' ARKode integrator.
void SetOrder(int order)
Chooses integration order for all explicit / implicit / IMEX methods.
void SetStepMode(int itask)
Select the ARKode step mode: ARK_NORMAL (default) or ARK_ONE_STEP.
Type
Types of ARKODE solvers.
Definition sundials.hpp:725
@ IMPLICIT
Implicit RK method.
Definition sundials.hpp:727
@ IMEX
Implicit-explicit ARK method.
Definition sundials.hpp:728
@ EXPLICIT
Explicit RK method.
Definition sundials.hpp:726
static int LinSysSolve(SUNLinearSolver LS, SUNMatrix A, N_Vector x, N_Vector b, sunrealtype tol)
Solve the linear system .
static int LinSysSetup(sunrealtype t, N_Vector y, N_Vector fy, SUNMatrix A, SUNMatrix M, sunbooleantype jok, sunbooleantype *jcur, sunrealtype gamma, void *user_data, N_Vector tmp1, N_Vector tmp2, N_Vector tmp3)
Setup the linear system .
void UseSundialsLinearSolver()
Attach a SUNDIALS GMRES linear solver to ARKode.
void Init(TimeDependentOperator &f_) override
Initialize ARKode: calls ARKStepCreate() to create the ARKStep memory and set some defaults.
static int MassMult2(N_Vector x, N_Vector v, sunrealtype t, void *mtimes_data)
Compute the matrix-vector product at time t.
void SetIRKTableNum(ARKODE_DIRKTableID table_id)
Choose a specific Butcher table for a diagonally implicit RK method.
void SetFixedStep(double dt)
Use a fixed time step size (disable temporal adaptivity).
void SetERKTableNum(ARKODE_ERKTableID table_id)
Choose a specific Butcher table for an explicit RK method.
int step_mode
ARKStep step mode (ARK_NORMAL or ARK_ONE_STEP).
Definition sundials.hpp:733
static int RHS1(sunrealtype t, const N_Vector y, N_Vector ydot, void *user_data)
void UseMFEMMassLinearSolver(int tdep)
Attach mass matrix linear system setup, solve, and matrix-vector product methods from the TimeDepende...
bool use_implicit
True for implicit or imex integration.
Definition sundials.hpp:734
void UseSundialsMassLinearSolver(int tdep)
Attach the SUNDIALS GMRES linear solver and the mass matrix matrix-vector product method from the Tim...
virtual ~ARKStepSolver()
Destroy the associated ARKode memory and SUNDIALS objects.
void SetIMEXTableNum(ARKODE_ERKTableID etable_id, ARKODE_DIRKTableID itable_id)
Choose a specific Butcher table for an IMEX RK method.
void UseMFEMLinearSolver()
Attach the linear system setup and solve methods from the TimeDependentOperator i....
static int MassMult1(SUNMatrix M, N_Vector x, N_Vector v)
Compute the matrix-vector product .
static int MassSysSolve(SUNLinearSolver LS, SUNMatrix M, N_Vector x, N_Vector b, sunrealtype tol)
Solve the linear system .
static int MassSysSetup(sunrealtype t, SUNMatrix M, void *user_data, N_Vector tmp1, N_Vector tmp2, N_Vector tmp3)
Setup the linear system .
void Step(Vector &x, real_t &t, real_t &dt) override
Integrate the ODE with ARKode using the specified step mode.
void SetSStolerances(double reltol, double abstol)
Set the scalar relative and scalar absolute tolerances.
static int RHSB(sunrealtype t, N_Vector y, N_Vector yB, N_Vector yBdot, void *user_dataB)
Wrapper to compute the ODE RHS backward function.
void EvalQuadIntegrationB(double t, Vector &dG_dp)
Evaluate Quadrature solution.
void EvalQuadIntegration(double t, Vector &q)
Evaluate Quadrature.
static int LinSysSetupB(sunrealtype t, N_Vector y, N_Vector yB, N_Vector fyB, SUNMatrix A, sunbooleantype jok, sunbooleantype *jcur, sunrealtype gamma, void *user_data, N_Vector tmp1, N_Vector tmp2, N_Vector tmp3)
Setup the linear system A x = b.
static int LinSysSolveB(SUNLinearSolver LS, SUNMatrix A, N_Vector x, N_Vector b, sunrealtype tol)
Solve the linear system A x = b.
static constexpr double default_abs_tolB
Default scalar backward absolute tolerance.
Definition sundials.hpp:599
void SetMaxNStepsB(int mxstepsB)
Set the maximum number of backward steps.
static int RHSQ(sunrealtype t, const N_Vector y, N_Vector qdot, void *user_data)
Wrapper to compute the ODE RHS Quadrature function.
void Step(Vector &x, double &t, double &dt) override
void InitB(TimeDependentAdjointOperator &f_)
Initialize the adjoint problem.
static int RHSQB(sunrealtype t, N_Vector y, N_Vector yB, N_Vector qBdot, void *user_dataB)
Wrapper to compute the ODE RHS Backwards Quadrature function.
SundialsNVector * q
Quadrature vector.
Definition sundials.hpp:591
int indexB
backward problem index
Definition sundials.hpp:572
void GetForwardSolution(double tB, mfem::Vector &yy)
Get Interpolated Forward solution y at backward integration time tB.
SUNLinearSolver LSB
Linear solver for A.
Definition sundials.hpp:590
void SetSVtolerancesB(double reltol, Vector abstol)
Tolerance specification functions for the adjoint problem.
void UseSundialsLinearSolverB()
Use built in SUNDIALS Newton solver.
void SetWFTolerances(EWTFunction func)
Set multiplicative error weights.
SundialsNVector * yy
State vector.
Definition sundials.hpp:593
void Init(TimeDependentAdjointOperator &f_)
SUNMatrix AB
Linear system A = I - gamma J, M - gamma J, or J.
Definition sundials.hpp:589
int ncheck
number of checkpoints used so far
Definition sundials.hpp:571
SundialsNVector * qB
State vector.
Definition sundials.hpp:594
static int ewt(N_Vector y, N_Vector w, void *user_data)
Error control function.
virtual void StepB(Vector &w, double &t, double &dt)
Solve one adjoint time step.
void InitQuadIntegrationB(mfem::Vector &qB0, double reltolQB=1e-3, double abstolQB=1e-8)
Initialize Quadrature Integration (Adjoint)
static constexpr double default_rel_tolB
Default scalar backward relative tolerance.
Definition sundials.hpp:597
void InitAdjointSolve(int steps, int interpolation)
Initialize Adjoint.
SundialsNVector * yB
State vector.
Definition sundials.hpp:592
void InitQuadIntegration(mfem::Vector &q0, double reltolQ=1e-3, double abstolQ=1e-8)
virtual ~CVODESSolver()
Destroy the associated CVODES memory and SUNDIALS objects.
void SetSStolerancesB(double reltol, double abstol)
Tolerance specification functions for the adjoint problem.
void UseMFEMLinearSolverB()
Set Linear Solver for the backward problem.
Interface to the CVODE library – linear multi-step methods.
Definition sundials.hpp:430
void SetStepMode(int itask)
Select the CVODE step mode: CV_NORMAL (default) or CV_ONE_STEP.
Definition sundials.cpp:913
void SetRootFinder(int components, RootFunction func)
Initialize Root Finder.
Definition sundials.cpp:705
std::function< int(sunrealtype t, Vector y, Vector gout, CVODESolver *)> RootFunction
Typedef for root finding functions.
Definition sundials.hpp:456
void SetSStolerances(double reltol, double abstol)
Set the scalar relative and scalar absolute tolerances.
Definition sundials.cpp:918
void Init(TimeDependentOperator &f_) override
Initialize CVODE: calls CVodeCreate() to create the CVODE memory and set some defaults.
Definition sundials.cpp:752
static int LinSysSolve(SUNLinearSolver LS, SUNMatrix A, N_Vector x, N_Vector b, sunrealtype tol)
Solve the linear system .
Definition sundials.cpp:728
virtual ~CVODESolver()
Destroy the associated CVODE memory and SUNDIALS objects.
static int root(sunrealtype t, N_Vector y, sunrealtype *gout, void *user_data)
Prototype to define root finding for CVODE.
Definition sundials.cpp:692
void SetMaxNSteps(int steps)
Set the maximum number of time steps.
Definition sundials.cpp:942
CVODESolver(int lmm)
Construct a serial wrapper to SUNDIALS' CVODE integrator.
Definition sundials.cpp:738
long GetNumSteps()
Get the number of internal steps taken so far.
Definition sundials.cpp:948
void Step(Vector &x, double &t, double &dt) override
Integrate the ODE with CVODE using the specified step mode.
Definition sundials.cpp:840
EWTFunction ewt_func
A class member to facilitate pointing to a user-specified error weight function.
Definition sundials.hpp:466
static int RHS(sunrealtype t, const N_Vector y, N_Vector ydot, void *user_data)
Number of components in gout.
Definition sundials.cpp:675
void SetMaxStep(double dt_max)
Set the maximum time step.
Definition sundials.cpp:936
int lmm_type
Linear multistep method type.
Definition sundials.hpp:432
void PrintInfo() const
Print various CVODE statistics.
Definition sundials.cpp:962
void UseSundialsLinearSolver()
Attach SUNDIALS GMRES linear solver to CVODE.
Definition sundials.cpp:898
void UseMFEMLinearSolver()
Attach the linear system setup and solve methods from the TimeDependentOperator i....
Definition sundials.cpp:867
RootFunction root_func
A class member to facilitate pointing to a user-specified root function.
Definition sundials.hpp:459
std::function< int(Vector y, Vector w, CVODESolver *)> EWTFunction
Typedef declaration for error weight functions.
Definition sundials.hpp:462
int step_mode
CVODE step mode (CV_NORMAL or CV_ONE_STEP).
Definition sundials.hpp:433
static int LinSysSetup(sunrealtype t, N_Vector y, N_Vector fy, SUNMatrix A, sunbooleantype jok, sunbooleantype *jcur, sunrealtype gamma, void *user_data, N_Vector tmp1, N_Vector tmp2, N_Vector tmp3)
Setup the linear system .
Definition sundials.cpp:713
void SetSVtolerances(double reltol, Vector abstol)
Set the scalar relative and vector of absolute tolerances.
Definition sundials.cpp:924
void SetMaxOrder(int max_order)
Set the maximum method order.
Definition sundials.cpp:956
static MemoryType GetHostMemoryType()
Get the current Host MemoryType. This is the MemoryType used by most MFEM classes when allocating mem...
Definition device.hpp:268
static bool IsAvailable()
Return true if an actual device (e.g. GPU) has been configured.
Definition device.hpp:244
static MemoryType GetDeviceMemoryType()
Get the current Device MemoryType. This is the MemoryType used by most MFEM classes when allocating m...
Definition device.hpp:277
Wrapper for hypre's parallel vector class.
Definition hypre.hpp:230
real_t abs_tol
Absolute tolerance.
Definition solvers.hpp:183
real_t rel_tol
Relative tolerance.
Definition solvers.hpp:180
const Operator * oper
Definition solvers.hpp:145
int print_level
(DEPRECATED) Legacy print level definition, which is left for compatibility with custom iterative sol...
Definition solvers.hpp:156
int max_iter
Limit for the number of iterations the solver is allowed to do.
Definition solvers.hpp:177
Interface to the KINSOL library – nonlinear solver methods.
Definition sundials.hpp:897
SundialsNVector * f_scale
scaling vectors
Definition sundials.hpp:901
KINSolver(int strategy, bool oper_grad=true)
Construct a serial wrapper to SUNDIALS' KINSOL nonlinear solver.
double aa_damping
Anderson Acceleration damping.
Definition sundials.hpp:905
void SetJFNKSolver(Solver &solver)
virtual ~KINSolver()
Destroy the associated KINSOL memory.
static int Mult(const N_Vector u, N_Vector fu, void *user_data)
Wrapper to compute the nonlinear residual .
static int LinSysSolve(SUNLinearSolver LS, SUNMatrix J, N_Vector u, N_Vector b, sunrealtype tol)
Solve the linear system .
int aa_delay
Anderson Acceleration delay.
Definition sundials.hpp:904
int global_strategy
KINSOL solution strategy.
Definition sundials.hpp:899
bool jfnk
enable JFNK
Definition sundials.hpp:908
static int PrecSolve(N_Vector uu, N_Vector uscale, N_Vector fval, N_Vector fscale, N_Vector vv, void *user_data)
Solve the preconditioner equation .
int maxlrs
Maximum linear solver restarts.
Definition sundials.hpp:911
const Operator * jacobian
stores oper->GetGradient()
Definition sundials.hpp:902
int maxli
Maximum linear iterations.
Definition sundials.hpp:910
Vector wrk
Work vector needed for the JFNK PC.
Definition sundials.hpp:909
void SetScaledStepTol(double sstol)
Set KINSOL's scaled step tolerance.
void SetSolver(Solver &solver) override
Set the linear solver for inverting the Jacobian.
SundialsNVector * y_scale
Definition sundials.hpp:901
void SetOperator(const Operator &op) override
Set the nonlinear Operator of the system and initialize KINSOL.
void SetMaxSetupCalls(int max_calls)
Set maximum number of nonlinear iterations without a Jacobian update.
static int LinSysSetup(N_Vector u, N_Vector fu, SUNMatrix J, void *user_data, N_Vector tmp1, N_Vector tmp2)
Setup the linear system .
int aa_orth
Anderson Acceleration orthogonalization routine.
Definition sundials.hpp:906
static int PrecSetup(N_Vector uu, N_Vector uscale, N_Vector fval, N_Vector fscale, void *user_data)
Setup the preconditioner.
void SetDamping(double damping)
double fp_damping
Fixed Point or Picard damping parameter.
Definition sundials.hpp:907
void EnableAndersonAcc(int n, int orth=KIN_ORTH_MGS, int delay=0, double damping=1.0)
Enable Anderson Acceleration for KIN_FP or KIN_PICARD.
void SetPrintLevel(int print_lvl) override
Set the print level for the KINSetPrintLevel function.
bool use_oper_grad
use the Jv prod function
Definition sundials.hpp:900
int aa_n
number of acceleration vectors
Definition sundials.hpp:903
static int GradientMult(N_Vector v, N_Vector Jv, N_Vector u, sunbooleantype *new_u, void *user_data)
Wrapper to compute the Jacobian-vector product .
bool IsKnown(const void *h_ptr)
Return true if the pointer is known by the memory manager.
Class used by MFEM to store pointers to host and/or device memory.
void SetHostPtrOwner(bool own) const
Set/clear the ownership flag for the host pointer. Ownership indicates whether the pointer will be de...
void SetDevicePtrOwner(bool own) const
Set/clear the ownership flag for the device pointer. Ownership indicates whether the pointer will be ...
void Wrap(T *ptr, int size, bool own)
Wrap an externally allocated host pointer, ptr with the current host memory type returned by MemoryMa...
void Delete()
Delete the owned pointers and reset the Memory object.
void ClearOwnerFlags() const
Clear the ownership flags for the host and device pointers, as well as any internal data allocated by...
void SetOperator(const Operator &op) override
Also calls SetOperator for the preconditioner if there is one.
Definition solvers.cpp:2048
TimeDependentOperator * f
Pointer to the associated TimeDependentOperator.
Definition ode.hpp:124
virtual void Init(TimeDependentOperator &f_)
Associate a TimeDependentOperator with the ODE solver.
Definition ode.cpp:181
Abstract operator.
Definition operator.hpp:25
int Height() const
Get the height (size of output) of the Operator. Synonym with NumRows().
Definition operator.hpp:66
int height
Dimension of the output / number of rows in the matrix.
Definition operator.hpp:27
virtual void Mult(const Vector &x, Vector &y) const =0
Operator application: y=A(x).
virtual Operator & GetGradient(const Vector &x) const
Evaluate the gradient operator at the point x. The default behavior in class Operator is to generate ...
Definition operator.hpp:134
Base class for solvers.
Definition operator.hpp:792
bool iterative_mode
If true, use the second argument of Mult() as an initial guess.
Definition operator.hpp:795
virtual void SetOperator(const Operator &op)=0
Set/update the solver for the given operator.
SundialsMemHelper & operator=(const SundialsMemHelper &)=delete
Disable copy assignment.
SundialsMemHelper()=default
Default constructor – object must be moved to.
static int SundialsMemHelper_Alloc(SUNMemoryHelper helper, SUNMemory *memptr, size_t memsize, SUNMemoryType mem_type #if(SUNDIALS_VERSION_MAJOR >=6), void *queue #endif)
Definition sundials.cpp:282
static int SundialsMemHelper_Dealloc(SUNMemoryHelper helper, SUNMemory sunmem #if(SUNDIALS_VERSION_MAJOR >=6), void *queue #endif)
Definition sundials.cpp:326
Vector interface for SUNDIALS N_Vectors.
Definition sundials.hpp:222
long GlobalSize() const
Returns the MPI global length for the internal N_Vector x.
Definition sundials.hpp:288
MPI_Comm GetComm() const
Returns the MPI communicator for the internal N_Vector x.
Definition sundials.hpp:278
void MakeRef(Vector &base, int offset, int s)
Reset the Vector to be a reference to a sub-vector of base.
Definition sundials.hpp:305
void SetSize(int s, long glob_size=0)
Resize the vector to size s.
Definition sundials.cpp:559
static bool UseManagedMemory()
Definition sundials.hpp:359
static N_Vector MakeNVector(bool use_device)
Create a N_Vector.
Definition sundials.cpp:577
~SundialsNVector()
Calls SUNDIALS N_VDestroy function if the N_Vector is owned by 'this'.
Definition sundials.cpp:545
void SetDataAndSize(double *d, int s, long glob_size=0)
Set the vector data and size.
Definition sundials.cpp:571
void _SetNvecDataAndSize_(long glob_size=0)
Set data and length of internal N_Vector x from 'this'.
Definition sundials.cpp:364
void _SetDataAndSize_()
Set data and length from the internal N_Vector x.
Definition sundials.cpp:438
N_Vector x
The actual SUNDIALS object.
Definition sundials.hpp:227
void SetData(double *d)
Definition sundials.cpp:565
bool MPIPlusX() const
Definition sundials.hpp:341
SundialsNVector()
Creates an empty SundialsNVector.
Definition sundials.cpp:486
static constexpr double default_abs_tol
Default scalar absolute tolerance.
Definition sundials.hpp:399
SUNMatrix M
Mass matrix M.
Definition sundials.hpp:384
int flag
Last flag returned from a call to SUNDIALS.
Definition sundials.hpp:377
long saved_global_size
Global vector length on last initialization.
Definition sundials.hpp:379
static constexpr double default_rel_tol
Default scalar relative tolerance.
Definition sundials.hpp:397
bool reinit
Flag to signal memory reinitialization is need.
Definition sundials.hpp:378
SundialsNVector * Y
State vector.
Definition sundials.hpp:381
void * sundials_mem
SUNDIALS mem structure.
Definition sundials.hpp:376
SUNLinearSolver LSA
Linear solver for A.
Definition sundials.hpp:385
SUNLinearSolver LSM
Linear solver for M.
Definition sundials.hpp:386
SUNNonlinearSolver NLS
Nonlinear solver.
Definition sundials.hpp:387
bool Parallel() const
Definition sundials.hpp:390
Singleton class for SUNContext and SundialsMemHelper objects.
Definition sundials.hpp:176
static SUNContext & GetContext()
Provides access to the SUNContext object.
Definition sundials.cpp:185
static SundialsMemHelper & GetMemHelper()
Provides access to the SundialsMemHelper object.
Definition sundials.cpp:190
static void Finalize()
Finalize sundials (called automatically at program exit if Sundials::Init() has been called).
Definition sundials.cpp:226
static void Init()
Initializes SUNContext and SundialsMemHelper objects. Should be called at the beginning of the callin...
Definition sundials.cpp:174
int GetAdjointHeight()
Returns the size of the adjoint problem state space.
Definition operator.hpp:730
Base abstract class for first order time dependent operators.
Definition operator.hpp:344
bool isExplicit() const
True if type is EXPLICIT.
Definition operator.hpp:409
virtual int SUNImplicitSolve(const Vector &r, Vector &dk, real_t tol)
Solve the ODE linear system A dk = r , where A and r are defined by the method SUNImplicitSetup().
Definition operator.cpp:327
virtual int SUNMassMult(const Vector &x, Vector &v)
Compute the mass matrix-vector product v = M x, where M is defined by the method SUNMassSetup().
Definition operator.cpp:345
virtual int SUNMassSetup()
Setup the mass matrix in the ODE system .
Definition operator.cpp:333
void Mult(const Vector &u, Vector &k) const override
Perform the action of the operator (u,t) -> k(u,t) where t is the current time set by SetTime() and k...
Definition operator.cpp:293
virtual int SUNMassSolve(const Vector &b, Vector &x, real_t tol)
Solve the mass matrix linear system M x = b, where M is defined by the method SUNMassSetup().
Definition operator.cpp:339
virtual void ExplicitMult(const Vector &u, Vector &v) const
Perform the action of the explicit part of the operator, G: v = G(u, t) where t is the current time.
Definition operator.cpp:282
virtual void SetEvalMode(const EvalMode new_eval_mode)
Set the evaluation mode of the time-dependent operator.
Definition operator.hpp:429
virtual void SetTime(const real_t t_)
Set the current time.
Definition operator.hpp:406
virtual int SUNImplicitSetup(const Vector &y, const Vector &v, int jok, int *jcur, real_t gamma)
Setup a linear system as needed by some SUNDIALS ODE solvers to perform a similar action to ImplicitS...
Definition operator.cpp:319
virtual real_t GetTime() const
Read the currently set time.
Definition operator.hpp:403
Vector data type.
Definition vector.hpp:82
virtual const real_t * HostRead() const
Shortcut for mfem::Read(vec.GetMemory(), vec.Size(), false).
Definition vector.hpp:524
virtual real_t * ReadWrite(bool on_dev=true)
Shortcut for mfem::ReadWrite(vec.GetMemory(), vec.Size(), on_dev).
Definition vector.hpp:536
void SetDataAndSize(real_t *d, int s)
Set the Vector data and size.
Definition vector.hpp:191
real_t Normlinf() const
Returns the l_infinity norm of the vector.
Definition vector.cpp:1004
Memory< real_t > data
Definition vector.hpp:85
Vector & Set(const real_t a, const Vector &x)
(*this) = a * x
Definition vector.cpp:341
virtual bool UseDevice() const
Return the device flag of the Memory object used by the Vector.
Definition vector.hpp:148
int Size() const
Returns the size of the vector.
Definition vector.hpp:234
void SetSize(int s)
Resize the vector to size s.
Definition vector.hpp:584
void SetData(real_t *d)
Definition vector.hpp:184
virtual real_t * HostReadWrite()
Shortcut for mfem::ReadWrite(vec.GetMemory(), vec.Size(), false).
Definition vector.hpp:540
real_t b
Definition lissajous.cpp:42
T * HostReadWrite(Memory< T > &mem, int size)
Shortcut to ReadWrite(Memory<T> &mem, int size, false)
Definition device.hpp:389
real_t u(const Vector &xvec)
Definition lor_mms.hpp:22
OutStream out(std::cout)
Global stream used by the library for standard output. Initially it uses the same std::streambuf as s...
Definition globals.hpp:66
T * ReadWrite(Memory< T > &mem, int size, bool on_dev=true)
Get a pointer for read+write access to mem with the mfem::Device's DeviceMemoryClass,...
Definition device.hpp:382
MemoryManager mm
The (single) global memory manager object.
float real_t
Definition config.hpp:46
MFEM_HOST_DEVICE real_t norm(const Complex &z)
Settings for the output behavior of the IterativeSolver.
Definition solvers.hpp:103
Helper struct to convert a C++ type to an MPI type.
MFEM_DEPRECATED SUNLinearSolver SUNLinSolNewEmpty(SUNContext)
Definition sundials.cpp:68
MFEM_DEPRECATED N_Vector N_VMake_MPIPlusX(MPI_Comm comm, N_Vector local_vector, SUNContext)
Definition sundials.cpp:150
MFEM_DEPRECATED void * KINCreate(SUNContext)
Definition sundials.cpp:106
MFEM_DEPRECATED SUNLinearSolver SUNLinSol_SPFGMR(N_Vector y, int pretype, int maxl, SUNContext)
Definition sundials.cpp:83
MFEM_DEPRECATED void * CVodeCreate(int lmm, SUNContext)
Definition sundials.cpp:91
MFEM_DEPRECATED N_Vector N_VNewEmpty_Parallel(MPI_Comm comm, sunindextype local_length, sunindextype global_length, SUNContext)
Definition sundials.cpp:115
MFEM_DEPRECATED SUNMatrix SUNMatNewEmpty(SUNContext)
Definition sundials.cpp:61
MFEM_DEPRECATED SUNMemoryHelper SUNMemoryHelper_NewEmpty(SUNContext)
Definition sundials.cpp:139
MFEM_DEPRECATED N_Vector SUN_Hip_OR_Cuda N_VNewWithMemHelp(sunindextype length, sunbooleantype use_managed_mem, SUNMemoryHelper helper, SUNContext)
Definition sundials.cpp:129
MFEM_DEPRECATED void * ARKStepCreate(ARKRhsFn fe, ARKRhsFn fi, sunrealtype t0, N_Vector y0, SUNContext)
Definition sundials.cpp:98
MFEM_DEPRECATED N_Vector N_VNewEmpty_Serial(sunindextype vec_length, SUNContext)
Definition sundials.cpp:54
MFEM_DEPRECATED SUNLinearSolver SUNLinSol_SPGMR(N_Vector y, int pretype, int maxl, SUNContext)
Definition sundials.cpp:75
int ARKODE_DIRKTableID
Definition sundials.hpp:66
realtype sunrealtype
'sunrealtype' was first introduced in v6.0.0
Definition sundials.hpp:76
int ARKODE_ERKTableID
Definition sundials.hpp:65
@ SUN_PREC_NONE
Definition sundials.hpp:81
@ SUN_PREC_RIGHT
Definition sundials.hpp:81
booleantype sunbooleantype
'sunbooleantype' was first introduced in v6.0.0
Definition sundials.hpp:78
constexpr ARKODE_ERKTableID ARKODE_ERK_NONE
Definition sundials.hpp:67
constexpr ARKODE_DIRKTableID ARKODE_DIRK_NONE
Definition sundials.hpp:68
@ SUN_SUCCESS
Definition sundials.hpp:95
void * SUNContext
Definition sundials.hpp:73