22 #include "../config/config.hpp"
29 int AmgXSolver::count = 0;
31 AMGX_resources_handle AmgXSolver::rsrc =
nullptr;
34 : ConvergenceCheck(false) {};
51 const AMGX_MODE amgxMode_,
const bool verbose)
65 const AMGX_MODE amgxMode_,
const bool verbose)
89 mpi_gpu_mode =
"serial";
91 AMGX_SAFE_CALL(AMGX_initialize());
93 AMGX_SAFE_CALL(AMGX_initialize_plugins());
95 AMGX_SAFE_CALL(AMGX_install_signal_handler());
97 MFEM_VERIFY(configSrc != CONFIG_SRC::UNDEFINED,
98 "AmgX configuration is not defined \n");
100 if (configSrc == CONFIG_SRC::EXTERNAL)
102 AMGX_SAFE_CALL(AMGX_config_create_from_file(&cfg, amgx_config.c_str()));
106 AMGX_SAFE_CALL(AMGX_config_create(&cfg, amgx_config.c_str()));
109 AMGX_SAFE_CALL(AMGX_resources_create_simple(&rsrc, cfg));
110 AMGX_SAFE_CALL(AMGX_solver_create(&solver, rsrc, precision_mode, cfg));
111 AMGX_SAFE_CALL(AMGX_matrix_create(&AmgXA, rsrc, precision_mode));
112 AMGX_SAFE_CALL(AMGX_vector_create(&AmgXP, rsrc, precision_mode));
113 AMGX_SAFE_CALL(AMGX_vector_create(&AmgXRHS, rsrc, precision_mode));
115 isInitialized =
true;
125 mfem_error(
"This AmgXSolver instance has been initialized on this process.");
129 mpi_gpu_mode =
"mpi-gpu-exclusive";
135 MPI_Comm_dup(comm, &gpuWorld);
136 MPI_Comm_size(gpuWorld, &gpuWorldSize);
137 MPI_Comm_rank(gpuWorld, &myGpuWorldRank);
140 nDevs = 1, devID = 0;
144 isInitialized =
true;
155 mfem_error(
"This AmgXSolver instance has been initialized on this process.");
158 mpi_gpu_mode =
"mpi-teams";
165 char name[MPI_MAX_PROCESSOR_NAME];
166 MPI_Get_processor_name(name, &len);
170 MPI_Comm_rank(comm, &globalcommrank);
173 InitMPIcomms(comm, nDevs);
181 isInitialized =
true;
189 amgx_config = config;
201 amgxMode = amgxMode_;
205 if (amgxMode == AMGX_MODE::PRECONDITIONER)
208 " \"config_version\": 2, \n"
210 " \"solver\": \"AMG\", \n"
211 " \"scope\": \"main\", \n"
212 " \"smoother\": \"JACOBI_L1\", \n"
213 " \"presweeps\": 1, \n"
214 " \"interpolator\": \"D2\", \n"
215 " \"max_row_sum\" : 0.9, \n"
216 " \"strength_threshold\" : 0.25, \n"
217 " \"postsweeps\": 1, \n"
218 " \"max_iters\": 1, \n"
222 amgx_config = amgx_config +
",\n"
223 " \"obtain_timings\": 1, \n"
224 " \"print_grid_stats\": 1, \n"
225 " \"monitor_residual\": 1, \n"
226 " \"print_solve_stats\": 1 \n";
230 amgx_config = amgx_config +
"\n";
232 amgx_config = amgx_config +
" }\n" +
"}\n";
236 else if (amgxMode == AMGX_MODE::SOLVER)
239 " \"config_version\": 2, \n"
241 " \"preconditioner\": { \n"
242 " \"solver\": \"AMG\", \n"
243 " \"smoother\": { \n"
244 " \"scope\": \"jacobi\", \n"
245 " \"solver\": \"JACOBI_L1\" \n"
247 " \"presweeps\": 1, \n"
248 " \"interpolator\": \"D2\", \n"
249 " \"max_row_sum\" : 0.9, \n"
250 " \"strength_threshold\" : 0.25, \n"
251 " \"max_iters\": 1, \n"
252 " \"scope\": \"amg\", \n"
253 " \"max_levels\": 100, \n"
254 " \"cycle\": \"V\", \n"
255 " \"postsweeps\": 1 \n"
257 " \"solver\": \"PCG\", \n"
258 " \"max_iters\": 150, \n"
259 " \"convergence\": \"RELATIVE_INI_CORE\", \n"
260 " \"scope\": \"main\", \n"
261 " \"tolerance\": 1e-12, \n"
262 " \"monitor_residual\": 1, \n"
263 " \"norm\": \"L2\" ";
266 amgx_config = amgx_config +
", \n"
267 " \"obtain_timings\": 1, \n"
268 " \"print_grid_stats\": 1, \n"
269 " \"print_solve_stats\": 1 \n";
273 amgx_config = amgx_config +
"\n";
275 amgx_config = amgx_config +
" } \n" +
"} \n";
287 void AmgXSolver::InitAmgX()
292 AMGX_SAFE_CALL(AMGX_initialize());
294 AMGX_SAFE_CALL(AMGX_initialize_plugins());
296 AMGX_SAFE_CALL(AMGX_install_signal_handler());
298 AMGX_SAFE_CALL(AMGX_register_print_callback(
299 [](
const char *msg,
int length)->
void
301 int irank; MPI_Comm_rank(MPI_COMM_WORLD, &irank);
305 MFEM_VERIFY(configSrc != CONFIG_SRC::UNDEFINED,
306 "AmgX configuration is not defined \n");
308 if (configSrc == CONFIG_SRC::EXTERNAL)
310 AMGX_SAFE_CALL(AMGX_config_create_from_file(&cfg, amgx_config.c_str()));
314 AMGX_SAFE_CALL(AMGX_config_create(&cfg, amgx_config.c_str()));
318 AMGX_SAFE_CALL(AMGX_config_add_parameters(&cfg,
"exception_handling=1"));
322 if (count == 1) { AMGX_SAFE_CALL(AMGX_resources_create(&rsrc, cfg, &gpuWorld, 1, &devID)); }
325 AMGX_SAFE_CALL(AMGX_vector_create(&AmgXP, rsrc, precision_mode));
326 AMGX_SAFE_CALL(AMGX_vector_create(&AmgXRHS, rsrc, precision_mode));
329 AMGX_SAFE_CALL(AMGX_matrix_create(&AmgXA, rsrc, precision_mode));
332 AMGX_SAFE_CALL(AMGX_solver_create(&solver, rsrc, precision_mode, cfg));
335 AMGX_SAFE_CALL(AMGX_config_get_default_number_of_rings(cfg, &ring));
339 void AmgXSolver::InitMPIcomms(
const MPI_Comm &comm,
const int nDevs)
342 MPI_Comm_dup(comm, &globalCpuWorld);
343 MPI_Comm_set_name(globalCpuWorld,
"globalCpuWorld");
346 MPI_Comm_size(globalCpuWorld, &globalSize);
347 MPI_Comm_rank(globalCpuWorld, &myGlobalRank);
350 MPI_Comm_split_type(globalCpuWorld,
351 MPI_COMM_TYPE_SHARED, 0, MPI_INFO_NULL, &localCpuWorld);
352 MPI_Comm_set_name(localCpuWorld,
"localCpuWorld");
355 MPI_Comm_size(localCpuWorld, &localSize);
356 MPI_Comm_rank(localCpuWorld, &myLocalRank);
361 MPI_Barrier(globalCpuWorld);
364 MPI_Comm_split(globalCpuWorld, gpuProc, 0, &gpuWorld);
367 if (gpuWorld != MPI_COMM_NULL)
369 MPI_Comm_set_name(gpuWorld,
"gpuWorld");
370 MPI_Comm_size(gpuWorld, &gpuWorldSize);
371 MPI_Comm_rank(gpuWorld, &myGpuWorldRank);
375 gpuWorldSize = MPI_UNDEFINED;
376 myGpuWorldRank = MPI_UNDEFINED;
380 MPI_Comm_split(localCpuWorld, devID, 0, &devWorld);
381 MPI_Comm_set_name(devWorld,
"devWorld");
384 MPI_Comm_size(devWorld, &devWorldSize);
385 MPI_Comm_rank(devWorld, &myDevWorldRank);
387 MPI_Barrier(globalCpuWorld);
391 void AmgXSolver::SetDeviceIDs(
const int nDevs)
394 if (nDevs == localSize)
399 else if (nDevs > localSize)
401 MFEM_WARNING(
"CUDA devices on the node " << nodeName.c_str() <<
402 " are more than the MPI processes launched. Only "<<
403 nDevs <<
" devices will be used.\n");
409 int nBasic = localSize / nDevs,
410 nRemain = localSize % nDevs;
412 if (myLocalRank < (nBasic+1)*nRemain)
414 devID = myLocalRank / (nBasic + 1);
415 if (myLocalRank % (nBasic + 1) == 0) { gpuProc = 0; }
419 devID = (myLocalRank - (nBasic+1)*nRemain) / nBasic + nRemain;
420 if ((myLocalRank - (nBasic+1)*nRemain) % nBasic == 0) { gpuProc = 0; }
425 void AmgXSolver::GatherArray(
const Array<double> &inArr, Array<double> &outArr,
426 const int mpiTeamSz,
const MPI_Comm &mpiTeamComm)
const
429 Array<int> Apart(mpiTeamSz);
430 int locAsz = inArr.Size();
431 MPI_Gather(&locAsz, 1, MPI_INT,
432 Apart.HostWrite(),1, MPI_INT,0,mpiTeamComm);
434 MPI_Barrier(mpiTeamComm);
437 Array<int> Adisp(mpiTeamSz);
438 int myid; MPI_Comm_rank(mpiTeamComm, &myid);
442 for (
int i=1; i<mpiTeamSz; ++i)
444 Adisp[i] = Adisp[i-1] + Apart[i-1];
448 MPI_Gatherv(inArr.HostRead(), inArr.Size(), MPI_DOUBLE,
449 outArr.HostWrite(), Apart.HostRead(), Adisp.HostRead(),
450 MPI_DOUBLE, 0, mpiTeamComm);
453 void AmgXSolver::GatherArray(
const Vector &inArr, Vector &outArr,
454 const int mpiTeamSz,
const MPI_Comm &mpiTeamComm)
const
457 Array<int> Apart(mpiTeamSz);
458 int locAsz = inArr.Size();
459 MPI_Gather(&locAsz, 1, MPI_INT,
460 Apart.HostWrite(),1, MPI_INT,0,mpiTeamComm);
462 MPI_Barrier(mpiTeamComm);
465 Array<int> Adisp(mpiTeamSz);
466 int myid; MPI_Comm_rank(mpiTeamComm, &myid);
470 for (
int i=1; i<mpiTeamSz; ++i)
472 Adisp[i] = Adisp[i-1] + Apart[i-1];
476 MPI_Gatherv(inArr.HostRead(), inArr.Size(), MPI_DOUBLE,
477 outArr.HostWrite(), Apart.HostRead(), Adisp.HostRead(),
478 MPI_DOUBLE, 0, mpiTeamComm);
481 void AmgXSolver::GatherArray(
const Array<int> &inArr, Array<int> &outArr,
482 const int mpiTeamSz,
const MPI_Comm &mpiTeamComm)
const
485 Array<int> Apart(mpiTeamSz);
486 int locAsz = inArr.Size();
487 MPI_Gather(&locAsz, 1, MPI_INT,
488 Apart.GetData(),1, MPI_INT,0,mpiTeamComm);
490 MPI_Barrier(mpiTeamComm);
493 Array<int> Adisp(mpiTeamSz);
494 int myid; MPI_Comm_rank(mpiTeamComm, &myid);
498 for (
int i=1; i<mpiTeamSz; ++i)
500 Adisp[i] = Adisp[i-1] + Apart[i-1];
504 MPI_Gatherv(inArr.HostRead(), inArr.Size(), MPI_INT,
505 outArr.HostWrite(), Apart.HostRead(), Adisp.HostRead(),
506 MPI_INT, 0, mpiTeamComm);
510 void AmgXSolver::GatherArray(
const Array<int64_t> &inArr,
511 Array<int64_t> &outArr,
512 const int mpiTeamSz,
const MPI_Comm &mpiTeamComm)
const
515 Array<int> Apart(mpiTeamSz);
516 int locAsz = inArr.Size();
517 MPI_Gather(&locAsz, 1, MPI_INT,
518 Apart.GetData(),1, MPI_INT,0,mpiTeamComm);
520 MPI_Barrier(mpiTeamComm);
523 Array<int> Adisp(mpiTeamSz);
524 int myid; MPI_Comm_rank(mpiTeamComm, &myid);
528 for (
int i=1; i<mpiTeamSz; ++i)
530 Adisp[i] = Adisp[i-1] + Apart[i-1];
534 MPI_Gatherv(inArr.HostRead(), inArr.Size(), MPI_INT64_T,
535 outArr.HostWrite(), Apart.HostRead(), Adisp.HostRead(),
536 MPI_INT64_T, 0, mpiTeamComm);
538 MPI_Barrier(mpiTeamComm);
541 void AmgXSolver::GatherArray(
const Vector &inArr, Vector &outArr,
542 const int mpiTeamSz,
const MPI_Comm &mpiTeamComm,
543 Array<int> &Apart, Array<int> &Adisp)
const
546 int locAsz = inArr.Size();
547 MPI_Allgather(&locAsz, 1, MPI_INT,
548 Apart.HostWrite(),1, MPI_INT, mpiTeamComm);
550 MPI_Barrier(mpiTeamComm);
554 for (
int i=1; i<mpiTeamSz; ++i)
556 Adisp[i] = Adisp[i-1] + Apart[i-1];
559 MPI_Gatherv(inArr.HostRead(), inArr.Size(), MPI_DOUBLE,
560 outArr.HostWrite(), Apart.HostRead(), Adisp.HostRead(),
561 MPI_DOUBLE, 0, mpiTeamComm);
564 void AmgXSolver::ScatterArray(
const Vector &inArr, Vector &outArr,
565 const int mpiTeamSz,
const MPI_Comm &mpiTeamComm,
566 Array<int> &Apart, Array<int> &Adisp)
const
568 MPI_Scatterv(inArr.HostRead(),Apart.HostRead(),Adisp.HostRead(),
569 MPI_DOUBLE,outArr.HostWrite(),outArr.Size(),
570 MPI_DOUBLE, 0, mpiTeamComm);
574 void AmgXSolver::SetMatrix(
const SparseMatrix &in_A,
const bool update_mat)
576 if (update_mat ==
false)
578 AMGX_SAFE_CALL(AMGX_matrix_upload_all(AmgXA, in_A.Height(),
579 in_A.NumNonZeroElems(),
583 in_A.ReadData(), NULL));
585 AMGX_SAFE_CALL(AMGX_solver_setup(solver, AmgXA));
586 AMGX_SAFE_CALL(AMGX_vector_bind(AmgXP, AmgXA));
587 AMGX_SAFE_CALL(AMGX_vector_bind(AmgXRHS, AmgXA));
591 AMGX_SAFE_CALL(AMGX_matrix_replace_coefficients(AmgXA,
593 in_A.NumNonZeroElems(),
594 in_A.ReadData(), NULL));
600 void AmgXSolver::SetMatrix(
const HypreParMatrix &A,
const bool update_mat)
603 #if MFEM_HYPRE_VERSION < 21600
604 mfem_error(
"Hypre version 2.16+ is required when using AmgX \n");
607 hypre_ParCSRMatrix * A_ptr =
608 (hypre_ParCSRMatrix *)const_cast<HypreParMatrix&>(A);
610 hypre_CSRMatrix *A_csr = hypre_MergeDiagAndOffd(A_ptr);
612 Array<double> loc_A(A_csr->data, (
int)A_csr->num_nonzeros);
613 const Array<HYPRE_Int> loc_I(A_csr->i, (
int)A_csr->num_rows+1);
616 Array<int64_t> loc_J((
int)A_csr->num_nonzeros);
617 for (
int i=0; i<A_csr->num_nonzeros; ++i)
619 loc_J[i] = A_csr->big_j[i];
623 if (mpi_gpu_mode==
"mpi-gpu-exclusive")
625 SetMatrixMPIGPUExclusive(A, loc_A, loc_I, loc_J, update_mat);
627 hypre_CSRMatrixDestroy(A_csr);
632 if (mpi_gpu_mode ==
"mpi-teams")
634 SetMatrixMPITeams(A, loc_A, loc_I, loc_J, update_mat);
636 hypre_CSRMatrixDestroy(A_csr);
640 mfem_error(
"Unsupported MPI_GPU combination \n");
643 void AmgXSolver::SetMatrixMPIGPUExclusive(
const HypreParMatrix &A,
644 const Array<double> &loc_A,
645 const Array<int> &loc_I,
646 const Array<int64_t> &loc_J,
647 const bool update_mat)
650 Array<int64_t> rowPart(gpuWorldSize+1); rowPart = 0.0;
652 int64_t myStart = A.GetRowStarts()[0];
654 MPI_Allgather(&myStart, 1, MPI_INT64_T,
655 rowPart.GetData(),1, MPI_INT64_T
657 MPI_Barrier(gpuWorld);
659 rowPart[gpuWorldSize] = A.M();
661 const int nGlobalRows = A.M();
662 const int local_rows = loc_I.Size()-1;
663 const int num_nnz = loc_I[local_rows];
665 if (update_mat ==
false)
667 AMGX_distribution_handle dist;
668 AMGX_SAFE_CALL(AMGX_distribution_create(&dist, cfg));
669 AMGX_SAFE_CALL(AMGX_distribution_set_partition_data(dist,
670 AMGX_DIST_PARTITION_OFFSETS,
673 AMGX_SAFE_CALL(AMGX_matrix_upload_distributed(AmgXA, nGlobalRows,
674 local_rows, num_nnz, 1, 1,
675 loc_I.Read(), loc_J.Read(),
676 loc_A.Read(), NULL, dist));
678 AMGX_SAFE_CALL(AMGX_distribution_destroy(dist));
680 MPI_Barrier(gpuWorld);
682 AMGX_SAFE_CALL(AMGX_solver_setup(solver, AmgXA));
684 AMGX_SAFE_CALL(AMGX_vector_bind(AmgXP, AmgXA));
685 AMGX_SAFE_CALL(AMGX_vector_bind(AmgXRHS, AmgXA));
689 AMGX_SAFE_CALL(AMGX_matrix_replace_coefficients(AmgXA, nGlobalRows,
690 num_nnz, loc_A, NULL));
694 void AmgXSolver::SetMatrixMPITeams(
const HypreParMatrix &A,
695 const Array<double> &loc_A,
696 const Array<int> &loc_I,
697 const Array<int64_t> &loc_J,
698 const bool update_mat)
703 Array<int64_t> all_J;
707 int J_allsz(0), all_NNZ(0), nDevRows(0);
708 const int loc_row_len = std::abs(A.RowPart()[1] -
710 const int loc_Jz_sz = loc_J.Size();
711 const int loc_A_sz = loc_A.Size();
713 MPI_Reduce(&loc_row_len, &nDevRows, 1, MPI_INT, MPI_SUM, 0, devWorld);
714 MPI_Reduce(&loc_Jz_sz, &J_allsz, 1, MPI_INT, MPI_SUM, 0, devWorld);
715 MPI_Reduce(&loc_A_sz, &all_NNZ, 1, MPI_INT, MPI_SUM, 0, devWorld);
717 MPI_Barrier(devWorld);
719 if (myDevWorldRank == 0)
721 all_I.SetSize(nDevRows+devWorldSize);
722 all_J.SetSize(J_allsz); all_J = 0.0;
723 all_A.SetSize(all_NNZ);
726 GatherArray(loc_I, all_I, devWorldSize, devWorld);
727 GatherArray(loc_J, all_J, devWorldSize, devWorld);
728 GatherArray(loc_A, all_A, devWorldSize, devWorld);
730 MPI_Barrier(devWorld);
733 int64_t local_rows(0);
735 if (myDevWorldRank == 0)
739 Array<int> z_ind(devWorldSize+1);
741 while (iter < devWorldSize-1)
745 z_ind[counter] = counter;
747 for (
int idx=1; idx<all_I.Size()-1; idx++)
751 z_ind[counter] = idx-1;
755 z_ind[devWorldSize] = all_I.Size()-1;
759 for (
int idx=z_ind[1]+1; idx < z_ind[2]; idx++)
761 all_I[idx] = all_I[idx-1] + (all_I[idx+1] - all_I[idx]);
766 for (
int idx=z_ind[2]; idx < all_I.Size()-1; ++idx)
768 all_I[idx] = all_I[idx+1];
776 z_ind[counter] = counter;
778 for (
int idx=1; idx<all_I.Size()-1; idx++)
782 z_ind[counter] = idx-1;
787 z_ind[devWorldSize] = all_I.Size()-1;
790 for (
int idx=z_ind[1]+1; idx < all_I.Size()-1; idx++)
792 all_I[idx] = all_I[idx-1] + (all_I[idx+1] - all_I[idx]);
794 local_nnz = all_I[all_I.Size()-devWorldSize];
795 local_rows = nDevRows;
799 mat_local_rows = local_rows;
800 Array<int64_t> rowPart;
803 rowPart.SetSize(gpuWorldSize+1); rowPart=0;
805 MPI_Allgather(&local_rows, 1, MPI_INT64_T,
806 &rowPart.GetData()[1], 1, MPI_INT64_T,
808 MPI_Barrier(gpuWorld);
811 for (
int i=1; i<rowPart.Size(); ++i)
813 rowPart[i] += rowPart[i-1];
817 MPI_Barrier(gpuWorld);
819 int nGlobalRows = A.M();
820 if (update_mat ==
false)
822 AMGX_distribution_handle dist;
823 AMGX_SAFE_CALL(AMGX_distribution_create(&dist, cfg));
824 AMGX_SAFE_CALL(AMGX_distribution_set_partition_data(dist,
825 AMGX_DIST_PARTITION_OFFSETS,
828 AMGX_SAFE_CALL(AMGX_matrix_upload_distributed(AmgXA, nGlobalRows,
829 local_rows, local_nnz,
830 1, 1, all_I.ReadWrite(),
835 AMGX_SAFE_CALL(AMGX_distribution_destroy(dist));
836 MPI_Barrier(gpuWorld);
838 AMGX_SAFE_CALL(AMGX_solver_setup(solver, AmgXA));
841 AMGX_SAFE_CALL(AMGX_vector_bind(AmgXP, AmgXA));
842 AMGX_SAFE_CALL(AMGX_vector_bind(AmgXRHS, AmgXA));
846 AMGX_SAFE_CALL(AMGX_matrix_replace_coefficients(AmgXA, nGlobalRows,
847 local_nnz, all_A, NULL));
860 dynamic_cast<const SparseMatrix*>(&op))
866 dynamic_cast<const HypreParMatrix*>(&op))
880 dynamic_cast<const SparseMatrix*>(&op))
882 SetMatrix(*Aptr,
true);
886 dynamic_cast<const HypreParMatrix*>(&op))
888 SetMatrix(*Aptr,
true);
904 if (mpi_gpu_mode !=
"mpi-teams")
906 AMGX_SAFE_CALL(AMGX_vector_upload(AmgXP, X.
Size(), 1, X.
ReadWrite()));
907 AMGX_SAFE_CALL(AMGX_vector_upload(AmgXRHS, B.
Size(), 1, B.
Read()));
909 if (mpi_gpu_mode !=
"serial")
912 MPI_Barrier(gpuWorld);
916 AMGX_SAFE_CALL(AMGX_solver_solve(solver,AmgXRHS, AmgXP));
918 AMGX_SOLVE_STATUS status;
919 AMGX_SAFE_CALL(AMGX_solver_get_status(solver, &status));
922 if (status == AMGX_SOLVE_DIVERGED)
928 mfem_error(
"AmgX solver failed to solve system \n");
932 AMGX_SAFE_CALL(AMGX_vector_download(AmgXP, X.
Write()));
937 Vector all_X(mat_local_rows);
938 Vector all_B(mat_local_rows);
944 GatherArray(X, all_X, devWorldSize, devWorld, Apart_X, Adisp_X);
945 GatherArray(B, all_B, devWorldSize, devWorld, Apart_B, Adisp_B);
946 MPI_Barrier(devWorld);
948 if (gpuWorld != MPI_COMM_NULL)
950 AMGX_SAFE_CALL(AMGX_vector_upload(AmgXP, all_X.
Size(), 1, all_X.
ReadWrite()));
951 AMGX_SAFE_CALL(AMGX_vector_upload(AmgXRHS, all_B.
Size(), 1, all_B.
ReadWrite()));
953 MPI_Barrier(gpuWorld);
955 AMGX_SAFE_CALL(AMGX_solver_solve(solver,AmgXRHS, AmgXP));
957 AMGX_SOLVE_STATUS status;
958 AMGX_SAFE_CALL(AMGX_solver_get_status(solver, &status));
959 if (status != AMGX_SOLVE_SUCCESS && amgxMode ==
SOLVER)
961 if (status == AMGX_SOLVE_DIVERGED)
967 mfem_error(
"AmgX solver failed to solve system \n");
971 AMGX_SAFE_CALL(AMGX_vector_download(AmgXP, all_X.
Write()));
974 ScatterArray(all_X, X, devWorldSize, devWorld, Apart_X, Adisp_X);
981 AMGX_SAFE_CALL(AMGX_solver_get_iterations_number(solver, &getIters));
988 if (! isInitialized || count < 1)
990 mfem_error(
"Error in AmgXSolver::Finalize(). \n"
991 "This AmgXWrapper has not been initialized. \n"
992 "Please initialize it before finalization.\n");
997 if (gpuProc == 0 || mpi_gpu_mode ==
"serial")
1001 AMGX_SAFE_CALL(AMGX_solver_destroy(solver));
1004 AMGX_SAFE_CALL(AMGX_matrix_destroy(AmgXA));
1007 AMGX_SAFE_CALL(AMGX_vector_destroy(AmgXP));
1008 AMGX_SAFE_CALL(AMGX_vector_destroy(AmgXRHS));
1013 AMGX_SAFE_CALL(AMGX_resources_destroy(rsrc));
1014 AMGX_SAFE_CALL(AMGX_config_destroy(cfg));
1016 AMGX_SAFE_CALL(AMGX_finalize_plugins());
1017 AMGX_SAFE_CALL(AMGX_finalize());
1021 AMGX_SAFE_CALL(AMGX_config_destroy(cfg));
1025 if (mpi_gpu_mode !=
"serial")
1027 MPI_Comm_free(&gpuWorld);
1035 gpuProc = MPI_UNDEFINED;
1036 if (globalCpuWorld != MPI_COMM_NULL)
1038 MPI_Comm_free(&globalCpuWorld);
1039 MPI_Comm_free(&localCpuWorld);
1040 MPI_Comm_free(&devWorld);
1047 isInitialized =
false;
bool ConvergenceCheck
Flag to check for convergence.
int Width() const
Get the width (size of input) of the Operator. Synonym with NumCols().
int Size() const
Returns the size of the vector.
bool iterative_mode
If true, use the second argument of Mult() as an initial guess.
void InitMPITeams(const MPI_Comm &comm, const int nDevs)
virtual void UseDevice(bool use_dev) const
Enable execution of Vector operations using the mfem::Device.
void ReadParameters(const std::string config, CONFIG_SRC source)
virtual void Mult(const Vector &b, Vector &x) const
Operator application: y=A(x).
void source(const Vector &x, Vector &f)
void SetConvergenceCheck(bool setConvergenceCheck_=true)
Add a check for convergence after applying Mult.
int Height() const
Get the height (size of output) of the Operator. Synonym with NumRows().
virtual double * Write(bool on_dev=true)
Shortcut for mfem::Write(vec.GetMemory(), vec.Size(), on_dev).
void mfem_error(const char *msg)
Function called when an error is encountered. Used by the macros MFEM_ABORT, MFEM_ASSERT, MFEM_VERIFY.
void UpdateOperator(const Operator &op)
virtual void SetOperator(const Operator &op)
void DefaultParameters(const AMGX_MODE amgxMode_, const bool verbose)
int height
Dimension of the output / number of rows in the matrix.
virtual double * ReadWrite(bool on_dev=true)
Shortcut for mfem::ReadWrite(vec.GetMemory(), vec.Size(), on_dev).
void InitExclusiveGPU(const MPI_Comm &comm)
OutStream out(std::cout)
Global stream used by the library for standard output. Initially it uses the same std::streambuf as s...
Wrapper for hypre's ParCSR matrix class.
AMGX_MODE
Flags to configure AmgXSolver as a solver or preconditioner.
virtual const double * Read(bool on_dev=true) const
Shortcut for mfem::Read(vec.GetMemory(), vec.Size(), on_dev).
int width
Dimension of the input / number of columns in the matrix.