MFEM v4.7.0
Finite element discretization library
Loading...
Searching...
No Matches
amgxsolver.hpp
Go to the documentation of this file.
1// Copyright (c) 2010-2024, Lawrence Livermore National Security, LLC. Produced
2// at the Lawrence Livermore National Laboratory. All Rights reserved. See files
3// LICENSE and NOTICE for details. LLNL-CODE-806117.
4//
5// This file is part of the MFEM library. For more information and source code
6// availability visit https://mfem.org.
7//
8// MFEM is free software; you can redistribute it and/or modify it under the
9// terms of the BSD-3 license. We welcome feedback and contributions, see file
10// CONTRIBUTING.md for details.
11
12#ifndef MFEM_AMGX_SOLVER
13#define MFEM_AMGX_SOLVER
14
15#include "../config/config.hpp"
16
17#ifdef MFEM_USE_AMGX
18
19#include <amgx_c.h>
20#ifdef MFEM_USE_MPI
21#include <mpi.h>
22#include "hypre.hpp"
23#else
24#include "operator.hpp"
25#include "sparsemat.hpp"
26#endif
27
28namespace mfem
29{
30
31/** @brief
32 MFEM wrapper for Nvidia's multigrid library, AmgX (github.com/NVIDIA/AMGX)
33
34 AmgX requires building MFEM with CUDA, and AMGX enabled. For distributed
35 memory parallism, MPI and Hypre (version 16.0+) are also required. Although
36 CUDA is required for building, the AmgX solver is compatible with a MFEM CPU
37 device configuration.
38
39 The AmgXSolver class is designed to work as a solver or preconditioner for
40 existing MFEM solvers. The AmgX solver class may be configured in one of
41 three ways:
42
43 Serial - Takes a SparseMatrix solves on a single GPU and assumes no MPI
44 communication.
45
46 Exclusive GPU - Takes a HypreParMatrix and assumes each MPI rank is paired
47 with an Nvidia GPU.
48
49 MPI Teams - Takes a HypreParMatrix and enables flexibility between number of
50 MPI ranks, and GPUs. Specifically, MPI ranks are grouped with GPUs and a
51 matrix consolidation step is taken so the MPI root of each team performs the
52 necessary AmgX library calls. The solution is then broadcasted to appropriate
53 ranks. This is particularly useful when configuring MFEM's device as
54 CPU. This work is based on the AmgXWrapper of Chuang and Barba. Routines were
55 adopted and modified for setting up MPI communicators.
56
57 Examples 1/1p in the examples/amgx directory demonstrate configuring the
58 wrapper as a solver and preconditioner, as well as configuring and running
59 with exclusive GPU or MPI teams modes.
60
61 This work is partially based on:
62
63 Pi-Yueh Chuang and Lorena A. Barba (2017).
64 AmgXWrapper: An interface between PETSc and the NVIDIA AmgX library.
65 J. Open Source Software, 2(16):280, doi:10.21105/joss.00280
66
67 See https://github.com/barbagroup/AmgXWrapper.
68*/
69class AmgXSolver : public Solver
70{
71public:
72
73 /// Flags to configure AmgXSolver as a solver or preconditioner
75 {
76 /// Use the preconditioned conjugate gradient method with the AMG
77 /// V-cycle used as a proconditioner. With the default configuration
78 /// a block Jacobi smoother is used.
80 /// Directly apply iterations of the AMG V cycle to the matrix
81 /// With the default configuration this will be 2 iterations
82 /// with block Jacobi smoother.
84 };
85
86 /// Flag to check for convergence
88
89 /** @brief
90 Flags to determine whether user solver settings are defined internally in
91 the source code or will be read through an external JSON file.
92 */
94 {
95 /// Configuration will be read directly from a string
97 /// Configure will be read from a specified file
100 };
101
102 AmgXSolver();
103
104 /** @brief
105 Configures AmgX with a default configuration based on the AmgX mode, and
106 verbosity. Assumes no MPI parallelism.
107 */
108 AmgXSolver(const AMGX_MODE amgxMode_, const bool verbose);
109
110 /** @brief Initialize the AmgX library for serial execution once
111 the solver configuration has been established through either the
112 AmgXSolver::ReadParameters method or the constructor. The constructor
113 will make this call.
114 */
115 void InitSerial();
116
117#ifdef MFEM_USE_MPI
118
119 /** @brief
120 Configures AmgX with a default configuration based on the AMGX_MODE
121 (AmgXSolver::SOLVER, AmgXSolver::PRECONDITIONER)
122 and verbosity. Pairs each MPI rank with one GPU.
123 */
124 AmgXSolver(const MPI_Comm &comm, const AMGX_MODE amgxMode_,
125 const bool verbose);
126
127 /** @brief
128 Configures AmgX with a default configuration based on the AMGX_MODE
129 (AmgXSolver::SOLVER, AmgXSolver::PRECONDITIONER)
130 and verbosity. Creates MPI teams around GPUs to support more ranks than
131 GPUs. Consolidates linear solver data to avoid multiple ranks sharing
132 GPUs. Requires specifying the number of devices in each compute node as
133 @a nDevs.
134 */
135 AmgXSolver(const MPI_Comm &comm, const int nDevs,
136 const AMGX_MODE amgx_Mode_, const bool verbose);
137
138 /** @brief Initialize the AmgX library in parallel mode with exactly one
139 GPU per rank after the solver configuration has been established,
140 either through the constructor or the AmgXSolver::ReadParameters
141 method. If configuring with a constructor, the constructor will make
142 this call.
143 */
144 void InitExclusiveGPU(const MPI_Comm &comm);
145
146 /** @brief Initialize the AmgX library and create MPI teams based on the
147 number of devices on each node @a nDevs. If configuring with a
148 constructor, the constructor will make this call, otherwise this will need
149 to be called after the solver configuration has been established through
150 the AmgXSolver::ReadParameters call.
151 */
152 void InitMPITeams(const MPI_Comm &comm,
153 const int nDevs);
154#endif
155
156 /** @brief Sets the Operator that is going to be solved via AmgX.
157 Supports operators based on either an MFEM SparseMatrix or
158 HypreParMatrix.
159 */
160 virtual void SetOperator(const Operator &op);
161
162 /** @brief Change the input operator that is being solved via AmgX.
163 Supports operators based on either an MFEM SparseMatrix or
164 HypreParMatrix.
165 */
166 void UpdateOperator(const Operator &op);
167
168 /** @brief Utilize the AmgX library to solve the linear system
169 where the "matrix" is the AMG approximation to the operator set
170 by AmgXSolver::SetOperator. If the mode is set to
171 AmgXSolver::PRECONDITIONER the initial guess for the
172 @a x vector will be set to zero, otherwise the value of @a x passed
173 in will be used.
174 */
175 virtual void Mult(const Vector& b, Vector& x) const;
176
177 /** @brief Return the number of iterations that were executed during the
178 last solve phase. */
179 int GetNumIterations();
180
181 /** @brief Read in the AmgX parameters either through a file or directly
182 through a properly formated string. If @a source is set to
183 AmgXSolver::EXTERNAL the parameters are loaded from a filename set by
184 @a config. If @a source is set to AmgXSolver::INTERNAL the parameters
185 are set directly by the string defined by @a config.
186 */
187 void ReadParameters(const std::string config, CONFIG_SRC source);
188
189 /** @brief Set up the AmgX library with the default paramaters.
190 @param [in] amgxMode_ AmgXSolver::PRECONDITIONER,
191 AmgXSolver::SOLVER.
192
193 @param [in] verbose true, false. Specifies the level
194 of verbosity.
195
196 When configured as a preconditioner, the default configuration is to apply
197 two iterations of an AMG V cycle with AmgX's default smoother (block
198 Jacobi).
199
200 When configured as a solver the preconditioned conjugate gradient method
201 is used with the AMG V-cycle and a block Jacobi smoother is used as a
202 preconditioner.
203 */
204 void DefaultParameters(const AMGX_MODE amgxMode_, const bool verbose);
205
206 /// Add a check for convergence after applying Mult.
207 void SetConvergenceCheck(bool setConvergenceCheck_=true);
208
209 /// Close down the AmgX library and free up any MPI Comms set up for it
210 ~AmgXSolver();
211
212 /// Close down the AmgX library and free up any MPI Comms set up for it
213 void Finalize();
214
215private:
216
217 AMGX_MODE amgxMode;
218
219 std::string amgx_config = "";
220
221 CONFIG_SRC configSrc = UNDEFINED;
222
223#ifdef MFEM_USE_MPI
224 /** @brief Consolidates matrix diagonal and off diagonal data and uploads
225 matrix to AmgX. */
226 void SetMatrixMPIGPUExclusive(const HypreParMatrix &A,
227 const Array<double> &loc_A,
228 const Array<int> &loc_I,
229 const Array<int64_t> &loc_J,
230 const bool update_mat = false);
231
232 /** @brief Consolidates matrix diagonal and off diagonal data for all ranks
233 in an MPI team. Root rank of each MPI team holds the consolidated
234 data and matrix. */
235 void SetMatrixMPITeams(const HypreParMatrix &A, const Array<double> &loc_A,
236 const Array<int> &loc_I, const Array<int64_t> &loc_J,
237 const bool update_mat = false);
238
239 /// Consolidate array data to the root node in a MPI team.
240 void GatherArray(const Array<double> &inArr, Array<double> &outArr,
241 const int mpiTeamSz, const MPI_Comm &mpiTeam) const;
242
243 /// Consolidate array data to the root node in a MPI team.
244 void GatherArray(const Vector &inArr, Vector &outArr,
245 const int mpiTeamSz, const MPI_Comm &mpiTeam) const;
246
247 /// Consolidate array data to the root node in a MPI team.
248 void GatherArray(const Array<int> &inArr, Array<int> &outArr,
249 const int mpiTeamSz, const MPI_Comm &mpiTeam) const;
250
251 /// Consolidate array data to the root node in a MPI team.
252 void GatherArray(const Array<int64_t> &inArr, Array<int64_t> &outArr,
253 const int mpiTeamSz, const MPI_Comm &mpiTeam) const;
254
255 /** @brief Consolidate array data to the root node in a MPI
256 team as well as store array partitions and displacements in
257 @a Apart and @a Adisp.
258 */
259 void GatherArray(const Vector &inArr, Vector &outArr,
260 const int mpiTeamSz, const MPI_Comm &mpiTeamComm,
261 Array<int> &Apart, Array<int> &Adisp) const;
262
263 void ScatterArray(const Vector &inArr, Vector &outArr,
264 const int mpiTeamSz, const MPI_Comm &mpi_comm,
265 Array<int> &Apart, Array<int> &Adisp) const;
266
267 void SetMatrix(const HypreParMatrix &A, const bool update_mat = false);
268#endif
269
270 void SetMatrix(const SparseMatrix &A, const bool update_mat = false);
271
272 static int count;
273
274 // Indicate if this instance has been initialized.
275 bool isInitialized = false;
276
277#ifdef MFEM_USE_MPI
278 // The name of the node that this MPI process belongs to.
279 std::string nodeName;
280
281 // Number of local GPU devices used by AmgX.
282 int nDevs;
283
284 // The ID of corresponding GPU device used by this MPI process.
285 int devID;
286
287 // A flag indicating if this process will invoke AmgX
288 int gpuProc = MPI_UNDEFINED;
289
290 // Communicator for all MPI ranks
291 MPI_Comm globalCpuWorld = MPI_COMM_NULL;
292
293 // Communicator for ranks in same node
294 MPI_Comm localCpuWorld;
295
296 // Communicator for ranks sharing a device
297 MPI_Comm devWorld;
298
299 // A communicator for MPI processes that will launch AmgX (root of devWorld)
300 MPI_Comm gpuWorld;
301
302 // Global number of MPI procs + rank id
303 int globalSize;
304
305 int myGlobalRank;
306
307 // Total number of MPI procs in a node + rank id
308 int localSize;
309
310 int myLocalRank;
311
312 // Total number of MPI ranks sharing a device + rank id
313 int devWorldSize;
314
315 int myDevWorldRank;
316
317 // Total number of MPI procs calling AmgX + rank id
318 int gpuWorldSize;
319
320 int myGpuWorldRank;
321#endif
322
323 // A parameter used by AmgX.
324 int ring;
325
326 // Sets AmgX precision (currently on double is supported)
327 AMGX_Mode precision_mode = AMGX_mode_dDDI;
328
329 // AmgX config object.
330 AMGX_config_handle cfg = nullptr;
331
332 // AmgX matrix object.
333 AMGX_matrix_handle AmgXA = nullptr;
334
335 // AmgX vector object representing unknowns.
336 AMGX_vector_handle AmgXP = nullptr;
337
338 // AmgX vector object representing RHS.
339 AMGX_vector_handle AmgXRHS = nullptr;
340
341 // AmgX solver object.
342 AMGX_solver_handle solver = nullptr;
343
344 // AmgX resource object.
345 static AMGX_resources_handle rsrc;
346
347 /// Set the ID of the corresponding GPU used by this process.
348 void SetDeviceIDs(const int nDevs);
349
350 /// Initialize all MPI communicators.
351#ifdef MFEM_USE_MPI
352 void InitMPIcomms(const MPI_Comm &comm, const int nDevs);
353#endif
354
355 void InitAmgX();
356
357 // Row partition for the HypreParMatrix
358 int64_t mat_local_rows;
359
360 std::string mpi_gpu_mode;
361};
362}
363#endif // MFEM_USE_AMGX
364#endif // MFEM_AMGX_SOLVER
MFEM wrapper for Nvidia's multigrid library, AmgX (github.com/NVIDIA/AMGX)
int GetNumIterations()
Return the number of iterations that were executed during the last solve phase.
bool ConvergenceCheck
Flag to check for convergence.
CONFIG_SRC
Flags to determine whether user solver settings are defined internally in the source code or will be ...
@ EXTERNAL
Configure will be read from a specified file.
@ INTERNAL
Configuration will be read directly from a string.
void Finalize()
Close down the AmgX library and free up any MPI Comms set up for it.
void DefaultParameters(const AMGX_MODE amgxMode_, const bool verbose)
Set up the AmgX library with the default paramaters.
~AmgXSolver()
Close down the AmgX library and free up any MPI Comms set up for it.
virtual void SetOperator(const Operator &op)
Sets the Operator that is going to be solved via AmgX. Supports operators based on either an MFEM Spa...
void InitSerial()
Initialize the AmgX library for serial execution once the solver configuration has been established t...
AMGX_MODE
Flags to configure AmgXSolver as a solver or preconditioner.
void SetConvergenceCheck(bool setConvergenceCheck_=true)
Add a check for convergence after applying Mult.
void InitMPITeams(const MPI_Comm &comm, const int nDevs)
Initialize the AmgX library and create MPI teams based on the number of devices on each node nDevs....
virtual void Mult(const Vector &b, Vector &x) const
Utilize the AmgX library to solve the linear system where the "matrix" is the AMG approximation to th...
void ReadParameters(const std::string config, CONFIG_SRC source)
Read in the AmgX parameters either through a file or directly through a properly formated string....
void InitExclusiveGPU(const MPI_Comm &comm)
Initialize the AmgX library in parallel mode with exactly one GPU per rank after the solver configura...
void UpdateOperator(const Operator &op)
Change the input operator that is being solved via AmgX. Supports operators based on either an MFEM S...
Wrapper for hypre's ParCSR matrix class.
Definition hypre.hpp:388
Abstract operator.
Definition operator.hpp:25
Base class for solvers.
Definition operator.hpp:683
Data type sparse matrix.
Definition sparsemat.hpp:51
Vector data type.
Definition vector.hpp:80
void source(const Vector &x, Vector &f)
Definition ex25.cpp:620
real_t b
Definition lissajous.cpp:42