12#ifndef MFEM_AMGX_SOLVER
13#define MFEM_AMGX_SOLVER
135 AmgXSolver(
const MPI_Comm &comm,
const int nDevs,
136 const AMGX_MODE amgx_Mode_,
const bool verbose);
219 std::string amgx_config =
"";
230 const bool update_mat =
false);
237 const bool update_mat =
false);
241 const int mpiTeamSz,
const MPI_Comm &mpiTeam)
const;
245 const int mpiTeamSz,
const MPI_Comm &mpiTeam)
const;
249 const int mpiTeamSz,
const MPI_Comm &mpiTeam)
const;
253 const int mpiTeamSz,
const MPI_Comm &mpiTeam)
const;
260 const int mpiTeamSz,
const MPI_Comm &mpiTeamComm,
264 const int mpiTeamSz,
const MPI_Comm &mpi_comm,
267 void SetMatrix(
const HypreParMatrix &A,
const bool update_mat =
false);
270 void SetMatrix(
const SparseMatrix &A,
const bool update_mat =
false);
275 bool isInitialized =
false;
279 std::string nodeName;
288 int gpuProc = MPI_UNDEFINED;
291 MPI_Comm globalCpuWorld = MPI_COMM_NULL;
294 MPI_Comm localCpuWorld;
327 AMGX_Mode precision_mode = AMGX_mode_dDDI;
330 AMGX_config_handle cfg =
nullptr;
333 AMGX_matrix_handle AmgXA =
nullptr;
336 AMGX_vector_handle AmgXP =
nullptr;
339 AMGX_vector_handle AmgXRHS =
nullptr;
342 AMGX_solver_handle solver =
nullptr;
345 static AMGX_resources_handle rsrc;
348 void SetDeviceIDs(
const int nDevs);
352 void InitMPIcomms(
const MPI_Comm &comm,
const int nDevs);
358 int64_t mat_local_rows;
360 std::string mpi_gpu_mode;
MFEM wrapper for Nvidia's multigrid library, AmgX (github.com/NVIDIA/AMGX)
int GetNumIterations()
Return the number of iterations that were executed during the last solve phase.
bool ConvergenceCheck
Flag to check for convergence.
CONFIG_SRC
Flags to determine whether user solver settings are defined internally in the source code or will be ...
@ EXTERNAL
Configure will be read from a specified file.
@ INTERNAL
Configuration will be read directly from a string.
void Finalize()
Close down the AmgX library and free up any MPI Comms set up for it.
void DefaultParameters(const AMGX_MODE amgxMode_, const bool verbose)
Set up the AmgX library with the default paramaters.
~AmgXSolver()
Close down the AmgX library and free up any MPI Comms set up for it.
virtual void SetOperator(const Operator &op)
Sets the Operator that is going to be solved via AmgX. Supports operators based on either an MFEM Spa...
void InitSerial()
Initialize the AmgX library for serial execution once the solver configuration has been established t...
AMGX_MODE
Flags to configure AmgXSolver as a solver or preconditioner.
void SetConvergenceCheck(bool setConvergenceCheck_=true)
Add a check for convergence after applying Mult.
void InitMPITeams(const MPI_Comm &comm, const int nDevs)
Initialize the AmgX library and create MPI teams based on the number of devices on each node nDevs....
virtual void Mult(const Vector &b, Vector &x) const
Utilize the AmgX library to solve the linear system where the "matrix" is the AMG approximation to th...
void ReadParameters(const std::string config, CONFIG_SRC source)
Read in the AmgX parameters either through a file or directly through a properly formated string....
void InitExclusiveGPU(const MPI_Comm &comm)
Initialize the AmgX library in parallel mode with exactly one GPU per rank after the solver configura...
void UpdateOperator(const Operator &op)
Change the input operator that is being solved via AmgX. Supports operators based on either an MFEM S...
Wrapper for hypre's ParCSR matrix class.
void source(const Vector &x, Vector &f)