14#ifdef MFEM_USE_STRUMPACK
28 double *data,
bool sym_sparse)
36 MPI_Comm_rank(comm, &rank);
37 MPI_Comm_size(comm, &nprocs);
40 dist[rank + 1] = first_loc_row + (
HYPRE_BigInt)num_loc_rows;
41 MPI_Allgather(MPI_IN_PLACE, 0, MPI_DATATYPE_NULL,
42 dist.
GetData() + 1, 1, HYPRE_MPI_BIG_INT, comm);
44#if !(defined(HYPRE_BIGINT) || defined(HYPRE_MIXEDINT))
45 A_ =
new strumpack::CSRMatrixMPI<double, HYPRE_BigInt>(
50 for (
int i = 0; i <= num_loc_rows; i++) { II[i] = (
HYPRE_BigInt)I[i]; }
51 A_ =
new strumpack::CSRMatrixMPI<double, HYPRE_BigInt>(
61 MFEM_VERIFY(APtr,
"Not a compatible matrix type");
62 MPI_Comm comm = APtr->
GetComm();
69 hypre_ParCSRMatrix *parcsr_op =
75 hypre_CSRMatrix *csr_op = hypre_MergeDiagAndOffd(parcsr_op);
77 HYPRE_Int *Iptr = csr_op->i;
78#if MFEM_HYPRE_VERSION >= 21600
81 HYPRE_Int *Jptr = csr_op->j;
83 double *data = csr_op->data;
86 HYPRE_Int m_loc = csr_op->num_rows;
90 MPI_Comm_rank(comm, &rank);
91 MPI_Comm_size(comm, &nprocs);
95 MPI_Allgather(MPI_IN_PLACE, 0, MPI_DATATYPE_NULL,
96 dist.GetData() + 1, 1, HYPRE_MPI_BIG_INT, comm);
98#if !defined(HYPRE_MIXEDINT)
99 A_ =
new strumpack::CSRMatrixMPI<double, HYPRE_BigInt>(
104 for (
int i = 0; i <= m_loc; i++) { II[i] = (
HYPRE_BigInt)Iptr[i]; }
105 A_ =
new strumpack::CSRMatrixMPI<double, HYPRE_BigInt>(
106 (
HYPRE_BigInt)m_loc, II.GetData(), Jptr, data, dist.GetData(),
111 hypre_CSRMatrixDestroy(csr_op);
119template <
typename STRUMPACKSolverType>
123 factor_verbose_(false),
124 solve_verbose_(false),
125 reorder_reuse_(false),
128 solver_ =
new STRUMPACKSolverType(comm, argc, argv,
false);
131template <
typename STRUMPACKSolverType>
135 factor_verbose_(false),
136 solve_verbose_(false),
137 reorder_reuse_(false),
140 solver_ =
new STRUMPACKSolverType(A.
GetComm(), argc, argv,
false);
144template <
typename STRUMPACKSolverType>
151template <
typename STRUMPACKSolverType>
155 solver_->options().set_from_command_line();
158template <
typename STRUMPACKSolverType>
162 factor_verbose_ = print_stat;
165template <
typename STRUMPACKSolverType>
169 solve_verbose_ = print_stat;
172template <
typename STRUMPACKSolverType>
176 solver_->options().set_rel_tol(rtol);
179template <
typename STRUMPACKSolverType>
183 solver_->options().set_abs_tol(atol);
186template <
typename STRUMPACKSolverType>
190 solver_->options().set_maxit(max_it);
193template <
typename STRUMPACKSolverType>
197 reorder_reuse_ = reuse;
200template <
typename STRUMPACKSolverType>
204 solver_->options().enable_gpu();
207template <
typename STRUMPACKSolverType>
211 solver_->options().disable_gpu();
214template <
typename STRUMPACKSolverType>
218 solver_->options().set_Krylov_solver(method);
221template <
typename STRUMPACKSolverType>
225 solver_->options().set_reordering_method(method);
228template <
typename STRUMPACKSolverType>
232 solver_->options().set_matching(job);
235template <
typename STRUMPACKSolverType>
239#if STRUMPACK_VERSION_MAJOR >= 5
240 solver_->options().set_compression(type);
244 case strumpack::NONE:
245 solver_->options().disable_BLR();
246 solver_->options().disable_HSS();
249 solver_->options().enable_BLR();
252 solver_->options().enable_HSS();
255 MFEM_ABORT(
"Invalid compression type for STRUMPACK version " <<
256 STRUMPACK_VERSION_MAJOR <<
"!");
262template <
typename STRUMPACKSolverType>
266#if STRUMPACK_VERSION_MAJOR >= 5
267 solver_->options().set_compression_rel_tol(rtol);
269 solver_->options().BLR_options().set_rel_tol(rtol);
270 solver_->options().HSS_options().set_rel_tol(rtol);
274template <
typename STRUMPACKSolverType>
278#if STRUMPACK_VERSION_MAJOR >= 5
279 solver_->options().set_compression_abs_tol(atol);
281 solver_->options().BLR_options().set_abs_tol(atol);
282 solver_->options().HSS_options().set_abs_tol(atol);
286#if STRUMPACK_VERSION_MAJOR >= 5
287template <
typename STRUMPACKSolverType>
291 solver_->options().set_lossy_precision(precision);
294template <
typename STRUMPACKSolverType>
298 solver_->options().HODLR_options().set_butterfly_levels(levels);
302template <
typename STRUMPACKSolverType>
307 bool first_mat = !APtr_;
310 "STRUMPACK: Operator is not a STRUMPACKRowLocMatrix!");
316 if (first_mat || !reorder_reuse_)
318 solver_->set_matrix(*(APtr_->GetA()));
322 solver_->update_matrix_values(*(APtr_->GetA()));
326template <
typename STRUMPACKSolverType>
331 "STRUMPACK: Operator must be set before the system can be "
333 solver_->options().set_verbose(factor_verbose_);
334 strumpack::ReturnCode ret = solver_->factor();
335 if (ret != strumpack::ReturnCode::SUCCESS)
337#if STRUMPACK_VERSION_MAJOR >= 7
338 MFEM_ABORT(
"STRUMPACK: Factor failed with return code " << ret <<
"!");
340 MFEM_ABORT(
"STRUMPACK: Factor failed!");
345template <
typename STRUMPACKSolverType>
349 MFEM_ASSERT(x.
Size() == Width(),
350 "STRUMPACK: Invalid x.Size() = " << x.
Size() <<
351 ", expected size = " << Width() <<
"!");
352 MFEM_ASSERT(y.
Size() == Height(),
353 "STRUMPACK: Invalid y.Size() = " << y.
Size() <<
354 ", expected size = " << Height() <<
"!");
360 solver_->options().set_verbose(solve_verbose_);
361 strumpack::ReturnCode ret = solver_->solve(xPtr, yPtr,
false);
362 if (ret != strumpack::ReturnCode::SUCCESS)
364#if STRUMPACK_VERSION_MAJOR >= 7
365 MFEM_ABORT(
"STRUMPACK: Solve failed with return code " << ret <<
"!");
367 MFEM_ABORT(
"STRUMPACK: Solve failed!");
372template <
typename STRUMPACKSolverType>
377 "Number of columns mismatch in STRUMPACK solve!");
381 MFEM_ASSERT(X[0] && Y[0],
"Missing Vector in STRUMPACK solve!");
388 if (nrhs_ != X.
Size())
390 rhs_.SetSize(X.
Size() * ldx);
391 sol_.SetSize(X.
Size() * ldx);
394 for (
int i = 0; i < nrhs_; i++)
396 MFEM_ASSERT(X[i] && X[i]->Size() == Width(),
397 "STRUMPACK: Missing or invalid sized RHS Vector in solve!");
402 const double *xPtr = rhs_.
HostRead();
403 double *yPtr = sol_.HostReadWrite();
406 solver_->options().set_verbose(solve_verbose_);
407 strumpack::ReturnCode ret = solver_->solve(nrhs_, xPtr, ldx, yPtr, ldx,
409 if (ret != strumpack::ReturnCode::SUCCESS)
411#if STRUMPACK_VERSION_MAJOR >= 7
412 MFEM_ABORT(
"STRUMPACK: Solve failed with return code " << ret <<
"!");
414 MFEM_ABORT(
"STRUMPACK: Solve failed!");
418 for (
int i = 0; i < nrhs_; i++)
420 MFEM_ASSERT(Y[i] && Y[i]->Size() == Width(),
421 "STRUMPACK: Missing or invalid sized solution Vector in solve!");
443 (comm, argc, argv) {}
451#if STRUMPACK_VERSION_MAJOR >= 7
455 SparseSolverMixedPrecisionMPIDist<float, double,
HYPRE_BigInt>>
461 SparseSolverMixedPrecisionMPIDist<float, double,
HYPRE_BigInt>>
467 SparseSolverMixedPrecisionMPIDist<float, double,
HYPRE_BigInt>>
468 (comm, argc, argv) {}
473 SparseSolverMixedPrecisionMPIDist<float, double,
HYPRE_BigInt>>
478 SparseSolverMPIDist<double, HYPRE_BigInt>>;
479#if STRUMPACK_VERSION_MAJOR >= 7
481 SparseSolverMixedPrecisionMPIDist<float, double, HYPRE_BigInt>>;
const T * HostRead() const
Shortcut for mfem::Read(a.GetMemory(), a.Size(), false).
int Size() const
Return the logical size of the array.
T * GetData()
Returns the data.
Wrapper for hypre's ParCSR matrix class.
void HypreRead() const
Update the internal hypre_ParCSRMatrix object, A, to be in hypre memory space.
void HostRead() const
Update the internal hypre_ParCSRMatrix object, A, to be on host.
MPI_Comm GetComm() const
MPI communicator.
int width
Dimension of the input / number of columns in the matrix.
int Height() const
Get the height (size of output) of the Operator. Synonym with NumRows().
int height
Dimension of the output / number of rows in the matrix.
int Width() const
Get the width (size of input) of the Operator. Synonym with NumCols().
STRUMPACKMixedPrecisionSolver(MPI_Comm comm)
Constructor with MPI_Comm parameter.
STRUMPACKRowLocMatrix(MPI_Comm comm, int num_loc_rows, HYPRE_BigInt first_loc_row, HYPRE_BigInt glob_nrows, HYPRE_BigInt glob_ncols, int *I, HYPRE_BigInt *J, double *data, bool sym_sparse=false)
Creates a general parallel matrix from a local CSR matrix on each processor.
MPI_Comm GetComm() const
Get the MPI Comm being used by the parallel matrix.
void SetOperator(const Operator &op)
Set the operator/matrix.
void ArrayMult(const Array< const Vector * > &X, Array< Vector * > &Y) const
Factor and solve the linear systems across the array of vectors.
void SetMatching(strumpack::MatchingJob job)
Configure static pivoting for stability.
void SetCompression(strumpack::CompressionType type)
Select compression for sparse data types.
STRUMPACKSolverBase(MPI_Comm comm, int argc, char *argv[])
Constructor with MPI_Comm parameter and command line arguments.
void SetPrintFactorStatistics(bool print_stat)
Set up verbose printing during the factor step.
void SetKrylovSolver(strumpack::KrylovSolver method)
Set the Krylov solver method to use.
void Mult(const Vector &x, Vector &y) const
Factor and solve the linear system .
void SetCompressionButterflyLevels(int levels)
Set the number of butterfly levels for the HODLR compression option.
void SetCompressionAbsTol(double atol)
Set the absolute tolerance for low rank compression methods.
virtual ~STRUMPACKSolverBase()
Default destructor.
void SetFromCommandLine()
Set options that were captured from the command line.
void SetCompressionLossyPrecision(int precision)
Set the precision for the lossy compression option.
void SetPrintSolveStatistics(bool print_stat)
Set up verbose printing during the solve step.
void SetCompressionRelTol(double rtol)
Set the relative tolerance for low rank compression methods.
void SetReorderingStrategy(strumpack::ReorderingStrategy method)
Set matrix reordering strategy.
STRUMPACKSolver(MPI_Comm comm)
Constructor with the MPI Comm parameter.
virtual const real_t * HostRead() const
Shortcut for mfem::Read(vec.GetMemory(), vec.Size(), false).
int Size() const
Returns the size of the vector.
virtual real_t * HostReadWrite()
Shortcut for mfem::ReadWrite(vec.GetMemory(), vec.Size(), false).
void Mult(const Table &A, const Table &B, Table &C)
C = A * B (as boolean matrices)