MFEM  v4.5.1
Finite element discretization library
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Pages
amgxsolver.cpp
Go to the documentation of this file.
1 // Copyright (c) 2010-2022, Lawrence Livermore National Security, LLC. Produced
2 // at the Lawrence Livermore National Laboratory. All Rights reserved. See files
3 // LICENSE and NOTICE for details. LLNL-CODE-806117.
4 //
5 // This file is part of the MFEM library. For more information and source code
6 // availability visit https://mfem.org.
7 //
8 // MFEM is free software; you can redistribute it and/or modify it under the
9 // terms of the BSD-3 license. We welcome feedback and contributions, see file
10 // CONTRIBUTING.md for details.
11 
12 // Implementation of the MFEM wrapper for Nvidia's multigrid library, AmgX
13 //
14 // This work is partially based on:
15 //
16 // Pi-Yueh Chuang and Lorena A. Barba (2017).
17 // AmgXWrapper: An interface between PETSc and the NVIDIA AmgX library.
18 // J. Open Source Software, 2(16):280, doi:10.21105/joss.00280
19 //
20 // See https://github.com/barbagroup/AmgXWrapper.
21 
22 #include "../config/config.hpp"
23 #include "amgxsolver.hpp"
24 #ifdef MFEM_USE_AMGX
25 
26 namespace mfem
27 {
28 
29 int AmgXSolver::count = 0;
30 
31 AMGX_resources_handle AmgXSolver::rsrc = nullptr;
32 
34  : ConvergenceCheck(false) {};
35 
36 AmgXSolver::AmgXSolver(const AMGX_MODE amgxMode_, const bool verbose)
37 {
38  amgxMode = amgxMode_;
39 
40  if (amgxMode == AmgXSolver::SOLVER) { ConvergenceCheck = true;}
41  else { ConvergenceCheck = false;}
42 
43  DefaultParameters(amgxMode, verbose);
44 
45  InitSerial();
46 }
47 
48 #ifdef MFEM_USE_MPI
49 
50 AmgXSolver::AmgXSolver(const MPI_Comm &comm,
51  const AMGX_MODE amgxMode_, const bool verbose)
52 {
53  std::string config;
54  amgxMode = amgxMode_;
55 
56  if (amgxMode == AmgXSolver::SOLVER) { ConvergenceCheck = true;}
57  else { ConvergenceCheck = false;}
58 
59  DefaultParameters(amgxMode, verbose);
60 
61  InitExclusiveGPU(comm);
62 }
63 
64 AmgXSolver::AmgXSolver(const MPI_Comm &comm, const int nDevs,
65  const AMGX_MODE amgxMode_, const bool verbose)
66 {
67  std::string config;
68  amgxMode = amgxMode_;
69 
70  if (amgxMode == AmgXSolver::SOLVER) { ConvergenceCheck = true;}
71  else { ConvergenceCheck = false;}
72 
73  DefaultParameters(amgxMode_, verbose);
74 
75  InitMPITeams(comm, nDevs);
76 }
77 
78 #endif
79 
81 {
82  if (isInitialized) { Finalize(); }
83 }
84 
86 {
87  count++;
88 
89  mpi_gpu_mode = "serial";
90 
91  AMGX_SAFE_CALL(AMGX_initialize());
92 
93  AMGX_SAFE_CALL(AMGX_initialize_plugins());
94 
95  AMGX_SAFE_CALL(AMGX_install_signal_handler());
96 
97  MFEM_VERIFY(configSrc != CONFIG_SRC::UNDEFINED,
98  "AmgX configuration is not defined \n");
99 
100  if (configSrc == CONFIG_SRC::EXTERNAL)
101  {
102  AMGX_SAFE_CALL(AMGX_config_create_from_file(&cfg, amgx_config.c_str()));
103  }
104  else
105  {
106  AMGX_SAFE_CALL(AMGX_config_create(&cfg, amgx_config.c_str()));
107  }
108 
109  AMGX_SAFE_CALL(AMGX_resources_create_simple(&rsrc, cfg));
110  AMGX_SAFE_CALL(AMGX_solver_create(&solver, rsrc, precision_mode, cfg));
111  AMGX_SAFE_CALL(AMGX_matrix_create(&AmgXA, rsrc, precision_mode));
112  AMGX_SAFE_CALL(AMGX_vector_create(&AmgXP, rsrc, precision_mode));
113  AMGX_SAFE_CALL(AMGX_vector_create(&AmgXRHS, rsrc, precision_mode));
114 
115  isInitialized = true;
116 }
117 
118 #ifdef MFEM_USE_MPI
119 
120 void AmgXSolver::InitExclusiveGPU(const MPI_Comm &comm)
121 {
122  // If this instance has already been initialized, skip
123  if (isInitialized)
124  {
125  mfem_error("This AmgXSolver instance has been initialized on this process.");
126  }
127 
128  // Note that every MPI rank may talk to a GPU
129  mpi_gpu_mode = "mpi-gpu-exclusive";
130  gpuProc = 0;
131 
132  // Increment number of AmgX instances
133  count++;
134 
135  MPI_Comm_dup(comm, &gpuWorld);
136  MPI_Comm_size(gpuWorld, &gpuWorldSize);
137  MPI_Comm_rank(gpuWorld, &myGpuWorldRank);
138 
139  // Each rank will only see 1 device call it device 0
140  nDevs = 1, devID = 0;
141 
142  InitAmgX();
143 
144  isInitialized = true;
145 }
146 
147 // Initialize for MPI ranks > GPUs, all devices are visible to all of the MPI
148 // ranks
149 void AmgXSolver::InitMPITeams(const MPI_Comm &comm,
150  const int nDevs)
151 {
152  // If this instance has already been initialized, skip
153  if (isInitialized)
154  {
155  mfem_error("This AmgXSolver instance has been initialized on this process.");
156  }
157 
158  mpi_gpu_mode = "mpi-teams";
159 
160  // Increment number of AmgX instances
161  count++;
162 
163  // Get the name of this node
164  int len;
165  char name[MPI_MAX_PROCESSOR_NAME];
166  MPI_Get_processor_name(name, &len);
167  nodeName = name;
168  int globalcommrank;
169 
170  MPI_Comm_rank(comm, &globalcommrank);
171 
172  // Initialize communicators and corresponding information
173  InitMPIcomms(comm, nDevs);
174 
175  // Only processes in gpuWorld are required to initialize AmgX
176  if (gpuProc == 0)
177  {
178  InitAmgX();
179  }
180 
181  isInitialized = true;
182 }
183 
184 #endif
185 
186 void AmgXSolver::ReadParameters(const std::string config,
187  const CONFIG_SRC source)
188 {
189  amgx_config = config;
190  configSrc = source;
191 }
192 
193 void AmgXSolver::SetConvergenceCheck(bool setConvergenceCheck_)
194 {
195  ConvergenceCheck = setConvergenceCheck_;
196 }
197 
199  const bool verbose)
200 {
201  amgxMode = amgxMode_;
202 
203  configSrc = INTERNAL;
204 
205  if (amgxMode == AMGX_MODE::PRECONDITIONER)
206  {
207  amgx_config = "{\n"
208  " \"config_version\": 2, \n"
209  " \"solver\": { \n"
210  " \"solver\": \"AMG\", \n"
211  " \"scope\": \"main\", \n"
212  " \"smoother\": \"JACOBI_L1\", \n"
213  " \"presweeps\": 1, \n"
214  " \"interpolator\": \"D2\", \n"
215  " \"max_row_sum\" : 0.9, \n"
216  " \"strength_threshold\" : 0.25, \n"
217  " \"postsweeps\": 1, \n"
218  " \"max_iters\": 1, \n"
219  " \"cycle\": \"V\"";
220  if (verbose)
221  {
222  amgx_config = amgx_config + ",\n"
223  " \"obtain_timings\": 1, \n"
224  " \"print_grid_stats\": 1, \n"
225  " \"monitor_residual\": 1, \n"
226  " \"print_solve_stats\": 1 \n";
227  }
228  else
229  {
230  amgx_config = amgx_config + "\n";
231  }
232  amgx_config = amgx_config + " }\n" + "}\n";
233  // use a zero initial guess in Mult()
234  iterative_mode = false;
235  }
236  else if (amgxMode == AMGX_MODE::SOLVER)
237  {
238  amgx_config = "{ \n"
239  " \"config_version\": 2, \n"
240  " \"solver\": { \n"
241  " \"preconditioner\": { \n"
242  " \"solver\": \"AMG\", \n"
243  " \"smoother\": { \n"
244  " \"scope\": \"jacobi\", \n"
245  " \"solver\": \"JACOBI_L1\" \n"
246  " }, \n"
247  " \"presweeps\": 1, \n"
248  " \"interpolator\": \"D2\", \n"
249  " \"max_row_sum\" : 0.9, \n"
250  " \"strength_threshold\" : 0.25, \n"
251  " \"max_iters\": 1, \n"
252  " \"scope\": \"amg\", \n"
253  " \"max_levels\": 100, \n"
254  " \"cycle\": \"V\", \n"
255  " \"postsweeps\": 1 \n"
256  " }, \n"
257  " \"solver\": \"PCG\", \n"
258  " \"max_iters\": 150, \n"
259  " \"convergence\": \"RELATIVE_INI_CORE\", \n"
260  " \"scope\": \"main\", \n"
261  " \"tolerance\": 1e-12, \n"
262  " \"monitor_residual\": 1, \n"
263  " \"norm\": \"L2\" ";
264  if (verbose)
265  {
266  amgx_config = amgx_config + ", \n"
267  " \"obtain_timings\": 1, \n"
268  " \"print_grid_stats\": 1, \n"
269  " \"print_solve_stats\": 1 \n";
270  }
271  else
272  {
273  amgx_config = amgx_config + "\n";
274  }
275  amgx_config = amgx_config + " } \n" + "} \n";
276  // use the user-specified vector as an initial guess in Mult()
277  iterative_mode = true;
278  }
279  else
280  {
281  mfem_error("AmgX mode not supported \n");
282  }
283 }
284 
285 // Sets up AmgX library for MPI builds
286 #ifdef MFEM_USE_MPI
287 void AmgXSolver::InitAmgX()
288 {
289  // Set up once
290  if (count == 1)
291  {
292  AMGX_SAFE_CALL(AMGX_initialize());
293 
294  AMGX_SAFE_CALL(AMGX_initialize_plugins());
295 
296  AMGX_SAFE_CALL(AMGX_install_signal_handler());
297 
298  AMGX_SAFE_CALL(AMGX_register_print_callback(
299  [](const char *msg, int length)->void
300  {
301  int irank; MPI_Comm_rank(MPI_COMM_WORLD, &irank);
302  if (irank == 0) { mfem::out<<msg;} }));
303  }
304 
305  MFEM_VERIFY(configSrc != CONFIG_SRC::UNDEFINED,
306  "AmgX configuration is not defined \n");
307 
308  if (configSrc == CONFIG_SRC::EXTERNAL)
309  {
310  AMGX_SAFE_CALL(AMGX_config_create_from_file(&cfg, amgx_config.c_str()));
311  }
312  else
313  {
314  AMGX_SAFE_CALL(AMGX_config_create(&cfg, amgx_config.c_str()));
315  }
316 
317  // Let AmgX handle returned error codes internally
318  AMGX_SAFE_CALL(AMGX_config_add_parameters(&cfg, "exception_handling=1"));
319 
320  // Create an AmgX resource object, only the first instance needs to create
321  // the resource object.
322  if (count == 1) { AMGX_SAFE_CALL(AMGX_resources_create(&rsrc, cfg, &gpuWorld, 1, &devID)); }
323 
324  // Create AmgX vector object for unknowns and RHS
325  AMGX_SAFE_CALL(AMGX_vector_create(&AmgXP, rsrc, precision_mode));
326  AMGX_SAFE_CALL(AMGX_vector_create(&AmgXRHS, rsrc, precision_mode));
327 
328  // Create AmgX matrix object for unknowns and RHS
329  AMGX_SAFE_CALL(AMGX_matrix_create(&AmgXA, rsrc, precision_mode));
330 
331  // Create an AmgX solver object
332  AMGX_SAFE_CALL(AMGX_solver_create(&solver, rsrc, precision_mode, cfg));
333 
334  // Obtain the default number of rings based on current configuration
335  AMGX_SAFE_CALL(AMGX_config_get_default_number_of_rings(cfg, &ring));
336 }
337 
338 // Groups MPI ranks into teams and assigns the roots to talk to GPUs
339 void AmgXSolver::InitMPIcomms(const MPI_Comm &comm, const int nDevs)
340 {
341  // Duplicate the global communicator
342  MPI_Comm_dup(comm, &globalCpuWorld);
343  MPI_Comm_set_name(globalCpuWorld, "globalCpuWorld");
344 
345  // Get size and rank for global communicator
346  MPI_Comm_size(globalCpuWorld, &globalSize);
347  MPI_Comm_rank(globalCpuWorld, &myGlobalRank);
348 
349  // Get the communicator for processors on the same node (local world)
350  MPI_Comm_split_type(globalCpuWorld,
351  MPI_COMM_TYPE_SHARED, 0, MPI_INFO_NULL, &localCpuWorld);
352  MPI_Comm_set_name(localCpuWorld, "localCpuWorld");
353 
354  // Get size and rank for local communicator
355  MPI_Comm_size(localCpuWorld, &localSize);
356  MPI_Comm_rank(localCpuWorld, &myLocalRank);
357 
358  // Set up corresponding ID of the device used by each local process
359  SetDeviceIDs(nDevs);
360 
361  MPI_Barrier(globalCpuWorld);
362 
363  // Split the global world into a world involved in AmgX and a null world
364  MPI_Comm_split(globalCpuWorld, gpuProc, 0, &gpuWorld);
365 
366  // Get size and rank for the communicator corresponding to gpuWorld
367  if (gpuWorld != MPI_COMM_NULL)
368  {
369  MPI_Comm_set_name(gpuWorld, "gpuWorld");
370  MPI_Comm_size(gpuWorld, &gpuWorldSize);
371  MPI_Comm_rank(gpuWorld, &myGpuWorldRank);
372  }
373  else // for those that will not communicate with the GPU
374  {
375  gpuWorldSize = MPI_UNDEFINED;
376  myGpuWorldRank = MPI_UNDEFINED;
377  }
378 
379  // Split local world into worlds corresponding to each CUDA device
380  MPI_Comm_split(localCpuWorld, devID, 0, &devWorld);
381  MPI_Comm_set_name(devWorld, "devWorld");
382 
383  // Get size and rank for the communicator corresponding to myWorld
384  MPI_Comm_size(devWorld, &devWorldSize);
385  MPI_Comm_rank(devWorld, &myDevWorldRank);
386 
387  MPI_Barrier(globalCpuWorld);
388 }
389 
390 // Determine MPI teams based on available devices
391 void AmgXSolver::SetDeviceIDs(const int nDevs)
392 {
393  // Set the ID of device that each local process will use
394  if (nDevs == localSize) // # of the devices and local process are the same
395  {
396  devID = myLocalRank;
397  gpuProc = 0;
398  }
399  else if (nDevs > localSize) // there are more devices than processes
400  {
401  MFEM_WARNING("CUDA devices on the node " << nodeName.c_str() <<
402  " are more than the MPI processes launched. Only "<<
403  nDevs << " devices will be used.\n");
404  devID = myLocalRank;
405  gpuProc = 0;
406  }
407  else // in case there are more ranks than devices
408  {
409  int nBasic = localSize / nDevs,
410  nRemain = localSize % nDevs;
411 
412  if (myLocalRank < (nBasic+1)*nRemain)
413  {
414  devID = myLocalRank / (nBasic + 1);
415  if (myLocalRank % (nBasic + 1) == 0) { gpuProc = 0; }
416  }
417  else
418  {
419  devID = (myLocalRank - (nBasic+1)*nRemain) / nBasic + nRemain;
420  if ((myLocalRank - (nBasic+1)*nRemain) % nBasic == 0) { gpuProc = 0; }
421  }
422  }
423 }
424 
425 void AmgXSolver::GatherArray(const Array<double> &inArr, Array<double> &outArr,
426  const int mpiTeamSz, const MPI_Comm &mpiTeamComm) const
427 {
428  // Calculate number of elements to be collected from each process
429  Array<int> Apart(mpiTeamSz);
430  int locAsz = inArr.Size();
431  MPI_Gather(&locAsz, 1, MPI_INT,
432  Apart.HostWrite(),1, MPI_INT,0,mpiTeamComm);
433 
434  MPI_Barrier(mpiTeamComm);
435 
436  // Determine stride for process (to be used by root)
437  Array<int> Adisp(mpiTeamSz);
438  int myid; MPI_Comm_rank(mpiTeamComm, &myid);
439  if (myid == 0)
440  {
441  Adisp[0] = 0;
442  for (int i=1; i<mpiTeamSz; ++i)
443  {
444  Adisp[i] = Adisp[i-1] + Apart[i-1];
445  }
446  }
447 
448  MPI_Gatherv(inArr.HostRead(), inArr.Size(), MPI_DOUBLE,
449  outArr.HostWrite(), Apart.HostRead(), Adisp.HostRead(),
450  MPI_DOUBLE, 0, mpiTeamComm);
451 }
452 
453 void AmgXSolver::GatherArray(const Vector &inArr, Vector &outArr,
454  const int mpiTeamSz, const MPI_Comm &mpiTeamComm) const
455 {
456  // Calculate number of elements to be collected from each process
457  Array<int> Apart(mpiTeamSz);
458  int locAsz = inArr.Size();
459  MPI_Gather(&locAsz, 1, MPI_INT,
460  Apart.HostWrite(),1, MPI_INT,0,mpiTeamComm);
461 
462  MPI_Barrier(mpiTeamComm);
463 
464  // Determine stride for process (to be used by root)
465  Array<int> Adisp(mpiTeamSz);
466  int myid; MPI_Comm_rank(mpiTeamComm, &myid);
467  if (myid == 0)
468  {
469  Adisp[0] = 0;
470  for (int i=1; i<mpiTeamSz; ++i)
471  {
472  Adisp[i] = Adisp[i-1] + Apart[i-1];
473  }
474  }
475 
476  MPI_Gatherv(inArr.HostRead(), inArr.Size(), MPI_DOUBLE,
477  outArr.HostWrite(), Apart.HostRead(), Adisp.HostRead(),
478  MPI_DOUBLE, 0, mpiTeamComm);
479 }
480 
481 void AmgXSolver::GatherArray(const Array<int> &inArr, Array<int> &outArr,
482  const int mpiTeamSz, const MPI_Comm &mpiTeamComm) const
483 {
484  // Calculate number of elements to be collected from each process
485  Array<int> Apart(mpiTeamSz);
486  int locAsz = inArr.Size();
487  MPI_Gather(&locAsz, 1, MPI_INT,
488  Apart.GetData(),1, MPI_INT,0,mpiTeamComm);
489 
490  MPI_Barrier(mpiTeamComm);
491 
492  // Determine stride for process (to be used by root)
493  Array<int> Adisp(mpiTeamSz);
494  int myid; MPI_Comm_rank(mpiTeamComm, &myid);
495  if (myid == 0)
496  {
497  Adisp[0] = 0;
498  for (int i=1; i<mpiTeamSz; ++i)
499  {
500  Adisp[i] = Adisp[i-1] + Apart[i-1];
501  }
502  }
503 
504  MPI_Gatherv(inArr.HostRead(), inArr.Size(), MPI_INT,
505  outArr.HostWrite(), Apart.HostRead(), Adisp.HostRead(),
506  MPI_INT, 0, mpiTeamComm);
507 }
508 
509 
510 void AmgXSolver::GatherArray(const Array<int64_t> &inArr,
511  Array<int64_t> &outArr,
512  const int mpiTeamSz, const MPI_Comm &mpiTeamComm) const
513 {
514  // Calculate number of elements to be collected from each process
515  Array<int> Apart(mpiTeamSz);
516  int locAsz = inArr.Size();
517  MPI_Gather(&locAsz, 1, MPI_INT,
518  Apart.GetData(),1, MPI_INT,0,mpiTeamComm);
519 
520  MPI_Barrier(mpiTeamComm);
521 
522  // Determine stride for process
523  Array<int> Adisp(mpiTeamSz);
524  int myid; MPI_Comm_rank(mpiTeamComm, &myid);
525  if (myid == 0)
526  {
527  Adisp[0] = 0;
528  for (int i=1; i<mpiTeamSz; ++i)
529  {
530  Adisp[i] = Adisp[i-1] + Apart[i-1];
531  }
532  }
533 
534  MPI_Gatherv(inArr.HostRead(), inArr.Size(), MPI_INT64_T,
535  outArr.HostWrite(), Apart.HostRead(), Adisp.HostRead(),
536  MPI_INT64_T, 0, mpiTeamComm);
537 
538  MPI_Barrier(mpiTeamComm);
539 }
540 
541 void AmgXSolver::GatherArray(const Vector &inArr, Vector &outArr,
542  const int mpiTeamSz, const MPI_Comm &mpiTeamComm,
543  Array<int> &Apart, Array<int> &Adisp) const
544 {
545  // Calculate number of elements to be collected from each process
546  int locAsz = inArr.Size();
547  MPI_Allgather(&locAsz, 1, MPI_INT,
548  Apart.HostWrite(),1, MPI_INT, mpiTeamComm);
549 
550  MPI_Barrier(mpiTeamComm);
551 
552  // Determine stride for process
553  Adisp[0] = 0;
554  for (int i=1; i<mpiTeamSz; ++i)
555  {
556  Adisp[i] = Adisp[i-1] + Apart[i-1];
557  }
558 
559  MPI_Gatherv(inArr.HostRead(), inArr.Size(), MPI_DOUBLE,
560  outArr.HostWrite(), Apart.HostRead(), Adisp.HostRead(),
561  MPI_DOUBLE, 0, mpiTeamComm);
562 }
563 
564 void AmgXSolver::ScatterArray(const Vector &inArr, Vector &outArr,
565  const int mpiTeamSz, const MPI_Comm &mpiTeamComm,
566  Array<int> &Apart, Array<int> &Adisp) const
567 {
568  MPI_Scatterv(inArr.HostRead(),Apart.HostRead(),Adisp.HostRead(),
569  MPI_DOUBLE,outArr.HostWrite(),outArr.Size(),
570  MPI_DOUBLE, 0, mpiTeamComm);
571 }
572 #endif
573 
574 void AmgXSolver::SetMatrix(const SparseMatrix &in_A, const bool update_mat)
575 {
576  if (update_mat == false)
577  {
578  AMGX_SAFE_CALL(AMGX_matrix_upload_all(AmgXA, in_A.Height(),
579  in_A.NumNonZeroElems(),
580  1, 1,
581  in_A.ReadI(),
582  in_A.ReadJ(),
583  in_A.ReadData(), NULL));
584 
585  AMGX_SAFE_CALL(AMGX_solver_setup(solver, AmgXA));
586  AMGX_SAFE_CALL(AMGX_vector_bind(AmgXP, AmgXA));
587  AMGX_SAFE_CALL(AMGX_vector_bind(AmgXRHS, AmgXA));
588  }
589  else
590  {
591  AMGX_SAFE_CALL(AMGX_matrix_replace_coefficients(AmgXA,
592  in_A.Height(),
593  in_A.NumNonZeroElems(),
594  in_A.ReadData(), NULL));
595  }
596 }
597 
598 #ifdef MFEM_USE_MPI
599 
600 void AmgXSolver::SetMatrix(const HypreParMatrix &A, const bool update_mat)
601 {
602  // Require hypre >= 2.16.
603 #if MFEM_HYPRE_VERSION < 21600
604  mfem_error("Hypre version 2.16+ is required when using AmgX \n");
605 #endif
606 
607  // Ensure HypreParMatrix is on the host
608  A.HostRead();
609 
610  hypre_ParCSRMatrix * A_ptr =
611  (hypre_ParCSRMatrix *)const_cast<HypreParMatrix&>(A);
612 
613  hypre_CSRMatrix *A_csr = hypre_MergeDiagAndOffd(A_ptr);
614 
615  A.HypreRead();
616 
617  Array<double> loc_A(A_csr->data, (int)A_csr->num_nonzeros);
618  const Array<HYPRE_Int> loc_I(A_csr->i, (int)A_csr->num_rows+1);
619 
620  // Column index must be int64_t so we must promote here
621  Array<int64_t> loc_J((int)A_csr->num_nonzeros);
622  for (int i=0; i<A_csr->num_nonzeros; ++i)
623  {
624  loc_J[i] = A_csr->big_j[i];
625  }
626 
627  // Assumes one GPU per MPI rank
628  if (mpi_gpu_mode=="mpi-gpu-exclusive")
629  {
630  SetMatrixMPIGPUExclusive(A, loc_A, loc_I, loc_J, update_mat);
631  // Free A_csr data from hypre_MergeDiagAndOffd method
632  hypre_CSRMatrixDestroy(A_csr);
633  return;
634  }
635 
636  // Assumes teams of MPI ranks are sharing a GPU
637  if (mpi_gpu_mode == "mpi-teams")
638  {
639  SetMatrixMPITeams(A, loc_A, loc_I, loc_J, update_mat);
640  // Free A_csr data from hypre_MergeDiagAndOffd method
641  hypre_CSRMatrixDestroy(A_csr);
642  return;
643  }
644 
645  mfem_error("Unsupported MPI_GPU combination \n");
646 }
647 
648 void AmgXSolver::SetMatrixMPIGPUExclusive(const HypreParMatrix &A,
649  const Array<double> &loc_A,
650  const Array<int> &loc_I,
651  const Array<int64_t> &loc_J,
652  const bool update_mat)
653 {
654  // Create a vector of offsets describing matrix row partitions
655  Array<int64_t> rowPart(gpuWorldSize+1); rowPart = 0.0;
656 
657  int64_t myStart = A.GetRowStarts()[0];
658 
659  MPI_Allgather(&myStart, 1, MPI_INT64_T,
660  rowPart.GetData(),1, MPI_INT64_T
661  ,gpuWorld);
662  MPI_Barrier(gpuWorld);
663 
664  rowPart[gpuWorldSize] = A.M();
665 
666  const int nGlobalRows = A.M();
667  const int local_rows = loc_I.Size()-1;
668  const int num_nnz = loc_I[local_rows];
669 
670  if (update_mat == false)
671  {
672  AMGX_distribution_handle dist;
673  AMGX_SAFE_CALL(AMGX_distribution_create(&dist, cfg));
674  AMGX_SAFE_CALL(AMGX_distribution_set_partition_data(dist,
675  AMGX_DIST_PARTITION_OFFSETS,
676  rowPart.GetData()));
677 
678  AMGX_SAFE_CALL(AMGX_matrix_upload_distributed(AmgXA, nGlobalRows,
679  local_rows, num_nnz, 1, 1,
680  loc_I.Read(), loc_J.Read(),
681  loc_A.Read(), NULL, dist));
682 
683  AMGX_SAFE_CALL(AMGX_distribution_destroy(dist));
684 
685  MPI_Barrier(gpuWorld);
686 
687  AMGX_SAFE_CALL(AMGX_solver_setup(solver, AmgXA));
688 
689  AMGX_SAFE_CALL(AMGX_vector_bind(AmgXP, AmgXA));
690  AMGX_SAFE_CALL(AMGX_vector_bind(AmgXRHS, AmgXA));
691  }
692  else
693  {
694  AMGX_SAFE_CALL(AMGX_matrix_replace_coefficients(AmgXA, nGlobalRows,
695  num_nnz, loc_A, NULL));
696  }
697 }
698 
699 void AmgXSolver::SetMatrixMPITeams(const HypreParMatrix &A,
700  const Array<double> &loc_A,
701  const Array<int> &loc_I,
702  const Array<int64_t> &loc_J,
703  const bool update_mat)
704 {
705  // The following arrays hold the consolidated diagonal + off-diagonal matrix
706  // data
707  Array<int> all_I;
708  Array<int64_t> all_J;
709  Array<double> all_A;
710 
711  // Determine array sizes
712  int J_allsz(0), all_NNZ(0), nDevRows(0);
713  const int loc_row_len = std::abs(A.RowPart()[1] -
714  A.RowPart()[0]); // end of row partition
715  const int loc_Jz_sz = loc_J.Size();
716  const int loc_A_sz = loc_A.Size();
717 
718  MPI_Reduce(&loc_row_len, &nDevRows, 1, MPI_INT, MPI_SUM, 0, devWorld);
719  MPI_Reduce(&loc_Jz_sz, &J_allsz, 1, MPI_INT, MPI_SUM, 0, devWorld);
720  MPI_Reduce(&loc_A_sz, &all_NNZ, 1, MPI_INT, MPI_SUM, 0, devWorld);
721 
722  MPI_Barrier(devWorld);
723 
724  if (myDevWorldRank == 0)
725  {
726  all_I.SetSize(nDevRows+devWorldSize);
727  all_J.SetSize(J_allsz); all_J = 0.0;
728  all_A.SetSize(all_NNZ);
729  }
730 
731  GatherArray(loc_I, all_I, devWorldSize, devWorld);
732  GatherArray(loc_J, all_J, devWorldSize, devWorld);
733  GatherArray(loc_A, all_A, devWorldSize, devWorld);
734 
735  MPI_Barrier(devWorld);
736 
737  int local_nnz(0);
738  int64_t local_rows(0);
739 
740  if (myDevWorldRank == 0)
741  {
742  // A fix up step is needed for the array holding row data to remove extra
743  // zeros when consolidating team data.
744  Array<int> z_ind(devWorldSize+1);
745  int iter = 1;
746  while (iter < devWorldSize-1)
747  {
748  // Determine the indices of zeros in global all_I array
749  int counter = 0;
750  z_ind[counter] = counter;
751  counter++;
752  for (int idx=1; idx<all_I.Size()-1; idx++)
753  {
754  if (all_I[idx]==0)
755  {
756  z_ind[counter] = idx-1;
757  counter++;
758  }
759  }
760  z_ind[devWorldSize] = all_I.Size()-1;
761  // End of determining indices of zeros in global all_I Array
762 
763  // Bump all_I
764  for (int idx=z_ind[1]+1; idx < z_ind[2]; idx++)
765  {
766  all_I[idx] = all_I[idx-1] + (all_I[idx+1] - all_I[idx]);
767  }
768 
769  // Shift array after bump to remove unnecessary values in middle of
770  // array
771  for (int idx=z_ind[2]; idx < all_I.Size()-1; ++idx)
772  {
773  all_I[idx] = all_I[idx+1];
774  }
775  iter++;
776  }
777 
778  // LAST TIME THROUGH ARRAY
779  // Determine the indices of zeros in global row_ptr array
780  int counter = 0;
781  z_ind[counter] = counter;
782  counter++;
783  for (int idx=1; idx<all_I.Size()-1; idx++)
784  {
785  if (all_I[idx]==0)
786  {
787  z_ind[counter] = idx-1;
788  counter++;
789  }
790  }
791 
792  z_ind[devWorldSize] = all_I.Size()-1;
793  // End of determining indices of zeros in global all_I Array BUMP all_I
794  // one last time
795  for (int idx=z_ind[1]+1; idx < all_I.Size()-1; idx++)
796  {
797  all_I[idx] = all_I[idx-1] + (all_I[idx+1] - all_I[idx]);
798  }
799  local_nnz = all_I[all_I.Size()-devWorldSize];
800  local_rows = nDevRows;
801  }
802 
803  // Create row partition
804  mat_local_rows = local_rows; // class copy
805  Array<int64_t> rowPart;
806  if (gpuProc == 0)
807  {
808  rowPart.SetSize(gpuWorldSize+1); rowPart=0;
809 
810  MPI_Allgather(&local_rows, 1, MPI_INT64_T,
811  &rowPart.GetData()[1], 1, MPI_INT64_T,
812  gpuWorld);
813  MPI_Barrier(gpuWorld);
814 
815  // Fixup step
816  for (int i=1; i<rowPart.Size(); ++i)
817  {
818  rowPart[i] += rowPart[i-1];
819  }
820 
821  // Upload A matrix to AmgX
822  MPI_Barrier(gpuWorld);
823 
824  int nGlobalRows = A.M();
825  if (update_mat == false)
826  {
827  AMGX_distribution_handle dist;
828  AMGX_SAFE_CALL(AMGX_distribution_create(&dist, cfg));
829  AMGX_SAFE_CALL(AMGX_distribution_set_partition_data(dist,
830  AMGX_DIST_PARTITION_OFFSETS,
831  rowPart.GetData()));
832 
833  AMGX_SAFE_CALL(AMGX_matrix_upload_distributed(AmgXA, nGlobalRows,
834  local_rows, local_nnz,
835  1, 1, all_I.ReadWrite(),
836  all_J.Read(),
837  all_A.Read(),
838  nullptr, dist));
839 
840  AMGX_SAFE_CALL(AMGX_distribution_destroy(dist));
841  MPI_Barrier(gpuWorld);
842 
843  AMGX_SAFE_CALL(AMGX_solver_setup(solver, AmgXA));
844 
845  // Bind vectors to A
846  AMGX_SAFE_CALL(AMGX_vector_bind(AmgXP, AmgXA));
847  AMGX_SAFE_CALL(AMGX_vector_bind(AmgXRHS, AmgXA));
848  }
849  else
850  {
851  AMGX_SAFE_CALL(AMGX_matrix_replace_coefficients(AmgXA, nGlobalRows,
852  local_nnz, all_A, NULL));
853  }
854  }
855 }
856 
857 #endif
858 
860 {
861  height = op.Height();
862  width = op.Width();
863 
864  if (const SparseMatrix* Aptr =
865  dynamic_cast<const SparseMatrix*>(&op))
866  {
867  SetMatrix(*Aptr);
868  }
869 #ifdef MFEM_USE_MPI
870  else if (const HypreParMatrix* Aptr =
871  dynamic_cast<const HypreParMatrix*>(&op))
872  {
873  SetMatrix(*Aptr);
874  }
875 #endif
876  else
877  {
878  mfem_error("Unsupported Operator Type \n");
879  }
880 }
881 
883 {
884  if (const SparseMatrix* Aptr =
885  dynamic_cast<const SparseMatrix*>(&op))
886  {
887  SetMatrix(*Aptr, true);
888  }
889 #ifdef MFEM_USE_MPI
890  else if (const HypreParMatrix* Aptr =
891  dynamic_cast<const HypreParMatrix*>(&op))
892  {
893  SetMatrix(*Aptr, true);
894  }
895 #endif
896  else
897  {
898  mfem_error("Unsupported Operator Type \n");
899  }
900 }
901 
902 void AmgXSolver::Mult(const Vector& B, Vector& X) const
903 {
904  // Set initial guess to zero
905  X.UseDevice(true);
906  if (!iterative_mode) { X = 0.0; }
907 
908  // Mult for serial, and mpi-exclusive modes
909  if (mpi_gpu_mode != "mpi-teams")
910  {
911  AMGX_SAFE_CALL(AMGX_vector_upload(AmgXP, X.Size(), 1, X.ReadWrite()));
912  AMGX_SAFE_CALL(AMGX_vector_upload(AmgXRHS, B.Size(), 1, B.Read()));
913 
914  if (mpi_gpu_mode != "serial")
915  {
916 #ifdef MFEM_USE_MPI
917  MPI_Barrier(gpuWorld);
918 #endif
919  }
920 
921  AMGX_SAFE_CALL(AMGX_solver_solve(solver,AmgXRHS, AmgXP));
922 
923  AMGX_SOLVE_STATUS status;
924  AMGX_SAFE_CALL(AMGX_solver_get_status(solver, &status));
925  if (status != AMGX_SOLVE_SUCCESS && ConvergenceCheck)
926  {
927  if (status == AMGX_SOLVE_DIVERGED)
928  {
929  mfem_error("AmgX solver diverged \n");
930  }
931  else
932  {
933  mfem_error("AmgX solver failed to solve system \n");
934  }
935  }
936 
937  AMGX_SAFE_CALL(AMGX_vector_download(AmgXP, X.Write()));
938  return;
939  }
940 
941 #ifdef MFEM_USE_MPI
942  Vector all_X(mat_local_rows);
943  Vector all_B(mat_local_rows);
944  Array<int> Apart_X(devWorldSize);
945  Array<int> Adisp_X(devWorldSize);
946  Array<int> Apart_B(devWorldSize);
947  Array<int> Adisp_B(devWorldSize);
948 
949  GatherArray(X, all_X, devWorldSize, devWorld, Apart_X, Adisp_X);
950  GatherArray(B, all_B, devWorldSize, devWorld, Apart_B, Adisp_B);
951  MPI_Barrier(devWorld);
952 
953  if (gpuWorld != MPI_COMM_NULL)
954  {
955  AMGX_SAFE_CALL(AMGX_vector_upload(AmgXP, all_X.Size(), 1, all_X.ReadWrite()));
956  AMGX_SAFE_CALL(AMGX_vector_upload(AmgXRHS, all_B.Size(), 1, all_B.ReadWrite()));
957 
958  MPI_Barrier(gpuWorld);
959 
960  AMGX_SAFE_CALL(AMGX_solver_solve(solver,AmgXRHS, AmgXP));
961 
962  AMGX_SOLVE_STATUS status;
963  AMGX_SAFE_CALL(AMGX_solver_get_status(solver, &status));
964  if (status != AMGX_SOLVE_SUCCESS && amgxMode == SOLVER)
965  {
966  if (status == AMGX_SOLVE_DIVERGED)
967  {
968  mfem_error("AmgX solver diverged \n");
969  }
970  else
971  {
972  mfem_error("AmgX solver failed to solve system \n");
973  }
974  }
975 
976  AMGX_SAFE_CALL(AMGX_vector_download(AmgXP, all_X.Write()));
977  }
978 
979  ScatterArray(all_X, X, devWorldSize, devWorld, Apart_X, Adisp_X);
980 #endif
981 }
982 
984 {
985  int getIters;
986  AMGX_SAFE_CALL(AMGX_solver_get_iterations_number(solver, &getIters));
987  return getIters;
988 }
989 
991 {
992  // Check instance is initialized
993  if (! isInitialized || count < 1)
994  {
995  mfem_error("Error in AmgXSolver::Finalize(). \n"
996  "This AmgXWrapper has not been initialized. \n"
997  "Please initialize it before finalization.\n");
998  }
999 
1000  // Only processes using GPU are required to destroy AmgX content
1001 #ifdef MFEM_USE_MPI
1002  if (gpuProc == 0 || mpi_gpu_mode == "serial")
1003 #endif
1004  {
1005  // Destroy solver instance
1006  AMGX_SAFE_CALL(AMGX_solver_destroy(solver));
1007 
1008  // Destroy matrix instance
1009  AMGX_SAFE_CALL(AMGX_matrix_destroy(AmgXA));
1010 
1011  // Destroy RHS and unknown vectors
1012  AMGX_SAFE_CALL(AMGX_vector_destroy(AmgXP));
1013  AMGX_SAFE_CALL(AMGX_vector_destroy(AmgXRHS));
1014 
1015  // Only the last instance need to destroy resource and finalizing AmgX
1016  if (count == 1)
1017  {
1018  AMGX_SAFE_CALL(AMGX_resources_destroy(rsrc));
1019  AMGX_SAFE_CALL(AMGX_config_destroy(cfg));
1020 
1021  AMGX_SAFE_CALL(AMGX_finalize_plugins());
1022  AMGX_SAFE_CALL(AMGX_finalize());
1023  }
1024  else
1025  {
1026  AMGX_SAFE_CALL(AMGX_config_destroy(cfg));
1027  }
1028 #ifdef MFEM_USE_MPI
1029  // destroy gpuWorld
1030  if (mpi_gpu_mode != "serial")
1031  {
1032  MPI_Comm_free(&gpuWorld);
1033  }
1034 #endif
1035  }
1036 
1037  // reset necessary variables in case users want to reuse the variable of
1038  // this instance for a new instance
1039 #ifdef MFEM_USE_MPI
1040  gpuProc = MPI_UNDEFINED;
1041  if (globalCpuWorld != MPI_COMM_NULL)
1042  {
1043  MPI_Comm_free(&globalCpuWorld);
1044  MPI_Comm_free(&localCpuWorld);
1045  MPI_Comm_free(&devWorld);
1046  }
1047 #endif
1048  // decrease the number of instances
1049  count -= 1;
1050 
1051  // change status
1052  isInitialized = false;
1053 }
1054 
1055 } // mfem namespace
1056 
1057 #endif
bool ConvergenceCheck
Flag to check for convergence.
Definition: amgxsolver.hpp:77
int Width() const
Get the width (size of input) of the Operator. Synonym with NumCols().
Definition: operator.hpp:73
int Size() const
Returns the size of the vector.
Definition: vector.hpp:200
bool iterative_mode
If true, use the second argument of Mult() as an initial guess.
Definition: operator.hpp:659
void InitMPITeams(const MPI_Comm &comm, const int nDevs)
Definition: amgxsolver.cpp:149
virtual void UseDevice(bool use_dev) const
Enable execution of Vector operations using the mfem::Device.
Definition: vector.hpp:118
void ReadParameters(const std::string config, CONFIG_SRC source)
Definition: amgxsolver.cpp:186
virtual void Mult(const Vector &b, Vector &x) const
Operator application: y=A(x).
Definition: amgxsolver.cpp:902
void source(const Vector &x, Vector &f)
Definition: ex25.cpp:620
void SetConvergenceCheck(bool setConvergenceCheck_=true)
Add a check for convergence after applying Mult.
Definition: amgxsolver.cpp:193
Data type sparse matrix.
Definition: sparsemat.hpp:50
int Height() const
Get the height (size of output) of the Operator. Synonym with NumRows().
Definition: operator.hpp:67
virtual double * Write(bool on_dev=true)
Shortcut for mfem::Write(vec.GetMemory(), vec.Size(), on_dev).
Definition: vector.hpp:457
void mfem_error(const char *msg)
Function called when an error is encountered. Used by the macros MFEM_ABORT, MFEM_ASSERT, MFEM_VERIFY.
Definition: error.cpp:154
void UpdateOperator(const Operator &op)
Definition: amgxsolver.cpp:882
virtual void SetOperator(const Operator &op)
Definition: amgxsolver.cpp:859
void DefaultParameters(const AMGX_MODE amgxMode_, const bool verbose)
Definition: amgxsolver.cpp:198
int height
Dimension of the output / number of rows in the matrix.
Definition: operator.hpp:27
virtual double * ReadWrite(bool on_dev=true)
Shortcut for mfem::ReadWrite(vec.GetMemory(), vec.Size(), on_dev).
Definition: vector.hpp:465
void InitExclusiveGPU(const MPI_Comm &comm)
Definition: amgxsolver.cpp:120
Vector data type.
Definition: vector.hpp:60
OutStream out(std::cout)
Global stream used by the library for standard output. Initially it uses the same std::streambuf as s...
Definition: globals.hpp:66
Abstract operator.
Definition: operator.hpp:24
Wrapper for hypre&#39;s ParCSR matrix class.
Definition: hypre.hpp:343
AMGX_MODE
Flags to configure AmgXSolver as a solver or preconditioner.
Definition: amgxsolver.hpp:74
virtual const double * Read(bool on_dev=true) const
Shortcut for mfem::Read(vec.GetMemory(), vec.Size(), on_dev).
Definition: vector.hpp:449
int width
Dimension of the input / number of columns in the matrix.
Definition: operator.hpp:28