MFEM  v4.4.0
Finite element discretization library
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Pages
strumpack.cpp
Go to the documentation of this file.
1 // Copyright (c) 2010-2022, Lawrence Livermore National Security, LLC. Produced
2 // at the Lawrence Livermore National Laboratory. All Rights reserved. See files
3 // LICENSE and NOTICE for details. LLNL-CODE-806117.
4 //
5 // This file is part of the MFEM library. For more information and source code
6 // availability visit https://mfem.org.
7 //
8 // MFEM is free software; you can redistribute it and/or modify it under the
9 // terms of the BSD-3 license. We welcome feedback and contributions, see file
10 // CONTRIBUTING.md for details.
11 
12 #include "../config/config.hpp"
13 
14 #ifdef MFEM_USE_STRUMPACK
15 #ifdef MFEM_USE_MPI
16 
17 #include "strumpack.hpp"
18 
19 using namespace std;
20 using namespace strumpack;
21 
22 namespace mfem
23 {
24 
25 STRUMPACKRowLocMatrix::STRUMPACKRowLocMatrix(MPI_Comm comm,
26  int num_loc_rows, int first_loc_row,
27  int glob_nrows, int glob_ncols,
28  int *I, int *J, double *data)
29  : comm_(comm), A_(NULL)
30 {
31  // Set mfem::Operator member data
32  height = num_loc_rows;
33  width = num_loc_rows;
34 
35  // Allocate STRUMPACK's CSRMatrixMPI
36  int nprocs, rank;
37  MPI_Comm_rank(comm_, &rank);
38  MPI_Comm_size(comm_, &nprocs);
39  int * dist = new int[nprocs + 1];
40  dist[rank + 1] = first_loc_row + num_loc_rows;
41  dist[0] = 0;
42  MPI_Allgather(MPI_IN_PLACE, 0, MPI_INT, dist + 1, 1, MPI_INT, comm_);
43  A_ = new CSRMatrixMPI<double,int>(num_loc_rows, I, J, data, dist, comm_, false);
44  delete[] dist;
45 }
46 
48  : comm_(hypParMat.GetComm()),
49  A_(NULL)
50 {
51  // First cast the parameter to a hypre_ParCSRMatrix
52  hypre_ParCSRMatrix * parcsr_op =
53  (hypre_ParCSRMatrix *)const_cast<HypreParMatrix&>(hypParMat);
54 
55  MFEM_ASSERT(parcsr_op != NULL,"STRUMPACK: const_cast failed in SetOperator");
56 
57  // Create the CSRMatrixMPI A_ by borrowing the internal data from a
58  // hypre_CSRMatrix.
59  hypre_CSRMatrix * csr_op = hypre_MergeDiagAndOffd(parcsr_op);
60  hypre_CSRMatrixSetDataOwner(csr_op,0);
61 #if MFEM_HYPRE_VERSION >= 21600
62  // For now, this method assumes that HYPRE_Int is int. Also, csr_op->num_cols
63  // is of type HYPRE_Int, so if we want to check for big indices in
64  // csr_op->big_j, we'll have to check all entries and that check will only be
65  // necessary in HYPRE_MIXEDINT mode which is not supported at the moment.
66  hypre_CSRMatrixBigJtoJ(csr_op);
67 #endif
68 
69  height = csr_op->num_rows;
70  width = csr_op->num_rows;
71 
72  int nprocs, rank;
73  MPI_Comm_rank(comm_, &rank);
74  MPI_Comm_size(comm_, &nprocs);
75  int * dist = new int[nprocs + 1];
76  dist[rank + 1] = parcsr_op->first_row_index + csr_op->num_rows;
77  dist[0] = 0;
78  MPI_Allgather(MPI_IN_PLACE, 0, MPI_INT, dist + 1, 1, MPI_INT, comm_);
79  A_ = new CSRMatrixMPI<double,int>(csr_op->num_rows, csr_op->i, csr_op->j,
80  csr_op->data, dist, comm_, false);
81  delete[] dist;
82 
83  // Everything has been copied or abducted so delete the structure
84  hypre_CSRMatrixDestroy(csr_op);
85 }
86 
88 {
89  // Delete the struct
90  if ( A_ != NULL ) { delete A_; }
91 }
92 
93 STRUMPACKSolver::STRUMPACKSolver( int argc, char* argv[], MPI_Comm comm )
94  : comm_(comm),
95  APtr_(NULL),
96  solver_(NULL)
97 {
98  this->Init(argc, argv);
99 }
100 
102  : comm_(A.GetComm()),
103  APtr_(&A),
104  solver_(NULL)
105 {
106  height = A.Height();
107  width = A.Width();
108 
109  this->Init(0, NULL);
110 }
111 
113 {
114  if ( solver_ != NULL ) { delete solver_; }
115 }
116 
117 void STRUMPACKSolver::Init( int argc, char* argv[] )
118 {
119  MPI_Comm_size(comm_, &numProcs_);
120  MPI_Comm_rank(comm_, &myid_);
121 
122  factor_verbose_ = false;
123  solve_verbose_ = false;
124 
125  solver_ = new StrumpackSparseSolverMPIDist<double,int>(comm_, argc, argv,
126  false);
127 }
128 
130 {
131  solver_->options().set_from_command_line( );
132 }
133 
135 {
136  factor_verbose_ = print_stat;
137 }
138 
140 {
141  solve_verbose_ = print_stat;
142 }
143 
144 void STRUMPACKSolver::SetKrylovSolver( strumpack::KrylovSolver method )
145 {
146  solver_->options().set_Krylov_solver( method );
147 }
148 
149 void STRUMPACKSolver::SetReorderingStrategy( strumpack::ReorderingStrategy
150  method )
151 {
152  solver_->options().set_reordering_method( method );
153 }
154 
156 {
157 #if STRUMPACK_VERSION_MAJOR >= 3
158  solver_->options().set_matching( strumpack::MatchingJob::NONE );
159 #else
160  solver_->options().set_mc64job( strumpack::MC64Job::NONE );
161 #endif
162 }
163 
165 {
166 #if STRUMPACK_VERSION_MAJOR >= 3
167  solver_->options().set_matching
168  ( strumpack::MatchingJob::MAX_DIAGONAL_PRODUCT_SCALING );
169 #else
170  solver_->options().set_mc64job
171  ( strumpack::MC64Job::MAX_DIAGONAL_PRODUCT_SCALING );
172 #endif
173 }
174 
175 #if STRUMPACK_VERSION_MAJOR >= 3
177 {
178  solver_->options().set_matching
179  ( strumpack::MatchingJob::COMBBLAS );
180 }
181 #endif
182 
183 void STRUMPACKSolver::SetRelTol( double rtol )
184 {
185  solver_->options().set_rel_tol( rtol );
186 }
187 
188 void STRUMPACKSolver::SetAbsTol( double atol )
189 {
190  solver_->options().set_abs_tol( atol );
191 }
192 
193 
194 void STRUMPACKSolver::Mult( const Vector & x, Vector & y ) const
195 {
196  MFEM_ASSERT(APtr_ != NULL,
197  "STRUMPACK Error: The operator must be set before"
198  " the system can be solved.");
199  MFEM_ASSERT(x.Size() == Width(), "invalid x.Size() = " << x.Size()
200  << ", expected size = " << Width());
201  MFEM_ASSERT(y.Size() == Height(), "invalid y.Size() = " << y.Size()
202  << ", expected size = " << Height());
203 
204  double* yPtr = (double*)y;
205  double* xPtr = (double*)(const_cast<Vector&>(x));
206 
207  solver_->options().set_verbose( factor_verbose_ );
208  ReturnCode ret = solver_->factor();
209  switch (ret)
210  {
211  case ReturnCode::SUCCESS: break;
212  case ReturnCode::MATRIX_NOT_SET:
213  {
214  MFEM_ABORT("STRUMPACK: Matrix was not set!");
215  }
216  break;
217  case ReturnCode::REORDERING_ERROR:
218  {
219  MFEM_ABORT("STRUMPACK: Matrix reordering failed!");
220  }
221  break;
222  }
223  solver_->options().set_verbose( solve_verbose_ );
224  solver_->solve(xPtr, yPtr);
225 
226 }
227 
229 {
230  // Verify that we have a compatible operator
231  APtr_ = dynamic_cast<const STRUMPACKRowLocMatrix*>(&op);
232  if ( APtr_ == NULL )
233  {
234  mfem_error("STRUMPACKSolver::SetOperator : not STRUMPACKRowLocMatrix!");
235  }
236 
237  solver_->set_matrix( *(APtr_->getA()) );
238 
239  // Set mfem::Operator member data
240  height = op.Height();
241  width = op.Width();
242 
243 }
244 
245 } // mfem namespace
246 
247 #endif // MFEM_USE_MPI
248 #endif // MFEM_USE_STRUMPACK
const STRUMPACKRowLocMatrix * APtr_
Definition: strumpack.hpp:161
int Width() const
Get the width (size of input) of the Operator. Synonym with NumCols().
Definition: operator.hpp:72
STRUMPACKSolver(int argc, char *argv[], MPI_Comm comm)
Definition: strumpack.cpp:93
int Size() const
Returns the size of the vector.
Definition: vector.hpp:199
strumpack::StrumpackSparseSolverMPIDist< double, int > * solver_
Definition: strumpack.hpp:162
strumpack::CSRMatrixMPI< double, int > * getA() const
Definition: strumpack.hpp:55
void SetPrintFactorStatistics(bool print_stat)
Definition: strumpack.cpp:134
STRUMPACKRowLocMatrix(MPI_Comm comm, int num_loc_rows, int first_loc_row, int glob_nrows, int glob_ncols, int *I, int *J, double *data)
Definition: strumpack.cpp:25
int Height() const
Get the height (size of output) of the Operator. Synonym with NumRows().
Definition: operator.hpp:66
void mfem_error(const char *msg)
Function called when an error is encountered. Used by the macros MFEM_ABORT, MFEM_ASSERT, MFEM_VERIFY.
Definition: error.cpp:154
void SetPrintSolveStatistics(bool print_stat)
Definition: strumpack.cpp:139
void SetReorderingStrategy(strumpack::ReorderingStrategy method)
Definition: strumpack.cpp:149
void SetRelTol(double rtol)
Definition: strumpack.cpp:183
A class to initialize the size of a Tensor.
Definition: dtensor.hpp:54
int height
Dimension of the output / number of rows in the matrix.
Definition: operator.hpp:27
void Mult(const Vector &x, Vector &y) const
Operator application: y=A(x).
Definition: strumpack.cpp:194
void SetKrylovSolver(strumpack::KrylovSolver method)
Definition: strumpack.cpp:144
void SetAbsTol(double atol)
Definition: strumpack.cpp:188
void SetOperator(const Operator &op)
Set/update the solver for the given operator.
Definition: strumpack.cpp:228
Vector data type.
Definition: vector.hpp:60
Abstract operator.
Definition: operator.hpp:24
Wrapper for hypre&#39;s ParCSR matrix class.
Definition: hypre.hpp:337
int width
Dimension of the input / number of columns in the matrix.
Definition: operator.hpp:28