MFEM  v4.3.0
Finite element discretization library
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Pages
amgxsolver.cpp
Go to the documentation of this file.
1 // Copyright (c) 2010-2021, Lawrence Livermore National Security, LLC. Produced
2 // at the Lawrence Livermore National Laboratory. All Rights reserved. See files
3 // LICENSE and NOTICE for details. LLNL-CODE-806117.
4 //
5 // This file is part of the MFEM library. For more information and source code
6 // availability visit https://mfem.org.
7 //
8 // MFEM is free software; you can redistribute it and/or modify it under the
9 // terms of the BSD-3 license. We welcome feedback and contributions, see file
10 // CONTRIBUTING.md for details.
11 
12 // Implementation of the MFEM wrapper for Nvidia's multigrid library, AmgX
13 //
14 // This work is partially based on:
15 //
16 // Pi-Yueh Chuang and Lorena A. Barba (2017).
17 // AmgXWrapper: An interface between PETSc and the NVIDIA AmgX library.
18 // J. Open Source Software, 2(16):280, doi:10.21105/joss.00280
19 //
20 // See https://github.com/barbagroup/AmgXWrapper.
21 
22 #include "../config/config.hpp"
23 #include "amgxsolver.hpp"
24 #ifdef MFEM_USE_AMGX
25 
26 namespace mfem
27 {
28 
29 int AmgXSolver::count = 0;
30 
31 AMGX_resources_handle AmgXSolver::rsrc = nullptr;
32 
34  : ConvergenceCheck(false) {};
35 
36 AmgXSolver::AmgXSolver(const AMGX_MODE amgxMode_, const bool verbose)
37 {
38  amgxMode = amgxMode_;
39 
40  if (amgxMode == AmgXSolver::SOLVER) { ConvergenceCheck = true;}
41  else { ConvergenceCheck = false;}
42 
43  DefaultParameters(amgxMode, verbose);
44 
45  InitSerial();
46 }
47 
48 #ifdef MFEM_USE_MPI
49 
50 AmgXSolver::AmgXSolver(const MPI_Comm &comm,
51  const AMGX_MODE amgxMode_, const bool verbose)
52 {
53  std::string config;
54  amgxMode = amgxMode_;
55 
56  if (amgxMode == AmgXSolver::SOLVER) { ConvergenceCheck = true;}
57  else { ConvergenceCheck = false;}
58 
59  DefaultParameters(amgxMode, verbose);
60 
61  InitExclusiveGPU(comm);
62 }
63 
64 AmgXSolver::AmgXSolver(const MPI_Comm &comm, const int nDevs,
65  const AMGX_MODE amgxMode_, const bool verbose)
66 {
67  std::string config;
68  amgxMode = amgxMode_;
69 
70  if (amgxMode == AmgXSolver::SOLVER) { ConvergenceCheck = true;}
71  else { ConvergenceCheck = false;}
72 
73  DefaultParameters(amgxMode_, verbose);
74 
75  InitMPITeams(comm, nDevs);
76 }
77 
78 #endif
79 
81 {
82  if (isInitialized) { Finalize(); }
83 }
84 
86 {
87  count++;
88 
89  mpi_gpu_mode = "serial";
90 
91  AMGX_SAFE_CALL(AMGX_initialize());
92 
93  AMGX_SAFE_CALL(AMGX_initialize_plugins());
94 
95  AMGX_SAFE_CALL(AMGX_install_signal_handler());
96 
97  MFEM_VERIFY(configSrc != CONFIG_SRC::UNDEFINED,
98  "AmgX configuration is not defined \n");
99 
100  if (configSrc == CONFIG_SRC::EXTERNAL)
101  {
102  AMGX_SAFE_CALL(AMGX_config_create_from_file(&cfg, amgx_config.c_str()));
103  }
104  else
105  {
106  AMGX_SAFE_CALL(AMGX_config_create(&cfg, amgx_config.c_str()));
107  }
108 
109  AMGX_SAFE_CALL(AMGX_resources_create_simple(&rsrc, cfg));
110  AMGX_SAFE_CALL(AMGX_solver_create(&solver, rsrc, precision_mode, cfg));
111  AMGX_SAFE_CALL(AMGX_matrix_create(&AmgXA, rsrc, precision_mode));
112  AMGX_SAFE_CALL(AMGX_vector_create(&AmgXP, rsrc, precision_mode));
113  AMGX_SAFE_CALL(AMGX_vector_create(&AmgXRHS, rsrc, precision_mode));
114 
115  isInitialized = true;
116 }
117 
118 #ifdef MFEM_USE_MPI
119 
120 void AmgXSolver::InitExclusiveGPU(const MPI_Comm &comm)
121 {
122  // If this instance has already been initialized, skip
123  if (isInitialized)
124  {
125  mfem_error("This AmgXSolver instance has been initialized on this process.");
126  }
127 
128  // Note that every MPI rank may talk to a GPU
129  mpi_gpu_mode = "mpi-gpu-exclusive";
130  gpuProc = 0;
131 
132  // Increment number of AmgX instances
133  count++;
134 
135  MPI_Comm_dup(comm, &gpuWorld);
136  MPI_Comm_size(gpuWorld, &gpuWorldSize);
137  MPI_Comm_rank(gpuWorld, &myGpuWorldRank);
138 
139  // Each rank will only see 1 device call it device 0
140  nDevs = 1, devID = 0;
141 
142  InitAmgX();
143 
144  isInitialized = true;
145 }
146 
147 // Initialize for MPI ranks > GPUs, all devices are visible to all of the MPI
148 // ranks
149 void AmgXSolver::InitMPITeams(const MPI_Comm &comm,
150  const int nDevs)
151 {
152  // If this instance has already been initialized, skip
153  if (isInitialized)
154  {
155  mfem_error("This AmgXSolver instance has been initialized on this process.");
156  }
157 
158  mpi_gpu_mode = "mpi-teams";
159 
160  // Increment number of AmgX instances
161  count++;
162 
163  // Get the name of this node
164  int len;
165  char name[MPI_MAX_PROCESSOR_NAME];
166  MPI_Get_processor_name(name, &len);
167  nodeName = name;
168  int globalcommrank;
169 
170  MPI_Comm_rank(comm, &globalcommrank);
171 
172  // Initialize communicators and corresponding information
173  InitMPIcomms(comm, nDevs);
174 
175  // Only processes in gpuWorld are required to initialize AmgX
176  if (gpuProc == 0)
177  {
178  InitAmgX();
179  }
180 
181  isInitialized = true;
182 }
183 
184 #endif
185 
186 void AmgXSolver::ReadParameters(const std::string config,
187  const CONFIG_SRC source)
188 {
189  amgx_config = config;
190  configSrc = source;
191 }
192 
193 void AmgXSolver::SetConvergenceCheck(bool setConvergenceCheck_)
194 {
195  ConvergenceCheck = setConvergenceCheck_;
196 }
197 
199  const bool verbose)
200 {
201  amgxMode = amgxMode_;
202 
203  configSrc = INTERNAL;
204 
205  if (amgxMode == AMGX_MODE::PRECONDITIONER)
206  {
207  amgx_config = "{\n"
208  " \"config_version\": 2, \n"
209  " \"solver\": { \n"
210  " \"solver\": \"AMG\", \n"
211  " \"scope\": \"main\", \n"
212  " \"smoother\": \"JACOBI_L1\", \n"
213  " \"presweeps\": 1, \n"
214  " \"interpolator\": \"D2\", \n"
215  " \"max_row_sum\" : 0.9, \n"
216  " \"strength_threshold\" : 0.25, \n"
217  " \"postsweeps\": 1, \n"
218  " \"max_iters\": 1, \n"
219  " \"cycle\": \"V\"";
220  if (verbose)
221  {
222  amgx_config = amgx_config + ",\n"
223  " \"obtain_timings\": 1, \n"
224  " \"print_grid_stats\": 1, \n"
225  " \"monitor_residual\": 1, \n"
226  " \"print_solve_stats\": 1 \n";
227  }
228  else
229  {
230  amgx_config = amgx_config + "\n";
231  }
232  amgx_config = amgx_config + " }\n" + "}\n";
233  // use a zero initial guess in Mult()
234  iterative_mode = false;
235  }
236  else if (amgxMode == AMGX_MODE::SOLVER)
237  {
238  amgx_config = "{ \n"
239  " \"config_version\": 2, \n"
240  " \"solver\": { \n"
241  " \"preconditioner\": { \n"
242  " \"solver\": \"AMG\", \n"
243  " \"smoother\": { \n"
244  " \"scope\": \"jacobi\", \n"
245  " \"solver\": \"JACOBI_L1\" \n"
246  " }, \n"
247  " \"presweeps\": 1, \n"
248  " \"interpolator\": \"D2\", \n"
249  " \"max_row_sum\" : 0.9, \n"
250  " \"strength_threshold\" : 0.25, \n"
251  " \"max_iters\": 1, \n"
252  " \"scope\": \"amg\", \n"
253  " \"max_levels\": 100, \n"
254  " \"cycle\": \"V\", \n"
255  " \"postsweeps\": 1 \n"
256  " }, \n"
257  " \"solver\": \"PCG\", \n"
258  " \"max_iters\": 150, \n"
259  " \"convergence\": \"RELATIVE_INI_CORE\", \n"
260  " \"scope\": \"main\", \n"
261  " \"tolerance\": 1e-12, \n"
262  " \"monitor_residual\": 1, \n"
263  " \"norm\": \"L2\" ";
264  if (verbose)
265  {
266  amgx_config = amgx_config + ", \n"
267  " \"obtain_timings\": 1, \n"
268  " \"print_grid_stats\": 1, \n"
269  " \"print_solve_stats\": 1 \n";
270  }
271  else
272  {
273  amgx_config = amgx_config + "\n";
274  }
275  amgx_config = amgx_config + " } \n" + "} \n";
276  // use the user-specified vector as an initial guess in Mult()
277  iterative_mode = true;
278  }
279  else
280  {
281  mfem_error("AmgX mode not supported \n");
282  }
283 }
284 
285 // Sets up AmgX library for MPI builds
286 #ifdef MFEM_USE_MPI
287 void AmgXSolver::InitAmgX()
288 {
289  // Set up once
290  if (count == 1)
291  {
292  AMGX_SAFE_CALL(AMGX_initialize());
293 
294  AMGX_SAFE_CALL(AMGX_initialize_plugins());
295 
296  AMGX_SAFE_CALL(AMGX_install_signal_handler());
297 
298  AMGX_SAFE_CALL(AMGX_register_print_callback(
299  [](const char *msg, int length)->void
300  {
301  int irank; MPI_Comm_rank(MPI_COMM_WORLD, &irank);
302  if (irank == 0) { mfem::out<<msg;} }));
303  }
304 
305  MFEM_VERIFY(configSrc != CONFIG_SRC::UNDEFINED,
306  "AmgX configuration is not defined \n");
307 
308  if (configSrc == CONFIG_SRC::EXTERNAL)
309  {
310  AMGX_SAFE_CALL(AMGX_config_create_from_file(&cfg, amgx_config.c_str()));
311  }
312  else
313  {
314  AMGX_SAFE_CALL(AMGX_config_create(&cfg, amgx_config.c_str()));
315  }
316 
317  // Let AmgX handle returned error codes internally
318  AMGX_SAFE_CALL(AMGX_config_add_parameters(&cfg, "exception_handling=1"));
319 
320  // Create an AmgX resource object, only the first instance needs to create
321  // the resource object.
322  if (count == 1) { AMGX_SAFE_CALL(AMGX_resources_create(&rsrc, cfg, &gpuWorld, 1, &devID)); }
323 
324  // Create AmgX vector object for unknowns and RHS
325  AMGX_SAFE_CALL(AMGX_vector_create(&AmgXP, rsrc, precision_mode));
326  AMGX_SAFE_CALL(AMGX_vector_create(&AmgXRHS, rsrc, precision_mode));
327 
328  // Create AmgX matrix object for unknowns and RHS
329  AMGX_SAFE_CALL(AMGX_matrix_create(&AmgXA, rsrc, precision_mode));
330 
331  // Create an AmgX solver object
332  AMGX_SAFE_CALL(AMGX_solver_create(&solver, rsrc, precision_mode, cfg));
333 
334  // Obtain the default number of rings based on current configuration
335  AMGX_SAFE_CALL(AMGX_config_get_default_number_of_rings(cfg, &ring));
336 }
337 
338 // Groups MPI ranks into teams and assigns the roots to talk to GPUs
339 void AmgXSolver::InitMPIcomms(const MPI_Comm &comm, const int nDevs)
340 {
341  // Duplicate the global communicator
342  MPI_Comm_dup(comm, &globalCpuWorld);
343  MPI_Comm_set_name(globalCpuWorld, "globalCpuWorld");
344 
345  // Get size and rank for global communicator
346  MPI_Comm_size(globalCpuWorld, &globalSize);
347  MPI_Comm_rank(globalCpuWorld, &myGlobalRank);
348 
349  // Get the communicator for processors on the same node (local world)
350  MPI_Comm_split_type(globalCpuWorld,
351  MPI_COMM_TYPE_SHARED, 0, MPI_INFO_NULL, &localCpuWorld);
352  MPI_Comm_set_name(localCpuWorld, "localCpuWorld");
353 
354  // Get size and rank for local communicator
355  MPI_Comm_size(localCpuWorld, &localSize);
356  MPI_Comm_rank(localCpuWorld, &myLocalRank);
357 
358  // Set up corresponding ID of the device used by each local process
359  SetDeviceIDs(nDevs);
360 
361  MPI_Barrier(globalCpuWorld);
362 
363  // Split the global world into a world involved in AmgX and a null world
364  MPI_Comm_split(globalCpuWorld, gpuProc, 0, &gpuWorld);
365 
366  // Get size and rank for the communicator corresponding to gpuWorld
367  if (gpuWorld != MPI_COMM_NULL)
368  {
369  MPI_Comm_set_name(gpuWorld, "gpuWorld");
370  MPI_Comm_size(gpuWorld, &gpuWorldSize);
371  MPI_Comm_rank(gpuWorld, &myGpuWorldRank);
372  }
373  else // for those that will not communicate with the GPU
374  {
375  gpuWorldSize = MPI_UNDEFINED;
376  myGpuWorldRank = MPI_UNDEFINED;
377  }
378 
379  // Split local world into worlds corresponding to each CUDA device
380  MPI_Comm_split(localCpuWorld, devID, 0, &devWorld);
381  MPI_Comm_set_name(devWorld, "devWorld");
382 
383  // Get size and rank for the communicator corresponding to myWorld
384  MPI_Comm_size(devWorld, &devWorldSize);
385  MPI_Comm_rank(devWorld, &myDevWorldRank);
386 
387  MPI_Barrier(globalCpuWorld);
388 }
389 
390 // Determine MPI teams based on available devices
391 void AmgXSolver::SetDeviceIDs(const int nDevs)
392 {
393  // Set the ID of device that each local process will use
394  if (nDevs == localSize) // # of the devices and local process are the same
395  {
396  devID = myLocalRank;
397  gpuProc = 0;
398  }
399  else if (nDevs > localSize) // there are more devices than processes
400  {
401  MFEM_WARNING("CUDA devices on the node " << nodeName.c_str() <<
402  " are more than the MPI processes launched. Only "<<
403  nDevs << " devices will be used.\n");
404  devID = myLocalRank;
405  gpuProc = 0;
406  }
407  else // in case there are more ranks than devices
408  {
409  int nBasic = localSize / nDevs,
410  nRemain = localSize % nDevs;
411 
412  if (myLocalRank < (nBasic+1)*nRemain)
413  {
414  devID = myLocalRank / (nBasic + 1);
415  if (myLocalRank % (nBasic + 1) == 0) { gpuProc = 0; }
416  }
417  else
418  {
419  devID = (myLocalRank - (nBasic+1)*nRemain) / nBasic + nRemain;
420  if ((myLocalRank - (nBasic+1)*nRemain) % nBasic == 0) { gpuProc = 0; }
421  }
422  }
423 }
424 
425 void AmgXSolver::GatherArray(const Array<double> &inArr, Array<double> &outArr,
426  const int mpiTeamSz, const MPI_Comm &mpiTeamComm) const
427 {
428  // Calculate number of elements to be collected from each process
429  Array<int> Apart(mpiTeamSz);
430  int locAsz = inArr.Size();
431  MPI_Gather(&locAsz, 1, MPI_INT,
432  Apart.HostWrite(),1, MPI_INT,0,mpiTeamComm);
433 
434  MPI_Barrier(mpiTeamComm);
435 
436  // Determine stride for process (to be used by root)
437  Array<int> Adisp(mpiTeamSz);
438  int myid; MPI_Comm_rank(mpiTeamComm, &myid);
439  if (myid == 0)
440  {
441  Adisp[0] = 0;
442  for (int i=1; i<mpiTeamSz; ++i)
443  {
444  Adisp[i] = Adisp[i-1] + Apart[i-1];
445  }
446  }
447 
448  MPI_Gatherv(inArr.HostRead(), inArr.Size(), MPI_DOUBLE,
449  outArr.HostWrite(), Apart.HostRead(), Adisp.HostRead(),
450  MPI_DOUBLE, 0, mpiTeamComm);
451 }
452 
453 void AmgXSolver::GatherArray(const Vector &inArr, Vector &outArr,
454  const int mpiTeamSz, const MPI_Comm &mpiTeamComm) const
455 {
456  // Calculate number of elements to be collected from each process
457  Array<int> Apart(mpiTeamSz);
458  int locAsz = inArr.Size();
459  MPI_Gather(&locAsz, 1, MPI_INT,
460  Apart.HostWrite(),1, MPI_INT,0,mpiTeamComm);
461 
462  MPI_Barrier(mpiTeamComm);
463 
464  // Determine stride for process (to be used by root)
465  Array<int> Adisp(mpiTeamSz);
466  int myid; MPI_Comm_rank(mpiTeamComm, &myid);
467  if (myid == 0)
468  {
469  Adisp[0] = 0;
470  for (int i=1; i<mpiTeamSz; ++i)
471  {
472  Adisp[i] = Adisp[i-1] + Apart[i-1];
473  }
474  }
475 
476  MPI_Gatherv(inArr.HostRead(), inArr.Size(), MPI_DOUBLE,
477  outArr.HostWrite(), Apart.HostRead(), Adisp.HostRead(),
478  MPI_DOUBLE, 0, mpiTeamComm);
479 }
480 
481 void AmgXSolver::GatherArray(const Array<int> &inArr, Array<int> &outArr,
482  const int mpiTeamSz, const MPI_Comm &mpiTeamComm) const
483 {
484  // Calculate number of elements to be collected from each process
485  Array<int> Apart(mpiTeamSz);
486  int locAsz = inArr.Size();
487  MPI_Gather(&locAsz, 1, MPI_INT,
488  Apart.GetData(),1, MPI_INT,0,mpiTeamComm);
489 
490  MPI_Barrier(mpiTeamComm);
491 
492  // Determine stride for process (to be used by root)
493  Array<int> Adisp(mpiTeamSz);
494  int myid; MPI_Comm_rank(mpiTeamComm, &myid);
495  if (myid == 0)
496  {
497  Adisp[0] = 0;
498  for (int i=1; i<mpiTeamSz; ++i)
499  {
500  Adisp[i] = Adisp[i-1] + Apart[i-1];
501  }
502  }
503 
504  MPI_Gatherv(inArr.HostRead(), inArr.Size(), MPI_INT,
505  outArr.HostWrite(), Apart.HostRead(), Adisp.HostRead(),
506  MPI_INT, 0, mpiTeamComm);
507 }
508 
509 
510 void AmgXSolver::GatherArray(const Array<int64_t> &inArr,
511  Array<int64_t> &outArr,
512  const int mpiTeamSz, const MPI_Comm &mpiTeamComm) const
513 {
514  // Calculate number of elements to be collected from each process
515  Array<int> Apart(mpiTeamSz);
516  int locAsz = inArr.Size();
517  MPI_Gather(&locAsz, 1, MPI_INT,
518  Apart.GetData(),1, MPI_INT,0,mpiTeamComm);
519 
520  MPI_Barrier(mpiTeamComm);
521 
522  // Determine stride for process
523  Array<int> Adisp(mpiTeamSz);
524  int myid; MPI_Comm_rank(mpiTeamComm, &myid);
525  if (myid == 0)
526  {
527  Adisp[0] = 0;
528  for (int i=1; i<mpiTeamSz; ++i)
529  {
530  Adisp[i] = Adisp[i-1] + Apart[i-1];
531  }
532  }
533 
534  MPI_Gatherv(inArr.HostRead(), inArr.Size(), MPI_INT64_T,
535  outArr.HostWrite(), Apart.HostRead(), Adisp.HostRead(),
536  MPI_INT64_T, 0, mpiTeamComm);
537 
538  MPI_Barrier(mpiTeamComm);
539 }
540 
541 void AmgXSolver::GatherArray(const Vector &inArr, Vector &outArr,
542  const int mpiTeamSz, const MPI_Comm &mpiTeamComm,
543  Array<int> &Apart, Array<int> &Adisp) const
544 {
545  // Calculate number of elements to be collected from each process
546  int locAsz = inArr.Size();
547  MPI_Allgather(&locAsz, 1, MPI_INT,
548  Apart.HostWrite(),1, MPI_INT, mpiTeamComm);
549 
550  MPI_Barrier(mpiTeamComm);
551 
552  // Determine stride for process
553  Adisp[0] = 0;
554  for (int i=1; i<mpiTeamSz; ++i)
555  {
556  Adisp[i] = Adisp[i-1] + Apart[i-1];
557  }
558 
559  MPI_Gatherv(inArr.HostRead(), inArr.Size(), MPI_DOUBLE,
560  outArr.HostWrite(), Apart.HostRead(), Adisp.HostRead(),
561  MPI_DOUBLE, 0, mpiTeamComm);
562 }
563 
564 void AmgXSolver::ScatterArray(const Vector &inArr, Vector &outArr,
565  const int mpiTeamSz, const MPI_Comm &mpiTeamComm,
566  Array<int> &Apart, Array<int> &Adisp) const
567 {
568  MPI_Scatterv(inArr.HostRead(),Apart.HostRead(),Adisp.HostRead(),
569  MPI_DOUBLE,outArr.HostWrite(),outArr.Size(),
570  MPI_DOUBLE, 0, mpiTeamComm);
571 }
572 #endif
573 
574 void AmgXSolver::SetMatrix(const SparseMatrix &in_A, const bool update_mat)
575 {
576  if (update_mat == false)
577  {
578  AMGX_SAFE_CALL(AMGX_matrix_upload_all(AmgXA, in_A.Height(),
579  in_A.NumNonZeroElems(),
580  1, 1,
581  in_A.ReadI(),
582  in_A.ReadJ(),
583  in_A.ReadData(), NULL));
584 
585  AMGX_SAFE_CALL(AMGX_solver_setup(solver, AmgXA));
586  AMGX_SAFE_CALL(AMGX_vector_bind(AmgXP, AmgXA));
587  AMGX_SAFE_CALL(AMGX_vector_bind(AmgXRHS, AmgXA));
588  }
589  else
590  {
591  AMGX_SAFE_CALL(AMGX_matrix_replace_coefficients(AmgXA,
592  in_A.Height(),
593  in_A.NumNonZeroElems(),
594  in_A.ReadData(), NULL));
595  }
596 }
597 
598 #ifdef MFEM_USE_MPI
599 
600 void AmgXSolver::SetMatrix(const HypreParMatrix &A, const bool update_mat)
601 {
602  // Require hypre >= 2.16.
603 #if MFEM_HYPRE_VERSION < 21600
604  mfem_error("Hypre version 2.16+ is required when using AmgX \n");
605 #endif
606 
607  hypre_ParCSRMatrix * A_ptr =
608  (hypre_ParCSRMatrix *)const_cast<HypreParMatrix&>(A);
609 
610  hypre_CSRMatrix *A_csr = hypre_MergeDiagAndOffd(A_ptr);
611 
612  Array<double> loc_A(A_csr->data, (int)A_csr->num_nonzeros);
613  const Array<HYPRE_Int> loc_I(A_csr->i, (int)A_csr->num_rows+1);
614 
615  // Column index must be int64_t so we must promote here
616  Array<int64_t> loc_J((int)A_csr->num_nonzeros);
617  for (int i=0; i<A_csr->num_nonzeros; ++i)
618  {
619  loc_J[i] = A_csr->big_j[i];
620  }
621 
622  // Assumes one GPU per MPI rank
623  if (mpi_gpu_mode=="mpi-gpu-exclusive")
624  {
625  SetMatrixMPIGPUExclusive(A, loc_A, loc_I, loc_J, update_mat);
626  // Free A_csr data from hypre_MergeDiagAndOffd method
627  hypre_CSRMatrixDestroy(A_csr);
628  return;
629  }
630 
631  // Assumes teams of MPI ranks are sharing a GPU
632  if (mpi_gpu_mode == "mpi-teams")
633  {
634  SetMatrixMPITeams(A, loc_A, loc_I, loc_J, update_mat);
635  // Free A_csr data from hypre_MergeDiagAndOffd method
636  hypre_CSRMatrixDestroy(A_csr);
637  return;
638  }
639 
640  mfem_error("Unsupported MPI_GPU combination \n");
641 }
642 
643 void AmgXSolver::SetMatrixMPIGPUExclusive(const HypreParMatrix &A,
644  const Array<double> &loc_A,
645  const Array<int> &loc_I,
646  const Array<int64_t> &loc_J,
647  const bool update_mat)
648 {
649  // Create a vector of offsets describing matrix row partitions
650  Array<int64_t> rowPart(gpuWorldSize+1); rowPart = 0.0;
651 
652  int64_t myStart = A.GetRowStarts()[0];
653 
654  MPI_Allgather(&myStart, 1, MPI_INT64_T,
655  rowPart.GetData(),1, MPI_INT64_T
656  ,gpuWorld);
657  MPI_Barrier(gpuWorld);
658 
659  rowPart[gpuWorldSize] = A.M();
660 
661  const int nGlobalRows = A.M();
662  const int local_rows = loc_I.Size()-1;
663  const int num_nnz = loc_I[local_rows];
664 
665  if (update_mat == false)
666  {
667  AMGX_distribution_handle dist;
668  AMGX_SAFE_CALL(AMGX_distribution_create(&dist, cfg));
669  AMGX_SAFE_CALL(AMGX_distribution_set_partition_data(dist,
670  AMGX_DIST_PARTITION_OFFSETS,
671  rowPart.GetData()));
672 
673  AMGX_SAFE_CALL(AMGX_matrix_upload_distributed(AmgXA, nGlobalRows,
674  local_rows, num_nnz, 1, 1,
675  loc_I.Read(), loc_J.Read(),
676  loc_A.Read(), NULL, dist));
677 
678  AMGX_SAFE_CALL(AMGX_distribution_destroy(dist));
679 
680  MPI_Barrier(gpuWorld);
681 
682  AMGX_SAFE_CALL(AMGX_solver_setup(solver, AmgXA));
683 
684  AMGX_SAFE_CALL(AMGX_vector_bind(AmgXP, AmgXA));
685  AMGX_SAFE_CALL(AMGX_vector_bind(AmgXRHS, AmgXA));
686  }
687  else
688  {
689  AMGX_SAFE_CALL(AMGX_matrix_replace_coefficients(AmgXA, nGlobalRows,
690  num_nnz, loc_A, NULL));
691  }
692 }
693 
694 void AmgXSolver::SetMatrixMPITeams(const HypreParMatrix &A,
695  const Array<double> &loc_A,
696  const Array<int> &loc_I,
697  const Array<int64_t> &loc_J,
698  const bool update_mat)
699 {
700  // The following arrays hold the consolidated diagonal + off-diagonal matrix
701  // data
702  Array<int> all_I;
703  Array<int64_t> all_J;
704  Array<double> all_A;
705 
706  // Determine array sizes
707  int J_allsz(0), all_NNZ(0), nDevRows(0);
708  const int loc_row_len = std::abs(A.RowPart()[1] -
709  A.RowPart()[0]); // end of row partition
710  const int loc_Jz_sz = loc_J.Size();
711  const int loc_A_sz = loc_A.Size();
712 
713  MPI_Reduce(&loc_row_len, &nDevRows, 1, MPI_INT, MPI_SUM, 0, devWorld);
714  MPI_Reduce(&loc_Jz_sz, &J_allsz, 1, MPI_INT, MPI_SUM, 0, devWorld);
715  MPI_Reduce(&loc_A_sz, &all_NNZ, 1, MPI_INT, MPI_SUM, 0, devWorld);
716 
717  MPI_Barrier(devWorld);
718 
719  if (myDevWorldRank == 0)
720  {
721  all_I.SetSize(nDevRows+devWorldSize);
722  all_J.SetSize(J_allsz); all_J = 0.0;
723  all_A.SetSize(all_NNZ);
724  }
725 
726  GatherArray(loc_I, all_I, devWorldSize, devWorld);
727  GatherArray(loc_J, all_J, devWorldSize, devWorld);
728  GatherArray(loc_A, all_A, devWorldSize, devWorld);
729 
730  MPI_Barrier(devWorld);
731 
732  int local_nnz(0);
733  int64_t local_rows(0);
734 
735  if (myDevWorldRank == 0)
736  {
737  // A fix up step is needed for the array holding row data to remove extra
738  // zeros when consolidating team data.
739  Array<int> z_ind(devWorldSize+1);
740  int iter = 1;
741  while (iter < devWorldSize-1)
742  {
743  // Determine the indices of zeros in global all_I array
744  int counter = 0;
745  z_ind[counter] = counter;
746  counter++;
747  for (int idx=1; idx<all_I.Size()-1; idx++)
748  {
749  if (all_I[idx]==0)
750  {
751  z_ind[counter] = idx-1;
752  counter++;
753  }
754  }
755  z_ind[devWorldSize] = all_I.Size()-1;
756  // End of determining indices of zeros in global all_I Array
757 
758  // Bump all_I
759  for (int idx=z_ind[1]+1; idx < z_ind[2]; idx++)
760  {
761  all_I[idx] = all_I[idx-1] + (all_I[idx+1] - all_I[idx]);
762  }
763 
764  // Shift array after bump to remove unnecessary values in middle of
765  // array
766  for (int idx=z_ind[2]; idx < all_I.Size()-1; ++idx)
767  {
768  all_I[idx] = all_I[idx+1];
769  }
770  iter++;
771  }
772 
773  // LAST TIME THROUGH ARRAY
774  // Determine the indices of zeros in global row_ptr array
775  int counter = 0;
776  z_ind[counter] = counter;
777  counter++;
778  for (int idx=1; idx<all_I.Size()-1; idx++)
779  {
780  if (all_I[idx]==0)
781  {
782  z_ind[counter] = idx-1;
783  counter++;
784  }
785  }
786 
787  z_ind[devWorldSize] = all_I.Size()-1;
788  // End of determining indices of zeros in global all_I Array BUMP all_I
789  // one last time
790  for (int idx=z_ind[1]+1; idx < all_I.Size()-1; idx++)
791  {
792  all_I[idx] = all_I[idx-1] + (all_I[idx+1] - all_I[idx]);
793  }
794  local_nnz = all_I[all_I.Size()-devWorldSize];
795  local_rows = nDevRows;
796  }
797 
798  // Create row partition
799  mat_local_rows = local_rows; // class copy
800  Array<int64_t> rowPart;
801  if (gpuProc == 0)
802  {
803  rowPart.SetSize(gpuWorldSize+1); rowPart=0;
804 
805  MPI_Allgather(&local_rows, 1, MPI_INT64_T,
806  &rowPart.GetData()[1], 1, MPI_INT64_T,
807  gpuWorld);
808  MPI_Barrier(gpuWorld);
809 
810  // Fixup step
811  for (int i=1; i<rowPart.Size(); ++i)
812  {
813  rowPart[i] += rowPart[i-1];
814  }
815 
816  // Upload A matrix to AmgX
817  MPI_Barrier(gpuWorld);
818 
819  int nGlobalRows = A.M();
820  if (update_mat == false)
821  {
822  AMGX_distribution_handle dist;
823  AMGX_SAFE_CALL(AMGX_distribution_create(&dist, cfg));
824  AMGX_SAFE_CALL(AMGX_distribution_set_partition_data(dist,
825  AMGX_DIST_PARTITION_OFFSETS,
826  rowPart.GetData()));
827 
828  AMGX_SAFE_CALL(AMGX_matrix_upload_distributed(AmgXA, nGlobalRows,
829  local_rows, local_nnz,
830  1, 1, all_I.ReadWrite(),
831  all_J.Read(),
832  all_A.Read(),
833  nullptr, dist));
834 
835  AMGX_SAFE_CALL(AMGX_distribution_destroy(dist));
836  MPI_Barrier(gpuWorld);
837 
838  AMGX_SAFE_CALL(AMGX_solver_setup(solver, AmgXA));
839 
840  // Bind vectors to A
841  AMGX_SAFE_CALL(AMGX_vector_bind(AmgXP, AmgXA));
842  AMGX_SAFE_CALL(AMGX_vector_bind(AmgXRHS, AmgXA));
843  }
844  else
845  {
846  AMGX_SAFE_CALL(AMGX_matrix_replace_coefficients(AmgXA, nGlobalRows,
847  local_nnz, all_A, NULL));
848  }
849  }
850 }
851 
852 #endif
853 
855 {
856  height = op.Height();
857  width = op.Width();
858 
859  if (const SparseMatrix* Aptr =
860  dynamic_cast<const SparseMatrix*>(&op))
861  {
862  SetMatrix(*Aptr);
863  }
864 #ifdef MFEM_USE_MPI
865  else if (const HypreParMatrix* Aptr =
866  dynamic_cast<const HypreParMatrix*>(&op))
867  {
868  SetMatrix(*Aptr);
869  }
870 #endif
871  else
872  {
873  mfem_error("Unsupported Operator Type \n");
874  }
875 }
876 
878 {
879  if (const SparseMatrix* Aptr =
880  dynamic_cast<const SparseMatrix*>(&op))
881  {
882  SetMatrix(*Aptr, true);
883  }
884 #ifdef MFEM_USE_MPI
885  else if (const HypreParMatrix* Aptr =
886  dynamic_cast<const HypreParMatrix*>(&op))
887  {
888  SetMatrix(*Aptr, true);
889  }
890 #endif
891  else
892  {
893  mfem_error("Unsupported Operator Type \n");
894  }
895 }
896 
897 void AmgXSolver::Mult(const Vector& B, Vector& X) const
898 {
899  // Set initial guess to zero
900  X.UseDevice(true);
901  if (!iterative_mode) { X = 0.0; }
902 
903  // Mult for serial, and mpi-exclusive modes
904  if (mpi_gpu_mode != "mpi-teams")
905  {
906  AMGX_SAFE_CALL(AMGX_vector_upload(AmgXP, X.Size(), 1, X.ReadWrite()));
907  AMGX_SAFE_CALL(AMGX_vector_upload(AmgXRHS, B.Size(), 1, B.Read()));
908 
909  if (mpi_gpu_mode != "serial")
910  {
911 #ifdef MFEM_USE_MPI
912  MPI_Barrier(gpuWorld);
913 #endif
914  }
915 
916  AMGX_SAFE_CALL(AMGX_solver_solve(solver,AmgXRHS, AmgXP));
917 
918  AMGX_SOLVE_STATUS status;
919  AMGX_SAFE_CALL(AMGX_solver_get_status(solver, &status));
920  if (status != AMGX_SOLVE_SUCCESS && ConvergenceCheck)
921  {
922  if (status == AMGX_SOLVE_DIVERGED)
923  {
924  mfem_error("AmgX solver diverged \n");
925  }
926  else
927  {
928  mfem_error("AmgX solver failed to solve system \n");
929  }
930  }
931 
932  AMGX_SAFE_CALL(AMGX_vector_download(AmgXP, X.Write()));
933  return;
934  }
935 
936 #ifdef MFEM_USE_MPI
937  Vector all_X(mat_local_rows);
938  Vector all_B(mat_local_rows);
939  Array<int> Apart_X(devWorldSize);
940  Array<int> Adisp_X(devWorldSize);
941  Array<int> Apart_B(devWorldSize);
942  Array<int> Adisp_B(devWorldSize);
943 
944  GatherArray(X, all_X, devWorldSize, devWorld, Apart_X, Adisp_X);
945  GatherArray(B, all_B, devWorldSize, devWorld, Apart_B, Adisp_B);
946  MPI_Barrier(devWorld);
947 
948  if (gpuWorld != MPI_COMM_NULL)
949  {
950  AMGX_SAFE_CALL(AMGX_vector_upload(AmgXP, all_X.Size(), 1, all_X.ReadWrite()));
951  AMGX_SAFE_CALL(AMGX_vector_upload(AmgXRHS, all_B.Size(), 1, all_B.ReadWrite()));
952 
953  MPI_Barrier(gpuWorld);
954 
955  AMGX_SAFE_CALL(AMGX_solver_solve(solver,AmgXRHS, AmgXP));
956 
957  AMGX_SOLVE_STATUS status;
958  AMGX_SAFE_CALL(AMGX_solver_get_status(solver, &status));
959  if (status != AMGX_SOLVE_SUCCESS && amgxMode == SOLVER)
960  {
961  if (status == AMGX_SOLVE_DIVERGED)
962  {
963  mfem_error("AmgX solver diverged \n");
964  }
965  else
966  {
967  mfem_error("AmgX solver failed to solve system \n");
968  }
969  }
970 
971  AMGX_SAFE_CALL(AMGX_vector_download(AmgXP, all_X.Write()));
972  }
973 
974  ScatterArray(all_X, X, devWorldSize, devWorld, Apart_X, Adisp_X);
975 #endif
976 }
977 
979 {
980  int getIters;
981  AMGX_SAFE_CALL(AMGX_solver_get_iterations_number(solver, &getIters));
982  return getIters;
983 }
984 
986 {
987  // Check instance is initialized
988  if (! isInitialized || count < 1)
989  {
990  mfem_error("Error in AmgXSolver::Finalize(). \n"
991  "This AmgXWrapper has not been initialized. \n"
992  "Please initialize it before finalization.\n");
993  }
994 
995  // Only processes using GPU are required to destroy AmgX content
996 #ifdef MFEM_USE_MPI
997  if (gpuProc == 0 || mpi_gpu_mode == "serial")
998 #endif
999  {
1000  // Destroy solver instance
1001  AMGX_SAFE_CALL(AMGX_solver_destroy(solver));
1002 
1003  // Destroy matrix instance
1004  AMGX_SAFE_CALL(AMGX_matrix_destroy(AmgXA));
1005 
1006  // Destroy RHS and unknown vectors
1007  AMGX_SAFE_CALL(AMGX_vector_destroy(AmgXP));
1008  AMGX_SAFE_CALL(AMGX_vector_destroy(AmgXRHS));
1009 
1010  // Only the last instance need to destroy resource and finalizing AmgX
1011  if (count == 1)
1012  {
1013  AMGX_SAFE_CALL(AMGX_resources_destroy(rsrc));
1014  AMGX_SAFE_CALL(AMGX_config_destroy(cfg));
1015 
1016  AMGX_SAFE_CALL(AMGX_finalize_plugins());
1017  AMGX_SAFE_CALL(AMGX_finalize());
1018  }
1019  else
1020  {
1021  AMGX_SAFE_CALL(AMGX_config_destroy(cfg));
1022  }
1023 #ifdef MFEM_USE_MPI
1024  // destroy gpuWorld
1025  if (mpi_gpu_mode != "serial")
1026  {
1027  MPI_Comm_free(&gpuWorld);
1028  }
1029 #endif
1030  }
1031 
1032  // reset necessary variables in case users want to reuse the variable of
1033  // this instance for a new instance
1034 #ifdef MFEM_USE_MPI
1035  gpuProc = MPI_UNDEFINED;
1036  if (globalCpuWorld != MPI_COMM_NULL)
1037  {
1038  MPI_Comm_free(&globalCpuWorld);
1039  MPI_Comm_free(&localCpuWorld);
1040  MPI_Comm_free(&devWorld);
1041  }
1042 #endif
1043  // decrease the number of instances
1044  count -= 1;
1045 
1046  // change status
1047  isInitialized = false;
1048 }
1049 
1050 } // mfem namespace
1051 
1052 #endif
bool ConvergenceCheck
Flag to check for convergence.
Definition: amgxsolver.hpp:77
int Width() const
Get the width (size of input) of the Operator. Synonym with NumCols().
Definition: operator.hpp:72
int Size() const
Returns the size of the vector.
Definition: vector.hpp:190
bool iterative_mode
If true, use the second argument of Mult() as an initial guess.
Definition: operator.hpp:652
void InitMPITeams(const MPI_Comm &comm, const int nDevs)
Definition: amgxsolver.cpp:149
virtual void UseDevice(bool use_dev) const
Enable execution of Vector operations using the mfem::Device.
Definition: vector.hpp:108
void ReadParameters(const std::string config, CONFIG_SRC source)
Definition: amgxsolver.cpp:186
virtual void Mult(const Vector &b, Vector &x) const
Operator application: y=A(x).
Definition: amgxsolver.cpp:897
void source(const Vector &x, Vector &f)
Definition: ex25.cpp:622
void SetConvergenceCheck(bool setConvergenceCheck_=true)
Add a check for convergence after applying Mult.
Definition: amgxsolver.cpp:193
Data type sparse matrix.
Definition: sparsemat.hpp:41
int Height() const
Get the height (size of output) of the Operator. Synonym with NumRows().
Definition: operator.hpp:66
virtual double * Write(bool on_dev=true)
Shortcut for mfem::Write(vec.GetMemory(), vec.Size(), on_dev).
Definition: vector.hpp:434
void mfem_error(const char *msg)
Function called when an error is encountered. Used by the macros MFEM_ABORT, MFEM_ASSERT, MFEM_VERIFY.
Definition: error.cpp:153
void UpdateOperator(const Operator &op)
Definition: amgxsolver.cpp:877
virtual void SetOperator(const Operator &op)
Definition: amgxsolver.cpp:854
void DefaultParameters(const AMGX_MODE amgxMode_, const bool verbose)
Definition: amgxsolver.cpp:198
int height
Dimension of the output / number of rows in the matrix.
Definition: operator.hpp:27
virtual double * ReadWrite(bool on_dev=true)
Shortcut for mfem::ReadWrite(vec.GetMemory(), vec.Size(), on_dev).
Definition: vector.hpp:442
void InitExclusiveGPU(const MPI_Comm &comm)
Definition: amgxsolver.cpp:120
Vector data type.
Definition: vector.hpp:60
OutStream out(std::cout)
Global stream used by the library for standard output. Initially it uses the same std::streambuf as s...
Definition: globals.hpp:66
Abstract operator.
Definition: operator.hpp:24
Wrapper for hypre&#39;s ParCSR matrix class.
Definition: hypre.hpp:277
AMGX_MODE
Flags to configure AmgXSolver as a solver or preconditioner.
Definition: amgxsolver.hpp:74
virtual const double * Read(bool on_dev=true) const
Shortcut for mfem::Read(vec.GetMemory(), vec.Size(), on_dev).
Definition: vector.hpp:426
int width
Dimension of the input / number of columns in the matrix.
Definition: operator.hpp:28