22 #include "../config/config.hpp"
29 int AmgXSolver::count = 0;
31 AMGX_resources_handle AmgXSolver::rsrc =
nullptr;
34 : ConvergenceCheck(false) {};
51 const AMGX_MODE amgxMode_,
const bool verbose)
65 const AMGX_MODE amgxMode_,
const bool verbose)
89 mpi_gpu_mode =
"serial";
91 AMGX_SAFE_CALL(AMGX_initialize());
93 AMGX_SAFE_CALL(AMGX_initialize_plugins());
95 AMGX_SAFE_CALL(AMGX_install_signal_handler());
97 MFEM_VERIFY(configSrc != CONFIG_SRC::UNDEFINED,
98 "AmgX configuration is not defined \n");
100 if (configSrc == CONFIG_SRC::EXTERNAL)
102 AMGX_SAFE_CALL(AMGX_config_create_from_file(&cfg, amgx_config.c_str()));
106 AMGX_SAFE_CALL(AMGX_config_create(&cfg, amgx_config.c_str()));
109 AMGX_SAFE_CALL(AMGX_resources_create_simple(&rsrc, cfg));
110 AMGX_SAFE_CALL(AMGX_solver_create(&solver, rsrc, precision_mode, cfg));
111 AMGX_SAFE_CALL(AMGX_matrix_create(&AmgXA, rsrc, precision_mode));
112 AMGX_SAFE_CALL(AMGX_vector_create(&AmgXP, rsrc, precision_mode));
113 AMGX_SAFE_CALL(AMGX_vector_create(&AmgXRHS, rsrc, precision_mode));
115 isInitialized =
true;
125 mfem_error(
"This AmgXSolver instance has been initialized on this process.");
129 mpi_gpu_mode =
"mpi-gpu-exclusive";
135 MPI_Comm_dup(comm, &gpuWorld);
136 MPI_Comm_size(gpuWorld, &gpuWorldSize);
137 MPI_Comm_rank(gpuWorld, &myGpuWorldRank);
140 nDevs = 1, devID = 0;
144 isInitialized =
true;
155 mfem_error(
"This AmgXSolver instance has been initialized on this process.");
158 mpi_gpu_mode =
"mpi-teams";
165 char name[MPI_MAX_PROCESSOR_NAME];
166 MPI_Get_processor_name(name, &len);
170 MPI_Comm_rank(comm, &globalcommrank);
173 InitMPIcomms(comm, nDevs);
181 isInitialized =
true;
189 amgx_config = config;
201 amgxMode = amgxMode_;
205 if (amgxMode == AMGX_MODE::PRECONDITIONER)
208 " \"config_version\": 2, \n"
210 " \"solver\": \"AMG\", \n"
211 " \"scope\": \"main\", \n"
212 " \"smoother\": \"JACOBI_L1\", \n"
213 " \"presweeps\": 1, \n"
214 " \"interpolator\": \"D2\", \n"
215 " \"max_row_sum\" : 0.9, \n"
216 " \"strength_threshold\" : 0.25, \n"
217 " \"postsweeps\": 1, \n"
218 " \"max_iters\": 1, \n"
222 amgx_config = amgx_config +
",\n"
223 " \"obtain_timings\": 1, \n"
224 " \"print_grid_stats\": 1, \n"
225 " \"monitor_residual\": 1, \n"
226 " \"print_solve_stats\": 1 \n";
230 amgx_config = amgx_config +
"\n";
232 amgx_config = amgx_config +
" }\n" +
"}\n";
236 else if (amgxMode == AMGX_MODE::SOLVER)
239 " \"config_version\": 2, \n"
241 " \"preconditioner\": { \n"
242 " \"solver\": \"AMG\", \n"
243 " \"smoother\": { \n"
244 " \"scope\": \"jacobi\", \n"
245 " \"solver\": \"JACOBI_L1\" \n"
247 " \"presweeps\": 1, \n"
248 " \"interpolator\": \"D2\", \n"
249 " \"max_row_sum\" : 0.9, \n"
250 " \"strength_threshold\" : 0.25, \n"
251 " \"max_iters\": 1, \n"
252 " \"scope\": \"amg\", \n"
253 " \"max_levels\": 100, \n"
254 " \"cycle\": \"V\", \n"
255 " \"postsweeps\": 1 \n"
257 " \"solver\": \"PCG\", \n"
258 " \"max_iters\": 150, \n"
259 " \"convergence\": \"RELATIVE_INI_CORE\", \n"
260 " \"scope\": \"main\", \n"
261 " \"tolerance\": 1e-12, \n"
262 " \"monitor_residual\": 1, \n"
263 " \"norm\": \"L2\" ";
266 amgx_config = amgx_config +
", \n"
267 " \"obtain_timings\": 1, \n"
268 " \"print_grid_stats\": 1, \n"
269 " \"print_solve_stats\": 1 \n";
273 amgx_config = amgx_config +
"\n";
275 amgx_config = amgx_config +
" } \n" +
"} \n";
287 void AmgXSolver::InitAmgX()
292 AMGX_SAFE_CALL(AMGX_initialize());
294 AMGX_SAFE_CALL(AMGX_initialize_plugins());
296 AMGX_SAFE_CALL(AMGX_install_signal_handler());
298 AMGX_SAFE_CALL(AMGX_register_print_callback(
299 [](
const char *msg,
int length)->
void
301 int irank; MPI_Comm_rank(MPI_COMM_WORLD, &irank);
305 MFEM_VERIFY(configSrc != CONFIG_SRC::UNDEFINED,
306 "AmgX configuration is not defined \n");
308 if (configSrc == CONFIG_SRC::EXTERNAL)
310 AMGX_SAFE_CALL(AMGX_config_create_from_file(&cfg, amgx_config.c_str()));
314 AMGX_SAFE_CALL(AMGX_config_create(&cfg, amgx_config.c_str()));
318 AMGX_SAFE_CALL(AMGX_config_add_parameters(&cfg,
"exception_handling=1"));
322 if (count == 1) { AMGX_SAFE_CALL(AMGX_resources_create(&rsrc, cfg, &gpuWorld, 1, &devID)); }
325 AMGX_SAFE_CALL(AMGX_vector_create(&AmgXP, rsrc, precision_mode));
326 AMGX_SAFE_CALL(AMGX_vector_create(&AmgXRHS, rsrc, precision_mode));
329 AMGX_SAFE_CALL(AMGX_matrix_create(&AmgXA, rsrc, precision_mode));
332 AMGX_SAFE_CALL(AMGX_solver_create(&solver, rsrc, precision_mode, cfg));
335 AMGX_SAFE_CALL(AMGX_config_get_default_number_of_rings(cfg, &ring));
339 void AmgXSolver::InitMPIcomms(
const MPI_Comm &comm,
const int nDevs)
342 MPI_Comm_dup(comm, &globalCpuWorld);
343 MPI_Comm_set_name(globalCpuWorld,
"globalCpuWorld");
346 MPI_Comm_size(globalCpuWorld, &globalSize);
347 MPI_Comm_rank(globalCpuWorld, &myGlobalRank);
350 MPI_Comm_split_type(globalCpuWorld,
351 MPI_COMM_TYPE_SHARED, 0, MPI_INFO_NULL, &localCpuWorld);
352 MPI_Comm_set_name(localCpuWorld,
"localCpuWorld");
355 MPI_Comm_size(localCpuWorld, &localSize);
356 MPI_Comm_rank(localCpuWorld, &myLocalRank);
361 MPI_Barrier(globalCpuWorld);
364 MPI_Comm_split(globalCpuWorld, gpuProc, 0, &gpuWorld);
367 if (gpuWorld != MPI_COMM_NULL)
369 MPI_Comm_set_name(gpuWorld,
"gpuWorld");
370 MPI_Comm_size(gpuWorld, &gpuWorldSize);
371 MPI_Comm_rank(gpuWorld, &myGpuWorldRank);
375 gpuWorldSize = MPI_UNDEFINED;
376 myGpuWorldRank = MPI_UNDEFINED;
380 MPI_Comm_split(localCpuWorld, devID, 0, &devWorld);
381 MPI_Comm_set_name(devWorld,
"devWorld");
384 MPI_Comm_size(devWorld, &devWorldSize);
385 MPI_Comm_rank(devWorld, &myDevWorldRank);
387 MPI_Barrier(globalCpuWorld);
391 void AmgXSolver::SetDeviceIDs(
const int nDevs)
394 if (nDevs == localSize)
399 else if (nDevs > localSize)
401 MFEM_WARNING(
"CUDA devices on the node " << nodeName.c_str() <<
402 " are more than the MPI processes launched. Only "<<
403 nDevs <<
" devices will be used.\n");
409 int nBasic = localSize / nDevs,
410 nRemain = localSize % nDevs;
412 if (myLocalRank < (nBasic+1)*nRemain)
414 devID = myLocalRank / (nBasic + 1);
415 if (myLocalRank % (nBasic + 1) == 0) { gpuProc = 0; }
419 devID = (myLocalRank - (nBasic+1)*nRemain) / nBasic + nRemain;
420 if ((myLocalRank - (nBasic+1)*nRemain) % nBasic == 0) { gpuProc = 0; }
425 void AmgXSolver::GatherArray(
const Array<double> &inArr, Array<double> &outArr,
426 const int mpiTeamSz,
const MPI_Comm &mpiTeamComm)
const
429 Array<int> Apart(mpiTeamSz);
430 int locAsz = inArr.Size();
431 MPI_Gather(&locAsz, 1, MPI_INT,
432 Apart.HostWrite(),1, MPI_INT,0,mpiTeamComm);
434 MPI_Barrier(mpiTeamComm);
437 Array<int> Adisp(mpiTeamSz);
438 int myid; MPI_Comm_rank(mpiTeamComm, &myid);
442 for (
int i=1; i<mpiTeamSz; ++i)
444 Adisp[i] = Adisp[i-1] + Apart[i-1];
448 MPI_Gatherv(inArr.HostRead(), inArr.Size(), MPI_DOUBLE,
449 outArr.HostWrite(), Apart.HostRead(), Adisp.HostRead(),
450 MPI_DOUBLE, 0, mpiTeamComm);
453 void AmgXSolver::GatherArray(
const Vector &inArr, Vector &outArr,
454 const int mpiTeamSz,
const MPI_Comm &mpiTeamComm)
const
457 Array<int> Apart(mpiTeamSz);
458 int locAsz = inArr.Size();
459 MPI_Gather(&locAsz, 1, MPI_INT,
460 Apart.HostWrite(),1, MPI_INT,0,mpiTeamComm);
462 MPI_Barrier(mpiTeamComm);
465 Array<int> Adisp(mpiTeamSz);
466 int myid; MPI_Comm_rank(mpiTeamComm, &myid);
470 for (
int i=1; i<mpiTeamSz; ++i)
472 Adisp[i] = Adisp[i-1] + Apart[i-1];
476 MPI_Gatherv(inArr.HostRead(), inArr.Size(), MPI_DOUBLE,
477 outArr.HostWrite(), Apart.HostRead(), Adisp.HostRead(),
478 MPI_DOUBLE, 0, mpiTeamComm);
481 void AmgXSolver::GatherArray(
const Array<int> &inArr, Array<int> &outArr,
482 const int mpiTeamSz,
const MPI_Comm &mpiTeamComm)
const
485 Array<int> Apart(mpiTeamSz);
486 int locAsz = inArr.Size();
487 MPI_Gather(&locAsz, 1, MPI_INT,
488 Apart.GetData(),1, MPI_INT,0,mpiTeamComm);
490 MPI_Barrier(mpiTeamComm);
493 Array<int> Adisp(mpiTeamSz);
494 int myid; MPI_Comm_rank(mpiTeamComm, &myid);
498 for (
int i=1; i<mpiTeamSz; ++i)
500 Adisp[i] = Adisp[i-1] + Apart[i-1];
504 MPI_Gatherv(inArr.HostRead(), inArr.Size(), MPI_INT,
505 outArr.HostWrite(), Apart.HostRead(), Adisp.HostRead(),
506 MPI_INT, 0, mpiTeamComm);
510 void AmgXSolver::GatherArray(
const Array<int64_t> &inArr,
511 Array<int64_t> &outArr,
512 const int mpiTeamSz,
const MPI_Comm &mpiTeamComm)
const
515 Array<int> Apart(mpiTeamSz);
516 int locAsz = inArr.Size();
517 MPI_Gather(&locAsz, 1, MPI_INT,
518 Apart.GetData(),1, MPI_INT,0,mpiTeamComm);
520 MPI_Barrier(mpiTeamComm);
523 Array<int> Adisp(mpiTeamSz);
524 int myid; MPI_Comm_rank(mpiTeamComm, &myid);
528 for (
int i=1; i<mpiTeamSz; ++i)
530 Adisp[i] = Adisp[i-1] + Apart[i-1];
534 MPI_Gatherv(inArr.HostRead(), inArr.Size(), MPI_INT64_T,
535 outArr.HostWrite(), Apart.HostRead(), Adisp.HostRead(),
536 MPI_INT64_T, 0, mpiTeamComm);
538 MPI_Barrier(mpiTeamComm);
541 void AmgXSolver::GatherArray(
const Vector &inArr, Vector &outArr,
542 const int mpiTeamSz,
const MPI_Comm &mpiTeamComm,
543 Array<int> &Apart, Array<int> &Adisp)
const
546 int locAsz = inArr.Size();
547 MPI_Allgather(&locAsz, 1, MPI_INT,
548 Apart.HostWrite(),1, MPI_INT, mpiTeamComm);
550 MPI_Barrier(mpiTeamComm);
554 for (
int i=1; i<mpiTeamSz; ++i)
556 Adisp[i] = Adisp[i-1] + Apart[i-1];
559 MPI_Gatherv(inArr.HostRead(), inArr.Size(), MPI_DOUBLE,
560 outArr.HostWrite(), Apart.HostRead(), Adisp.HostRead(),
561 MPI_DOUBLE, 0, mpiTeamComm);
564 void AmgXSolver::ScatterArray(
const Vector &inArr, Vector &outArr,
565 const int mpiTeamSz,
const MPI_Comm &mpiTeamComm,
566 Array<int> &Apart, Array<int> &Adisp)
const
568 MPI_Scatterv(inArr.HostRead(),Apart.HostRead(),Adisp.HostRead(),
569 MPI_DOUBLE,outArr.HostWrite(),outArr.Size(),
570 MPI_DOUBLE, 0, mpiTeamComm);
574 void AmgXSolver::SetMatrix(
const SparseMatrix &in_A,
const bool update_mat)
576 if (update_mat ==
false)
578 AMGX_SAFE_CALL(AMGX_matrix_upload_all(AmgXA, in_A.Height(),
579 in_A.NumNonZeroElems(),
583 in_A.ReadData(), NULL));
585 AMGX_SAFE_CALL(AMGX_solver_setup(solver, AmgXA));
586 AMGX_SAFE_CALL(AMGX_vector_bind(AmgXP, AmgXA));
587 AMGX_SAFE_CALL(AMGX_vector_bind(AmgXRHS, AmgXA));
591 AMGX_SAFE_CALL(AMGX_matrix_replace_coefficients(AmgXA,
593 in_A.NumNonZeroElems(),
594 in_A.ReadData(), NULL));
600 void AmgXSolver::SetMatrix(
const HypreParMatrix &A,
const bool update_mat)
603 #if MFEM_HYPRE_VERSION < 21600
604 mfem_error(
"Hypre version 2.16+ is required when using AmgX \n");
610 hypre_ParCSRMatrix * A_ptr =
611 (hypre_ParCSRMatrix *)const_cast<HypreParMatrix&>(A);
613 hypre_CSRMatrix *A_csr = hypre_MergeDiagAndOffd(A_ptr);
617 Array<double> loc_A(A_csr->data, (
int)A_csr->num_nonzeros);
618 const Array<HYPRE_Int> loc_I(A_csr->i, (
int)A_csr->num_rows+1);
621 Array<int64_t> loc_J((
int)A_csr->num_nonzeros);
622 for (
int i=0; i<A_csr->num_nonzeros; ++i)
624 loc_J[i] = A_csr->big_j[i];
628 if (mpi_gpu_mode==
"mpi-gpu-exclusive")
630 SetMatrixMPIGPUExclusive(A, loc_A, loc_I, loc_J, update_mat);
632 hypre_CSRMatrixDestroy(A_csr);
637 if (mpi_gpu_mode ==
"mpi-teams")
639 SetMatrixMPITeams(A, loc_A, loc_I, loc_J, update_mat);
641 hypre_CSRMatrixDestroy(A_csr);
645 mfem_error(
"Unsupported MPI_GPU combination \n");
648 void AmgXSolver::SetMatrixMPIGPUExclusive(
const HypreParMatrix &A,
649 const Array<double> &loc_A,
650 const Array<int> &loc_I,
651 const Array<int64_t> &loc_J,
652 const bool update_mat)
655 Array<int64_t> rowPart(gpuWorldSize+1); rowPart = 0.0;
657 int64_t myStart = A.GetRowStarts()[0];
659 MPI_Allgather(&myStart, 1, MPI_INT64_T,
660 rowPart.GetData(),1, MPI_INT64_T
662 MPI_Barrier(gpuWorld);
664 rowPart[gpuWorldSize] = A.M();
666 const int nGlobalRows = A.M();
667 const int local_rows = loc_I.Size()-1;
668 const int num_nnz = loc_I[local_rows];
670 if (update_mat ==
false)
672 AMGX_distribution_handle dist;
673 AMGX_SAFE_CALL(AMGX_distribution_create(&dist, cfg));
674 AMGX_SAFE_CALL(AMGX_distribution_set_partition_data(dist,
675 AMGX_DIST_PARTITION_OFFSETS,
678 AMGX_SAFE_CALL(AMGX_matrix_upload_distributed(AmgXA, nGlobalRows,
679 local_rows, num_nnz, 1, 1,
680 loc_I.Read(), loc_J.Read(),
681 loc_A.Read(), NULL, dist));
683 AMGX_SAFE_CALL(AMGX_distribution_destroy(dist));
685 MPI_Barrier(gpuWorld);
687 AMGX_SAFE_CALL(AMGX_solver_setup(solver, AmgXA));
689 AMGX_SAFE_CALL(AMGX_vector_bind(AmgXP, AmgXA));
690 AMGX_SAFE_CALL(AMGX_vector_bind(AmgXRHS, AmgXA));
694 AMGX_SAFE_CALL(AMGX_matrix_replace_coefficients(AmgXA, nGlobalRows,
695 num_nnz, loc_A, NULL));
699 void AmgXSolver::SetMatrixMPITeams(
const HypreParMatrix &A,
700 const Array<double> &loc_A,
701 const Array<int> &loc_I,
702 const Array<int64_t> &loc_J,
703 const bool update_mat)
708 Array<int64_t> all_J;
712 int J_allsz(0), all_NNZ(0), nDevRows(0);
713 const int loc_row_len = std::abs(A.RowPart()[1] -
715 const int loc_Jz_sz = loc_J.Size();
716 const int loc_A_sz = loc_A.Size();
718 MPI_Reduce(&loc_row_len, &nDevRows, 1, MPI_INT, MPI_SUM, 0, devWorld);
719 MPI_Reduce(&loc_Jz_sz, &J_allsz, 1, MPI_INT, MPI_SUM, 0, devWorld);
720 MPI_Reduce(&loc_A_sz, &all_NNZ, 1, MPI_INT, MPI_SUM, 0, devWorld);
722 MPI_Barrier(devWorld);
724 if (myDevWorldRank == 0)
726 all_I.SetSize(nDevRows+devWorldSize);
727 all_J.SetSize(J_allsz); all_J = 0.0;
728 all_A.SetSize(all_NNZ);
731 GatherArray(loc_I, all_I, devWorldSize, devWorld);
732 GatherArray(loc_J, all_J, devWorldSize, devWorld);
733 GatherArray(loc_A, all_A, devWorldSize, devWorld);
735 MPI_Barrier(devWorld);
738 int64_t local_rows(0);
740 if (myDevWorldRank == 0)
744 Array<int> z_ind(devWorldSize+1);
746 while (iter < devWorldSize-1)
750 z_ind[counter] = counter;
752 for (
int idx=1; idx<all_I.Size()-1; idx++)
756 z_ind[counter] = idx-1;
760 z_ind[devWorldSize] = all_I.Size()-1;
764 for (
int idx=z_ind[1]+1; idx < z_ind[2]; idx++)
766 all_I[idx] = all_I[idx-1] + (all_I[idx+1] - all_I[idx]);
771 for (
int idx=z_ind[2]; idx < all_I.Size()-1; ++idx)
773 all_I[idx] = all_I[idx+1];
781 z_ind[counter] = counter;
783 for (
int idx=1; idx<all_I.Size()-1; idx++)
787 z_ind[counter] = idx-1;
792 z_ind[devWorldSize] = all_I.Size()-1;
795 for (
int idx=z_ind[1]+1; idx < all_I.Size()-1; idx++)
797 all_I[idx] = all_I[idx-1] + (all_I[idx+1] - all_I[idx]);
799 local_nnz = all_I[all_I.Size()-devWorldSize];
800 local_rows = nDevRows;
804 mat_local_rows = local_rows;
805 Array<int64_t> rowPart;
808 rowPart.SetSize(gpuWorldSize+1); rowPart=0;
810 MPI_Allgather(&local_rows, 1, MPI_INT64_T,
811 &rowPart.GetData()[1], 1, MPI_INT64_T,
813 MPI_Barrier(gpuWorld);
816 for (
int i=1; i<rowPart.Size(); ++i)
818 rowPart[i] += rowPart[i-1];
822 MPI_Barrier(gpuWorld);
824 int nGlobalRows = A.M();
825 if (update_mat ==
false)
827 AMGX_distribution_handle dist;
828 AMGX_SAFE_CALL(AMGX_distribution_create(&dist, cfg));
829 AMGX_SAFE_CALL(AMGX_distribution_set_partition_data(dist,
830 AMGX_DIST_PARTITION_OFFSETS,
833 AMGX_SAFE_CALL(AMGX_matrix_upload_distributed(AmgXA, nGlobalRows,
834 local_rows, local_nnz,
835 1, 1, all_I.ReadWrite(),
840 AMGX_SAFE_CALL(AMGX_distribution_destroy(dist));
841 MPI_Barrier(gpuWorld);
843 AMGX_SAFE_CALL(AMGX_solver_setup(solver, AmgXA));
846 AMGX_SAFE_CALL(AMGX_vector_bind(AmgXP, AmgXA));
847 AMGX_SAFE_CALL(AMGX_vector_bind(AmgXRHS, AmgXA));
851 AMGX_SAFE_CALL(AMGX_matrix_replace_coefficients(AmgXA, nGlobalRows,
852 local_nnz, all_A, NULL));
865 dynamic_cast<const SparseMatrix*>(&op))
871 dynamic_cast<const HypreParMatrix*>(&op))
885 dynamic_cast<const SparseMatrix*>(&op))
887 SetMatrix(*Aptr,
true);
891 dynamic_cast<const HypreParMatrix*>(&op))
893 SetMatrix(*Aptr,
true);
909 if (mpi_gpu_mode !=
"mpi-teams")
911 AMGX_SAFE_CALL(AMGX_vector_upload(AmgXP, X.
Size(), 1, X.
ReadWrite()));
912 AMGX_SAFE_CALL(AMGX_vector_upload(AmgXRHS, B.
Size(), 1, B.
Read()));
914 if (mpi_gpu_mode !=
"serial")
917 MPI_Barrier(gpuWorld);
921 AMGX_SAFE_CALL(AMGX_solver_solve(solver,AmgXRHS, AmgXP));
923 AMGX_SOLVE_STATUS status;
924 AMGX_SAFE_CALL(AMGX_solver_get_status(solver, &status));
927 if (status == AMGX_SOLVE_DIVERGED)
933 mfem_error(
"AmgX solver failed to solve system \n");
937 AMGX_SAFE_CALL(AMGX_vector_download(AmgXP, X.
Write()));
942 Vector all_X(mat_local_rows);
943 Vector all_B(mat_local_rows);
949 GatherArray(X, all_X, devWorldSize, devWorld, Apart_X, Adisp_X);
950 GatherArray(B, all_B, devWorldSize, devWorld, Apart_B, Adisp_B);
951 MPI_Barrier(devWorld);
953 if (gpuWorld != MPI_COMM_NULL)
955 AMGX_SAFE_CALL(AMGX_vector_upload(AmgXP, all_X.
Size(), 1, all_X.
ReadWrite()));
956 AMGX_SAFE_CALL(AMGX_vector_upload(AmgXRHS, all_B.
Size(), 1, all_B.
ReadWrite()));
958 MPI_Barrier(gpuWorld);
960 AMGX_SAFE_CALL(AMGX_solver_solve(solver,AmgXRHS, AmgXP));
962 AMGX_SOLVE_STATUS status;
963 AMGX_SAFE_CALL(AMGX_solver_get_status(solver, &status));
964 if (status != AMGX_SOLVE_SUCCESS && amgxMode ==
SOLVER)
966 if (status == AMGX_SOLVE_DIVERGED)
972 mfem_error(
"AmgX solver failed to solve system \n");
976 AMGX_SAFE_CALL(AMGX_vector_download(AmgXP, all_X.
Write()));
979 ScatterArray(all_X, X, devWorldSize, devWorld, Apart_X, Adisp_X);
986 AMGX_SAFE_CALL(AMGX_solver_get_iterations_number(solver, &getIters));
993 if (! isInitialized || count < 1)
995 mfem_error(
"Error in AmgXSolver::Finalize(). \n"
996 "This AmgXWrapper has not been initialized. \n"
997 "Please initialize it before finalization.\n");
1002 if (gpuProc == 0 || mpi_gpu_mode ==
"serial")
1006 AMGX_SAFE_CALL(AMGX_solver_destroy(solver));
1009 AMGX_SAFE_CALL(AMGX_matrix_destroy(AmgXA));
1012 AMGX_SAFE_CALL(AMGX_vector_destroy(AmgXP));
1013 AMGX_SAFE_CALL(AMGX_vector_destroy(AmgXRHS));
1018 AMGX_SAFE_CALL(AMGX_resources_destroy(rsrc));
1019 AMGX_SAFE_CALL(AMGX_config_destroy(cfg));
1021 AMGX_SAFE_CALL(AMGX_finalize_plugins());
1022 AMGX_SAFE_CALL(AMGX_finalize());
1026 AMGX_SAFE_CALL(AMGX_config_destroy(cfg));
1030 if (mpi_gpu_mode !=
"serial")
1032 MPI_Comm_free(&gpuWorld);
1040 gpuProc = MPI_UNDEFINED;
1041 if (globalCpuWorld != MPI_COMM_NULL)
1043 MPI_Comm_free(&globalCpuWorld);
1044 MPI_Comm_free(&localCpuWorld);
1045 MPI_Comm_free(&devWorld);
1052 isInitialized =
false;
bool ConvergenceCheck
Flag to check for convergence.
int Width() const
Get the width (size of input) of the Operator. Synonym with NumCols().
int Size() const
Returns the size of the vector.
bool iterative_mode
If true, use the second argument of Mult() as an initial guess.
void InitMPITeams(const MPI_Comm &comm, const int nDevs)
virtual void UseDevice(bool use_dev) const
Enable execution of Vector operations using the mfem::Device.
void ReadParameters(const std::string config, CONFIG_SRC source)
virtual void Mult(const Vector &b, Vector &x) const
Operator application: y=A(x).
void source(const Vector &x, Vector &f)
void SetConvergenceCheck(bool setConvergenceCheck_=true)
Add a check for convergence after applying Mult.
int Height() const
Get the height (size of output) of the Operator. Synonym with NumRows().
virtual double * Write(bool on_dev=true)
Shortcut for mfem::Write(vec.GetMemory(), vec.Size(), on_dev).
void mfem_error(const char *msg)
Function called when an error is encountered. Used by the macros MFEM_ABORT, MFEM_ASSERT, MFEM_VERIFY.
void UpdateOperator(const Operator &op)
virtual void SetOperator(const Operator &op)
void DefaultParameters(const AMGX_MODE amgxMode_, const bool verbose)
int height
Dimension of the output / number of rows in the matrix.
virtual double * ReadWrite(bool on_dev=true)
Shortcut for mfem::ReadWrite(vec.GetMemory(), vec.Size(), on_dev).
void InitExclusiveGPU(const MPI_Comm &comm)
OutStream out(std::cout)
Global stream used by the library for standard output. Initially it uses the same std::streambuf as s...
Wrapper for hypre's ParCSR matrix class.
AMGX_MODE
Flags to configure AmgXSolver as a solver or preconditioner.
virtual const double * Read(bool on_dev=true) const
Shortcut for mfem::Read(vec.GetMemory(), vec.Size(), on_dev).
int width
Dimension of the input / number of columns in the matrix.