12 #include "../config/config.hpp"
14 #ifdef MFEM_USE_STRUMPACK
20 using namespace strumpack;
25 STRUMPACKRowLocMatrix::STRUMPACKRowLocMatrix(MPI_Comm comm,
26 int num_loc_rows,
int first_loc_row,
27 int glob_nrows,
int glob_ncols,
28 int *I,
int *J,
double *data)
29 : comm_(comm), A_(NULL)
37 MPI_Comm_rank(comm_, &rank);
38 MPI_Comm_size(comm_, &nprocs);
39 int * dist =
new int[nprocs + 1];
40 dist[rank + 1] = first_loc_row + num_loc_rows;
42 MPI_Allgather(MPI_IN_PLACE, 0, MPI_INT, dist + 1, 1, MPI_INT, comm_);
43 A_ =
new CSRMatrixMPI<double,int>(num_loc_rows, I, J, data, dist, comm_,
false);
48 : comm_(hypParMat.GetComm()),
52 hypre_ParCSRMatrix * parcsr_op =
53 (hypre_ParCSRMatrix *)const_cast<HypreParMatrix&>(hypParMat);
55 MFEM_ASSERT(parcsr_op != NULL,
"STRUMPACK: const_cast failed in SetOperator");
59 hypre_CSRMatrix * csr_op = hypre_MergeDiagAndOffd(parcsr_op);
60 hypre_CSRMatrixSetDataOwner(csr_op,0);
61 #if MFEM_HYPRE_VERSION >= 21600
66 hypre_CSRMatrixBigJtoJ(csr_op);
70 width = csr_op->num_rows;
73 MPI_Comm_rank(comm_, &rank);
74 MPI_Comm_size(comm_, &nprocs);
75 int * dist =
new int[nprocs + 1];
76 dist[rank + 1] = parcsr_op->first_row_index + csr_op->num_rows;
78 MPI_Allgather(MPI_IN_PLACE, 0, MPI_INT, dist + 1, 1, MPI_INT, comm_);
79 A_ =
new CSRMatrixMPI<double,int>(csr_op->num_rows, csr_op->i, csr_op->j,
80 csr_op->data, dist, comm_,
false);
84 hypre_CSRMatrixDestroy(csr_op);
90 if ( A_ != NULL ) {
delete A_; }
98 this->
Init(argc, argv);
102 : comm_(A.GetComm()),
117 void STRUMPACKSolver::Init(
int argc,
char* argv[] )
125 solver_ =
new StrumpackSparseSolverMPIDist<double,int>(
comm_, argc, argv,
131 solver_->options().set_from_command_line( );
146 solver_->options().set_Krylov_solver( method );
152 solver_->options().set_reordering_method( method );
157 #if STRUMPACK_VERSION_MAJOR >= 3
158 solver_->options().set_matching( strumpack::MatchingJob::NONE );
160 solver_->options().set_mc64job( strumpack::MC64Job::NONE );
166 #if STRUMPACK_VERSION_MAJOR >= 3
167 solver_->options().set_matching
168 ( strumpack::MatchingJob::MAX_DIAGONAL_PRODUCT_SCALING );
171 ( strumpack::MC64Job::MAX_DIAGONAL_PRODUCT_SCALING );
175 #if STRUMPACK_VERSION_MAJOR >= 3
178 solver_->options().set_matching
179 ( strumpack::MatchingJob::COMBBLAS );
185 solver_->options().set_rel_tol( rtol );
190 solver_->options().set_abs_tol( atol );
196 MFEM_ASSERT(
APtr_ != NULL,
197 "STRUMPACK Error: The operator must be set before"
198 " the system can be solved.");
199 MFEM_ASSERT(x.
Size() ==
Width(),
"invalid x.Size() = " << x.
Size()
200 <<
", expected size = " <<
Width());
201 MFEM_ASSERT(y.
Size() ==
Height(),
"invalid y.Size() = " << y.
Size()
202 <<
", expected size = " <<
Height());
204 double* yPtr = (
double*)y;
205 double* xPtr = (
double*)(const_cast<Vector&>(x));
208 ReturnCode ret =
solver_->factor();
211 case ReturnCode::SUCCESS:
break;
212 case ReturnCode::MATRIX_NOT_SET:
214 MFEM_ABORT(
"STRUMPACK: Matrix was not set!");
217 case ReturnCode::REORDERING_ERROR:
219 MFEM_ABORT(
"STRUMPACK: Matrix reordering failed!");
234 mfem_error(
"STRUMPACKSolver::SetOperator : not STRUMPACKRowLocMatrix!");
247 #endif // MFEM_USE_MPI
248 #endif // MFEM_USE_STRUMPACK
const STRUMPACKRowLocMatrix * APtr_
int Width() const
Get the width (size of input) of the Operator. Synonym with NumCols().
STRUMPACKSolver(int argc, char *argv[], MPI_Comm comm)
int Size() const
Returns the size of the vector.
strumpack::StrumpackSparseSolverMPIDist< double, int > * solver_
void SetFromCommandLine()
strumpack::CSRMatrixMPI< double, int > * getA() const
void SetPrintFactorStatistics(bool print_stat)
STRUMPACKRowLocMatrix(MPI_Comm comm, int num_loc_rows, int first_loc_row, int glob_nrows, int glob_ncols, int *I, int *J, double *data)
int Height() const
Get the height (size of output) of the Operator. Synonym with NumRows().
void mfem_error(const char *msg)
Function called when an error is encountered. Used by the macros MFEM_ABORT, MFEM_ASSERT, MFEM_VERIFY.
void SetPrintSolveStatistics(bool print_stat)
void SetReorderingStrategy(strumpack::ReorderingStrategy method)
void SetRelTol(double rtol)
A class to initialize the size of a Tensor.
int height
Dimension of the output / number of rows in the matrix.
void Mult(const Vector &x, Vector &y) const
Operator application: y=A(x).
void SetKrylovSolver(strumpack::KrylovSolver method)
void SetAbsTol(double atol)
void SetOperator(const Operator &op)
Set/update the solver for the given operator.
Wrapper for hypre's ParCSR matrix class.
int width
Dimension of the input / number of columns in the matrix.
void EnableParallelMatching()