MFEM v4.7.0
Finite element discretization library
Loading...
Searching...
No Matches
strumpack.cpp
Go to the documentation of this file.
1// Copyright (c) 2010-2024, Lawrence Livermore National Security, LLC. Produced
2// at the Lawrence Livermore National Laboratory. All Rights reserved. See files
3// LICENSE and NOTICE for details. LLNL-CODE-806117.
4//
5// This file is part of the MFEM library. For more information and source code
6// availability visit https://mfem.org.
7//
8// MFEM is free software; you can redistribute it and/or modify it under the
9// terms of the BSD-3 license. We welcome feedback and contributions, see file
10// CONTRIBUTING.md for details.
11
12#include "../config/config.hpp"
13
14#ifdef MFEM_USE_STRUMPACK
15#ifdef MFEM_USE_MPI
16
17#include "strumpack.hpp"
18
19namespace mfem
20{
21
23 int num_loc_rows,
24 HYPRE_BigInt first_loc_row,
25 HYPRE_BigInt glob_nrows,
26 HYPRE_BigInt glob_ncols,
27 int *I, HYPRE_BigInt *J,
28 double *data, bool sym_sparse)
29{
30 // Set mfem::Operator member data
31 height = num_loc_rows;
32 width = num_loc_rows;
33
34 // Allocate STRUMPACK's CSRMatrixMPI (copies all inputs)
35 int rank, nprocs;
36 MPI_Comm_rank(comm, &rank);
37 MPI_Comm_size(comm, &nprocs);
38 Array<HYPRE_BigInt> dist(nprocs + 1);
39 dist[0] = 0;
40 dist[rank + 1] = first_loc_row + (HYPRE_BigInt)num_loc_rows;
41 MPI_Allgather(MPI_IN_PLACE, 0, MPI_DATATYPE_NULL,
42 dist.GetData() + 1, 1, HYPRE_MPI_BIG_INT, comm);
43
44#if !(defined(HYPRE_BIGINT) || defined(HYPRE_MIXEDINT))
45 A_ = new strumpack::CSRMatrixMPI<double, HYPRE_BigInt>(
46 (HYPRE_BigInt)num_loc_rows, I, J, data, dist.GetData(),
47 comm, sym_sparse);
48#else
49 Array<HYPRE_BigInt> II(num_loc_rows+1);
50 for (int i = 0; i <= num_loc_rows; i++) { II[i] = (HYPRE_BigInt)I[i]; }
51 A_ = new strumpack::CSRMatrixMPI<double, HYPRE_BigInt>(
52 (HYPRE_BigInt)num_loc_rows, II.GetData(), J, data, dist.GetData(),
53 comm, sym_sparse);
54#endif
55}
56
58 bool sym_sparse)
59{
60 const HypreParMatrix *APtr = dynamic_cast<const HypreParMatrix *>(&op);
61 MFEM_VERIFY(APtr, "Not a compatible matrix type");
62 MPI_Comm comm = APtr->GetComm();
63
64 // Set mfem::Operator member data
65 height = op.Height();
66 width = op.Width();
67
68 // First cast the parameter to a hypre_ParCSRMatrix
69 hypre_ParCSRMatrix *parcsr_op =
70 (hypre_ParCSRMatrix *)const_cast<HypreParMatrix &>(*APtr);
71
72 // Create the CSRMatrixMPI A by taking the internal data from a
73 // hypre_CSRMatrix
74 APtr->HostRead();
75 hypre_CSRMatrix *csr_op = hypre_MergeDiagAndOffd(parcsr_op);
76 APtr->HypreRead();
77 HYPRE_Int *Iptr = csr_op->i;
78#if MFEM_HYPRE_VERSION >= 21600
79 HYPRE_BigInt *Jptr = csr_op->big_j;
80#else
81 HYPRE_Int *Jptr = csr_op->j;
82#endif
83 double *data = csr_op->data;
84
85 HYPRE_BigInt fst_row = parcsr_op->first_row_index;
86 HYPRE_Int m_loc = csr_op->num_rows;
87
88 // Allocate STRUMPACK's CSRMatrixMPI
89 int rank, nprocs;
90 MPI_Comm_rank(comm, &rank);
91 MPI_Comm_size(comm, &nprocs);
92 Array<HYPRE_BigInt> dist(nprocs + 1);
93 dist[0] = 0;
94 dist[rank + 1] = fst_row + (HYPRE_BigInt)m_loc;
95 MPI_Allgather(MPI_IN_PLACE, 0, MPI_DATATYPE_NULL,
96 dist.GetData() + 1, 1, HYPRE_MPI_BIG_INT, comm);
97
98#if !defined(HYPRE_MIXEDINT)
99 A_ = new strumpack::CSRMatrixMPI<double, HYPRE_BigInt>(
100 (HYPRE_BigInt)m_loc, Iptr, Jptr, data, dist.GetData(),
101 comm, sym_sparse);
102#else
104 for (int i = 0; i <= m_loc; i++) { II[i] = (HYPRE_BigInt)Iptr[i]; }
105 A_ = new strumpack::CSRMatrixMPI<double, HYPRE_BigInt>(
106 (HYPRE_BigInt)m_loc, II.GetData(), Jptr, data, dist.GetData(),
107 comm, sym_sparse);
108#endif
109
110 // Everything has been copied so delete the structure
111 hypre_CSRMatrixDestroy(csr_op);
112}
115{
116 delete A_;
117}
118
119template <typename STRUMPACKSolverType>
121STRUMPACKSolverBase(MPI_Comm comm, int argc, char *argv[])
122 : APtr_(NULL),
123 factor_verbose_(false),
124 solve_verbose_(false),
125 reorder_reuse_(false),
126 nrhs_(-1)
127{
128 solver_ = new STRUMPACKSolverType(comm, argc, argv, false);
129}
130
131template <typename STRUMPACKSolverType>
133STRUMPACKSolverBase(STRUMPACKRowLocMatrix &A, int argc, char *argv[])
134 : APtr_(&A),
135 factor_verbose_(false),
136 solve_verbose_(false),
137 reorder_reuse_(false),
138 nrhs_(-1)
139{
140 solver_ = new STRUMPACKSolverType(A.GetComm(), argc, argv, false);
141 SetOperator(A);
142}
143
144template <typename STRUMPACKSolverType>
147{
148 delete solver_;
149}
150
151template <typename STRUMPACKSolverType>
154{
155 solver_->options().set_from_command_line();
156}
157
158template <typename STRUMPACKSolverType>
160SetPrintFactorStatistics(bool print_stat)
161{
162 factor_verbose_ = print_stat;
163}
164
165template <typename STRUMPACKSolverType>
167SetPrintSolveStatistics(bool print_stat)
168{
169 solve_verbose_ = print_stat;
170}
171
172template <typename STRUMPACKSolverType>
174::SetRelTol(double rtol)
175{
176 solver_->options().set_rel_tol(rtol);
177}
178
179template <typename STRUMPACKSolverType>
181::SetAbsTol(double atol)
182{
183 solver_->options().set_abs_tol(atol);
184}
185
186template <typename STRUMPACKSolverType>
188::SetMaxIter(int max_it)
189{
190 solver_->options().set_maxit(max_it);
191}
192
193template <typename STRUMPACKSolverType>
196{
197 reorder_reuse_ = reuse;
198}
199
200template <typename STRUMPACKSolverType>
203{
204 solver_->options().enable_gpu();
205}
206
207template <typename STRUMPACKSolverType>
210{
211 solver_->options().disable_gpu();
212}
213
214template <typename STRUMPACKSolverType>
216SetKrylovSolver(strumpack::KrylovSolver method)
217{
218 solver_->options().set_Krylov_solver(method);
219}
220
221template <typename STRUMPACKSolverType>
223SetReorderingStrategy(strumpack::ReorderingStrategy method)
224{
225 solver_->options().set_reordering_method(method);
226}
227
228template <typename STRUMPACKSolverType>
230SetMatching(strumpack::MatchingJob job)
231{
232 solver_->options().set_matching(job);
233}
234
235template <typename STRUMPACKSolverType>
237SetCompression(strumpack::CompressionType type)
238{
239#if STRUMPACK_VERSION_MAJOR >= 5
240 solver_->options().set_compression(type);
241#else
242 switch (type)
243 {
244 case strumpack::NONE:
245 solver_->options().disable_BLR();
246 solver_->options().disable_HSS();
247 break;
248 case strumpack::BLR:
249 solver_->options().enable_BLR();
250 break;
251 case strumpack::HSS:
252 solver_->options().enable_HSS();
253 break;
254 default:
255 MFEM_ABORT("Invalid compression type for STRUMPACK version " <<
256 STRUMPACK_VERSION_MAJOR << "!");
257 break;
258 }
259#endif
260}
261
262template <typename STRUMPACKSolverType>
264SetCompressionRelTol(double rtol)
265{
266#if STRUMPACK_VERSION_MAJOR >= 5
267 solver_->options().set_compression_rel_tol(rtol);
268#else
269 solver_->options().BLR_options().set_rel_tol(rtol);
270 solver_->options().HSS_options().set_rel_tol(rtol);
271#endif
272}
273
274template <typename STRUMPACKSolverType>
276SetCompressionAbsTol(double atol)
277{
278#if STRUMPACK_VERSION_MAJOR >= 5
279 solver_->options().set_compression_abs_tol(atol);
280#else
281 solver_->options().BLR_options().set_abs_tol(atol);
282 solver_->options().HSS_options().set_abs_tol(atol);
283#endif
284}
285
286#if STRUMPACK_VERSION_MAJOR >= 5
287template <typename STRUMPACKSolverType>
289SetCompressionLossyPrecision(int precision)
290{
291 solver_->options().set_lossy_precision(precision);
292}
293
294template <typename STRUMPACKSolverType>
297{
298 solver_->options().HODLR_options().set_butterfly_levels(levels);
299}
300#endif
301
302template <typename STRUMPACKSolverType>
304SetOperator(const Operator &op)
305{
306 // Verify that we have a compatible operator
307 bool first_mat = !APtr_;
308 APtr_ = dynamic_cast<const STRUMPACKRowLocMatrix *>(&op);
309 MFEM_VERIFY(APtr_,
310 "STRUMPACK: Operator is not a STRUMPACKRowLocMatrix!");
311
312 // Set mfem::Operator member data
313 height = op.Height();
314 width = op.Width();
315
316 if (first_mat || !reorder_reuse_)
317 {
318 solver_->set_matrix(*(APtr_->GetA()));
319 }
320 else
321 {
322 solver_->update_matrix_values(*(APtr_->GetA()));
323 }
324}
325
326template <typename STRUMPACKSolverType>
328FactorInternal() const
329{
330 MFEM_ASSERT(APtr_,
331 "STRUMPACK: Operator must be set before the system can be "
332 "solved!");
333 solver_->options().set_verbose(factor_verbose_);
334 strumpack::ReturnCode ret = solver_->factor();
335 if (ret != strumpack::ReturnCode::SUCCESS)
336 {
337#if STRUMPACK_VERSION_MAJOR >= 7
338 MFEM_ABORT("STRUMPACK: Factor failed with return code " << ret << "!");
339#else
340 MFEM_ABORT("STRUMPACK: Factor failed!");
341#endif
342 }
343}
344
345template <typename STRUMPACKSolverType>
347Mult(const Vector &x, Vector &y) const
348{
349 MFEM_ASSERT(x.Size() == Width(),
350 "STRUMPACK: Invalid x.Size() = " << x.Size() <<
351 ", expected size = " << Width() << "!");
352 MFEM_ASSERT(y.Size() == Height(),
353 "STRUMPACK: Invalid y.Size() = " << y.Size() <<
354 ", expected size = " << Height() << "!");
355
356 const double *xPtr = x.HostRead();
357 double *yPtr = y.HostReadWrite();
358
359 FactorInternal();
360 solver_->options().set_verbose(solve_verbose_);
361 strumpack::ReturnCode ret = solver_->solve(xPtr, yPtr, false);
362 if (ret != strumpack::ReturnCode::SUCCESS)
363 {
364#if STRUMPACK_VERSION_MAJOR >= 7
365 MFEM_ABORT("STRUMPACK: Solve failed with return code " << ret << "!");
366#else
367 MFEM_ABORT("STRUMPACK: Solve failed!");
368#endif
369 }
370}
371
372template <typename STRUMPACKSolverType>
375{
376 MFEM_ASSERT(X.Size() == Y.Size(),
377 "Number of columns mismatch in STRUMPACK solve!");
378 if (X.Size() == 1)
379 {
380 nrhs_ = 1;
381 MFEM_ASSERT(X[0] && Y[0], "Missing Vector in STRUMPACK solve!");
382 Mult(*X[0], *Y[0]);
383 return;
384 }
385
386 // Multiple RHS case
387 int ldx = Height();
388 if (nrhs_ != X.Size())
389 {
390 rhs_.SetSize(X.Size() * ldx);
391 sol_.SetSize(X.Size() * ldx);
392 nrhs_ = X.Size();
393 }
394 for (int i = 0; i < nrhs_; i++)
395 {
396 MFEM_ASSERT(X[i] && X[i]->Size() == Width(),
397 "STRUMPACK: Missing or invalid sized RHS Vector in solve!");
398 Vector s(rhs_, i * ldx, ldx);
399 s = *X[i];
400 rhs_.SyncMemory(s); // Update flags for rhs_ if updated on device
401 }
402 const double *xPtr = rhs_.HostRead();
403 double *yPtr = sol_.HostReadWrite();
404
405 FactorInternal();
406 solver_->options().set_verbose(solve_verbose_);
407 strumpack::ReturnCode ret = solver_->solve(nrhs_, xPtr, ldx, yPtr, ldx,
408 false);
409 if (ret != strumpack::ReturnCode::SUCCESS)
410 {
411#if STRUMPACK_VERSION_MAJOR >= 7
412 MFEM_ABORT("STRUMPACK: Solve failed with return code " << ret << "!");
413#else
414 MFEM_ABORT("STRUMPACK: Solve failed!");
415#endif
416 }
417
418 for (int i = 0; i < nrhs_; i++)
419 {
420 MFEM_ASSERT(Y[i] && Y[i]->Size() == Width(),
421 "STRUMPACK: Missing or invalid sized solution Vector in solve!");
422 Vector s(sol_, i * ldx, ldx);
423 *Y[i] = s;
424 }
425}
426
428STRUMPACKSolver(MPI_Comm comm)
429 : STRUMPACKSolverBase<strumpack::
430 SparseSolverMPIDist<double, HYPRE_BigInt>>
431 (comm, 0, NULL) {}
432
435 : STRUMPACKSolverBase<strumpack::
436 SparseSolverMPIDist<double, HYPRE_BigInt>>
437 (A, 0, NULL) {}
438
440STRUMPACKSolver(MPI_Comm comm, int argc, char *argv[])
441 : STRUMPACKSolverBase<strumpack::
442 SparseSolverMPIDist<double, HYPRE_BigInt>>
443 (comm, argc, argv) {}
444
446STRUMPACKSolver(STRUMPACKRowLocMatrix &A, int argc, char *argv[])
447 : STRUMPACKSolverBase<strumpack::
448 SparseSolverMPIDist<double, HYPRE_BigInt>>
449 (A, argc, argv) {}
450
451#if STRUMPACK_VERSION_MAJOR >= 7
454 : STRUMPACKSolverBase<strumpack::
455 SparseSolverMixedPrecisionMPIDist<float, double, HYPRE_BigInt>>
456 (comm, 0, NULL) {}
457
460 : STRUMPACKSolverBase<strumpack::
461 SparseSolverMixedPrecisionMPIDist<float, double, HYPRE_BigInt>>
462 (A, 0, NULL) {}
463
465STRUMPACKMixedPrecisionSolver(MPI_Comm comm, int argc, char *argv[])
466 : STRUMPACKSolverBase<strumpack::
467 SparseSolverMixedPrecisionMPIDist<float, double, HYPRE_BigInt>>
468 (comm, argc, argv) {}
469
472 : STRUMPACKSolverBase<strumpack::
473 SparseSolverMixedPrecisionMPIDist<float, double, HYPRE_BigInt>>
474 (A, argc, argv) {}
475#endif
476
477template class STRUMPACKSolverBase<strumpack::
478 SparseSolverMPIDist<double, HYPRE_BigInt>>;
479#if STRUMPACK_VERSION_MAJOR >= 7
480template class STRUMPACKSolverBase<strumpack::
481 SparseSolverMixedPrecisionMPIDist<float, double, HYPRE_BigInt>>;
482#endif
483
484} // mfem namespace
485
486#endif // MFEM_USE_MPI
487#endif // MFEM_USE_STRUMPACK
const T * HostRead() const
Shortcut for mfem::Read(a.GetMemory(), a.Size(), false).
Definition array.hpp:321
int Size() const
Return the logical size of the array.
Definition array.hpp:144
T * GetData()
Returns the data.
Definition array.hpp:118
Wrapper for hypre's ParCSR matrix class.
Definition hypre.hpp:388
void HypreRead() const
Update the internal hypre_ParCSRMatrix object, A, to be in hypre memory space.
Definition hypre.hpp:898
void HostRead() const
Update the internal hypre_ParCSRMatrix object, A, to be on host.
Definition hypre.hpp:881
MPI_Comm GetComm() const
MPI communicator.
Definition hypre.hpp:578
Abstract operator.
Definition operator.hpp:25
int width
Dimension of the input / number of columns in the matrix.
Definition operator.hpp:28
int Height() const
Get the height (size of output) of the Operator. Synonym with NumRows().
Definition operator.hpp:66
int height
Dimension of the output / number of rows in the matrix.
Definition operator.hpp:27
int Width() const
Get the width (size of input) of the Operator. Synonym with NumCols().
Definition operator.hpp:72
STRUMPACKMixedPrecisionSolver(MPI_Comm comm)
Constructor with MPI_Comm parameter.
STRUMPACKRowLocMatrix(MPI_Comm comm, int num_loc_rows, HYPRE_BigInt first_loc_row, HYPRE_BigInt glob_nrows, HYPRE_BigInt glob_ncols, int *I, HYPRE_BigInt *J, double *data, bool sym_sparse=false)
Creates a general parallel matrix from a local CSR matrix on each processor.
Definition strumpack.cpp:22
MPI_Comm GetComm() const
Get the MPI Comm being used by the parallel matrix.
Definition strumpack.hpp:62
void SetOperator(const Operator &op)
Set the operator/matrix.
void ArrayMult(const Array< const Vector * > &X, Array< Vector * > &Y) const
Factor and solve the linear systems across the array of vectors.
void SetMatching(strumpack::MatchingJob job)
Configure static pivoting for stability.
void SetCompression(strumpack::CompressionType type)
Select compression for sparse data types.
STRUMPACKSolverBase(MPI_Comm comm, int argc, char *argv[])
Constructor with MPI_Comm parameter and command line arguments.
void SetPrintFactorStatistics(bool print_stat)
Set up verbose printing during the factor step.
void SetKrylovSolver(strumpack::KrylovSolver method)
Set the Krylov solver method to use.
void Mult(const Vector &x, Vector &y) const
Factor and solve the linear system .
void SetCompressionButterflyLevels(int levels)
Set the number of butterfly levels for the HODLR compression option.
void SetCompressionAbsTol(double atol)
Set the absolute tolerance for low rank compression methods.
virtual ~STRUMPACKSolverBase()
Default destructor.
void SetFromCommandLine()
Set options that were captured from the command line.
void SetCompressionLossyPrecision(int precision)
Set the precision for the lossy compression option.
void SetPrintSolveStatistics(bool print_stat)
Set up verbose printing during the solve step.
void SetCompressionRelTol(double rtol)
Set the relative tolerance for low rank compression methods.
void SetReorderingStrategy(strumpack::ReorderingStrategy method)
Set matrix reordering strategy.
STRUMPACKSolver(MPI_Comm comm)
Constructor with the MPI Comm parameter.
Vector data type.
Definition vector.hpp:80
virtual const real_t * HostRead() const
Shortcut for mfem::Read(vec.GetMemory(), vec.Size(), false).
Definition vector.hpp:478
int Size() const
Returns the size of the vector.
Definition vector.hpp:218
virtual real_t * HostReadWrite()
Shortcut for mfem::ReadWrite(vec.GetMemory(), vec.Size(), false).
Definition vector.hpp:494
HYPRE_Int HYPRE_BigInt
void Mult(const Table &A, const Table &B, Table &C)
C = A * B (as boolean matrices)
Definition table.cpp:476
RefCoord s[3]