MFEM v4.7.0
Finite element discretization library
Loading...
Searching...
No Matches
amgxsolver.cpp
Go to the documentation of this file.
1// Copyright (c) 2010-2024, Lawrence Livermore National Security, LLC. Produced
2// at the Lawrence Livermore National Laboratory. All Rights reserved. See files
3// LICENSE and NOTICE for details. LLNL-CODE-806117.
4//
5// This file is part of the MFEM library. For more information and source code
6// availability visit https://mfem.org.
7//
8// MFEM is free software; you can redistribute it and/or modify it under the
9// terms of the BSD-3 license. We welcome feedback and contributions, see file
10// CONTRIBUTING.md for details.
11
12// Implementation of the MFEM wrapper for Nvidia's multigrid library, AmgX
13//
14// This work is partially based on:
15//
16// Pi-Yueh Chuang and Lorena A. Barba (2017).
17// AmgXWrapper: An interface between PETSc and the NVIDIA AmgX library.
18// J. Open Source Software, 2(16):280, doi:10.21105/joss.00280
19//
20// See https://github.com/barbagroup/AmgXWrapper.
21
22#include "../config/config.hpp"
23#include "amgxsolver.hpp"
24#ifdef MFEM_USE_AMGX
25
26#ifdef MFEM_USE_MPI
28#endif
29
30namespace mfem
31{
32
33int AmgXSolver::count = 0;
34
35AMGX_resources_handle AmgXSolver::rsrc = nullptr;
36
38 : ConvergenceCheck(false) {};
39
40AmgXSolver::AmgXSolver(const AMGX_MODE amgxMode_, const bool verbose)
41{
42 amgxMode = amgxMode_;
43
44 if (amgxMode == AmgXSolver::SOLVER) { ConvergenceCheck = true;}
45 else { ConvergenceCheck = false;}
46
47 DefaultParameters(amgxMode, verbose);
48
49 InitSerial();
50}
51
52#ifdef MFEM_USE_MPI
53
54AmgXSolver::AmgXSolver(const MPI_Comm &comm,
55 const AMGX_MODE amgxMode_, const bool verbose)
56{
57 std::string config;
58 amgxMode = amgxMode_;
59
60 if (amgxMode == AmgXSolver::SOLVER) { ConvergenceCheck = true;}
61 else { ConvergenceCheck = false;}
62
63 DefaultParameters(amgxMode, verbose);
64
65 InitExclusiveGPU(comm);
66}
67
68AmgXSolver::AmgXSolver(const MPI_Comm &comm, const int nDevs,
69 const AMGX_MODE amgxMode_, const bool verbose)
70{
71 std::string config;
72 amgxMode = amgxMode_;
73
74 if (amgxMode == AmgXSolver::SOLVER) { ConvergenceCheck = true;}
75 else { ConvergenceCheck = false;}
76
77 DefaultParameters(amgxMode_, verbose);
78
79 InitMPITeams(comm, nDevs);
80}
81
82#endif
83
85{
86 if (isInitialized) { Finalize(); }
87}
88
90{
91 count++;
92
93 mpi_gpu_mode = "serial";
94
95 AMGX_SAFE_CALL(AMGX_initialize());
96
97 AMGX_SAFE_CALL(AMGX_initialize_plugins());
98
99 AMGX_SAFE_CALL(AMGX_install_signal_handler());
100
101 MFEM_VERIFY(configSrc != CONFIG_SRC::UNDEFINED,
102 "AmgX configuration is not defined \n");
103
104 if (configSrc == CONFIG_SRC::EXTERNAL)
105 {
106 AMGX_SAFE_CALL(AMGX_config_create_from_file(&cfg, amgx_config.c_str()));
107 }
108 else
109 {
110 AMGX_SAFE_CALL(AMGX_config_create(&cfg, amgx_config.c_str()));
111 }
112
113 AMGX_SAFE_CALL(AMGX_resources_create_simple(&rsrc, cfg));
114 AMGX_SAFE_CALL(AMGX_solver_create(&solver, rsrc, precision_mode, cfg));
115 AMGX_SAFE_CALL(AMGX_matrix_create(&AmgXA, rsrc, precision_mode));
116 AMGX_SAFE_CALL(AMGX_vector_create(&AmgXP, rsrc, precision_mode));
117 AMGX_SAFE_CALL(AMGX_vector_create(&AmgXRHS, rsrc, precision_mode));
118
119 isInitialized = true;
120}
121
122#ifdef MFEM_USE_MPI
123
124void AmgXSolver::InitExclusiveGPU(const MPI_Comm &comm)
125{
126 // If this instance has already been initialized, skip
127 if (isInitialized)
128 {
129 mfem_error("This AmgXSolver instance has been initialized on this process.");
130 }
131
132 // Note that every MPI rank may talk to a GPU
133 mpi_gpu_mode = "mpi-gpu-exclusive";
134 gpuProc = 0;
135
136 // Increment number of AmgX instances
137 count++;
138
139 MPI_Comm_dup(comm, &gpuWorld);
140 MPI_Comm_size(gpuWorld, &gpuWorldSize);
141 MPI_Comm_rank(gpuWorld, &myGpuWorldRank);
142
143 // Each rank will only see 1 device call it device 0
144 nDevs = 1, devID = 0;
145
146 InitAmgX();
147
148 isInitialized = true;
149}
150
151// Initialize for MPI ranks > GPUs, all devices are visible to all of the MPI
152// ranks
153void AmgXSolver::InitMPITeams(const MPI_Comm &comm,
154 const int nDevs)
155{
156 // If this instance has already been initialized, skip
157 if (isInitialized)
158 {
159 mfem_error("This AmgXSolver instance has been initialized on this process.");
160 }
161
162 mpi_gpu_mode = "mpi-teams";
163
164 // Increment number of AmgX instances
165 count++;
166
167 // Get the name of this node
168 int len;
169 char name[MPI_MAX_PROCESSOR_NAME];
170 MPI_Get_processor_name(name, &len);
171 nodeName = name;
172 int globalcommrank;
173
174 MPI_Comm_rank(comm, &globalcommrank);
175
176 // Initialize communicators and corresponding information
177 InitMPIcomms(comm, nDevs);
178
179 // Only processes in gpuWorld are required to initialize AmgX
180 if (gpuProc == 0)
181 {
182 InitAmgX();
183 }
184
185 isInitialized = true;
186}
187
188#endif
189
190void AmgXSolver::ReadParameters(const std::string config,
191 const CONFIG_SRC source)
192{
193 amgx_config = config;
194 configSrc = source;
195}
196
197void AmgXSolver::SetConvergenceCheck(bool setConvergenceCheck_)
198{
199 ConvergenceCheck = setConvergenceCheck_;
200}
201
203 const bool verbose)
204{
205 amgxMode = amgxMode_;
206
207 configSrc = INTERNAL;
208
209 if (amgxMode == AMGX_MODE::PRECONDITIONER)
210 {
211 amgx_config = "{\n"
212 " \"config_version\": 2, \n"
213 " \"solver\": { \n"
214 " \"solver\": \"AMG\", \n"
215 " \"scope\": \"main\", \n"
216 " \"smoother\": \"JACOBI_L1\", \n"
217 " \"presweeps\": 1, \n"
218 " \"interpolator\": \"D2\", \n"
219 " \"max_row_sum\" : 0.9, \n"
220 " \"strength_threshold\" : 0.25, \n"
221 " \"postsweeps\": 1, \n"
222 " \"max_iters\": 1, \n"
223 " \"cycle\": \"V\"";
224 if (verbose)
225 {
226 amgx_config = amgx_config + ",\n"
227 " \"obtain_timings\": 1, \n"
228 " \"print_grid_stats\": 1, \n"
229 " \"monitor_residual\": 1, \n"
230 " \"print_solve_stats\": 1 \n";
231 }
232 else
233 {
234 amgx_config = amgx_config + "\n";
235 }
236 amgx_config = amgx_config + " }\n" + "}\n";
237 // use a zero initial guess in Mult()
238 iterative_mode = false;
239 }
240 else if (amgxMode == AMGX_MODE::SOLVER)
241 {
242 amgx_config = "{ \n"
243 " \"config_version\": 2, \n"
244 " \"solver\": { \n"
245 " \"preconditioner\": { \n"
246 " \"solver\": \"AMG\", \n"
247 " \"smoother\": { \n"
248 " \"scope\": \"jacobi\", \n"
249 " \"solver\": \"JACOBI_L1\" \n"
250 " }, \n"
251 " \"presweeps\": 1, \n"
252 " \"interpolator\": \"D2\", \n"
253 " \"max_row_sum\" : 0.9, \n"
254 " \"strength_threshold\" : 0.25, \n"
255 " \"max_iters\": 1, \n"
256 " \"scope\": \"amg\", \n"
257 " \"max_levels\": 100, \n"
258 " \"cycle\": \"V\", \n"
259 " \"postsweeps\": 1 \n"
260 " }, \n"
261 " \"solver\": \"PCG\", \n"
262 " \"max_iters\": 150, \n"
263 " \"convergence\": \"RELATIVE_INI_CORE\", \n"
264 " \"scope\": \"main\", \n"
265 " \"tolerance\": 1e-12, \n"
266 " \"monitor_residual\": 1, \n"
267 " \"norm\": \"L2\" ";
268 if (verbose)
269 {
270 amgx_config = amgx_config + ", \n"
271 " \"obtain_timings\": 1, \n"
272 " \"print_grid_stats\": 1, \n"
273 " \"print_solve_stats\": 1 \n";
274 }
275 else
276 {
277 amgx_config = amgx_config + "\n";
278 }
279 amgx_config = amgx_config + " } \n" + "} \n";
280 // use the user-specified vector as an initial guess in Mult()
281 iterative_mode = true;
282 }
283 else
284 {
285 mfem_error("AmgX mode not supported \n");
286 }
287}
288
289// Sets up AmgX library for MPI builds
290#ifdef MFEM_USE_MPI
291void AmgXSolver::InitAmgX()
292{
293 // Set up once
294 if (count == 1)
295 {
296 AMGX_SAFE_CALL(AMGX_initialize());
297
298 AMGX_SAFE_CALL(AMGX_initialize_plugins());
299
300 AMGX_SAFE_CALL(AMGX_install_signal_handler());
301
302 AMGX_SAFE_CALL(AMGX_register_print_callback(
303 [](const char *msg, int length)->void
304 {
305 int irank; MPI_Comm_rank(MPI_COMM_WORLD, &irank);
306 if (irank == 0) { mfem::out<<msg;} }));
307 }
308
309 MFEM_VERIFY(configSrc != CONFIG_SRC::UNDEFINED,
310 "AmgX configuration is not defined \n");
311
312 if (configSrc == CONFIG_SRC::EXTERNAL)
313 {
314 AMGX_SAFE_CALL(AMGX_config_create_from_file(&cfg, amgx_config.c_str()));
315 }
316 else
317 {
318 AMGX_SAFE_CALL(AMGX_config_create(&cfg, amgx_config.c_str()));
319 }
320
321 // Let AmgX handle returned error codes internally
322 AMGX_SAFE_CALL(AMGX_config_add_parameters(&cfg, "exception_handling=1"));
323
324 // Create an AmgX resource object, only the first instance needs to create
325 // the resource object.
326 if (count == 1) { AMGX_SAFE_CALL(AMGX_resources_create(&rsrc, cfg, &gpuWorld, 1, &devID)); }
327
328 // Create AmgX vector object for unknowns and RHS
329 AMGX_SAFE_CALL(AMGX_vector_create(&AmgXP, rsrc, precision_mode));
330 AMGX_SAFE_CALL(AMGX_vector_create(&AmgXRHS, rsrc, precision_mode));
331
332 // Create AmgX matrix object for unknowns and RHS
333 AMGX_SAFE_CALL(AMGX_matrix_create(&AmgXA, rsrc, precision_mode));
334
335 // Create an AmgX solver object
336 AMGX_SAFE_CALL(AMGX_solver_create(&solver, rsrc, precision_mode, cfg));
337
338 // Obtain the default number of rings based on current configuration
339 AMGX_SAFE_CALL(AMGX_config_get_default_number_of_rings(cfg, &ring));
340}
341
342// Groups MPI ranks into teams and assigns the roots to talk to GPUs
343void AmgXSolver::InitMPIcomms(const MPI_Comm &comm, const int nDevs)
344{
345 // Duplicate the global communicator
346 MPI_Comm_dup(comm, &globalCpuWorld);
347 MPI_Comm_set_name(globalCpuWorld, "globalCpuWorld");
348
349 // Get size and rank for global communicator
350 MPI_Comm_size(globalCpuWorld, &globalSize);
351 MPI_Comm_rank(globalCpuWorld, &myGlobalRank);
352
353 // Get the communicator for processors on the same node (local world)
354 MPI_Comm_split_type(globalCpuWorld,
355 MPI_COMM_TYPE_SHARED, 0, MPI_INFO_NULL, &localCpuWorld);
356 MPI_Comm_set_name(localCpuWorld, "localCpuWorld");
357
358 // Get size and rank for local communicator
359 MPI_Comm_size(localCpuWorld, &localSize);
360 MPI_Comm_rank(localCpuWorld, &myLocalRank);
361
362 // Set up corresponding ID of the device used by each local process
363 SetDeviceIDs(nDevs);
364
365 MPI_Barrier(globalCpuWorld);
366
367 // Split the global world into a world involved in AmgX and a null world
368 MPI_Comm_split(globalCpuWorld, gpuProc, 0, &gpuWorld);
369
370 // Get size and rank for the communicator corresponding to gpuWorld
371 if (gpuWorld != MPI_COMM_NULL)
372 {
373 MPI_Comm_set_name(gpuWorld, "gpuWorld");
374 MPI_Comm_size(gpuWorld, &gpuWorldSize);
375 MPI_Comm_rank(gpuWorld, &myGpuWorldRank);
376 }
377 else // for those that will not communicate with the GPU
378 {
379 gpuWorldSize = MPI_UNDEFINED;
380 myGpuWorldRank = MPI_UNDEFINED;
381 }
382
383 // Split local world into worlds corresponding to each CUDA device
384 MPI_Comm_split(localCpuWorld, devID, 0, &devWorld);
385 MPI_Comm_set_name(devWorld, "devWorld");
386
387 // Get size and rank for the communicator corresponding to myWorld
388 MPI_Comm_size(devWorld, &devWorldSize);
389 MPI_Comm_rank(devWorld, &myDevWorldRank);
390
391 MPI_Barrier(globalCpuWorld);
392}
393
394// Determine MPI teams based on available devices
395void AmgXSolver::SetDeviceIDs(const int nDevs)
396{
397 // Set the ID of device that each local process will use
398 if (nDevs == localSize) // # of the devices and local process are the same
399 {
400 devID = myLocalRank;
401 gpuProc = 0;
402 }
403 else if (nDevs > localSize) // there are more devices than processes
404 {
405 MFEM_WARNING("CUDA devices on the node " << nodeName.c_str() <<
406 " are more than the MPI processes launched. Only "<<
407 nDevs << " devices will be used.\n");
408 devID = myLocalRank;
409 gpuProc = 0;
410 }
411 else // in case there are more ranks than devices
412 {
413 int nBasic = localSize / nDevs,
414 nRemain = localSize % nDevs;
415
416 if (myLocalRank < (nBasic+1)*nRemain)
417 {
418 devID = myLocalRank / (nBasic + 1);
419 if (myLocalRank % (nBasic + 1) == 0) { gpuProc = 0; }
420 }
421 else
422 {
423 devID = (myLocalRank - (nBasic+1)*nRemain) / nBasic + nRemain;
424 if ((myLocalRank - (nBasic+1)*nRemain) % nBasic == 0) { gpuProc = 0; }
425 }
426 }
427}
428
429void AmgXSolver::GatherArray(const Array<double> &inArr, Array<double> &outArr,
430 const int mpiTeamSz, const MPI_Comm &mpiTeamComm) const
431{
432 // Calculate number of elements to be collected from each process
433 Array<int> Apart(mpiTeamSz);
434 int locAsz = inArr.Size();
435 MPI_Gather(&locAsz, 1, MPI_INT,
436 Apart.HostWrite(),1, MPI_INT,0,mpiTeamComm);
437
438 MPI_Barrier(mpiTeamComm);
439
440 // Determine stride for process (to be used by root)
441 Array<int> Adisp(mpiTeamSz);
442 int myid; MPI_Comm_rank(mpiTeamComm, &myid);
443 if (myid == 0)
444 {
445 Adisp[0] = 0;
446 for (int i=1; i<mpiTeamSz; ++i)
447 {
448 Adisp[i] = Adisp[i-1] + Apart[i-1];
449 }
450 }
451
452 MPI_Gatherv(inArr.HostRead(), inArr.Size(), MPITypeMap<real_t>::mpi_type,
453 outArr.HostWrite(), Apart.HostRead(), Adisp.HostRead(),
454 MPITypeMap<real_t>::mpi_type, 0, mpiTeamComm);
455}
456
457void AmgXSolver::GatherArray(const Vector &inArr, Vector &outArr,
458 const int mpiTeamSz, const MPI_Comm &mpiTeamComm) const
459{
460 // Calculate number of elements to be collected from each process
461 Array<int> Apart(mpiTeamSz);
462 int locAsz = inArr.Size();
463 MPI_Gather(&locAsz, 1, MPI_INT,
464 Apart.HostWrite(),1, MPI_INT,0,mpiTeamComm);
465
466 MPI_Barrier(mpiTeamComm);
467
468 // Determine stride for process (to be used by root)
469 Array<int> Adisp(mpiTeamSz);
470 int myid; MPI_Comm_rank(mpiTeamComm, &myid);
471 if (myid == 0)
472 {
473 Adisp[0] = 0;
474 for (int i=1; i<mpiTeamSz; ++i)
475 {
476 Adisp[i] = Adisp[i-1] + Apart[i-1];
477 }
478 }
479
480 MPI_Gatherv(inArr.HostRead(), inArr.Size(), MPITypeMap<real_t>::mpi_type,
481 outArr.HostWrite(), Apart.HostRead(), Adisp.HostRead(),
482 MPITypeMap<real_t>::mpi_type, 0, mpiTeamComm);
483}
484
485void AmgXSolver::GatherArray(const Array<int> &inArr, Array<int> &outArr,
486 const int mpiTeamSz, const MPI_Comm &mpiTeamComm) const
487{
488 // Calculate number of elements to be collected from each process
489 Array<int> Apart(mpiTeamSz);
490 int locAsz = inArr.Size();
491 MPI_Gather(&locAsz, 1, MPI_INT,
492 Apart.GetData(),1, MPI_INT,0,mpiTeamComm);
493
494 MPI_Barrier(mpiTeamComm);
495
496 // Determine stride for process (to be used by root)
497 Array<int> Adisp(mpiTeamSz);
498 int myid; MPI_Comm_rank(mpiTeamComm, &myid);
499 if (myid == 0)
500 {
501 Adisp[0] = 0;
502 for (int i=1; i<mpiTeamSz; ++i)
503 {
504 Adisp[i] = Adisp[i-1] + Apart[i-1];
505 }
506 }
507
508 MPI_Gatherv(inArr.HostRead(), inArr.Size(), MPI_INT,
509 outArr.HostWrite(), Apart.HostRead(), Adisp.HostRead(),
510 MPI_INT, 0, mpiTeamComm);
511}
512
513
514void AmgXSolver::GatherArray(const Array<int64_t> &inArr,
515 Array<int64_t> &outArr,
516 const int mpiTeamSz, const MPI_Comm &mpiTeamComm) const
517{
518 // Calculate number of elements to be collected from each process
519 Array<int> Apart(mpiTeamSz);
520 int locAsz = inArr.Size();
521 MPI_Gather(&locAsz, 1, MPI_INT,
522 Apart.GetData(),1, MPI_INT,0,mpiTeamComm);
523
524 MPI_Barrier(mpiTeamComm);
525
526 // Determine stride for process
527 Array<int> Adisp(mpiTeamSz);
528 int myid; MPI_Comm_rank(mpiTeamComm, &myid);
529 if (myid == 0)
530 {
531 Adisp[0] = 0;
532 for (int i=1; i<mpiTeamSz; ++i)
533 {
534 Adisp[i] = Adisp[i-1] + Apart[i-1];
535 }
536 }
537
538 MPI_Gatherv(inArr.HostRead(), inArr.Size(), MPI_INT64_T,
539 outArr.HostWrite(), Apart.HostRead(), Adisp.HostRead(),
540 MPI_INT64_T, 0, mpiTeamComm);
541
542 MPI_Barrier(mpiTeamComm);
543}
544
545void AmgXSolver::GatherArray(const Vector &inArr, Vector &outArr,
546 const int mpiTeamSz, const MPI_Comm &mpiTeamComm,
547 Array<int> &Apart, Array<int> &Adisp) const
548{
549 // Calculate number of elements to be collected from each process
550 int locAsz = inArr.Size();
551 MPI_Allgather(&locAsz, 1, MPI_INT,
552 Apart.HostWrite(),1, MPI_INT, mpiTeamComm);
553
554 MPI_Barrier(mpiTeamComm);
555
556 // Determine stride for process
557 Adisp[0] = 0;
558 for (int i=1; i<mpiTeamSz; ++i)
559 {
560 Adisp[i] = Adisp[i-1] + Apart[i-1];
561 }
562
563 MPI_Gatherv(inArr.HostRead(), inArr.Size(), MPITypeMap<real_t>::mpi_type,
564 outArr.HostWrite(), Apart.HostRead(), Adisp.HostRead(),
565 MPITypeMap<real_t>::mpi_type, 0, mpiTeamComm);
566}
567
568void AmgXSolver::ScatterArray(const Vector &inArr, Vector &outArr,
569 const int mpiTeamSz, const MPI_Comm &mpiTeamComm,
570 Array<int> &Apart, Array<int> &Adisp) const
571{
572 MPI_Scatterv(inArr.HostRead(),Apart.HostRead(),Adisp.HostRead(),
573 MPITypeMap<real_t>::mpi_type,outArr.HostWrite(),outArr.Size(),
574 MPITypeMap<real_t>::mpi_type, 0, mpiTeamComm);
575}
576#endif
577
578void AmgXSolver::SetMatrix(const SparseMatrix &in_A, const bool update_mat)
579{
580 if (update_mat == false)
581 {
582 AMGX_SAFE_CALL(AMGX_matrix_upload_all(AmgXA, in_A.Height(),
583 in_A.NumNonZeroElems(),
584 1, 1,
585 in_A.ReadI(),
586 in_A.ReadJ(),
587 in_A.ReadData(), NULL));
588
589 AMGX_SAFE_CALL(AMGX_solver_setup(solver, AmgXA));
590 AMGX_SAFE_CALL(AMGX_vector_bind(AmgXP, AmgXA));
591 AMGX_SAFE_CALL(AMGX_vector_bind(AmgXRHS, AmgXA));
592 }
593 else
594 {
595 AMGX_SAFE_CALL(AMGX_matrix_replace_coefficients(AmgXA,
596 in_A.Height(),
597 in_A.NumNonZeroElems(),
598 in_A.ReadData(), NULL));
599 }
600}
601
602#ifdef MFEM_USE_MPI
603
604void AmgXSolver::SetMatrix(const HypreParMatrix &A, const bool update_mat)
605{
606 // Require hypre >= 2.16.
607#if MFEM_HYPRE_VERSION < 21600
608 mfem_error("Hypre version 2.16+ is required when using AmgX \n");
609#endif
610
611 // Ensure HypreParMatrix is on the host
612 A.HostRead();
613
614 hypre_ParCSRMatrix * A_ptr =
615 (hypre_ParCSRMatrix *)const_cast<HypreParMatrix&>(A);
616
617 hypre_CSRMatrix *A_csr = hypre_MergeDiagAndOffd(A_ptr);
618
619 A.HypreRead();
620
621 Array<double> loc_A(A_csr->data, (int)A_csr->num_nonzeros);
622 const Array<HYPRE_Int> loc_I(A_csr->i, (int)A_csr->num_rows+1);
623
624 // Column index must be int64_t so we must promote here
625 Array<int64_t> loc_J((int)A_csr->num_nonzeros);
626 for (int i=0; i<A_csr->num_nonzeros; ++i)
627 {
628 loc_J[i] = A_csr->big_j[i];
629 }
630
631 // Assumes one GPU per MPI rank
632 if (mpi_gpu_mode=="mpi-gpu-exclusive")
633 {
634 SetMatrixMPIGPUExclusive(A, loc_A, loc_I, loc_J, update_mat);
635 // Free A_csr data from hypre_MergeDiagAndOffd method
636 hypre_CSRMatrixDestroy(A_csr);
637 return;
638 }
639
640 // Assumes teams of MPI ranks are sharing a GPU
641 if (mpi_gpu_mode == "mpi-teams")
642 {
643 SetMatrixMPITeams(A, loc_A, loc_I, loc_J, update_mat);
644 // Free A_csr data from hypre_MergeDiagAndOffd method
645 hypre_CSRMatrixDestroy(A_csr);
646 return;
647 }
648
649 mfem_error("Unsupported MPI_GPU combination \n");
650}
651
652void AmgXSolver::SetMatrixMPIGPUExclusive(const HypreParMatrix &A,
653 const Array<double> &loc_A,
654 const Array<int> &loc_I,
655 const Array<int64_t> &loc_J,
656 const bool update_mat)
657{
658 // Create a vector of offsets describing matrix row partitions
659 Array<int64_t> rowPart(gpuWorldSize+1); rowPart = 0.0;
660
661 int64_t myStart = A.GetRowStarts()[0];
662
663 MPI_Allgather(&myStart, 1, MPI_INT64_T,
664 rowPart.GetData(),1, MPI_INT64_T
665 ,gpuWorld);
666 MPI_Barrier(gpuWorld);
667
668 rowPart[gpuWorldSize] = A.M();
669
670 const int nGlobalRows = A.M();
671 const int local_rows = loc_I.Size()-1;
672 const int num_nnz = loc_I[local_rows];
673
674 if (update_mat == false)
675 {
676 AMGX_distribution_handle dist;
677 AMGX_SAFE_CALL(AMGX_distribution_create(&dist, cfg));
678 AMGX_SAFE_CALL(AMGX_distribution_set_partition_data(dist,
679 AMGX_DIST_PARTITION_OFFSETS,
680 rowPart.GetData()));
681
682 AMGX_SAFE_CALL(AMGX_matrix_upload_distributed(AmgXA, nGlobalRows,
683 local_rows, num_nnz, 1, 1,
684 loc_I.Read(), loc_J.Read(),
685 loc_A.Read(), NULL, dist));
686
687 AMGX_SAFE_CALL(AMGX_distribution_destroy(dist));
688
689 MPI_Barrier(gpuWorld);
690
691 AMGX_SAFE_CALL(AMGX_solver_setup(solver, AmgXA));
692
693 AMGX_SAFE_CALL(AMGX_vector_bind(AmgXP, AmgXA));
694 AMGX_SAFE_CALL(AMGX_vector_bind(AmgXRHS, AmgXA));
695 }
696 else
697 {
698 AMGX_SAFE_CALL(AMGX_matrix_replace_coefficients(AmgXA, nGlobalRows,
699 num_nnz, loc_A, NULL));
700 }
701}
702
703void AmgXSolver::SetMatrixMPITeams(const HypreParMatrix &A,
704 const Array<double> &loc_A,
705 const Array<int> &loc_I,
706 const Array<int64_t> &loc_J,
707 const bool update_mat)
708{
709 // The following arrays hold the consolidated diagonal + off-diagonal matrix
710 // data
711 Array<int> all_I;
712 Array<int64_t> all_J;
713 Array<double> all_A;
714
715 // Determine array sizes
716 int J_allsz(0), all_NNZ(0), nDevRows(0);
717 const int loc_row_len = std::abs(A.RowPart()[1] -
718 A.RowPart()[0]); // end of row partition
719 const int loc_Jz_sz = loc_J.Size();
720 const int loc_A_sz = loc_A.Size();
721
722 MPI_Reduce(&loc_row_len, &nDevRows, 1, MPI_INT, MPI_SUM, 0, devWorld);
723 MPI_Reduce(&loc_Jz_sz, &J_allsz, 1, MPI_INT, MPI_SUM, 0, devWorld);
724 MPI_Reduce(&loc_A_sz, &all_NNZ, 1, MPI_INT, MPI_SUM, 0, devWorld);
725
726 MPI_Barrier(devWorld);
727
728 if (myDevWorldRank == 0)
729 {
730 all_I.SetSize(nDevRows+devWorldSize);
731 all_J.SetSize(J_allsz); all_J = 0.0;
732 all_A.SetSize(all_NNZ);
733 }
734
735 GatherArray(loc_I, all_I, devWorldSize, devWorld);
736 GatherArray(loc_J, all_J, devWorldSize, devWorld);
737 GatherArray(loc_A, all_A, devWorldSize, devWorld);
738
739 MPI_Barrier(devWorld);
740
741 int local_nnz(0);
742 int64_t local_rows(0);
743
744 if (myDevWorldRank == 0)
745 {
746 // A fix up step is needed for the array holding row data to remove extra
747 // zeros when consolidating team data.
748 Array<int> z_ind(devWorldSize+1);
749 int iter = 1;
750 while (iter < devWorldSize-1)
751 {
752 // Determine the indices of zeros in global all_I array
753 int counter = 0;
754 z_ind[counter] = counter;
755 counter++;
756 for (int idx=1; idx<all_I.Size()-1; idx++)
757 {
758 if (all_I[idx]==0)
759 {
760 z_ind[counter] = idx-1;
761 counter++;
762 }
763 }
764 z_ind[devWorldSize] = all_I.Size()-1;
765 // End of determining indices of zeros in global all_I Array
766
767 // Bump all_I
768 for (int idx=z_ind[1]+1; idx < z_ind[2]; idx++)
769 {
770 all_I[idx] = all_I[idx-1] + (all_I[idx+1] - all_I[idx]);
771 }
772
773 // Shift array after bump to remove unnecessary values in middle of
774 // array
775 for (int idx=z_ind[2]; idx < all_I.Size()-1; ++idx)
776 {
777 all_I[idx] = all_I[idx+1];
778 }
779 iter++;
780 }
781
782 // LAST TIME THROUGH ARRAY
783 // Determine the indices of zeros in global row_ptr array
784 int counter = 0;
785 z_ind[counter] = counter;
786 counter++;
787 for (int idx=1; idx<all_I.Size()-1; idx++)
788 {
789 if (all_I[idx]==0)
790 {
791 z_ind[counter] = idx-1;
792 counter++;
793 }
794 }
795
796 z_ind[devWorldSize] = all_I.Size()-1;
797 // End of determining indices of zeros in global all_I Array BUMP all_I
798 // one last time
799 for (int idx=z_ind[1]+1; idx < all_I.Size()-1; idx++)
800 {
801 all_I[idx] = all_I[idx-1] + (all_I[idx+1] - all_I[idx]);
802 }
803 local_nnz = all_I[all_I.Size()-devWorldSize];
804 local_rows = nDevRows;
805 }
806
807 // Create row partition
808 mat_local_rows = local_rows; // class copy
809 Array<int64_t> rowPart;
810 if (gpuProc == 0)
811 {
812 rowPart.SetSize(gpuWorldSize+1); rowPart=0;
813
814 MPI_Allgather(&local_rows, 1, MPI_INT64_T,
815 &rowPart.GetData()[1], 1, MPI_INT64_T,
816 gpuWorld);
817 MPI_Barrier(gpuWorld);
818
819 // Fixup step
820 for (int i=1; i<rowPart.Size(); ++i)
821 {
822 rowPart[i] += rowPart[i-1];
823 }
824
825 // Upload A matrix to AmgX
826 MPI_Barrier(gpuWorld);
827
828 int nGlobalRows = A.M();
829 if (update_mat == false)
830 {
831 AMGX_distribution_handle dist;
832 AMGX_SAFE_CALL(AMGX_distribution_create(&dist, cfg));
833 AMGX_SAFE_CALL(AMGX_distribution_set_partition_data(dist,
834 AMGX_DIST_PARTITION_OFFSETS,
835 rowPart.GetData()));
836
837 AMGX_SAFE_CALL(AMGX_matrix_upload_distributed(AmgXA, nGlobalRows,
838 local_rows, local_nnz,
839 1, 1, all_I.ReadWrite(),
840 all_J.Read(),
841 all_A.Read(),
842 nullptr, dist));
843
844 AMGX_SAFE_CALL(AMGX_distribution_destroy(dist));
845 MPI_Barrier(gpuWorld);
846
847 AMGX_SAFE_CALL(AMGX_solver_setup(solver, AmgXA));
848
849 // Bind vectors to A
850 AMGX_SAFE_CALL(AMGX_vector_bind(AmgXP, AmgXA));
851 AMGX_SAFE_CALL(AMGX_vector_bind(AmgXRHS, AmgXA));
852 }
853 else
854 {
855 AMGX_SAFE_CALL(AMGX_matrix_replace_coefficients(AmgXA, nGlobalRows,
856 local_nnz, all_A, NULL));
857 }
858 }
859}
860
861#endif
862
864{
865 height = op.Height();
866 width = op.Width();
867
868 if (const SparseMatrix* Aptr =
869 dynamic_cast<const SparseMatrix*>(&op))
870 {
871 SetMatrix(*Aptr);
872 }
873#ifdef MFEM_USE_MPI
874 else if (const HypreParMatrix* Aptr =
875 dynamic_cast<const HypreParMatrix*>(&op))
876 {
877 SetMatrix(*Aptr);
878 }
879#endif
880 else
881 {
882 mfem_error("Unsupported Operator Type \n");
883 }
884}
885
887{
888 if (const SparseMatrix* Aptr =
889 dynamic_cast<const SparseMatrix*>(&op))
890 {
891 SetMatrix(*Aptr, true);
892 }
893#ifdef MFEM_USE_MPI
894 else if (const HypreParMatrix* Aptr =
895 dynamic_cast<const HypreParMatrix*>(&op))
896 {
897 SetMatrix(*Aptr, true);
898 }
899#endif
900 else
901 {
902 mfem_error("Unsupported Operator Type \n");
903 }
904}
905
906void AmgXSolver::Mult(const Vector& B, Vector& X) const
907{
908 // Set initial guess to zero
909 X.UseDevice(true);
910 if (!iterative_mode) { X = 0.0; }
911
912 // Mult for serial, and mpi-exclusive modes
913 if (mpi_gpu_mode != "mpi-teams")
914 {
915 AMGX_SAFE_CALL(AMGX_vector_upload(AmgXP, X.Size(), 1, X.ReadWrite()));
916 AMGX_SAFE_CALL(AMGX_vector_upload(AmgXRHS, B.Size(), 1, B.Read()));
917
918 if (mpi_gpu_mode != "serial")
919 {
920#ifdef MFEM_USE_MPI
921 MPI_Barrier(gpuWorld);
922#endif
923 }
924
925 AMGX_SAFE_CALL(AMGX_solver_solve(solver,AmgXRHS, AmgXP));
926
927 AMGX_SOLVE_STATUS status;
928 AMGX_SAFE_CALL(AMGX_solver_get_status(solver, &status));
929 if (status != AMGX_SOLVE_SUCCESS && ConvergenceCheck)
930 {
931 if (status == AMGX_SOLVE_DIVERGED)
932 {
933 mfem_error("AmgX solver diverged \n");
934 }
935 else
936 {
937 mfem_error("AmgX solver failed to solve system \n");
938 }
939 }
940
941 AMGX_SAFE_CALL(AMGX_vector_download(AmgXP, X.Write()));
942 return;
943 }
944
945#ifdef MFEM_USE_MPI
946 Vector all_X(mat_local_rows);
947 Vector all_B(mat_local_rows);
948 Array<int> Apart_X(devWorldSize);
949 Array<int> Adisp_X(devWorldSize);
950 Array<int> Apart_B(devWorldSize);
951 Array<int> Adisp_B(devWorldSize);
952
953 GatherArray(X, all_X, devWorldSize, devWorld, Apart_X, Adisp_X);
954 GatherArray(B, all_B, devWorldSize, devWorld, Apart_B, Adisp_B);
955 MPI_Barrier(devWorld);
956
957 if (gpuWorld != MPI_COMM_NULL)
958 {
959 AMGX_SAFE_CALL(AMGX_vector_upload(AmgXP, all_X.Size(), 1, all_X.ReadWrite()));
960 AMGX_SAFE_CALL(AMGX_vector_upload(AmgXRHS, all_B.Size(), 1, all_B.ReadWrite()));
961
962 MPI_Barrier(gpuWorld);
963
964 AMGX_SAFE_CALL(AMGX_solver_solve(solver,AmgXRHS, AmgXP));
965
966 AMGX_SOLVE_STATUS status;
967 AMGX_SAFE_CALL(AMGX_solver_get_status(solver, &status));
968 if (status != AMGX_SOLVE_SUCCESS && amgxMode == SOLVER)
969 {
970 if (status == AMGX_SOLVE_DIVERGED)
971 {
972 mfem_error("AmgX solver diverged \n");
973 }
974 else
975 {
976 mfem_error("AmgX solver failed to solve system \n");
977 }
978 }
979
980 AMGX_SAFE_CALL(AMGX_vector_download(AmgXP, all_X.Write()));
981 }
982
983 ScatterArray(all_X, X, devWorldSize, devWorld, Apart_X, Adisp_X);
984#endif
985}
986
988{
989 int getIters;
990 AMGX_SAFE_CALL(AMGX_solver_get_iterations_number(solver, &getIters));
991 return getIters;
992}
993
995{
996 // Check instance is initialized
997 if (! isInitialized || count < 1)
998 {
999 mfem_error("Error in AmgXSolver::Finalize(). \n"
1000 "This AmgXWrapper has not been initialized. \n"
1001 "Please initialize it before finalization.\n");
1002 }
1003
1004 // Only processes using GPU are required to destroy AmgX content
1005#ifdef MFEM_USE_MPI
1006 if (gpuProc == 0 || mpi_gpu_mode == "serial")
1007#endif
1008 {
1009 // Destroy solver instance
1010 AMGX_SAFE_CALL(AMGX_solver_destroy(solver));
1011
1012 // Destroy matrix instance
1013 AMGX_SAFE_CALL(AMGX_matrix_destroy(AmgXA));
1014
1015 // Destroy RHS and unknown vectors
1016 AMGX_SAFE_CALL(AMGX_vector_destroy(AmgXP));
1017 AMGX_SAFE_CALL(AMGX_vector_destroy(AmgXRHS));
1018
1019 // Only the last instance need to destroy resource and finalizing AmgX
1020 if (count == 1)
1021 {
1022 AMGX_SAFE_CALL(AMGX_resources_destroy(rsrc));
1023 AMGX_SAFE_CALL(AMGX_config_destroy(cfg));
1024
1025 AMGX_SAFE_CALL(AMGX_finalize_plugins());
1026 AMGX_SAFE_CALL(AMGX_finalize());
1027 }
1028 else
1029 {
1030 AMGX_SAFE_CALL(AMGX_config_destroy(cfg));
1031 }
1032#ifdef MFEM_USE_MPI
1033 // destroy gpuWorld
1034 if (mpi_gpu_mode != "serial")
1035 {
1036 MPI_Comm_free(&gpuWorld);
1037 }
1038#endif
1039 }
1040
1041 // reset necessary variables in case users want to reuse the variable of
1042 // this instance for a new instance
1043#ifdef MFEM_USE_MPI
1044 gpuProc = MPI_UNDEFINED;
1045 if (globalCpuWorld != MPI_COMM_NULL)
1046 {
1047 MPI_Comm_free(&globalCpuWorld);
1048 MPI_Comm_free(&localCpuWorld);
1049 MPI_Comm_free(&devWorld);
1050 }
1051#endif
1052 // decrease the number of instances
1053 count -= 1;
1054
1055 // change status
1056 isInitialized = false;
1057}
1058
1059} // mfem namespace
1060
1061#endif
int GetNumIterations()
Return the number of iterations that were executed during the last solve phase.
bool ConvergenceCheck
Flag to check for convergence.
CONFIG_SRC
Flags to determine whether user solver settings are defined internally in the source code or will be ...
@ EXTERNAL
Configure will be read from a specified file.
@ INTERNAL
Configuration will be read directly from a string.
void Finalize()
Close down the AmgX library and free up any MPI Comms set up for it.
void DefaultParameters(const AMGX_MODE amgxMode_, const bool verbose)
Set up the AmgX library with the default paramaters.
~AmgXSolver()
Close down the AmgX library and free up any MPI Comms set up for it.
virtual void SetOperator(const Operator &op)
Sets the Operator that is going to be solved via AmgX. Supports operators based on either an MFEM Spa...
void InitSerial()
Initialize the AmgX library for serial execution once the solver configuration has been established t...
AMGX_MODE
Flags to configure AmgXSolver as a solver or preconditioner.
void SetConvergenceCheck(bool setConvergenceCheck_=true)
Add a check for convergence after applying Mult.
void InitMPITeams(const MPI_Comm &comm, const int nDevs)
Initialize the AmgX library and create MPI teams based on the number of devices on each node nDevs....
virtual void Mult(const Vector &b, Vector &x) const
Utilize the AmgX library to solve the linear system where the "matrix" is the AMG approximation to th...
void ReadParameters(const std::string config, CONFIG_SRC source)
Read in the AmgX parameters either through a file or directly through a properly formated string....
void InitExclusiveGPU(const MPI_Comm &comm)
Initialize the AmgX library in parallel mode with exactly one GPU per rank after the solver configura...
void UpdateOperator(const Operator &op)
Change the input operator that is being solved via AmgX. Supports operators based on either an MFEM S...
Wrapper for hypre's ParCSR matrix class.
Definition hypre.hpp:388
Abstract operator.
Definition operator.hpp:25
int width
Dimension of the input / number of columns in the matrix.
Definition operator.hpp:28
int Height() const
Get the height (size of output) of the Operator. Synonym with NumRows().
Definition operator.hpp:66
int height
Dimension of the output / number of rows in the matrix.
Definition operator.hpp:27
int Width() const
Get the width (size of input) of the Operator. Synonym with NumCols().
Definition operator.hpp:72
bool iterative_mode
If true, use the second argument of Mult() as an initial guess.
Definition operator.hpp:686
Data type sparse matrix.
Definition sparsemat.hpp:51
Vector data type.
Definition vector.hpp:80
virtual const real_t * Read(bool on_dev=true) const
Shortcut for mfem::Read(vec.GetMemory(), vec.Size(), on_dev).
Definition vector.hpp:474
virtual real_t * ReadWrite(bool on_dev=true)
Shortcut for mfem::ReadWrite(vec.GetMemory(), vec.Size(), on_dev).
Definition vector.hpp:490
int Size() const
Returns the size of the vector.
Definition vector.hpp:218
virtual void UseDevice(bool use_dev) const
Enable execution of Vector operations using the mfem::Device.
Definition vector.hpp:136
virtual real_t * Write(bool on_dev=true)
Shortcut for mfem::Write(vec.GetMemory(), vec.Size(), on_dev).
Definition vector.hpp:482
void source(const Vector &x, Vector &f)
Definition ex25.cpp:620
void mfem_error(const char *msg)
Definition error.cpp:154
OutStream out(std::cout)
Global stream used by the library for standard output. Initially it uses the same std::streambuf as s...
Definition globals.hpp:66