MFEM  v4.5.2
Finite element discretization library
strumpack.cpp
Go to the documentation of this file.
1 // Copyright (c) 2010-2023, Lawrence Livermore National Security, LLC. Produced
2 // at the Lawrence Livermore National Laboratory. All Rights reserved. See files
3 // LICENSE and NOTICE for details. LLNL-CODE-806117.
4 //
5 // This file is part of the MFEM library. For more information and source code
6 // availability visit https://mfem.org.
7 //
8 // MFEM is free software; you can redistribute it and/or modify it under the
9 // terms of the BSD-3 license. We welcome feedback and contributions, see file
10 // CONTRIBUTING.md for details.
11 
12 #include "../config/config.hpp"
13 
14 #ifdef MFEM_USE_STRUMPACK
15 #ifdef MFEM_USE_MPI
16 
17 #include "strumpack.hpp"
18 
19 using namespace std;
20 using namespace strumpack;
21 
22 namespace mfem
23 {
24 
25 STRUMPACKRowLocMatrix::STRUMPACKRowLocMatrix(MPI_Comm comm,
26  int num_loc_rows, int first_loc_row,
27  int glob_nrows, int glob_ncols,
28  int *I, int *J, double *data)
29  : comm_(comm), A_(NULL)
30 {
31  // Set mfem::Operator member data
32  height = num_loc_rows;
33  width = num_loc_rows;
34 
35  // Allocate STRUMPACK's CSRMatrixMPI
36  int nprocs, rank;
37  MPI_Comm_rank(comm_, &rank);
38  MPI_Comm_size(comm_, &nprocs);
39  int * dist = new int[nprocs + 1];
40  dist[rank + 1] = first_loc_row + num_loc_rows;
41  dist[0] = 0;
42  MPI_Allgather(MPI_IN_PLACE, 0, MPI_INT, dist + 1, 1, MPI_INT, comm_);
43  A_ = new CSRMatrixMPI<double,int>(num_loc_rows, I, J, data, dist, comm_, false);
44  delete[] dist;
45 }
46 
48  : comm_(hypParMat.GetComm()),
49  A_(NULL)
50 {
51  // First cast the parameter to a hypre_ParCSRMatrix
52  hypre_ParCSRMatrix * parcsr_op =
53  (hypre_ParCSRMatrix *)const_cast<HypreParMatrix&>(hypParMat);
54 
55  MFEM_ASSERT(parcsr_op != NULL,"STRUMPACK: const_cast failed in SetOperator");
56 
57  // Create the CSRMatrixMPI A_ by borrowing the internal data from a
58  // hypre_CSRMatrix.
59  hypParMat.HostRead();
60  hypre_CSRMatrix * csr_op = hypre_MergeDiagAndOffd(parcsr_op);
61  hypParMat.HypreRead();
62  hypre_CSRMatrixSetDataOwner(csr_op,0);
63 #if MFEM_HYPRE_VERSION >= 21600
64  // For now, this method assumes that HYPRE_Int is int. Also, csr_op->num_cols
65  // is of type HYPRE_Int, so if we want to check for big indices in
66  // csr_op->big_j, we'll have to check all entries and that check will only be
67  // necessary in HYPRE_MIXEDINT mode which is not supported at the moment.
68  hypre_CSRMatrixBigJtoJ(csr_op);
69 #endif
70 
71  height = csr_op->num_rows;
72  width = csr_op->num_rows;
73 
74  int nprocs, rank;
75  MPI_Comm_rank(comm_, &rank);
76  MPI_Comm_size(comm_, &nprocs);
77  int * dist = new int[nprocs + 1];
78  dist[rank + 1] = parcsr_op->first_row_index + csr_op->num_rows;
79  dist[0] = 0;
80  MPI_Allgather(MPI_IN_PLACE, 0, MPI_INT, dist + 1, 1, MPI_INT, comm_);
81  A_ = new CSRMatrixMPI<double,int>(csr_op->num_rows, csr_op->i, csr_op->j,
82  csr_op->data, dist, comm_, false);
83  delete[] dist;
84 
85  // Everything has been copied or abducted so delete the structure
86  hypre_CSRMatrixDestroy(csr_op);
87 }
88 
90 {
91  // Delete the struct
92  if ( A_ != NULL ) { delete A_; }
93 }
94 
95 STRUMPACKSolver::STRUMPACKSolver( int argc, char* argv[], MPI_Comm comm )
96  : comm_(comm),
97  APtr_(NULL),
98  solver_(NULL)
99 {
100  this->Init(argc, argv);
101 }
102 
104  : comm_(A.GetComm()),
105  APtr_(&A),
106  solver_(NULL)
107 {
108  height = A.Height();
109  width = A.Width();
110 
111  this->Init(0, NULL);
112 }
113 
115 {
116  if ( solver_ != NULL ) { delete solver_; }
117 }
118 
119 void STRUMPACKSolver::Init( int argc, char* argv[] )
120 {
121  MPI_Comm_size(comm_, &numProcs_);
122  MPI_Comm_rank(comm_, &myid_);
123 
124  factor_verbose_ = false;
125  solve_verbose_ = false;
126 
127  solver_ = new StrumpackSparseSolverMPIDist<double,int>(comm_, argc, argv,
128  false);
129 }
130 
132 {
133  solver_->options().set_from_command_line( );
134 }
135 
137 {
138  factor_verbose_ = print_stat;
139 }
140 
142 {
143  solve_verbose_ = print_stat;
144 }
145 
146 void STRUMPACKSolver::SetKrylovSolver( strumpack::KrylovSolver method )
147 {
148  solver_->options().set_Krylov_solver( method );
149 }
150 
151 void STRUMPACKSolver::SetReorderingStrategy( strumpack::ReorderingStrategy
152  method )
153 {
154  solver_->options().set_reordering_method( method );
155 }
156 
158 {
159 #if STRUMPACK_VERSION_MAJOR >= 3
160  solver_->options().set_matching( strumpack::MatchingJob::NONE );
161 #else
162  solver_->options().set_mc64job( strumpack::MC64Job::NONE );
163 #endif
164 }
165 
167 {
168 #if STRUMPACK_VERSION_MAJOR >= 3
169  solver_->options().set_matching
170  ( strumpack::MatchingJob::MAX_DIAGONAL_PRODUCT_SCALING );
171 #else
172  solver_->options().set_mc64job
173  ( strumpack::MC64Job::MAX_DIAGONAL_PRODUCT_SCALING );
174 #endif
175 }
176 
177 #if STRUMPACK_VERSION_MAJOR >= 3
179 {
180  solver_->options().set_matching
181  ( strumpack::MatchingJob::COMBBLAS );
182 }
183 #endif
184 
185 void STRUMPACKSolver::SetRelTol( double rtol )
186 {
187  solver_->options().set_rel_tol( rtol );
188 }
189 
190 void STRUMPACKSolver::SetAbsTol( double atol )
191 {
192  solver_->options().set_abs_tol( atol );
193 }
194 
195 
196 void STRUMPACKSolver::Mult( const Vector & x, Vector & y ) const
197 {
198  MFEM_ASSERT(APtr_ != NULL,
199  "STRUMPACK Error: The operator must be set before"
200  " the system can be solved.");
201  MFEM_ASSERT(x.Size() == Width(), "invalid x.Size() = " << x.Size()
202  << ", expected size = " << Width());
203  MFEM_ASSERT(y.Size() == Height(), "invalid y.Size() = " << y.Size()
204  << ", expected size = " << Height());
205 
206  double* yPtr = y.HostWrite();
207  const double* xPtr = x.HostRead();
208 
209  solver_->options().set_verbose( factor_verbose_ );
210  ReturnCode ret = solver_->factor();
211  switch (ret)
212  {
213  case ReturnCode::SUCCESS: break;
214  case ReturnCode::MATRIX_NOT_SET:
215  {
216  MFEM_ABORT("STRUMPACK: Matrix was not set!");
217  }
218  break;
219  case ReturnCode::REORDERING_ERROR:
220  {
221  MFEM_ABORT("STRUMPACK: Matrix reordering failed!");
222  }
223  break;
224  default:
225  {
226  MFEM_ABORT("STRUMPACK: 'factor()' error code = " << ret);
227  }
228  }
229  solver_->options().set_verbose( solve_verbose_ );
230  solver_->solve(xPtr, yPtr);
231 
232 }
233 
235 {
236  // Verify that we have a compatible operator
237  APtr_ = dynamic_cast<const STRUMPACKRowLocMatrix*>(&op);
238  if ( APtr_ == NULL )
239  {
240  mfem_error("STRUMPACKSolver::SetOperator : not STRUMPACKRowLocMatrix!");
241  }
242 
243  solver_->set_matrix( *(APtr_->getA()) );
244 
245  // Set mfem::Operator member data
246  height = op.Height();
247  width = op.Width();
248 
249 }
250 
251 } // mfem namespace
252 
253 #endif // MFEM_USE_MPI
254 #endif // MFEM_USE_STRUMPACK
const STRUMPACKRowLocMatrix * APtr_
Definition: strumpack.hpp:161
virtual const double * HostRead() const
Shortcut for mfem::Read(vec.GetMemory(), vec.Size(), false).
Definition: vector.hpp:452
STRUMPACKSolver(int argc, char *argv[], MPI_Comm comm)
Definition: strumpack.cpp:95
int Width() const
Get the width (size of input) of the Operator. Synonym with NumCols().
Definition: operator.hpp:72
int Size() const
Returns the size of the vector.
Definition: vector.hpp:199
virtual double * HostWrite()
Shortcut for mfem::Write(vec.GetMemory(), vec.Size(), false).
Definition: vector.hpp:460
strumpack::StrumpackSparseSolverMPIDist< double, int > * solver_
Definition: strumpack.hpp:162
STL namespace.
void HostRead() const
Update the internal hypre_ParCSRMatrix object, A, to be on host.
Definition: hypre.hpp:837
void SetPrintFactorStatistics(bool print_stat)
Definition: strumpack.cpp:136
STRUMPACKRowLocMatrix(MPI_Comm comm, int num_loc_rows, int first_loc_row, int glob_nrows, int glob_ncols, int *I, int *J, double *data)
Definition: strumpack.cpp:25
void mfem_error(const char *msg)
Function called when an error is encountered. Used by the macros MFEM_ABORT, MFEM_ASSERT, MFEM_VERIFY.
Definition: error.cpp:154
void SetPrintSolveStatistics(bool print_stat)
Definition: strumpack.cpp:141
strumpack::CSRMatrixMPI< double, int > * getA() const
Definition: strumpack.hpp:55
void SetReorderingStrategy(strumpack::ReorderingStrategy method)
Definition: strumpack.cpp:151
void SetRelTol(double rtol)
Definition: strumpack.cpp:185
int Height() const
Get the height (size of output) of the Operator. Synonym with NumRows().
Definition: operator.hpp:66
A class to initialize the size of a Tensor.
Definition: dtensor.hpp:54
int height
Dimension of the output / number of rows in the matrix.
Definition: operator.hpp:27
void SetKrylovSolver(strumpack::KrylovSolver method)
Definition: strumpack.cpp:146
void HypreRead() const
Update the internal hypre_ParCSRMatrix object, A, to be in hypre memory space.
Definition: hypre.hpp:854
void SetAbsTol(double atol)
Definition: strumpack.cpp:190
void SetOperator(const Operator &op)
Set/update the solver for the given operator.
Definition: strumpack.cpp:234
Vector data type.
Definition: vector.hpp:60
void Mult(const Vector &x, Vector &y) const
Operator application: y=A(x).
Definition: strumpack.cpp:196
Abstract operator.
Definition: operator.hpp:24
Wrapper for hypre&#39;s ParCSR matrix class.
Definition: hypre.hpp:343
int width
Dimension of the input / number of columns in the matrix.
Definition: operator.hpp:28