MFEM  v4.2.0
Finite element discretization library
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Pages
amgxsolver.cpp
Go to the documentation of this file.
1 // Copyright (c) 2010-2020, Lawrence Livermore National Security, LLC. Produced
2 // at the Lawrence Livermore National Laboratory. All Rights reserved. See files
3 // LICENSE and NOTICE for details. LLNL-CODE-806117.
4 //
5 // This file is part of the MFEM library. For more information and source code
6 // availability visit https://mfem.org.
7 //
8 // MFEM is free software; you can redistribute it and/or modify it under the
9 // terms of the BSD-3 license. We welcome feedback and contributions, see file
10 // CONTRIBUTING.md for details.
11 
12 // Implementation of the MFEM wrapper for Nvidia's multigrid library, AmgX
13 //
14 // This work is partially based on:
15 //
16 // Pi-Yueh Chuang and Lorena A. Barba (2017).
17 // AmgXWrapper: An interface between PETSc and the NVIDIA AmgX library.
18 // J. Open Source Software, 2(16):280, doi:10.21105/joss.00280
19 //
20 // See https://github.com/barbagroup/AmgXWrapper.
21 
22 #include "../config/config.hpp"
23 #include "amgxsolver.hpp"
24 #ifdef MFEM_USE_AMGX
25 
26 namespace mfem
27 {
28 
29 int AmgXSolver::count = 0;
30 
31 AMGX_resources_handle AmgXSolver::rsrc = nullptr;
32 
33 AmgXSolver::AmgXSolver(const AMGX_MODE amgxMode_, const bool verbose)
34 {
35  amgxMode = amgxMode_;
36 
37  DefaultParameters(amgxMode, verbose);
38 
39  InitSerial();
40 }
41 
42 #ifdef MFEM_USE_MPI
43 
44 AmgXSolver::AmgXSolver(const MPI_Comm &comm,
45  const AMGX_MODE amgxMode_, const bool verbose)
46 {
47  std::string config;
48  amgxMode = amgxMode_;
49 
50  DefaultParameters(amgxMode, verbose);
51 
52  InitExclusiveGPU(comm);
53 }
54 
55 AmgXSolver::AmgXSolver(const MPI_Comm &comm, const int nDevs,
56  const AMGX_MODE amgxMode_, const bool verbose)
57 {
58  std::string config;
59  amgxMode = amgxMode_;
60 
61  DefaultParameters(amgxMode_, verbose);
62 
63  InitMPITeams(comm, nDevs);
64 }
65 
66 #endif
67 
69 {
70  if (isInitialized) { Finalize(); }
71 }
72 
74 {
75  count++;
76 
77  mpi_gpu_mode = "serial";
78 
79  AMGX_SAFE_CALL(AMGX_initialize());
80 
81  AMGX_SAFE_CALL(AMGX_initialize_plugins());
82 
83  AMGX_SAFE_CALL(AMGX_install_signal_handler());
84 
85  MFEM_VERIFY(configSrc != CONFIG_SRC::UNDEFINED,
86  "AmgX configuration is not defined \n");
87 
88  if (configSrc == CONFIG_SRC::EXTERNAL)
89  {
90  AMGX_SAFE_CALL(AMGX_config_create_from_file(&cfg, amgx_config.c_str()));
91  }
92  else
93  {
94  AMGX_SAFE_CALL(AMGX_config_create(&cfg, amgx_config.c_str()));
95  }
96 
97  AMGX_resources_create_simple(&rsrc, cfg);
98  AMGX_solver_create(&solver, rsrc, precision_mode, cfg);
99  AMGX_matrix_create(&AmgXA, rsrc, precision_mode);
100  AMGX_vector_create(&AmgXP, rsrc, precision_mode);
101  AMGX_vector_create(&AmgXRHS, rsrc, precision_mode);
102 
103  isInitialized = true;
104 }
105 
106 #ifdef MFEM_USE_MPI
107 
108 void AmgXSolver::InitExclusiveGPU(const MPI_Comm &comm)
109 {
110  // If this instance has already been initialized, skip
111  if (isInitialized)
112  {
113  mfem_error("This AmgXSolver instance has been initialized on this process.");
114  }
115 
116  // Note that every MPI rank may talk to a GPU
117  mpi_gpu_mode = "mpi-gpu-exclusive";
118  gpuProc = 0;
119 
120  // Increment number of AmgX instances
121  count++;
122 
123  MPI_Comm_dup(comm, &gpuWorld);
124  MPI_Comm_size(gpuWorld, &gpuWorldSize);
125  MPI_Comm_rank(gpuWorld, &myGpuWorldRank);
126 
127  // Each rank will only see 1 device call it device 0
128  nDevs = 1, devID = 0;
129 
130  InitAmgX();
131 
132  isInitialized = true;
133 }
134 
135 // Initialize for MPI ranks > GPUs, all devices are visible to all of the MPI
136 // ranks
137 void AmgXSolver::InitMPITeams(const MPI_Comm &comm,
138  const int nDevs)
139 {
140  // If this instance has already been initialized, skip
141  if (isInitialized)
142  {
143  mfem_error("This AmgXSolver instance has been initialized on this process.");
144  }
145 
146  mpi_gpu_mode = "mpi-teams";
147 
148  // Increment number of AmgX instances
149  count++;
150 
151  // Get the name of this node
152  int len;
153  char name[MPI_MAX_PROCESSOR_NAME];
154  MPI_Get_processor_name(name, &len);
155  nodeName = name;
156  int globalcommrank;
157 
158  MPI_Comm_rank(comm, &globalcommrank);
159 
160  // Initialize communicators and corresponding information
161  InitMPIcomms(comm, nDevs);
162 
163  // Only processes in gpuWorld are required to initialize AmgX
164  if (gpuProc == 0)
165  {
166  InitAmgX();
167  }
168 
169  isInitialized = true;
170 }
171 
172 #endif
173 
174 void AmgXSolver::ReadParameters(const std::string config,
175  const CONFIG_SRC source)
176 {
177  amgx_config = config;
178  configSrc = source;
179 }
180 
182  const bool verbose)
183 {
184  amgxMode = amgxMode_;
185 
186  configSrc = INTERNAL;
187 
188  if (amgxMode == AMGX_MODE::PRECONDITIONER)
189  {
190  amgx_config = "{\n"
191  " \"config_version\": 2, \n"
192  " \"solver\": { \n"
193  " \"solver\": \"AMG\", \n"
194  " \"presweeps\": 1, \n"
195  " \"postsweeps\": 1, \n"
196  " \"interpolator\": \"D2\", \n"
197  " \"max_iters\": 2, \n"
198  " \"convergence\": \"ABSOLUTE\", \n"
199  " \"cycle\": \"V\"";
200  if (verbose)
201  {
202  amgx_config = amgx_config + ",\n"
203  " \"obtain_timings\": 1, \n"
204  " \"monitor_residual\": 1, \n"
205  " \"print_grid_stats\": 1, \n"
206  " \"print_solve_stats\": 1 \n";
207  }
208  else
209  {
210  amgx_config = amgx_config + "\n";
211  }
212  amgx_config = amgx_config + " }\n" + "}\n";
213  }
214  else if (amgxMode == AMGX_MODE::SOLVER)
215  {
216  amgx_config = "{ \n"
217  " \"config_version\": 2, \n"
218  " \"solver\": { \n"
219  " \"preconditioner\": { \n"
220  " \"solver\": \"AMG\", \n"
221  " \"smoother\": { \n"
222  " \"scope\": \"jacobi\", \n"
223  " \"solver\": \"BLOCK_JACOBI\", \n"
224  " \"relaxation_factor\": 0.7 \n"
225  " }, \n"
226  " \"presweeps\": 1, \n"
227  " \"interpolator\": \"D2\", \n"
228  " \"max_row_sum\" : 0.9, \n"
229  " \"strength_threshold\" : 0.25, \n"
230  " \"max_iters\": 2, \n"
231  " \"scope\": \"amg\", \n"
232  " \"max_levels\": 100, \n"
233  " \"cycle\": \"V\", \n"
234  " \"postsweeps\": 1 \n"
235  " }, \n"
236  " \"solver\": \"PCG\", \n"
237  " \"max_iters\": 100, \n"
238  " \"convergence\": \"RELATIVE_MAX\", \n"
239  " \"scope\": \"main\", \n"
240  " \"tolerance\": 1e-12, \n"
241  " \"norm\": \"L2\" ";
242  if (verbose)
243  {
244  amgx_config = amgx_config + ", \n"
245  " \"obtain_timings\": 1, \n"
246  " \"monitor_residual\": 1, \n"
247  " \"print_grid_stats\": 1, \n"
248  " \"print_solve_stats\": 1 \n";
249  }
250  else
251  {
252  amgx_config = amgx_config + "\n";
253  }
254  amgx_config = amgx_config + " } \n" + "} \n";
255  }
256  else
257  {
258  mfem_error("AmgX mode not supported \n");
259  }
260 }
261 
262 // Sets up AmgX library for MPI builds
263 #ifdef MFEM_USE_MPI
264 void AmgXSolver::InitAmgX()
265 {
266  // Set up once
267  if (count == 1)
268  {
269  AMGX_SAFE_CALL(AMGX_initialize());
270 
271  AMGX_SAFE_CALL(AMGX_initialize_plugins());
272 
273  AMGX_SAFE_CALL(AMGX_install_signal_handler());
274 
275  AMGX_SAFE_CALL(AMGX_register_print_callback(
276  [](const char *msg, int length)->void
277  {
278  int irank; MPI_Comm_rank(MPI_COMM_WORLD, &irank);
279  if (irank == 0) { mfem::out<<msg;} }));
280  }
281 
282  MFEM_VERIFY(configSrc != CONFIG_SRC::UNDEFINED,
283  "AmgX configuration is not defined \n");
284 
285  if (configSrc == CONFIG_SRC::EXTERNAL)
286  {
287  AMGX_SAFE_CALL(AMGX_config_create_from_file(&cfg, amgx_config.c_str()));
288  }
289  else
290  {
291  AMGX_SAFE_CALL(AMGX_config_create(&cfg, amgx_config.c_str()));
292  }
293 
294  // Let AmgX handle returned error codes internally
295  AMGX_SAFE_CALL(AMGX_config_add_parameters(&cfg, "exception_handling=1"));
296 
297  // Create an AmgX resource object, only the first instance needs to create
298  // the resource object.
299  if (count == 1) { AMGX_resources_create(&rsrc, cfg, &gpuWorld, 1, &devID); }
300 
301  // Create AmgX vector object for unknowns and RHS
302  AMGX_vector_create(&AmgXP, rsrc, precision_mode);
303  AMGX_vector_create(&AmgXRHS, rsrc, precision_mode);
304 
305  // Create AmgX matrix object for unknowns and RHS
306  AMGX_matrix_create(&AmgXA, rsrc, precision_mode);
307 
308  // Create an AmgX solver object
309  AMGX_solver_create(&solver, rsrc, precision_mode, cfg);
310 
311  // Obtain the default number of rings based on current configuration
312  AMGX_config_get_default_number_of_rings(cfg, &ring);
313 }
314 
315 // Groups MPI ranks into teams and assigns the roots to talk to GPUs
316 void AmgXSolver::InitMPIcomms(const MPI_Comm &comm, const int nDevs)
317 {
318  // Duplicate the global communicator
319  MPI_Comm_dup(comm, &globalCpuWorld);
320  MPI_Comm_set_name(globalCpuWorld, "globalCpuWorld");
321 
322  // Get size and rank for global communicator
323  MPI_Comm_size(globalCpuWorld, &globalSize);
324  MPI_Comm_rank(globalCpuWorld, &myGlobalRank);
325 
326  // Get the communicator for processors on the same node (local world)
327  MPI_Comm_split_type(globalCpuWorld,
328  MPI_COMM_TYPE_SHARED, 0, MPI_INFO_NULL, &localCpuWorld);
329  MPI_Comm_set_name(localCpuWorld, "localCpuWorld");
330 
331  // Get size and rank for local communicator
332  MPI_Comm_size(localCpuWorld, &localSize);
333  MPI_Comm_rank(localCpuWorld, &myLocalRank);
334 
335  // Set up corresponding ID of the device used by each local process
336  SetDeviceIDs(nDevs);
337 
338  MPI_Barrier(globalCpuWorld);
339 
340  // Split the global world into a world involved in AmgX and a null world
341  MPI_Comm_split(globalCpuWorld, gpuProc, 0, &gpuWorld);
342 
343  // Get size and rank for the communicator corresponding to gpuWorld
344  if (gpuWorld != MPI_COMM_NULL)
345  {
346  MPI_Comm_set_name(gpuWorld, "gpuWorld");
347  MPI_Comm_size(gpuWorld, &gpuWorldSize);
348  MPI_Comm_rank(gpuWorld, &myGpuWorldRank);
349  }
350  else // for those that will not communicate with the GPU
351  {
352  gpuWorldSize = MPI_UNDEFINED;
353  myGpuWorldRank = MPI_UNDEFINED;
354  }
355 
356  // Split local world into worlds corresponding to each CUDA device
357  MPI_Comm_split(localCpuWorld, devID, 0, &devWorld);
358  MPI_Comm_set_name(devWorld, "devWorld");
359 
360  // Get size and rank for the communicator corresponding to myWorld
361  MPI_Comm_size(devWorld, &devWorldSize);
362  MPI_Comm_rank(devWorld, &myDevWorldRank);
363 
364  MPI_Barrier(globalCpuWorld);
365 }
366 
367 // Determine MPI teams based on available devices
368 void AmgXSolver::SetDeviceIDs(const int nDevs)
369 {
370  // Set the ID of device that each local process will use
371  if (nDevs == localSize) // # of the devices and local process are the same
372  {
373  devID = myLocalRank;
374  gpuProc = 0;
375  }
376  else if (nDevs > localSize) // there are more devices than processes
377  {
378  MFEM_WARNING("CUDA devices on the node " << nodeName.c_str() <<
379  " are more than the MPI processes launched. Only "<<
380  nDevs << " devices will be used.\n");
381  devID = myLocalRank;
382  gpuProc = 0;
383  }
384  else // in case there are more ranks than devices
385  {
386  int nBasic = localSize / nDevs,
387  nRemain = localSize % nDevs;
388 
389  if (myLocalRank < (nBasic+1)*nRemain)
390  {
391  devID = myLocalRank / (nBasic + 1);
392  if (myLocalRank % (nBasic + 1) == 0) { gpuProc = 0; }
393  }
394  else
395  {
396  devID = (myLocalRank - (nBasic+1)*nRemain) / nBasic + nRemain;
397  if ((myLocalRank - (nBasic+1)*nRemain) % nBasic == 0) { gpuProc = 0; }
398  }
399  }
400 }
401 
402 void AmgXSolver::GatherArray(const Array<double> &inArr, Array<double> &outArr,
403  const int mpiTeamSz, const MPI_Comm &mpiTeamComm) const
404 {
405  // Calculate number of elements to be collected from each process
406  Array<int> Apart(mpiTeamSz);
407  int locAsz = inArr.Size();
408  MPI_Gather(&locAsz, 1, MPI_INT,
409  Apart.HostWrite(),1, MPI_INT,0,mpiTeamComm);
410 
411  MPI_Barrier(mpiTeamComm);
412 
413  // Determine stride for process (to be used by root)
414  Array<int> Adisp(mpiTeamSz);
415  int myid; MPI_Comm_rank(mpiTeamComm, &myid);
416  if (myid == 0)
417  {
418  Adisp[0] = 0;
419  for (int i=1; i<mpiTeamSz; ++i)
420  {
421  Adisp[i] = Adisp[i-1] + Apart[i-1];
422  }
423  }
424 
425  MPI_Gatherv(inArr.HostRead(), inArr.Size(), MPI_DOUBLE,
426  outArr.HostWrite(), Apart.HostRead(), Adisp.HostRead(),
427  MPI_DOUBLE, 0, mpiTeamComm);
428 }
429 
430 void AmgXSolver::GatherArray(const Vector &inArr, Vector &outArr,
431  const int mpiTeamSz, const MPI_Comm &mpiTeamComm) const
432 {
433  // Calculate number of elements to be collected from each process
434  Array<int> Apart(mpiTeamSz);
435  int locAsz = inArr.Size();
436  MPI_Gather(&locAsz, 1, MPI_INT,
437  Apart.HostWrite(),1, MPI_INT,0,mpiTeamComm);
438 
439  MPI_Barrier(mpiTeamComm);
440 
441  // Determine stride for process (to be used by root)
442  Array<int> Adisp(mpiTeamSz);
443  int myid; MPI_Comm_rank(mpiTeamComm, &myid);
444  if (myid == 0)
445  {
446  Adisp[0] = 0;
447  for (int i=1; i<mpiTeamSz; ++i)
448  {
449  Adisp[i] = Adisp[i-1] + Apart[i-1];
450  }
451  }
452 
453  MPI_Gatherv(inArr.HostRead(), inArr.Size(), MPI_DOUBLE,
454  outArr.HostWrite(), Apart.HostRead(), Adisp.HostRead(),
455  MPI_DOUBLE, 0, mpiTeamComm);
456 }
457 
458 void AmgXSolver::GatherArray(const Array<int> &inArr, Array<int> &outArr,
459  const int mpiTeamSz, const MPI_Comm &mpiTeamComm) const
460 {
461  // Calculate number of elements to be collected from each process
462  Array<int> Apart(mpiTeamSz);
463  int locAsz = inArr.Size();
464  MPI_Gather(&locAsz, 1, MPI_INT,
465  Apart.GetData(),1, MPI_INT,0,mpiTeamComm);
466 
467  MPI_Barrier(mpiTeamComm);
468 
469  // Determine stride for process (to be used by root)
470  Array<int> Adisp(mpiTeamSz);
471  int myid; MPI_Comm_rank(mpiTeamComm, &myid);
472  if (myid == 0)
473  {
474  Adisp[0] = 0;
475  for (int i=1; i<mpiTeamSz; ++i)
476  {
477  Adisp[i] = Adisp[i-1] + Apart[i-1];
478  }
479  }
480 
481  MPI_Gatherv(inArr.HostRead(), inArr.Size(), MPI_INT,
482  outArr.HostWrite(), Apart.HostRead(), Adisp.HostRead(),
483  MPI_INT, 0, mpiTeamComm);
484 }
485 
486 
487 void AmgXSolver::GatherArray(const Array<int64_t> &inArr,
488  Array<int64_t> &outArr,
489  const int mpiTeamSz, const MPI_Comm &mpiTeamComm) const
490 {
491  // Calculate number of elements to be collected from each process
492  Array<int> Apart(mpiTeamSz);
493  int locAsz = inArr.Size();
494  MPI_Gather(&locAsz, 1, MPI_INT,
495  Apart.GetData(),1, MPI_INT,0,mpiTeamComm);
496 
497  MPI_Barrier(mpiTeamComm);
498 
499  // Determine stride for process
500  Array<int> Adisp(mpiTeamSz);
501  int myid; MPI_Comm_rank(mpiTeamComm, &myid);
502  if (myid == 0)
503  {
504  Adisp[0] = 0;
505  for (int i=1; i<mpiTeamSz; ++i)
506  {
507  Adisp[i] = Adisp[i-1] + Apart[i-1];
508  }
509  }
510 
511  MPI_Gatherv(inArr.HostRead(), inArr.Size(), MPI_INT64_T,
512  outArr.HostWrite(), Apart.HostRead(), Adisp.HostRead(),
513  MPI_INT64_T, 0, mpiTeamComm);
514 
515  MPI_Barrier(mpiTeamComm);
516 }
517 
518 void AmgXSolver::GatherArray(const Vector &inArr, Vector &outArr,
519  const int mpiTeamSz, const MPI_Comm &mpiTeamComm,
520  Array<int> &Apart, Array<int> &Adisp) const
521 {
522  // Calculate number of elements to be collected from each process
523  int locAsz = inArr.Size();
524  MPI_Allgather(&locAsz, 1, MPI_INT,
525  Apart.HostWrite(),1, MPI_INT, mpiTeamComm);
526 
527  MPI_Barrier(mpiTeamComm);
528 
529  // Determine stride for process
530  Adisp[0] = 0;
531  for (int i=1; i<mpiTeamSz; ++i)
532  {
533  Adisp[i] = Adisp[i-1] + Apart[i-1];
534  }
535 
536  MPI_Gatherv(inArr.HostRead(), inArr.Size(), MPI_DOUBLE,
537  outArr.HostWrite(), Apart.HostRead(), Adisp.HostRead(),
538  MPI_DOUBLE, 0, mpiTeamComm);
539 }
540 
541 void AmgXSolver::ScatterArray(const Vector &inArr, Vector &outArr,
542  const int mpiTeamSz, const MPI_Comm &mpiTeamComm,
543  Array<int> &Apart, Array<int> &Adisp) const
544 {
545  MPI_Scatterv(inArr.HostRead(),Apart.HostRead(),Adisp.HostRead(),
546  MPI_DOUBLE,outArr.HostWrite(),outArr.Size(),
547  MPI_DOUBLE, 0, mpiTeamComm);
548 }
549 #endif
550 
551 void AmgXSolver::SetMatrix(const SparseMatrix &in_A, const bool update_mat)
552 {
553  if (update_mat == false)
554  {
555  AMGX_matrix_upload_all(AmgXA, in_A.Height(),
556  in_A.NumNonZeroElems(),
557  1, 1,
558  in_A.ReadI(),
559  in_A.ReadJ(),
560  in_A.ReadData(), NULL);
561 
562  AMGX_solver_setup(solver, AmgXA);
563  AMGX_vector_bind(AmgXP, AmgXA);
564  AMGX_vector_bind(AmgXRHS, AmgXA);
565  }
566  else
567  {
568  AMGX_matrix_replace_coefficients(AmgXA,
569  in_A.Height(),
570  in_A.NumNonZeroElems(),
571  in_A.ReadData(), NULL);
572  }
573 }
574 
575 #ifdef MFEM_USE_MPI
576 
577 void AmgXSolver::SetMatrix(const HypreParMatrix &A, const bool update_mat)
578 {
579  // Require hypre >= 2.16.
580 #if MFEM_HYPRE_VERSION < 21600
581  mfem_error("Hypre version 2.16+ is required when using AmgX \n");
582 #endif
583 
584  hypre_ParCSRMatrix * A_ptr =
585  (hypre_ParCSRMatrix *)const_cast<HypreParMatrix&>(A);
586 
587  hypre_CSRMatrix *A_csr = hypre_MergeDiagAndOffd(A_ptr);
588 
589  Array<double> loc_A(A_csr->data, (int)A_csr->num_nonzeros);
590  const Array<int> loc_I(A_csr->i, (int)A_csr->num_rows+1);
591 
592  // Column index must be int64_t so we must promote here
593  Array<int64_t> loc_J((int)A_csr->num_nonzeros);
594  for (int i=0; i<A_csr->num_nonzeros; ++i)
595  {
596  loc_J[i] = A_csr->big_j[i];
597  }
598 
599  // Assumes one GPU per MPI rank
600  if (mpi_gpu_mode=="mpi-gpu-exclusive")
601  {
602  return SetMatrixMPIGPUExclusive(A, loc_A, loc_I, loc_J, update_mat);
603  }
604 
605  // Assumes teams of MPI ranks are sharing a GPU
606  if (mpi_gpu_mode == "mpi-teams")
607  {
608  return SetMatrixMPITeams(A, loc_A, loc_I, loc_J, update_mat);
609  }
610 
611  mfem_error("Unsupported MPI_GPU combination \n");
612 }
613 
614 void AmgXSolver::SetMatrixMPIGPUExclusive(const HypreParMatrix &A,
615  const Array<double> &loc_A,
616  const Array<int> &loc_I,
617  const Array<int64_t> &loc_J,
618  const bool update_mat)
619 {
620  // Create a vector of offsets describing matrix row partitions
621  Array<int64_t> rowPart(gpuWorldSize+1); rowPart = 0.0;
622 
623  int64_t myStart = A.GetRowStarts()[0];
624 
625  MPI_Allgather(&myStart, 1, MPI_INT64_T,
626  rowPart.GetData(),1, MPI_INT64_T
627  ,gpuWorld);
628  MPI_Barrier(gpuWorld);
629 
630  rowPart[gpuWorldSize] = A.M();
631 
632  const int nGlobalRows = A.M();
633  const int local_rows = loc_I.Size()-1;
634  const int num_nnz = loc_I[local_rows];
635 
636  if (update_mat == false)
637  {
638  AMGX_distribution_handle dist;
639  AMGX_distribution_create(&dist, cfg);
640  AMGX_distribution_set_partition_data(dist, AMGX_DIST_PARTITION_OFFSETS,
641  rowPart.GetData());
642 
643  AMGX_matrix_upload_distributed(AmgXA, nGlobalRows, local_rows,
644  num_nnz, 1, 1, loc_I.Read(),
645  loc_J.Read(), loc_A.Read(),
646  NULL, dist);
647 
648  AMGX_distribution_destroy(dist);
649 
650  MPI_Barrier(gpuWorld);
651 
652  AMGX_solver_setup(solver, AmgXA);
653 
654  AMGX_vector_bind(AmgXP, AmgXA);
655  AMGX_vector_bind(AmgXRHS, AmgXA);
656  }
657  else
658  {
659  AMGX_matrix_replace_coefficients(AmgXA,nGlobalRows,num_nnz,loc_A, NULL);
660  }
661 }
662 
663 void AmgXSolver::SetMatrixMPITeams(const HypreParMatrix &A,
664  const Array<double> &loc_A,
665  const Array<int> &loc_I,
666  const Array<int64_t> &loc_J,
667  const bool update_mat)
668 {
669  // The following arrays hold the consolidated diagonal + off-diagonal matrix
670  // data
671  Array<int> all_I;
672  Array<int64_t> all_J;
673  Array<double> all_A;
674 
675  // Determine array sizes
676  int J_allsz(0), all_NNZ(0), nDevRows(0);
677  const int loc_row_len = std::abs(A.RowPart()[1] -
678  A.RowPart()[0]); // end of row partition
679  const int loc_Jz_sz = loc_J.Size();
680  const int loc_A_sz = loc_A.Size();
681 
682  MPI_Reduce(&loc_row_len, &nDevRows, 1, MPI_INT, MPI_SUM, 0, devWorld);
683  MPI_Reduce(&loc_Jz_sz, &J_allsz, 1, MPI_INT, MPI_SUM, 0, devWorld);
684  MPI_Reduce(&loc_A_sz, &all_NNZ, 1, MPI_INT, MPI_SUM, 0, devWorld);
685 
686  MPI_Barrier(devWorld);
687 
688  if (myDevWorldRank == 0)
689  {
690  all_I.SetSize(nDevRows+devWorldSize);
691  all_J.SetSize(J_allsz); all_J = 0.0;
692  all_A.SetSize(all_NNZ);
693  }
694 
695  GatherArray(loc_I, all_I, devWorldSize, devWorld);
696  GatherArray(loc_J, all_J, devWorldSize, devWorld);
697  GatherArray(loc_A, all_A, devWorldSize, devWorld);
698 
699  MPI_Barrier(devWorld);
700 
701  int local_nnz(0);
702  int64_t local_rows(0);
703 
704  if (myDevWorldRank == 0)
705  {
706  // A fix up step is needed for the array holding row data to remove extra
707  // zeros when consolidating team data.
708  Array<int> z_ind(devWorldSize+1);
709  int iter = 1;
710  while (iter < devWorldSize-1)
711  {
712  // Determine the indices of zeros in global all_I array
713  int counter = 0;
714  z_ind[counter] = counter;
715  counter++;
716  for (int idx=1; idx<all_I.Size()-1; idx++)
717  {
718  if (all_I[idx]==0)
719  {
720  z_ind[counter] = idx-1;
721  counter++;
722  }
723  }
724  z_ind[devWorldSize] = all_I.Size()-1;
725  // End of determining indices of zeros in global all_I Array
726 
727  // Bump all_I
728  for (int idx=z_ind[1]+1; idx < z_ind[2]; idx++)
729  {
730  all_I[idx] = all_I[idx-1] + (all_I[idx+1] - all_I[idx]);
731  }
732 
733  // Shift array after bump to remove unnecessary values in middle of
734  // array
735  for (int idx=z_ind[2]; idx < all_I.Size()-1; ++idx)
736  {
737  all_I[idx] = all_I[idx+1];
738  }
739  iter++;
740  }
741 
742  // LAST TIME THROUGH ARRAY
743  // Determine the indices of zeros in global row_ptr array
744  int counter = 0;
745  z_ind[counter] = counter;
746  counter++;
747  for (int idx=1; idx<all_I.Size()-1; idx++)
748  {
749  if (all_I[idx]==0)
750  {
751  z_ind[counter] = idx-1;
752  counter++;
753  }
754  }
755 
756  z_ind[devWorldSize] = all_I.Size()-1;
757  // End of determining indices of zeros in global all_I Array BUMP all_I
758  // one last time
759  for (int idx=z_ind[1]+1; idx < all_I.Size()-1; idx++)
760  {
761  all_I[idx] = all_I[idx-1] + (all_I[idx+1] - all_I[idx]);
762  }
763  local_nnz = all_I[all_I.Size()-devWorldSize];
764  local_rows = nDevRows;
765  }
766 
767  // Create row partition
768  mat_local_rows = local_rows; // class copy
769  Array<int64_t> rowPart;
770  if (gpuProc == 0)
771  {
772  rowPart.SetSize(gpuWorldSize+1); rowPart=0;
773 
774  MPI_Allgather(&local_rows, 1, MPI_INT64_T,
775  &rowPart.GetData()[1], 1, MPI_INT64_T,
776  gpuWorld);
777  MPI_Barrier(gpuWorld);
778 
779  // Fixup step
780  for (int i=1; i<rowPart.Size(); ++i)
781  {
782  rowPart[i] += rowPart[i-1];
783  }
784 
785  // Upload A matrix to AmgX
786  MPI_Barrier(gpuWorld);
787 
788  int nGlobalRows = A.M();
789  if (update_mat == false)
790  {
791  AMGX_distribution_handle dist;
792  AMGX_distribution_create(&dist, cfg);
793  AMGX_distribution_set_partition_data(dist, AMGX_DIST_PARTITION_OFFSETS,
794  rowPart.GetData());
795 
796  AMGX_matrix_upload_distributed(AmgXA, nGlobalRows, local_rows,
797  local_nnz,
798  1, 1, all_I.ReadWrite(),
799  all_J.Read(),
800  all_A.Read(),
801  nullptr, dist);
802 
803  AMGX_distribution_destroy(dist);
804  MPI_Barrier(gpuWorld);
805 
806  AMGX_solver_setup(solver, AmgXA);
807 
808  // Bind vectors to A
809  AMGX_vector_bind(AmgXP, AmgXA);
810  AMGX_vector_bind(AmgXRHS, AmgXA);
811  }
812  else
813  {
814  AMGX_matrix_replace_coefficients(AmgXA,nGlobalRows,local_nnz,all_A,NULL);
815  }
816  }
817 }
818 
819 #endif
820 
822 {
823  height = op.Height();
824  width = op.Width();
825 
826  if (const SparseMatrix* Aptr =
827  dynamic_cast<const SparseMatrix*>(&op))
828  {
829  SetMatrix(*Aptr);
830  }
831 #ifdef MFEM_USE_MPI
832  else if (const HypreParMatrix* Aptr =
833  dynamic_cast<const HypreParMatrix*>(&op))
834  {
835  SetMatrix(*Aptr);
836  }
837 #endif
838  else
839  {
840  mfem_error("Unsupported Operator Type \n");
841  }
842 }
843 
845 {
846  if (const SparseMatrix* Aptr =
847  dynamic_cast<const SparseMatrix*>(&op))
848  {
849  SetMatrix(*Aptr, true);
850  }
851 #ifdef MFEM_USE_MPI
852  else if (const HypreParMatrix* Aptr =
853  dynamic_cast<const HypreParMatrix*>(&op))
854  {
855  SetMatrix(*Aptr, true);
856  }
857 #endif
858  else
859  {
860  mfem_error("Unsupported Operator Type \n");
861  }
862 }
863 
864 void AmgXSolver::Mult(const Vector& B, Vector& X) const
865 {
866  // Set initial guess to zero
867  X.UseDevice(true);
868  X = 0.0;
869 
870  // Mult for serial, and mpi-exclusive modes
871  if (mpi_gpu_mode != "mpi-teams")
872  {
873  AMGX_vector_upload(AmgXP, X.Size(), 1, X.ReadWrite());
874  AMGX_vector_upload(AmgXRHS, B.Size(), 1, B.Read());
875 
876  if (mpi_gpu_mode != "serial")
877  {
878 #ifdef MFEM_USE_MPI
879  MPI_Barrier(gpuWorld);
880 #endif
881  }
882 
883  AMGX_solver_solve(solver,AmgXRHS, AmgXP);
884 
885  AMGX_SOLVE_STATUS status;
886  AMGX_solver_get_status(solver, &status);
887  if (status != AMGX_SOLVE_SUCCESS && amgxMode == SOLVER)
888  {
889  if (status == AMGX_SOLVE_DIVERGED)
890  {
891  mfem_error("AmgX solver failed to solve system \n");
892  }
893  else
894  {
895  mfem_error("AmgX solver diverged \n");
896  }
897  }
898 
899  AMGX_vector_download(AmgXP, X.Write());
900  return;
901  }
902 
903 #ifdef MFEM_USE_MPI
904  Vector all_X(mat_local_rows);
905  Vector all_B(mat_local_rows);
906  Array<int> Apart_X(devWorldSize);
907  Array<int> Adisp_X(devWorldSize);
908  Array<int> Apart_B(devWorldSize);
909  Array<int> Adisp_B(devWorldSize);
910 
911  GatherArray(X, all_X, devWorldSize, devWorld, Apart_X, Adisp_X);
912  GatherArray(B, all_B, devWorldSize, devWorld, Apart_B, Adisp_B);
913  MPI_Barrier(devWorld);
914 
915  if (gpuWorld != MPI_COMM_NULL)
916  {
917  AMGX_vector_upload(AmgXP, all_X.Size(), 1, all_X.ReadWrite());
918  AMGX_vector_upload(AmgXRHS, all_B.Size(), 1, all_B.ReadWrite());
919 
920  MPI_Barrier(gpuWorld);
921 
922  AMGX_solver_solve(solver,AmgXRHS, AmgXP);
923 
924  AMGX_SOLVE_STATUS status;
925  AMGX_solver_get_status(solver, &status);
926  if (status != AMGX_SOLVE_SUCCESS && amgxMode == SOLVER)
927  {
928  if (status == AMGX_SOLVE_DIVERGED)
929  {
930  mfem_error("AmgX solver failed to solve system \n");
931  }
932  else
933  {
934  mfem_error("AmgX solver diverged \n");
935  }
936  }
937 
938  AMGX_vector_download(AmgXP, all_X.Write());
939  }
940 
941  ScatterArray(all_X, X, devWorldSize, devWorld, Apart_X, Adisp_X);
942 #endif
943 }
944 
946 {
947  int getIters;
948  AMGX_solver_get_iterations_number(solver, &getIters);
949  return getIters;
950 }
951 
953 {
954  // Check instance is initialized
955  if (! isInitialized || count < 1)
956  {
957  mfem_error("Error in AmgXSolver::Finalize(). \n"
958  "This AmgXWrapper has not been initialized. \n"
959  "Please initialize it before finalization.\n");
960  }
961 
962  // Only processes using GPU are required to destroy AmgX content
963 #ifdef MFEM_USE_MPI
964  if (gpuProc == 0 || mpi_gpu_mode == "serial")
965 #endif
966  {
967  // Destroy solver instance
968  AMGX_solver_destroy(solver);
969 
970  // Destroy matrix instance
971  AMGX_matrix_destroy(AmgXA);
972 
973  // Destroy RHS and unknown vectors
974  AMGX_vector_destroy(AmgXP);
975  AMGX_vector_destroy(AmgXRHS);
976 
977  // Only the last instance need to destroy resource and finalizing AmgX
978  if (count == 1)
979  {
980  AMGX_resources_destroy(rsrc);
981  AMGX_SAFE_CALL(AMGX_config_destroy(cfg));
982 
983  AMGX_SAFE_CALL(AMGX_finalize_plugins());
984  AMGX_SAFE_CALL(AMGX_finalize());
985  }
986  else
987  {
988  AMGX_config_destroy(cfg);
989  }
990 #ifdef MFEM_USE_MPI
991  // destroy gpuWorld
992  if (mpi_gpu_mode != "serial")
993  {
994  MPI_Comm_free(&gpuWorld);
995  }
996 #endif
997  }
998 
999  // re-set necessary variables in case users want to reuse the variable of
1000  // this instance for a new instance
1001 #ifdef MFEM_USE_MPI
1002  gpuProc = MPI_UNDEFINED;
1003  if (globalCpuWorld != MPI_COMM_NULL)
1004  {
1005  MPI_Comm_free(&globalCpuWorld);
1006  MPI_Comm_free(&localCpuWorld);
1007  MPI_Comm_free(&devWorld);
1008  }
1009 #endif
1010  // decrease the number of instances
1011  count -= 1;
1012 
1013  // change status
1014  isInitialized = false;
1015 }
1016 
1017 } // mfem namespace
1018 
1019 #endif
AmgXSolver()=default
int Width() const
Get the width (size of input) of the Operator. Synonym with NumCols().
Definition: operator.hpp:71
void UseDevice(bool use_dev) const
Enable execution of Vector operations using the mfem::Device.
Definition: vector.hpp:89
int Size() const
Returns the size of the vector.
Definition: vector.hpp:160
void InitMPITeams(const MPI_Comm &comm, const int nDevs)
Definition: amgxsolver.cpp:137
double * Write(bool on_dev=true)
Shortcut for mfem::Write(vec.GetMemory(), vec.Size(), on_dev).
Definition: vector.hpp:380
void ReadParameters(const std::string config, CONFIG_SRC source)
Definition: amgxsolver.cpp:174
virtual void Mult(const Vector &b, Vector &x) const
Operator application: y=A(x).
Definition: amgxsolver.cpp:864
void source(const Vector &x, Vector &f)
Definition: ex25.cpp:581
double * ReadWrite(bool on_dev=true)
Shortcut for mfem::ReadWrite(vec.GetMemory(), vec.Size(), on_dev).
Definition: vector.hpp:388
Data type sparse matrix.
Definition: sparsemat.hpp:46
int Height() const
Get the height (size of output) of the Operator. Synonym with NumRows().
Definition: operator.hpp:65
void mfem_error(const char *msg)
Function called when an error is encountered. Used by the macros MFEM_ABORT, MFEM_ASSERT, MFEM_VERIFY.
Definition: error.cpp:153
void UpdateOperator(const Operator &op)
Definition: amgxsolver.cpp:844
virtual void SetOperator(const Operator &op)
Definition: amgxsolver.cpp:821
void DefaultParameters(const AMGX_MODE amgxMode_, const bool verbose)
Definition: amgxsolver.cpp:181
int height
Dimension of the output / number of rows in the matrix.
Definition: operator.hpp:27
void InitExclusiveGPU(const MPI_Comm &comm)
Definition: amgxsolver.cpp:108
const double * Read(bool on_dev=true) const
Shortcut for mfem::Read(vec.GetMemory(), vec.Size(), on_dev).
Definition: vector.hpp:372
Vector data type.
Definition: vector.hpp:51
OutStream out(std::cout)
Global stream used by the library for standard output. Initially it uses the same std::streambuf as s...
Definition: globals.hpp:66
Abstract operator.
Definition: operator.hpp:24
Wrapper for hypre&#39;s ParCSR matrix class.
Definition: hypre.hpp:181
AMGX_MODE
Flags to configure AmgXSolver as a solver or preconditioner.
Definition: amgxsolver.hpp:74
int width
Dimension of the input / number of columns in the matrix.
Definition: operator.hpp:28