22 #include "../config/config.hpp" 29 int AmgXSolver::count = 0;
31 AMGX_resources_handle AmgXSolver::rsrc =
nullptr;
34 : ConvergenceCheck(false) {};
51 const AMGX_MODE amgxMode_,
const bool verbose)
65 const AMGX_MODE amgxMode_,
const bool verbose)
89 mpi_gpu_mode =
"serial";
91 AMGX_SAFE_CALL(AMGX_initialize());
93 AMGX_SAFE_CALL(AMGX_initialize_plugins());
95 AMGX_SAFE_CALL(AMGX_install_signal_handler());
97 MFEM_VERIFY(configSrc != CONFIG_SRC::UNDEFINED,
98 "AmgX configuration is not defined \n");
100 if (configSrc == CONFIG_SRC::EXTERNAL)
102 AMGX_SAFE_CALL(AMGX_config_create_from_file(&cfg, amgx_config.c_str()));
106 AMGX_SAFE_CALL(AMGX_config_create(&cfg, amgx_config.c_str()));
109 AMGX_SAFE_CALL(AMGX_resources_create_simple(&rsrc, cfg));
110 AMGX_SAFE_CALL(AMGX_solver_create(&solver, rsrc, precision_mode, cfg));
111 AMGX_SAFE_CALL(AMGX_matrix_create(&AmgXA, rsrc, precision_mode));
112 AMGX_SAFE_CALL(AMGX_vector_create(&AmgXP, rsrc, precision_mode));
113 AMGX_SAFE_CALL(AMGX_vector_create(&AmgXRHS, rsrc, precision_mode));
115 isInitialized =
true;
125 mfem_error(
"This AmgXSolver instance has been initialized on this process.");
129 mpi_gpu_mode =
"mpi-gpu-exclusive";
135 MPI_Comm_dup(comm, &gpuWorld);
136 MPI_Comm_size(gpuWorld, &gpuWorldSize);
137 MPI_Comm_rank(gpuWorld, &myGpuWorldRank);
140 nDevs = 1, devID = 0;
144 isInitialized =
true;
155 mfem_error(
"This AmgXSolver instance has been initialized on this process.");
158 mpi_gpu_mode =
"mpi-teams";
165 char name[MPI_MAX_PROCESSOR_NAME];
166 MPI_Get_processor_name(name, &len);
170 MPI_Comm_rank(comm, &globalcommrank);
173 InitMPIcomms(comm, nDevs);
181 isInitialized =
true;
189 amgx_config = config;
201 amgxMode = amgxMode_;
205 if (amgxMode == AMGX_MODE::PRECONDITIONER)
208 " \"config_version\": 2, \n" 210 " \"solver\": \"AMG\", \n" 211 " \"scope\": \"main\", \n" 212 " \"smoother\": \"JACOBI_L1\", \n" 213 " \"presweeps\": 1, \n" 214 " \"interpolator\": \"D2\", \n" 215 " \"max_row_sum\" : 0.9, \n" 216 " \"strength_threshold\" : 0.25, \n" 217 " \"postsweeps\": 1, \n" 218 " \"max_iters\": 1, \n" 222 amgx_config = amgx_config +
",\n" 223 " \"obtain_timings\": 1, \n" 224 " \"print_grid_stats\": 1, \n" 225 " \"monitor_residual\": 1, \n" 226 " \"print_solve_stats\": 1 \n";
230 amgx_config = amgx_config +
"\n";
232 amgx_config = amgx_config +
" }\n" +
"}\n";
236 else if (amgxMode == AMGX_MODE::SOLVER)
239 " \"config_version\": 2, \n" 241 " \"preconditioner\": { \n" 242 " \"solver\": \"AMG\", \n" 243 " \"smoother\": { \n" 244 " \"scope\": \"jacobi\", \n" 245 " \"solver\": \"JACOBI_L1\" \n" 247 " \"presweeps\": 1, \n" 248 " \"interpolator\": \"D2\", \n" 249 " \"max_row_sum\" : 0.9, \n" 250 " \"strength_threshold\" : 0.25, \n" 251 " \"max_iters\": 1, \n" 252 " \"scope\": \"amg\", \n" 253 " \"max_levels\": 100, \n" 254 " \"cycle\": \"V\", \n" 255 " \"postsweeps\": 1 \n" 257 " \"solver\": \"PCG\", \n" 258 " \"max_iters\": 150, \n" 259 " \"convergence\": \"RELATIVE_INI_CORE\", \n" 260 " \"scope\": \"main\", \n" 261 " \"tolerance\": 1e-12, \n" 262 " \"monitor_residual\": 1, \n" 263 " \"norm\": \"L2\" ";
266 amgx_config = amgx_config +
", \n" 267 " \"obtain_timings\": 1, \n" 268 " \"print_grid_stats\": 1, \n" 269 " \"print_solve_stats\": 1 \n";
273 amgx_config = amgx_config +
"\n";
275 amgx_config = amgx_config +
" } \n" +
"} \n";
287 void AmgXSolver::InitAmgX()
292 AMGX_SAFE_CALL(AMGX_initialize());
294 AMGX_SAFE_CALL(AMGX_initialize_plugins());
296 AMGX_SAFE_CALL(AMGX_install_signal_handler());
298 AMGX_SAFE_CALL(AMGX_register_print_callback(
299 [](
const char *msg,
int length)->
void 301 int irank; MPI_Comm_rank(MPI_COMM_WORLD, &irank);
305 MFEM_VERIFY(configSrc != CONFIG_SRC::UNDEFINED,
306 "AmgX configuration is not defined \n");
308 if (configSrc == CONFIG_SRC::EXTERNAL)
310 AMGX_SAFE_CALL(AMGX_config_create_from_file(&cfg, amgx_config.c_str()));
314 AMGX_SAFE_CALL(AMGX_config_create(&cfg, amgx_config.c_str()));
318 AMGX_SAFE_CALL(AMGX_config_add_parameters(&cfg,
"exception_handling=1"));
322 if (count == 1) { AMGX_SAFE_CALL(AMGX_resources_create(&rsrc, cfg, &gpuWorld, 1, &devID)); }
325 AMGX_SAFE_CALL(AMGX_vector_create(&AmgXP, rsrc, precision_mode));
326 AMGX_SAFE_CALL(AMGX_vector_create(&AmgXRHS, rsrc, precision_mode));
329 AMGX_SAFE_CALL(AMGX_matrix_create(&AmgXA, rsrc, precision_mode));
332 AMGX_SAFE_CALL(AMGX_solver_create(&solver, rsrc, precision_mode, cfg));
335 AMGX_SAFE_CALL(AMGX_config_get_default_number_of_rings(cfg, &ring));
339 void AmgXSolver::InitMPIcomms(
const MPI_Comm &comm,
const int nDevs)
342 MPI_Comm_dup(comm, &globalCpuWorld);
343 MPI_Comm_set_name(globalCpuWorld,
"globalCpuWorld");
346 MPI_Comm_size(globalCpuWorld, &globalSize);
347 MPI_Comm_rank(globalCpuWorld, &myGlobalRank);
350 MPI_Comm_split_type(globalCpuWorld,
351 MPI_COMM_TYPE_SHARED, 0, MPI_INFO_NULL, &localCpuWorld);
352 MPI_Comm_set_name(localCpuWorld,
"localCpuWorld");
355 MPI_Comm_size(localCpuWorld, &localSize);
356 MPI_Comm_rank(localCpuWorld, &myLocalRank);
361 MPI_Barrier(globalCpuWorld);
364 MPI_Comm_split(globalCpuWorld, gpuProc, 0, &gpuWorld);
367 if (gpuWorld != MPI_COMM_NULL)
369 MPI_Comm_set_name(gpuWorld,
"gpuWorld");
370 MPI_Comm_size(gpuWorld, &gpuWorldSize);
371 MPI_Comm_rank(gpuWorld, &myGpuWorldRank);
375 gpuWorldSize = MPI_UNDEFINED;
376 myGpuWorldRank = MPI_UNDEFINED;
380 MPI_Comm_split(localCpuWorld, devID, 0, &devWorld);
381 MPI_Comm_set_name(devWorld,
"devWorld");
384 MPI_Comm_size(devWorld, &devWorldSize);
385 MPI_Comm_rank(devWorld, &myDevWorldRank);
387 MPI_Barrier(globalCpuWorld);
391 void AmgXSolver::SetDeviceIDs(
const int nDevs)
394 if (nDevs == localSize)
399 else if (nDevs > localSize)
401 MFEM_WARNING(
"CUDA devices on the node " << nodeName.c_str() <<
402 " are more than the MPI processes launched. Only "<<
403 nDevs <<
" devices will be used.\n");
409 int nBasic = localSize / nDevs,
410 nRemain = localSize % nDevs;
412 if (myLocalRank < (nBasic+1)*nRemain)
414 devID = myLocalRank / (nBasic + 1);
415 if (myLocalRank % (nBasic + 1) == 0) { gpuProc = 0; }
419 devID = (myLocalRank - (nBasic+1)*nRemain) / nBasic + nRemain;
420 if ((myLocalRank - (nBasic+1)*nRemain) % nBasic == 0) { gpuProc = 0; }
425 void AmgXSolver::GatherArray(
const Array<double> &inArr, Array<double> &outArr,
426 const int mpiTeamSz,
const MPI_Comm &mpiTeamComm)
const 429 Array<int> Apart(mpiTeamSz);
430 int locAsz = inArr.Size();
431 MPI_Gather(&locAsz, 1, MPI_INT,
432 Apart.HostWrite(),1, MPI_INT,0,mpiTeamComm);
434 MPI_Barrier(mpiTeamComm);
437 Array<int> Adisp(mpiTeamSz);
438 int myid; MPI_Comm_rank(mpiTeamComm, &myid);
442 for (
int i=1; i<mpiTeamSz; ++i)
444 Adisp[i] = Adisp[i-1] + Apart[i-1];
448 MPI_Gatherv(inArr.HostRead(), inArr.Size(), MPI_DOUBLE,
449 outArr.HostWrite(), Apart.HostRead(), Adisp.HostRead(),
450 MPI_DOUBLE, 0, mpiTeamComm);
453 void AmgXSolver::GatherArray(
const Vector &inArr, Vector &outArr,
454 const int mpiTeamSz,
const MPI_Comm &mpiTeamComm)
const 457 Array<int> Apart(mpiTeamSz);
458 int locAsz = inArr.Size();
459 MPI_Gather(&locAsz, 1, MPI_INT,
460 Apart.HostWrite(),1, MPI_INT,0,mpiTeamComm);
462 MPI_Barrier(mpiTeamComm);
465 Array<int> Adisp(mpiTeamSz);
466 int myid; MPI_Comm_rank(mpiTeamComm, &myid);
470 for (
int i=1; i<mpiTeamSz; ++i)
472 Adisp[i] = Adisp[i-1] + Apart[i-1];
476 MPI_Gatherv(inArr.HostRead(), inArr.Size(), MPI_DOUBLE,
477 outArr.HostWrite(), Apart.HostRead(), Adisp.HostRead(),
478 MPI_DOUBLE, 0, mpiTeamComm);
481 void AmgXSolver::GatherArray(
const Array<int> &inArr, Array<int> &outArr,
482 const int mpiTeamSz,
const MPI_Comm &mpiTeamComm)
const 485 Array<int> Apart(mpiTeamSz);
486 int locAsz = inArr.Size();
487 MPI_Gather(&locAsz, 1, MPI_INT,
488 Apart.GetData(),1, MPI_INT,0,mpiTeamComm);
490 MPI_Barrier(mpiTeamComm);
493 Array<int> Adisp(mpiTeamSz);
494 int myid; MPI_Comm_rank(mpiTeamComm, &myid);
498 for (
int i=1; i<mpiTeamSz; ++i)
500 Adisp[i] = Adisp[i-1] + Apart[i-1];
504 MPI_Gatherv(inArr.HostRead(), inArr.Size(), MPI_INT,
505 outArr.HostWrite(), Apart.HostRead(), Adisp.HostRead(),
506 MPI_INT, 0, mpiTeamComm);
510 void AmgXSolver::GatherArray(
const Array<int64_t> &inArr,
511 Array<int64_t> &outArr,
512 const int mpiTeamSz,
const MPI_Comm &mpiTeamComm)
const 515 Array<int> Apart(mpiTeamSz);
516 int locAsz = inArr.Size();
517 MPI_Gather(&locAsz, 1, MPI_INT,
518 Apart.GetData(),1, MPI_INT,0,mpiTeamComm);
520 MPI_Barrier(mpiTeamComm);
523 Array<int> Adisp(mpiTeamSz);
524 int myid; MPI_Comm_rank(mpiTeamComm, &myid);
528 for (
int i=1; i<mpiTeamSz; ++i)
530 Adisp[i] = Adisp[i-1] + Apart[i-1];
534 MPI_Gatherv(inArr.HostRead(), inArr.Size(), MPI_INT64_T,
535 outArr.HostWrite(), Apart.HostRead(), Adisp.HostRead(),
536 MPI_INT64_T, 0, mpiTeamComm);
538 MPI_Barrier(mpiTeamComm);
541 void AmgXSolver::GatherArray(
const Vector &inArr, Vector &outArr,
542 const int mpiTeamSz,
const MPI_Comm &mpiTeamComm,
543 Array<int> &Apart, Array<int> &Adisp)
const 546 int locAsz = inArr.Size();
547 MPI_Allgather(&locAsz, 1, MPI_INT,
548 Apart.HostWrite(),1, MPI_INT, mpiTeamComm);
550 MPI_Barrier(mpiTeamComm);
554 for (
int i=1; i<mpiTeamSz; ++i)
556 Adisp[i] = Adisp[i-1] + Apart[i-1];
559 MPI_Gatherv(inArr.HostRead(), inArr.Size(), MPI_DOUBLE,
560 outArr.HostWrite(), Apart.HostRead(), Adisp.HostRead(),
561 MPI_DOUBLE, 0, mpiTeamComm);
564 void AmgXSolver::ScatterArray(
const Vector &inArr, Vector &outArr,
565 const int mpiTeamSz,
const MPI_Comm &mpiTeamComm,
566 Array<int> &Apart, Array<int> &Adisp)
const 568 MPI_Scatterv(inArr.HostRead(),Apart.HostRead(),Adisp.HostRead(),
569 MPI_DOUBLE,outArr.HostWrite(),outArr.Size(),
570 MPI_DOUBLE, 0, mpiTeamComm);
574 void AmgXSolver::SetMatrix(
const SparseMatrix &in_A,
const bool update_mat)
576 if (update_mat ==
false)
578 AMGX_SAFE_CALL(AMGX_matrix_upload_all(AmgXA, in_A.Height(),
579 in_A.NumNonZeroElems(),
583 in_A.ReadData(), NULL));
585 AMGX_SAFE_CALL(AMGX_solver_setup(solver, AmgXA));
586 AMGX_SAFE_CALL(AMGX_vector_bind(AmgXP, AmgXA));
587 AMGX_SAFE_CALL(AMGX_vector_bind(AmgXRHS, AmgXA));
591 AMGX_SAFE_CALL(AMGX_matrix_replace_coefficients(AmgXA,
593 in_A.NumNonZeroElems(),
594 in_A.ReadData(), NULL));
600 void AmgXSolver::SetMatrix(
const HypreParMatrix &A,
const bool update_mat)
603 #if MFEM_HYPRE_VERSION < 21600 604 mfem_error(
"Hypre version 2.16+ is required when using AmgX \n");
610 hypre_ParCSRMatrix * A_ptr =
611 (hypre_ParCSRMatrix *)const_cast<HypreParMatrix&>(A);
613 hypre_CSRMatrix *A_csr = hypre_MergeDiagAndOffd(A_ptr);
617 Array<double> loc_A(A_csr->data, (
int)A_csr->num_nonzeros);
618 const Array<HYPRE_Int> loc_I(A_csr->i, (
int)A_csr->num_rows+1);
621 Array<int64_t> loc_J((
int)A_csr->num_nonzeros);
622 for (
int i=0; i<A_csr->num_nonzeros; ++i)
624 loc_J[i] = A_csr->big_j[i];
628 if (mpi_gpu_mode==
"mpi-gpu-exclusive")
630 SetMatrixMPIGPUExclusive(A, loc_A, loc_I, loc_J, update_mat);
632 hypre_CSRMatrixDestroy(A_csr);
637 if (mpi_gpu_mode ==
"mpi-teams")
639 SetMatrixMPITeams(A, loc_A, loc_I, loc_J, update_mat);
641 hypre_CSRMatrixDestroy(A_csr);
645 mfem_error(
"Unsupported MPI_GPU combination \n");
648 void AmgXSolver::SetMatrixMPIGPUExclusive(
const HypreParMatrix &A,
649 const Array<double> &loc_A,
650 const Array<int> &loc_I,
651 const Array<int64_t> &loc_J,
652 const bool update_mat)
655 Array<int64_t> rowPart(gpuWorldSize+1); rowPart = 0.0;
657 int64_t myStart = A.GetRowStarts()[0];
659 MPI_Allgather(&myStart, 1, MPI_INT64_T,
660 rowPart.GetData(),1, MPI_INT64_T
662 MPI_Barrier(gpuWorld);
664 rowPart[gpuWorldSize] = A.M();
666 const int nGlobalRows = A.M();
667 const int local_rows = loc_I.Size()-1;
668 const int num_nnz = loc_I[local_rows];
670 if (update_mat ==
false)
672 AMGX_distribution_handle dist;
673 AMGX_SAFE_CALL(AMGX_distribution_create(&dist, cfg));
674 AMGX_SAFE_CALL(AMGX_distribution_set_partition_data(dist,
675 AMGX_DIST_PARTITION_OFFSETS,
678 AMGX_SAFE_CALL(AMGX_matrix_upload_distributed(AmgXA, nGlobalRows,
679 local_rows, num_nnz, 1, 1,
680 loc_I.Read(), loc_J.Read(),
681 loc_A.Read(), NULL, dist));
683 AMGX_SAFE_CALL(AMGX_distribution_destroy(dist));
685 MPI_Barrier(gpuWorld);
687 AMGX_SAFE_CALL(AMGX_solver_setup(solver, AmgXA));
689 AMGX_SAFE_CALL(AMGX_vector_bind(AmgXP, AmgXA));
690 AMGX_SAFE_CALL(AMGX_vector_bind(AmgXRHS, AmgXA));
694 AMGX_SAFE_CALL(AMGX_matrix_replace_coefficients(AmgXA, nGlobalRows,
695 num_nnz, loc_A, NULL));
699 void AmgXSolver::SetMatrixMPITeams(
const HypreParMatrix &A,
700 const Array<double> &loc_A,
701 const Array<int> &loc_I,
702 const Array<int64_t> &loc_J,
703 const bool update_mat)
708 Array<int64_t> all_J;
712 int J_allsz(0), all_NNZ(0), nDevRows(0);
713 const int loc_row_len = std::abs(A.RowPart()[1] -
715 const int loc_Jz_sz = loc_J.Size();
716 const int loc_A_sz = loc_A.Size();
718 MPI_Reduce(&loc_row_len, &nDevRows, 1, MPI_INT, MPI_SUM, 0, devWorld);
719 MPI_Reduce(&loc_Jz_sz, &J_allsz, 1, MPI_INT, MPI_SUM, 0, devWorld);
720 MPI_Reduce(&loc_A_sz, &all_NNZ, 1, MPI_INT, MPI_SUM, 0, devWorld);
722 MPI_Barrier(devWorld);
724 if (myDevWorldRank == 0)
726 all_I.SetSize(nDevRows+devWorldSize);
727 all_J.SetSize(J_allsz); all_J = 0.0;
728 all_A.SetSize(all_NNZ);
731 GatherArray(loc_I, all_I, devWorldSize, devWorld);
732 GatherArray(loc_J, all_J, devWorldSize, devWorld);
733 GatherArray(loc_A, all_A, devWorldSize, devWorld);
735 MPI_Barrier(devWorld);
738 int64_t local_rows(0);
740 if (myDevWorldRank == 0)
744 Array<int> z_ind(devWorldSize+1);
746 while (iter < devWorldSize-1)
750 z_ind[counter] = counter;
752 for (
int idx=1; idx<all_I.Size()-1; idx++)
756 z_ind[counter] = idx-1;
760 z_ind[devWorldSize] = all_I.Size()-1;
764 for (
int idx=z_ind[1]+1; idx < z_ind[2]; idx++)
766 all_I[idx] = all_I[idx-1] + (all_I[idx+1] - all_I[idx]);
771 for (
int idx=z_ind[2]; idx < all_I.Size()-1; ++idx)
773 all_I[idx] = all_I[idx+1];
781 z_ind[counter] = counter;
783 for (
int idx=1; idx<all_I.Size()-1; idx++)
787 z_ind[counter] = idx-1;
792 z_ind[devWorldSize] = all_I.Size()-1;
795 for (
int idx=z_ind[1]+1; idx < all_I.Size()-1; idx++)
797 all_I[idx] = all_I[idx-1] + (all_I[idx+1] - all_I[idx]);
799 local_nnz = all_I[all_I.Size()-devWorldSize];
800 local_rows = nDevRows;
804 mat_local_rows = local_rows;
805 Array<int64_t> rowPart;
808 rowPart.SetSize(gpuWorldSize+1); rowPart=0;
810 MPI_Allgather(&local_rows, 1, MPI_INT64_T,
811 &rowPart.GetData()[1], 1, MPI_INT64_T,
813 MPI_Barrier(gpuWorld);
816 for (
int i=1; i<rowPart.Size(); ++i)
818 rowPart[i] += rowPart[i-1];
822 MPI_Barrier(gpuWorld);
824 int nGlobalRows = A.M();
825 if (update_mat ==
false)
827 AMGX_distribution_handle dist;
828 AMGX_SAFE_CALL(AMGX_distribution_create(&dist, cfg));
829 AMGX_SAFE_CALL(AMGX_distribution_set_partition_data(dist,
830 AMGX_DIST_PARTITION_OFFSETS,
833 AMGX_SAFE_CALL(AMGX_matrix_upload_distributed(AmgXA, nGlobalRows,
834 local_rows, local_nnz,
835 1, 1, all_I.ReadWrite(),
840 AMGX_SAFE_CALL(AMGX_distribution_destroy(dist));
841 MPI_Barrier(gpuWorld);
843 AMGX_SAFE_CALL(AMGX_solver_setup(solver, AmgXA));
846 AMGX_SAFE_CALL(AMGX_vector_bind(AmgXP, AmgXA));
847 AMGX_SAFE_CALL(AMGX_vector_bind(AmgXRHS, AmgXA));
851 AMGX_SAFE_CALL(AMGX_matrix_replace_coefficients(AmgXA, nGlobalRows,
852 local_nnz, all_A, NULL));
865 dynamic_cast<const SparseMatrix*>(&op))
871 dynamic_cast<const HypreParMatrix*>(&op))
885 dynamic_cast<const SparseMatrix*>(&op))
887 SetMatrix(*Aptr,
true);
891 dynamic_cast<const HypreParMatrix*>(&op))
893 SetMatrix(*Aptr,
true);
909 if (mpi_gpu_mode !=
"mpi-teams")
911 AMGX_SAFE_CALL(AMGX_vector_upload(AmgXP, X.
Size(), 1, X.
ReadWrite()));
912 AMGX_SAFE_CALL(AMGX_vector_upload(AmgXRHS, B.
Size(), 1, B.
Read()));
914 if (mpi_gpu_mode !=
"serial")
917 MPI_Barrier(gpuWorld);
921 AMGX_SAFE_CALL(AMGX_solver_solve(solver,AmgXRHS, AmgXP));
923 AMGX_SOLVE_STATUS status;
924 AMGX_SAFE_CALL(AMGX_solver_get_status(solver, &status));
927 if (status == AMGX_SOLVE_DIVERGED)
933 mfem_error(
"AmgX solver failed to solve system \n");
937 AMGX_SAFE_CALL(AMGX_vector_download(AmgXP, X.
Write()));
942 Vector all_X(mat_local_rows);
943 Vector all_B(mat_local_rows);
949 GatherArray(X, all_X, devWorldSize, devWorld, Apart_X, Adisp_X);
950 GatherArray(B, all_B, devWorldSize, devWorld, Apart_B, Adisp_B);
951 MPI_Barrier(devWorld);
953 if (gpuWorld != MPI_COMM_NULL)
955 AMGX_SAFE_CALL(AMGX_vector_upload(AmgXP, all_X.
Size(), 1, all_X.
ReadWrite()));
956 AMGX_SAFE_CALL(AMGX_vector_upload(AmgXRHS, all_B.
Size(), 1, all_B.
ReadWrite()));
958 MPI_Barrier(gpuWorld);
960 AMGX_SAFE_CALL(AMGX_solver_solve(solver,AmgXRHS, AmgXP));
962 AMGX_SOLVE_STATUS status;
963 AMGX_SAFE_CALL(AMGX_solver_get_status(solver, &status));
964 if (status != AMGX_SOLVE_SUCCESS && amgxMode ==
SOLVER)
966 if (status == AMGX_SOLVE_DIVERGED)
972 mfem_error(
"AmgX solver failed to solve system \n");
976 AMGX_SAFE_CALL(AMGX_vector_download(AmgXP, all_X.
Write()));
979 ScatterArray(all_X, X, devWorldSize, devWorld, Apart_X, Adisp_X);
986 AMGX_SAFE_CALL(AMGX_solver_get_iterations_number(solver, &getIters));
993 if (! isInitialized || count < 1)
995 mfem_error(
"Error in AmgXSolver::Finalize(). \n" 996 "This AmgXWrapper has not been initialized. \n" 997 "Please initialize it before finalization.\n");
1002 if (gpuProc == 0 || mpi_gpu_mode ==
"serial")
1006 AMGX_SAFE_CALL(AMGX_solver_destroy(solver));
1009 AMGX_SAFE_CALL(AMGX_matrix_destroy(AmgXA));
1012 AMGX_SAFE_CALL(AMGX_vector_destroy(AmgXP));
1013 AMGX_SAFE_CALL(AMGX_vector_destroy(AmgXRHS));
1018 AMGX_SAFE_CALL(AMGX_resources_destroy(rsrc));
1019 AMGX_SAFE_CALL(AMGX_config_destroy(cfg));
1021 AMGX_SAFE_CALL(AMGX_finalize_plugins());
1022 AMGX_SAFE_CALL(AMGX_finalize());
1026 AMGX_SAFE_CALL(AMGX_config_destroy(cfg));
1030 if (mpi_gpu_mode !=
"serial")
1032 MPI_Comm_free(&gpuWorld);
1040 gpuProc = MPI_UNDEFINED;
1041 if (globalCpuWorld != MPI_COMM_NULL)
1043 MPI_Comm_free(&globalCpuWorld);
1044 MPI_Comm_free(&localCpuWorld);
1045 MPI_Comm_free(&devWorld);
1052 isInitialized =
false;
bool ConvergenceCheck
Flag to check for convergence.
virtual void UseDevice(bool use_dev) const
Enable execution of Vector operations using the mfem::Device.
int Width() const
Get the width (size of input) of the Operator. Synonym with NumCols().
int Size() const
Returns the size of the vector.
virtual const double * Read(bool on_dev=true) const
Shortcut for mfem::Read(vec.GetMemory(), vec.Size(), on_dev).
bool iterative_mode
If true, use the second argument of Mult() as an initial guess.
void InitMPITeams(const MPI_Comm &comm, const int nDevs)
void ReadParameters(const std::string config, CONFIG_SRC source)
void source(const Vector &x, Vector &f)
void SetConvergenceCheck(bool setConvergenceCheck_=true)
Add a check for convergence after applying Mult.
virtual double * Write(bool on_dev=true)
Shortcut for mfem::Write(vec.GetMemory(), vec.Size(), on_dev).
void mfem_error(const char *msg)
Function called when an error is encountered. Used by the macros MFEM_ABORT, MFEM_ASSERT, MFEM_VERIFY.
void UpdateOperator(const Operator &op)
OutStream out(std::cout)
Global stream used by the library for standard output. Initially it uses the same std::streambuf as s...
virtual void SetOperator(const Operator &op)
int Height() const
Get the height (size of output) of the Operator. Synonym with NumRows().
void DefaultParameters(const AMGX_MODE amgxMode_, const bool verbose)
int height
Dimension of the output / number of rows in the matrix.
virtual double * ReadWrite(bool on_dev=true)
Shortcut for mfem::ReadWrite(vec.GetMemory(), vec.Size(), on_dev).
void InitExclusiveGPU(const MPI_Comm &comm)
virtual void Mult(const Vector &b, Vector &x) const
Operator application: y=A(x).
Wrapper for hypre's ParCSR matrix class.
AMGX_MODE
Flags to configure AmgXSolver as a solver or preconditioner.
int width
Dimension of the input / number of columns in the matrix.