MFEM  v4.5.1
Finite element discretization library
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Pages
mumps.cpp
Go to the documentation of this file.
1 // Copyright (c) 2010-2022, Lawrence Livermore National Security, LLC. Produced
2 // at the Lawrence Livermore National Laboratory. All Rights reserved. See files
3 // LICENSE and NOTICE for details. LLNL-CODE-806117.
4 //
5 // This file is part of the MFEM library. For more information and source code
6 // availability visit https://mfem.org.
7 //
8 // MFEM is free software; you can redistribute it and/or modify it under the
9 // terms of the BSD-3 license. We welcome feedback and contributions, see file
10 // CONTRIBUTING.md for details.
11 
12 #include "../config/config.hpp"
13 
14 #ifdef MFEM_USE_MUMPS
15 #ifdef MFEM_USE_MPI
16 
17 #include "mumps.hpp"
18 
19 #ifdef HYPRE_BIGINT
20 #error "MUMPSSolver requires HYPRE_Int == int, for now."
21 #endif
22 
23 // macro s.t. indices match MUMPS documentation
24 #define MUMPS_ICNTL(I) icntl[(I) -1]
25 #define MUMPS_INFO(I) info[(I) -1]
26 
27 namespace mfem
28 {
29 
31 {
32  auto APtr = dynamic_cast<const HypreParMatrix *>(&op);
33 
34  MFEM_VERIFY(APtr, "Not compatible matrix type");
35 
36  height = op.Height();
37  width = op.Width();
38 
39  comm = APtr->GetComm();
40  MPI_Comm_size(comm, &numProcs);
41  MPI_Comm_rank(comm, &myid);
42 
43  auto parcsr_op = (hypre_ParCSRMatrix *) const_cast<HypreParMatrix &>(*APtr);
44 
45  APtr->HostRead();
46  hypre_CSRMatrix *csr_op = hypre_MergeDiagAndOffd(parcsr_op);
47  APtr->HypreRead();
48 #if MFEM_HYPRE_VERSION >= 21600
49  hypre_CSRMatrixBigJtoJ(csr_op);
50 #endif
51 
52  int *Iptr = csr_op->i;
53  int *Jptr = csr_op->j;
54  int n_loc = csr_op->num_rows;
55 
56  row_start = parcsr_op->first_row_index;
57 
58  MUMPS_INT8 nnz = 0;
59  if (mat_type)
60  {
61  // count nnz in case of symmetric mode
62  int k = 0;
63  for (int i = 0; i < n_loc; i++)
64  {
65  for (int j = Iptr[i]; j < Iptr[i + 1]; j++)
66  {
67  int ii = row_start + i + 1;
68  int jj = Jptr[k] + 1;
69  k++;
70  if (ii>=jj) { nnz++; }
71  }
72  }
73  }
74  else
75  {
76  nnz = csr_op->num_nonzeros;
77  }
78 
79  int * I = new int[nnz];
80  int * J = new int[nnz];
81 
82  // Fill in I and J arrays for
83  // COO format in 1-based indexing
84  int k = 0;
85  double * data;
86  if (mat_type)
87  {
88  int l = 0;
89  data = new double[nnz];
90  for (int i = 0; i < n_loc; i++)
91  {
92  for (int j = Iptr[i]; j < Iptr[i + 1]; j++)
93  {
94  int ii = row_start + i + 1;
95  int jj = Jptr[k] + 1;
96  if (ii >= jj)
97  {
98  I[l] = ii;
99  J[l] = jj;
100  data[l++] = csr_op->data[k];
101  }
102  k++;
103  }
104  }
105  }
106  else
107  {
108  for (int i = 0; i < n_loc; i++)
109  {
110  for (int j = Iptr[i]; j < Iptr[i + 1]; j++)
111  {
112  I[k] = row_start + i + 1;
113  J[k] = Jptr[k] + 1;
114  k++;
115  }
116  }
117  data = csr_op->data;
118  }
119 
120  // new MUMPS object
121  if (id)
122  {
123  id->job = -2;
124  dmumps_c(id);
125  delete id;
126  }
127  id = new DMUMPS_STRUC_C;
128  // C to Fortran communicator
129  id->comm_fortran = (MUMPS_INT) MPI_Comm_c2f(comm);
130 
131  // Host is involved in computation
132  id->par = 1;
133 
134  id->sym = mat_type;
135 
136  // MUMPS init
137  id->job = -1;
138  dmumps_c(id);
139 
140  // Set MUMPS default parameters
141  SetParameters();
142 
143  id->n = parcsr_op->global_num_rows;
144 
145  id->nnz_loc = nnz;
146 
147  id->irn_loc = I;
148 
149  id->jcn_loc = J;
150 
151  id->a_loc = data;
152 
153  // MUMPS Analysis
154  id->job = 1;
155  dmumps_c(id);
156 
157  // MUMPS Factorization
158  id->job = 2;
159  dmumps_c(id);
160 
161  hypre_CSRMatrixDestroy(csr_op);
162  delete [] I;
163  delete [] J;
164  if (mat_type) { delete [] data; }
165 
166 #if MFEM_MUMPS_VERSION >= 530
167  delete [] irhs_loc;
168  irhs_loc = new int[n_loc];
169  for (int i = 0; i < n_loc; i++)
170  {
171  irhs_loc[i] = row_start + i + 1;
172  }
173  row_starts.SetSize(numProcs);
174  MPI_Allgather(&row_start, 1, MPI_INT, row_starts, 1, MPI_INT, comm);
175 #else
176  if (myid == 0)
177  {
178  delete [] rhs_glob;
179  delete [] recv_counts;
180  rhs_glob = new double[parcsr_op->global_num_rows];
181  recv_counts = new int[numProcs];
182  }
183  MPI_Gather(&n_loc, 1, MPI_INT, recv_counts, 1, MPI_INT, 0, comm);
184  if (myid == 0)
185  {
186  delete [] displs;
187  displs = new int[numProcs];
188  displs[0] = 0;
189  int s = 0;
190  for (int k = 0; k < numProcs-1; k++)
191  {
192  s += recv_counts[k];
193  displs[k+1] = s;
194  }
195  }
196 #endif
197 }
198 
199 void MUMPSSolver::Mult(const Vector &x, Vector &y) const
200 {
201  x.HostRead();
202  y.HostReadWrite();
203 #if MFEM_MUMPS_VERSION >= 530
204 
205  id->nloc_rhs = x.Size();
206  id->lrhs_loc = x.Size();
207  id->rhs_loc = x.GetData();
208  id->irhs_loc = irhs_loc;
209 
210  id->lsol_loc = id->MUMPS_INFO(23);
211  id->isol_loc = new int[id->MUMPS_INFO(23)];
212  id->sol_loc = new double[id->MUMPS_INFO(23)];
213 
214  // MUMPS solve
215  id->job = 3;
216  dmumps_c(id);
217 
218  RedistributeSol(id->isol_loc, id->sol_loc, y.GetData());
219 
220  delete [] id->sol_loc;
221  delete [] id->isol_loc;
222 #else
223  MPI_Gatherv(x.GetData(), x.Size(), MPI_DOUBLE,
224  rhs_glob, recv_counts,
225  displs, MPI_DOUBLE, 0, comm);
226 
227  if (myid == 0) { id->rhs = rhs_glob; }
228 
229  // MUMPS solve
230  id->job = 3;
231  dmumps_c(id);
232 
233  MPI_Scatterv(rhs_glob, recv_counts, displs,
234  MPI_DOUBLE, y.GetData(), y.Size(),
235  MPI_DOUBLE, 0, comm);
236 #endif
237 }
238 
239 void MUMPSSolver::MultTranspose(const Vector &x, Vector &y) const
240 {
241  // Set flag for Transpose Solve
242  id->MUMPS_ICNTL(9) = 0;
243  Mult(x,y);
244  // Reset the flag
245  id->MUMPS_ICNTL(9) = 1;
246 
247 }
248 
249 void MUMPSSolver::SetPrintLevel(int print_lvl)
250 {
251  print_level = print_lvl;
252 }
253 
255 {
256  mat_type = mtype;
257 }
258 
260 {
261  if (id)
262  {
263 #if MFEM_MUMPS_VERSION >= 530
264  delete [] irhs_loc;
265 #else
266  delete [] recv_counts;
267  delete [] displs;
268  delete [] rhs_glob;
269 #endif
270  id->job = -2;
271  dmumps_c(id);
272  delete id;
273  }
274 }
275 
276 void MUMPSSolver::SetParameters()
277 {
278  // output stream for error messages
279  id->MUMPS_ICNTL(1) = 6;
280  // output stream for diagnosting printing local to each proc
281  id->MUMPS_ICNTL(2) = 6;
282  // output stream for global info
283  id->MUMPS_ICNTL(3) = 6;
284  // Level of error printing
285  id->MUMPS_ICNTL(4) = print_level;
286  //input matrix format (assembled)
287  id->MUMPS_ICNTL(5) = 0;
288  // Use A or A^T
289  id->MUMPS_ICNTL(9) = 1;
290  // Iterative refinement (disabled)
291  id->MUMPS_ICNTL(10) = 0;
292  // Error analysis-statistics (disabled)
293  id->MUMPS_ICNTL(11) = 0;
294  // Use of ScaLAPACK (Parallel factorization on root)
295  id->MUMPS_ICNTL(13) = 0;
296  // Percentage increase of estimated workspace (default = 20%)
297  id->MUMPS_ICNTL(14) = 20;
298  // Number of OpenMP threads (default)
299  id->MUMPS_ICNTL(16) = 0;
300  // Matrix input format (distributed)
301  id->MUMPS_ICNTL(18) = 3;
302  // Schur complement (no Schur complement matrix returned)
303  id->MUMPS_ICNTL(19) = 0;
304 
305 #if MFEM_MUMPS_VERSION >= 530
306  // Distributed RHS
307  id->MUMPS_ICNTL(20) = 10;
308  // Distributed Sol
309  id->MUMPS_ICNTL(21) = 1;
310 #else
311  // Centralized RHS
312  id->MUMPS_ICNTL(20) = 0;
313  // Centralized Sol
314  id->MUMPS_ICNTL(21) = 0;
315 #endif
316  // Out of core factorization and solve (disabled)
317  id->MUMPS_ICNTL(22) = 0;
318  // Max size of working memory (default = based on estimates)
319  id->MUMPS_ICNTL(23) = 0;
320 }
321 
322 #if MFEM_MUMPS_VERSION >= 530
323 int MUMPSSolver::GetRowRank(int i, const Array<int> &row_starts_) const
324 {
325  if (row_starts_.Size() == 1)
326  {
327  return 0;
328  }
329  auto up = std::upper_bound(row_starts_.begin(), row_starts_.end(), i);
330  return std::distance(row_starts_.begin(), up) - 1;
331 }
332 
333 void MUMPSSolver::RedistributeSol(const int * row_map,
334  const double * x, double * y) const
335 {
336  int size = id->MUMPS_INFO(23);
337  int * send_count = new int[numProcs]();
338  for (int i = 0; i < size; i++)
339  {
340  int j = row_map[i] - 1;
341  int row_rank = GetRowRank(j, row_starts);
342  if (myid == row_rank) { continue; }
343  send_count[row_rank]++;
344  }
345 
346  int * recv_count = new int[numProcs];
347  MPI_Alltoall(send_count, 1, MPI_INT, recv_count, 1, MPI_INT, comm);
348 
349  int * send_displ = new int [numProcs]; send_displ[0] = 0;
350  int * recv_displ = new int [numProcs]; recv_displ[0] = 0;
351  int sbuff_size = send_count[numProcs-1];
352  int rbuff_size = recv_count[numProcs-1];
353  for (int k = 0; k < numProcs - 1; k++)
354  {
355  send_displ[k + 1] = send_displ[k] + send_count[k];
356  recv_displ[k + 1] = recv_displ[k] + recv_count[k];
357  sbuff_size += send_count[k];
358  rbuff_size += recv_count[k];
359  }
360 
361  int * sendbuf_index = new int[sbuff_size];
362  double * sendbuf_values = new double[sbuff_size];
363  int * soffs = new int[numProcs]();
364 
365  for (int i = 0; i < size; i++)
366  {
367  int j = row_map[i] - 1;
368  int row_rank = GetRowRank(j, row_starts);
369  if (myid == row_rank)
370  {
371  int local_index = j - row_start;
372  y[local_index] = x[i];
373  }
374  else
375  {
376  int k = send_displ[row_rank] + soffs[row_rank];
377  sendbuf_index[k] = j;
378  sendbuf_values[k] = x[i];
379  soffs[row_rank]++;
380  }
381  }
382 
383  int * recvbuf_index = new int[rbuff_size];
384  double * recvbuf_values = new double[rbuff_size];
385  MPI_Alltoallv(sendbuf_index,
386  send_count,
387  send_displ,
388  MPI_INT,
389  recvbuf_index,
390  recv_count,
391  recv_displ,
392  MPI_INT,
393  comm);
394  MPI_Alltoallv(sendbuf_values,
395  send_count,
396  send_displ,
397  MPI_DOUBLE,
398  recvbuf_values,
399  recv_count,
400  recv_displ,
401  MPI_DOUBLE,
402  comm);
403 
404  // Unpack recv buffer
405  for (int i = 0; i < rbuff_size; i++)
406  {
407  int local_index = recvbuf_index[i] - row_start;
408  y[local_index] = recvbuf_values[i];
409  }
410 
411  delete [] recvbuf_values;
412  delete [] recvbuf_index;
413  delete [] soffs;
414  delete [] sendbuf_values;
415  delete [] sendbuf_index;
416  delete [] recv_displ;
417  delete [] send_displ;
418  delete [] recv_count;
419  delete [] send_count;
420 }
421 #endif
422 
423 } // namespace mfem
424 
425 #endif // MFEM_USE_MPI
426 #endif // MFEM_USE_MUMPS
void SetOperator(const Operator &op)
Set the Operator and perform factorization.
Definition: mumps.cpp:30
void SetPrintLevel(int print_lvl)
Set the error print level for MUMPS.
Definition: mumps.cpp:249
int Width() const
Get the width (size of input) of the Operator. Synonym with NumCols().
Definition: operator.hpp:73
int Size() const
Returns the size of the vector.
Definition: vector.hpp:200
double * GetData() const
Return a pointer to the beginning of the Vector data.
Definition: vector.hpp:209
void MultTranspose(const Vector &x, Vector &y) const
Transpose Solve y = Op^{-T} x.
Definition: mumps.cpp:239
void Mult(const Vector &x, Vector &y) const
Solve y = Op^{-1} x.
Definition: mumps.cpp:199
void SetMatrixSymType(MatType mtype)
Set the matrix type.
Definition: mumps.cpp:254
int Height() const
Get the height (size of output) of the Operator. Synonym with NumRows().
Definition: operator.hpp:67
virtual const double * HostRead() const
Shortcut for mfem::Read(vec.GetMemory(), vec.Size(), false).
Definition: vector.hpp:453
void SetSize(int nsize)
Change the logical size of the array, keep existing entries.
Definition: array.hpp:679
int height
Dimension of the output / number of rows in the matrix.
Definition: operator.hpp:27
Vector data type.
Definition: vector.hpp:60
RefCoord s[3]
Abstract operator.
Definition: operator.hpp:24
Wrapper for hypre&#39;s ParCSR matrix class.
Definition: hypre.hpp:343
virtual double * HostReadWrite()
Shortcut for mfem::ReadWrite(vec.GetMemory(), vec.Size(), false).
Definition: vector.hpp:469
int width
Dimension of the input / number of columns in the matrix.
Definition: operator.hpp:28