33int AmgXSolver::count = 0;
35AMGX_resources_handle AmgXSolver::rsrc =
nullptr;
38 : ConvergenceCheck(false) {};
55 const AMGX_MODE amgxMode_,
const bool verbose)
69 const AMGX_MODE amgxMode_,
const bool verbose)
93 mpi_gpu_mode =
"serial";
95 AMGX_SAFE_CALL(AMGX_initialize());
97 AMGX_SAFE_CALL(AMGX_initialize_plugins());
99 AMGX_SAFE_CALL(AMGX_install_signal_handler());
102 "AmgX configuration is not defined \n");
106 AMGX_SAFE_CALL(AMGX_config_create_from_file(&cfg, amgx_config.c_str()));
110 AMGX_SAFE_CALL(AMGX_config_create(&cfg, amgx_config.c_str()));
113 AMGX_SAFE_CALL(AMGX_resources_create_simple(&rsrc, cfg));
114 AMGX_SAFE_CALL(AMGX_solver_create(&solver, rsrc, precision_mode, cfg));
115 AMGX_SAFE_CALL(AMGX_matrix_create(&AmgXA, rsrc, precision_mode));
116 AMGX_SAFE_CALL(AMGX_vector_create(&AmgXP, rsrc, precision_mode));
117 AMGX_SAFE_CALL(AMGX_vector_create(&AmgXRHS, rsrc, precision_mode));
119 isInitialized =
true;
129 mfem_error(
"This AmgXSolver instance has been initialized on this process.");
133 mpi_gpu_mode =
"mpi-gpu-exclusive";
139 MPI_Comm_dup(comm, &gpuWorld);
140 MPI_Comm_size(gpuWorld, &gpuWorldSize);
141 MPI_Comm_rank(gpuWorld, &myGpuWorldRank);
144 nDevs = 1, devID = 0;
148 isInitialized =
true;
159 mfem_error(
"This AmgXSolver instance has been initialized on this process.");
162 mpi_gpu_mode =
"mpi-teams";
169 char name[MPI_MAX_PROCESSOR_NAME];
170 MPI_Get_processor_name(name, &len);
174 MPI_Comm_rank(comm, &globalcommrank);
177 InitMPIcomms(comm, nDevs);
185 isInitialized =
true;
193 amgx_config = config;
205 amgxMode = amgxMode_;
212 " \"config_version\": 2, \n"
214 " \"solver\": \"AMG\", \n"
215 " \"scope\": \"main\", \n"
216 " \"smoother\": \"JACOBI_L1\", \n"
217 " \"presweeps\": 1, \n"
218 " \"interpolator\": \"D2\", \n"
219 " \"max_row_sum\" : 0.9, \n"
220 " \"strength_threshold\" : 0.25, \n"
221 " \"postsweeps\": 1, \n"
222 " \"max_iters\": 1, \n"
226 amgx_config = amgx_config +
",\n"
227 " \"obtain_timings\": 1, \n"
228 " \"print_grid_stats\": 1, \n"
229 " \"monitor_residual\": 1, \n"
230 " \"print_solve_stats\": 1 \n";
234 amgx_config = amgx_config +
"\n";
236 amgx_config = amgx_config +
" }\n" +
"}\n";
243 " \"config_version\": 2, \n"
245 " \"preconditioner\": { \n"
246 " \"solver\": \"AMG\", \n"
247 " \"smoother\": { \n"
248 " \"scope\": \"jacobi\", \n"
249 " \"solver\": \"JACOBI_L1\" \n"
251 " \"presweeps\": 1, \n"
252 " \"interpolator\": \"D2\", \n"
253 " \"max_row_sum\" : 0.9, \n"
254 " \"strength_threshold\" : 0.25, \n"
255 " \"max_iters\": 1, \n"
256 " \"scope\": \"amg\", \n"
257 " \"max_levels\": 100, \n"
258 " \"cycle\": \"V\", \n"
259 " \"postsweeps\": 1 \n"
261 " \"solver\": \"PCG\", \n"
262 " \"max_iters\": 150, \n"
263 " \"convergence\": \"RELATIVE_INI_CORE\", \n"
264 " \"scope\": \"main\", \n"
265 " \"tolerance\": 1e-12, \n"
266 " \"monitor_residual\": 1, \n"
267 " \"norm\": \"L2\" ";
270 amgx_config = amgx_config +
", \n"
271 " \"obtain_timings\": 1, \n"
272 " \"print_grid_stats\": 1, \n"
273 " \"print_solve_stats\": 1 \n";
277 amgx_config = amgx_config +
"\n";
279 amgx_config = amgx_config +
" } \n" +
"} \n";
291void AmgXSolver::InitAmgX()
296 AMGX_SAFE_CALL(AMGX_initialize());
298 AMGX_SAFE_CALL(AMGX_initialize_plugins());
300 AMGX_SAFE_CALL(AMGX_install_signal_handler());
302 AMGX_SAFE_CALL(AMGX_register_print_callback(
303 [](
const char *msg,
int length)->
void
305 int irank; MPI_Comm_rank(MPI_COMM_WORLD, &irank);
310 "AmgX configuration is not defined \n");
314 AMGX_SAFE_CALL(AMGX_config_create_from_file(&cfg, amgx_config.c_str()));
318 AMGX_SAFE_CALL(AMGX_config_create(&cfg, amgx_config.c_str()));
322 AMGX_SAFE_CALL(AMGX_config_add_parameters(&cfg,
"exception_handling=1"));
326 if (count == 1) { AMGX_SAFE_CALL(AMGX_resources_create(&rsrc, cfg, &gpuWorld, 1, &devID)); }
329 AMGX_SAFE_CALL(AMGX_vector_create(&AmgXP, rsrc, precision_mode));
330 AMGX_SAFE_CALL(AMGX_vector_create(&AmgXRHS, rsrc, precision_mode));
333 AMGX_SAFE_CALL(AMGX_matrix_create(&AmgXA, rsrc, precision_mode));
336 AMGX_SAFE_CALL(AMGX_solver_create(&solver, rsrc, precision_mode, cfg));
339 AMGX_SAFE_CALL(AMGX_config_get_default_number_of_rings(cfg, &ring));
343void AmgXSolver::InitMPIcomms(
const MPI_Comm &comm,
const int nDevs)
346 MPI_Comm_dup(comm, &globalCpuWorld);
347 MPI_Comm_set_name(globalCpuWorld,
"globalCpuWorld");
350 MPI_Comm_size(globalCpuWorld, &globalSize);
351 MPI_Comm_rank(globalCpuWorld, &myGlobalRank);
354 MPI_Comm_split_type(globalCpuWorld,
355 MPI_COMM_TYPE_SHARED, 0, MPI_INFO_NULL, &localCpuWorld);
356 MPI_Comm_set_name(localCpuWorld,
"localCpuWorld");
359 MPI_Comm_size(localCpuWorld, &localSize);
360 MPI_Comm_rank(localCpuWorld, &myLocalRank);
365 MPI_Barrier(globalCpuWorld);
368 MPI_Comm_split(globalCpuWorld, gpuProc, 0, &gpuWorld);
371 if (gpuWorld != MPI_COMM_NULL)
373 MPI_Comm_set_name(gpuWorld,
"gpuWorld");
374 MPI_Comm_size(gpuWorld, &gpuWorldSize);
375 MPI_Comm_rank(gpuWorld, &myGpuWorldRank);
379 gpuWorldSize = MPI_UNDEFINED;
380 myGpuWorldRank = MPI_UNDEFINED;
384 MPI_Comm_split(localCpuWorld, devID, 0, &devWorld);
385 MPI_Comm_set_name(devWorld,
"devWorld");
388 MPI_Comm_size(devWorld, &devWorldSize);
389 MPI_Comm_rank(devWorld, &myDevWorldRank);
391 MPI_Barrier(globalCpuWorld);
395void AmgXSolver::SetDeviceIDs(
const int nDevs)
398 if (nDevs == localSize)
403 else if (nDevs > localSize)
405 MFEM_WARNING(
"CUDA devices on the node " << nodeName.c_str() <<
406 " are more than the MPI processes launched. Only "<<
407 nDevs <<
" devices will be used.\n");
413 int nBasic = localSize / nDevs,
414 nRemain = localSize % nDevs;
416 if (myLocalRank < (nBasic+1)*nRemain)
418 devID = myLocalRank / (nBasic + 1);
419 if (myLocalRank % (nBasic + 1) == 0) { gpuProc = 0; }
423 devID = (myLocalRank - (nBasic+1)*nRemain) / nBasic + nRemain;
424 if ((myLocalRank - (nBasic+1)*nRemain) % nBasic == 0) { gpuProc = 0; }
429void AmgXSolver::GatherArray(
const Array<double> &inArr, Array<double> &outArr,
430 const int mpiTeamSz,
const MPI_Comm &mpiTeamComm)
const
433 Array<int> Apart(mpiTeamSz);
434 int locAsz = inArr.Size();
435 MPI_Gather(&locAsz, 1, MPI_INT,
436 Apart.HostWrite(),1, MPI_INT,0,mpiTeamComm);
438 MPI_Barrier(mpiTeamComm);
441 Array<int> Adisp(mpiTeamSz);
442 int myid; MPI_Comm_rank(mpiTeamComm, &myid);
446 for (
int i=1; i<mpiTeamSz; ++i)
448 Adisp[i] = Adisp[i-1] + Apart[i-1];
452 MPI_Gatherv(inArr.HostRead(), inArr.Size(), MPITypeMap<real_t>::mpi_type,
453 outArr.HostWrite(), Apart.HostRead(), Adisp.HostRead(),
454 MPITypeMap<real_t>::mpi_type, 0, mpiTeamComm);
457void AmgXSolver::GatherArray(
const Vector &inArr, Vector &outArr,
458 const int mpiTeamSz,
const MPI_Comm &mpiTeamComm)
const
461 Array<int> Apart(mpiTeamSz);
462 int locAsz = inArr.Size();
463 MPI_Gather(&locAsz, 1, MPI_INT,
464 Apart.HostWrite(),1, MPI_INT,0,mpiTeamComm);
466 MPI_Barrier(mpiTeamComm);
469 Array<int> Adisp(mpiTeamSz);
470 int myid; MPI_Comm_rank(mpiTeamComm, &myid);
474 for (
int i=1; i<mpiTeamSz; ++i)
476 Adisp[i] = Adisp[i-1] + Apart[i-1];
480 MPI_Gatherv(inArr.HostRead(), inArr.Size(), MPITypeMap<real_t>::mpi_type,
481 outArr.HostWrite(), Apart.HostRead(), Adisp.HostRead(),
482 MPITypeMap<real_t>::mpi_type, 0, mpiTeamComm);
485void AmgXSolver::GatherArray(
const Array<int> &inArr, Array<int> &outArr,
486 const int mpiTeamSz,
const MPI_Comm &mpiTeamComm)
const
489 Array<int> Apart(mpiTeamSz);
490 int locAsz = inArr.Size();
491 MPI_Gather(&locAsz, 1, MPI_INT,
492 Apart.GetData(),1, MPI_INT,0,mpiTeamComm);
494 MPI_Barrier(mpiTeamComm);
497 Array<int> Adisp(mpiTeamSz);
498 int myid; MPI_Comm_rank(mpiTeamComm, &myid);
502 for (
int i=1; i<mpiTeamSz; ++i)
504 Adisp[i] = Adisp[i-1] + Apart[i-1];
508 MPI_Gatherv(inArr.HostRead(), inArr.Size(), MPI_INT,
509 outArr.HostWrite(), Apart.HostRead(), Adisp.HostRead(),
510 MPI_INT, 0, mpiTeamComm);
514void AmgXSolver::GatherArray(
const Array<int64_t> &inArr,
515 Array<int64_t> &outArr,
516 const int mpiTeamSz,
const MPI_Comm &mpiTeamComm)
const
519 Array<int> Apart(mpiTeamSz);
520 int locAsz = inArr.Size();
521 MPI_Gather(&locAsz, 1, MPI_INT,
522 Apart.GetData(),1, MPI_INT,0,mpiTeamComm);
524 MPI_Barrier(mpiTeamComm);
527 Array<int> Adisp(mpiTeamSz);
528 int myid; MPI_Comm_rank(mpiTeamComm, &myid);
532 for (
int i=1; i<mpiTeamSz; ++i)
534 Adisp[i] = Adisp[i-1] + Apart[i-1];
538 MPI_Gatherv(inArr.HostRead(), inArr.Size(), MPI_INT64_T,
539 outArr.HostWrite(), Apart.HostRead(), Adisp.HostRead(),
540 MPI_INT64_T, 0, mpiTeamComm);
542 MPI_Barrier(mpiTeamComm);
545void AmgXSolver::GatherArray(
const Vector &inArr, Vector &outArr,
546 const int mpiTeamSz,
const MPI_Comm &mpiTeamComm,
547 Array<int> &Apart, Array<int> &Adisp)
const
550 int locAsz = inArr.Size();
551 MPI_Allgather(&locAsz, 1, MPI_INT,
552 Apart.HostWrite(),1, MPI_INT, mpiTeamComm);
554 MPI_Barrier(mpiTeamComm);
558 for (
int i=1; i<mpiTeamSz; ++i)
560 Adisp[i] = Adisp[i-1] + Apart[i-1];
563 MPI_Gatherv(inArr.HostRead(), inArr.Size(), MPITypeMap<real_t>::mpi_type,
564 outArr.HostWrite(), Apart.HostRead(), Adisp.HostRead(),
565 MPITypeMap<real_t>::mpi_type, 0, mpiTeamComm);
568void AmgXSolver::ScatterArray(
const Vector &inArr, Vector &outArr,
569 const int mpiTeamSz,
const MPI_Comm &mpiTeamComm,
570 Array<int> &Apart, Array<int> &Adisp)
const
572 MPI_Scatterv(inArr.HostRead(),Apart.HostRead(),Adisp.HostRead(),
573 MPITypeMap<real_t>::mpi_type,outArr.HostWrite(),outArr.Size(),
574 MPITypeMap<real_t>::mpi_type, 0, mpiTeamComm);
578void AmgXSolver::SetMatrix(
const SparseMatrix &in_A,
const bool update_mat)
580 if (update_mat ==
false)
582 AMGX_SAFE_CALL(AMGX_matrix_upload_all(AmgXA, in_A.Height(),
583 in_A.NumNonZeroElems(),
587 in_A.ReadData(), NULL));
589 AMGX_SAFE_CALL(AMGX_solver_setup(solver, AmgXA));
590 AMGX_SAFE_CALL(AMGX_vector_bind(AmgXP, AmgXA));
591 AMGX_SAFE_CALL(AMGX_vector_bind(AmgXRHS, AmgXA));
595 AMGX_SAFE_CALL(AMGX_matrix_replace_coefficients(AmgXA,
597 in_A.NumNonZeroElems(),
598 in_A.ReadData(), NULL));
604void AmgXSolver::SetMatrix(
const HypreParMatrix &A,
const bool update_mat)
607#if MFEM_HYPRE_VERSION < 21600
608 mfem_error(
"Hypre version 2.16+ is required when using AmgX \n");
614 hypre_ParCSRMatrix * A_ptr =
615 (hypre_ParCSRMatrix *)
const_cast<HypreParMatrix&
>(A);
617 hypre_CSRMatrix *A_csr = hypre_MergeDiagAndOffd(A_ptr);
621 Array<double> loc_A(A_csr->data, (
int)A_csr->num_nonzeros);
622 const Array<HYPRE_Int> loc_I(A_csr->i, (
int)A_csr->num_rows+1);
625 Array<int64_t> loc_J((
int)A_csr->num_nonzeros);
626 for (
int i=0; i<A_csr->num_nonzeros; ++i)
628 loc_J[i] = A_csr->big_j[i];
632 if (mpi_gpu_mode==
"mpi-gpu-exclusive")
634 SetMatrixMPIGPUExclusive(A, loc_A, loc_I, loc_J, update_mat);
636 hypre_CSRMatrixDestroy(A_csr);
641 if (mpi_gpu_mode ==
"mpi-teams")
643 SetMatrixMPITeams(A, loc_A, loc_I, loc_J, update_mat);
645 hypre_CSRMatrixDestroy(A_csr);
649 mfem_error(
"Unsupported MPI_GPU combination \n");
652void AmgXSolver::SetMatrixMPIGPUExclusive(
const HypreParMatrix &A,
653 const Array<double> &loc_A,
654 const Array<int> &loc_I,
655 const Array<int64_t> &loc_J,
656 const bool update_mat)
659 Array<int64_t> rowPart(gpuWorldSize+1); rowPart = 0.0;
661 int64_t myStart = A.GetRowStarts()[0];
663 MPI_Allgather(&myStart, 1, MPI_INT64_T,
664 rowPart.GetData(),1, MPI_INT64_T
666 MPI_Barrier(gpuWorld);
668 rowPart[gpuWorldSize] = A.M();
670 const int nGlobalRows = A.M();
671 const int local_rows = loc_I.Size()-1;
672 const int num_nnz = loc_I[local_rows];
674 if (update_mat ==
false)
676 AMGX_distribution_handle dist;
677 AMGX_SAFE_CALL(AMGX_distribution_create(&dist, cfg));
678 AMGX_SAFE_CALL(AMGX_distribution_set_partition_data(dist,
679 AMGX_DIST_PARTITION_OFFSETS,
682 AMGX_SAFE_CALL(AMGX_matrix_upload_distributed(AmgXA, nGlobalRows,
683 local_rows, num_nnz, 1, 1,
684 loc_I.Read(), loc_J.Read(),
685 loc_A.Read(), NULL, dist));
687 AMGX_SAFE_CALL(AMGX_distribution_destroy(dist));
689 MPI_Barrier(gpuWorld);
691 AMGX_SAFE_CALL(AMGX_solver_setup(solver, AmgXA));
693 AMGX_SAFE_CALL(AMGX_vector_bind(AmgXP, AmgXA));
694 AMGX_SAFE_CALL(AMGX_vector_bind(AmgXRHS, AmgXA));
698 AMGX_SAFE_CALL(AMGX_matrix_replace_coefficients(AmgXA, nGlobalRows,
699 num_nnz, loc_A, NULL));
703void AmgXSolver::SetMatrixMPITeams(
const HypreParMatrix &A,
704 const Array<double> &loc_A,
705 const Array<int> &loc_I,
706 const Array<int64_t> &loc_J,
707 const bool update_mat)
712 Array<int64_t> all_J;
716 int J_allsz(0), all_NNZ(0), nDevRows(0);
717 const int loc_row_len = std::abs(A.RowPart()[1] -
719 const int loc_Jz_sz = loc_J.Size();
720 const int loc_A_sz = loc_A.Size();
722 MPI_Reduce(&loc_row_len, &nDevRows, 1, MPI_INT, MPI_SUM, 0, devWorld);
723 MPI_Reduce(&loc_Jz_sz, &J_allsz, 1, MPI_INT, MPI_SUM, 0, devWorld);
724 MPI_Reduce(&loc_A_sz, &all_NNZ, 1, MPI_INT, MPI_SUM, 0, devWorld);
726 MPI_Barrier(devWorld);
728 if (myDevWorldRank == 0)
730 all_I.SetSize(nDevRows+devWorldSize);
731 all_J.SetSize(J_allsz); all_J = 0.0;
732 all_A.SetSize(all_NNZ);
735 GatherArray(loc_I, all_I, devWorldSize, devWorld);
736 GatherArray(loc_J, all_J, devWorldSize, devWorld);
737 GatherArray(loc_A, all_A, devWorldSize, devWorld);
739 MPI_Barrier(devWorld);
742 int64_t local_rows(0);
744 if (myDevWorldRank == 0)
748 Array<int> z_ind(devWorldSize+1);
750 while (iter < devWorldSize-1)
754 z_ind[counter] = counter;
756 for (
int idx=1; idx<all_I.Size()-1; idx++)
760 z_ind[counter] = idx-1;
764 z_ind[devWorldSize] = all_I.Size()-1;
768 for (
int idx=z_ind[1]+1; idx < z_ind[2]; idx++)
770 all_I[idx] = all_I[idx-1] + (all_I[idx+1] - all_I[idx]);
775 for (
int idx=z_ind[2]; idx < all_I.Size()-1; ++idx)
777 all_I[idx] = all_I[idx+1];
785 z_ind[counter] = counter;
787 for (
int idx=1; idx<all_I.Size()-1; idx++)
791 z_ind[counter] = idx-1;
796 z_ind[devWorldSize] = all_I.Size()-1;
799 for (
int idx=z_ind[1]+1; idx < all_I.Size()-1; idx++)
801 all_I[idx] = all_I[idx-1] + (all_I[idx+1] - all_I[idx]);
803 local_nnz = all_I[all_I.Size()-devWorldSize];
804 local_rows = nDevRows;
808 mat_local_rows = local_rows;
809 Array<int64_t> rowPart;
812 rowPart.SetSize(gpuWorldSize+1); rowPart=0;
814 MPI_Allgather(&local_rows, 1, MPI_INT64_T,
815 &rowPart.GetData()[1], 1, MPI_INT64_T,
817 MPI_Barrier(gpuWorld);
820 for (
int i=1; i<rowPart.Size(); ++i)
822 rowPart[i] += rowPart[i-1];
826 MPI_Barrier(gpuWorld);
828 int nGlobalRows = A.M();
829 if (update_mat ==
false)
831 AMGX_distribution_handle dist;
832 AMGX_SAFE_CALL(AMGX_distribution_create(&dist, cfg));
833 AMGX_SAFE_CALL(AMGX_distribution_set_partition_data(dist,
834 AMGX_DIST_PARTITION_OFFSETS,
837 AMGX_SAFE_CALL(AMGX_matrix_upload_distributed(AmgXA, nGlobalRows,
838 local_rows, local_nnz,
839 1, 1, all_I.ReadWrite(),
844 AMGX_SAFE_CALL(AMGX_distribution_destroy(dist));
845 MPI_Barrier(gpuWorld);
847 AMGX_SAFE_CALL(AMGX_solver_setup(solver, AmgXA));
850 AMGX_SAFE_CALL(AMGX_vector_bind(AmgXP, AmgXA));
851 AMGX_SAFE_CALL(AMGX_vector_bind(AmgXRHS, AmgXA));
855 AMGX_SAFE_CALL(AMGX_matrix_replace_coefficients(AmgXA, nGlobalRows,
856 local_nnz, all_A, NULL));
891 SetMatrix(*Aptr,
true);
897 SetMatrix(*Aptr,
true);
913 if (mpi_gpu_mode !=
"mpi-teams")
915 AMGX_SAFE_CALL(AMGX_vector_upload(AmgXP, X.
Size(), 1, X.
ReadWrite()));
916 AMGX_SAFE_CALL(AMGX_vector_upload(AmgXRHS, B.
Size(), 1, B.
Read()));
918 if (mpi_gpu_mode !=
"serial")
921 MPI_Barrier(gpuWorld);
925 AMGX_SAFE_CALL(AMGX_solver_solve(solver,AmgXRHS, AmgXP));
927 AMGX_SOLVE_STATUS status;
928 AMGX_SAFE_CALL(AMGX_solver_get_status(solver, &status));
931 if (status == AMGX_SOLVE_DIVERGED)
937 mfem_error(
"AmgX solver failed to solve system \n");
941 AMGX_SAFE_CALL(AMGX_vector_download(AmgXP, X.
Write()));
946 Vector all_X(mat_local_rows);
947 Vector all_B(mat_local_rows);
953 GatherArray(X, all_X, devWorldSize, devWorld, Apart_X, Adisp_X);
954 GatherArray(B, all_B, devWorldSize, devWorld, Apart_B, Adisp_B);
955 MPI_Barrier(devWorld);
957 if (gpuWorld != MPI_COMM_NULL)
959 AMGX_SAFE_CALL(AMGX_vector_upload(AmgXP, all_X.
Size(), 1, all_X.
ReadWrite()));
960 AMGX_SAFE_CALL(AMGX_vector_upload(AmgXRHS, all_B.
Size(), 1, all_B.
ReadWrite()));
962 MPI_Barrier(gpuWorld);
964 AMGX_SAFE_CALL(AMGX_solver_solve(solver,AmgXRHS, AmgXP));
966 AMGX_SOLVE_STATUS status;
967 AMGX_SAFE_CALL(AMGX_solver_get_status(solver, &status));
968 if (status != AMGX_SOLVE_SUCCESS && amgxMode ==
SOLVER)
970 if (status == AMGX_SOLVE_DIVERGED)
976 mfem_error(
"AmgX solver failed to solve system \n");
980 AMGX_SAFE_CALL(AMGX_vector_download(AmgXP, all_X.
Write()));
983 ScatterArray(all_X, X, devWorldSize, devWorld, Apart_X, Adisp_X);
990 AMGX_SAFE_CALL(AMGX_solver_get_iterations_number(solver, &getIters));
997 if (! isInitialized || count < 1)
999 mfem_error(
"Error in AmgXSolver::Finalize(). \n"
1000 "This AmgXWrapper has not been initialized. \n"
1001 "Please initialize it before finalization.\n");
1006 if (gpuProc == 0 || mpi_gpu_mode ==
"serial")
1010 AMGX_SAFE_CALL(AMGX_solver_destroy(solver));
1013 AMGX_SAFE_CALL(AMGX_matrix_destroy(AmgXA));
1016 AMGX_SAFE_CALL(AMGX_vector_destroy(AmgXP));
1017 AMGX_SAFE_CALL(AMGX_vector_destroy(AmgXRHS));
1022 AMGX_SAFE_CALL(AMGX_resources_destroy(rsrc));
1023 AMGX_SAFE_CALL(AMGX_config_destroy(cfg));
1025 AMGX_SAFE_CALL(AMGX_finalize_plugins());
1026 AMGX_SAFE_CALL(AMGX_finalize());
1030 AMGX_SAFE_CALL(AMGX_config_destroy(cfg));
1034 if (mpi_gpu_mode !=
"serial")
1036 MPI_Comm_free(&gpuWorld);
1044 gpuProc = MPI_UNDEFINED;
1045 if (globalCpuWorld != MPI_COMM_NULL)
1047 MPI_Comm_free(&globalCpuWorld);
1048 MPI_Comm_free(&localCpuWorld);
1049 MPI_Comm_free(&devWorld);
1056 isInitialized =
false;
int GetNumIterations()
Return the number of iterations that were executed during the last solve phase.
bool ConvergenceCheck
Flag to check for convergence.
CONFIG_SRC
Flags to determine whether user solver settings are defined internally in the source code or will be ...
@ EXTERNAL
Configure will be read from a specified file.
@ INTERNAL
Configuration will be read directly from a string.
void Finalize()
Close down the AmgX library and free up any MPI Comms set up for it.
void DefaultParameters(const AMGX_MODE amgxMode_, const bool verbose)
Set up the AmgX library with the default paramaters.
~AmgXSolver()
Close down the AmgX library and free up any MPI Comms set up for it.
virtual void SetOperator(const Operator &op)
Sets the Operator that is going to be solved via AmgX. Supports operators based on either an MFEM Spa...
void InitSerial()
Initialize the AmgX library for serial execution once the solver configuration has been established t...
AMGX_MODE
Flags to configure AmgXSolver as a solver or preconditioner.
void SetConvergenceCheck(bool setConvergenceCheck_=true)
Add a check for convergence after applying Mult.
void InitMPITeams(const MPI_Comm &comm, const int nDevs)
Initialize the AmgX library and create MPI teams based on the number of devices on each node nDevs....
virtual void Mult(const Vector &b, Vector &x) const
Utilize the AmgX library to solve the linear system where the "matrix" is the AMG approximation to th...
void ReadParameters(const std::string config, CONFIG_SRC source)
Read in the AmgX parameters either through a file or directly through a properly formated string....
void InitExclusiveGPU(const MPI_Comm &comm)
Initialize the AmgX library in parallel mode with exactly one GPU per rank after the solver configura...
void UpdateOperator(const Operator &op)
Change the input operator that is being solved via AmgX. Supports operators based on either an MFEM S...
Wrapper for hypre's ParCSR matrix class.
int width
Dimension of the input / number of columns in the matrix.
int Height() const
Get the height (size of output) of the Operator. Synonym with NumRows().
int height
Dimension of the output / number of rows in the matrix.
int Width() const
Get the width (size of input) of the Operator. Synonym with NumCols().
bool iterative_mode
If true, use the second argument of Mult() as an initial guess.
virtual const real_t * Read(bool on_dev=true) const
Shortcut for mfem::Read(vec.GetMemory(), vec.Size(), on_dev).
virtual real_t * ReadWrite(bool on_dev=true)
Shortcut for mfem::ReadWrite(vec.GetMemory(), vec.Size(), on_dev).
int Size() const
Returns the size of the vector.
virtual void UseDevice(bool use_dev) const
Enable execution of Vector operations using the mfem::Device.
virtual real_t * Write(bool on_dev=true)
Shortcut for mfem::Write(vec.GetMemory(), vec.Size(), on_dev).
void source(const Vector &x, Vector &f)
void mfem_error(const char *msg)
OutStream out(std::cout)
Global stream used by the library for standard output. Initially it uses the same std::streambuf as s...