MFEM  v4.4.0
Finite element discretization library
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Pages
mumps.cpp
Go to the documentation of this file.
1 // Copyright (c) 2010-2022, Lawrence Livermore National Security, LLC. Produced
2 // at the Lawrence Livermore National Laboratory. All Rights reserved. See files
3 // LICENSE and NOTICE for details. LLNL-CODE-806117.
4 //
5 // This file is part of the MFEM library. For more information and source code
6 // availability visit https://mfem.org.
7 //
8 // MFEM is free software; you can redistribute it and/or modify it under the
9 // terms of the BSD-3 license. We welcome feedback and contributions, see file
10 // CONTRIBUTING.md for details.
11 
12 #include "../config/config.hpp"
13 
14 #ifdef MFEM_USE_MUMPS
15 #ifdef MFEM_USE_MPI
16 
17 #include "mumps.hpp"
18 
19 #ifdef HYPRE_BIGINT
20 #error "MUMPSSolver requires HYPRE_Int == int, for now."
21 #endif
22 
23 // macro s.t. indices match MUMPS documentation
24 #define MUMPS_ICNTL(I) icntl[(I) -1]
25 #define MUMPS_INFO(I) info[(I) -1]
26 
27 namespace mfem
28 {
29 
31 {
32  auto APtr = dynamic_cast<const HypreParMatrix *>(&op);
33 
34  MFEM_VERIFY(APtr, "Not compatible matrix type");
35 
36  height = op.Height();
37  width = op.Width();
38 
39  comm = APtr->GetComm();
40  MPI_Comm_size(comm, &numProcs);
41  MPI_Comm_rank(comm, &myid);
42 
43  auto parcsr_op = (hypre_ParCSRMatrix *) const_cast<HypreParMatrix &>(*APtr);
44 
45  hypre_CSRMatrix *csr_op = hypre_MergeDiagAndOffd(parcsr_op);
46 #if MFEM_HYPRE_VERSION >= 21600
47  hypre_CSRMatrixBigJtoJ(csr_op);
48 #endif
49 
50  int *Iptr = csr_op->i;
51  int *Jptr = csr_op->j;
52  int n_loc = csr_op->num_rows;
53 
54  row_start = parcsr_op->first_row_index;
55 
56  MUMPS_INT8 nnz = 0;
57  if (mat_type)
58  {
59  // count nnz in case of symmetric mode
60  int k = 0;
61  for (int i = 0; i < n_loc; i++)
62  {
63  for (int j = Iptr[i]; j < Iptr[i + 1]; j++)
64  {
65  int ii = row_start + i + 1;
66  int jj = Jptr[k] + 1;
67  k++;
68  if (ii>=jj) { nnz++; }
69  }
70  }
71  }
72  else
73  {
74  nnz = csr_op->num_nonzeros;
75  }
76 
77  int * I = new int[nnz];
78  int * J = new int[nnz];
79 
80  // Fill in I and J arrays for
81  // COO format in 1-based indexing
82  int k = 0;
83  double * data;
84  if (mat_type)
85  {
86  int l = 0;
87  data = new double[nnz];
88  for (int i = 0; i < n_loc; i++)
89  {
90  for (int j = Iptr[i]; j < Iptr[i + 1]; j++)
91  {
92  int ii = row_start + i + 1;
93  int jj = Jptr[k] + 1;
94  if (ii >= jj)
95  {
96  I[l] = ii;
97  J[l] = jj;
98  data[l++] = csr_op->data[k];
99  }
100  k++;
101  }
102  }
103  }
104  else
105  {
106  for (int i = 0; i < n_loc; i++)
107  {
108  for (int j = Iptr[i]; j < Iptr[i + 1]; j++)
109  {
110  I[k] = row_start + i + 1;
111  J[k] = Jptr[k] + 1;
112  k++;
113  }
114  }
115  data = csr_op->data;
116  }
117 
118  // new MUMPS object
119  if (id)
120  {
121  id->job = -2;
122  dmumps_c(id);
123  delete id;
124  }
125  id = new DMUMPS_STRUC_C;
126  // C to Fortran communicator
127  id->comm_fortran = (MUMPS_INT) MPI_Comm_c2f(comm);
128 
129  // Host is involved in computation
130  id->par = 1;
131 
132  id->sym = mat_type;
133 
134  // MUMPS init
135  id->job = -1;
136  dmumps_c(id);
137 
138  // Set MUMPS default parameters
139  SetParameters();
140 
141  id->n = parcsr_op->global_num_rows;
142 
143  id->nnz_loc = nnz;
144 
145  id->irn_loc = I;
146 
147  id->jcn_loc = J;
148 
149  id->a_loc = data;
150 
151  // MUMPS Analysis
152  id->job = 1;
153  dmumps_c(id);
154 
155  // MUMPS Factorization
156  id->job = 2;
157  dmumps_c(id);
158 
159  hypre_CSRMatrixDestroy(csr_op);
160  delete [] I;
161  delete [] J;
162  if (mat_type) { delete [] data; }
163 
164 #if MFEM_MUMPS_VERSION >= 530
165  delete [] irhs_loc;
166  irhs_loc = new int[n_loc];
167  for (int i = 0; i < n_loc; i++)
168  {
169  irhs_loc[i] = row_start + i + 1;
170  }
171  row_starts.SetSize(numProcs);
172  MPI_Allgather(&row_start, 1, MPI_INT, row_starts, 1, MPI_INT, comm);
173 #else
174  if (myid == 0)
175  {
176  delete [] rhs_glob;
177  delete [] recv_counts;
178  rhs_glob = new double[parcsr_op->global_num_rows];
179  recv_counts = new int[numProcs];
180  }
181  MPI_Gather(&n_loc, 1, MPI_INT, recv_counts, 1, MPI_INT, 0, comm);
182  if (myid == 0)
183  {
184  delete [] displs;
185  displs = new int[numProcs];
186  displs[0] = 0;
187  int s = 0;
188  for (int k = 0; k < numProcs-1; k++)
189  {
190  s += recv_counts[k];
191  displs[k+1] = s;
192  }
193  }
194 #endif
195 }
196 
197 void MUMPSSolver::Mult(const Vector &x, Vector &y) const
198 {
199  x.HostRead();
200  y.HostReadWrite();
201 #if MFEM_MUMPS_VERSION >= 530
202 
203  id->nloc_rhs = x.Size();
204  id->lrhs_loc = x.Size();
205  id->rhs_loc = x.GetData();
206  id->irhs_loc = irhs_loc;
207 
208  id->lsol_loc = id->MUMPS_INFO(23);
209  id->isol_loc = new int[id->MUMPS_INFO(23)];
210  id->sol_loc = new double[id->MUMPS_INFO(23)];
211 
212  // MUMPS solve
213  id->job = 3;
214  dmumps_c(id);
215 
216  RedistributeSol(id->isol_loc, id->sol_loc, y.GetData());
217 
218  delete [] id->sol_loc;
219  delete [] id->isol_loc;
220 #else
221  MPI_Gatherv(x.GetData(), x.Size(), MPI_DOUBLE,
222  rhs_glob, recv_counts,
223  displs, MPI_DOUBLE, 0, comm);
224 
225  if (myid == 0) { id->rhs = rhs_glob; }
226 
227  // MUMPS solve
228  id->job = 3;
229  dmumps_c(id);
230 
231  MPI_Scatterv(rhs_glob, recv_counts, displs,
232  MPI_DOUBLE, y.GetData(), y.Size(),
233  MPI_DOUBLE, 0, comm);
234 #endif
235 }
236 
237 void MUMPSSolver::MultTranspose(const Vector &x, Vector &y) const
238 {
239  // Set flag for Transpose Solve
240  id->MUMPS_ICNTL(9) = 0;
241  Mult(x,y);
242  // Reset the flag
243  id->MUMPS_ICNTL(9) = 1;
244 
245 }
246 
247 void MUMPSSolver::SetPrintLevel(int print_lvl)
248 {
249  print_level = print_lvl;
250 }
251 
253 {
254  mat_type = mtype;
255 }
256 
258 {
259  if (id)
260  {
261 #if MFEM_MUMPS_VERSION >= 530
262  delete [] irhs_loc;
263 #else
264  delete [] recv_counts;
265  delete [] displs;
266  delete [] rhs_glob;
267 #endif
268  id->job = -2;
269  dmumps_c(id);
270  delete id;
271  }
272 }
273 
274 void MUMPSSolver::SetParameters()
275 {
276  // output stream for error messages
277  id->MUMPS_ICNTL(1) = 6;
278  // output stream for diagnosting printing local to each proc
279  id->MUMPS_ICNTL(2) = 6;
280  // output stream for global info
281  id->MUMPS_ICNTL(3) = 6;
282  // Level of error printing
283  id->MUMPS_ICNTL(4) = print_level;
284  //input matrix format (assembled)
285  id->MUMPS_ICNTL(5) = 0;
286  // Use A or A^T
287  id->MUMPS_ICNTL(9) = 1;
288  // Iterative refinement (disabled)
289  id->MUMPS_ICNTL(10) = 0;
290  // Error analysis-statistics (disabled)
291  id->MUMPS_ICNTL(11) = 0;
292  // Use of ScaLAPACK (Parallel factorization on root)
293  id->MUMPS_ICNTL(13) = 0;
294  // Percentage increase of estimated workspace (default = 20%)
295  id->MUMPS_ICNTL(14) = 20;
296  // Number of OpenMP threads (default)
297  id->MUMPS_ICNTL(16) = 0;
298  // Matrix input format (distributed)
299  id->MUMPS_ICNTL(18) = 3;
300  // Schur complement (no Schur complement matrix returned)
301  id->MUMPS_ICNTL(19) = 0;
302 
303 #if MFEM_MUMPS_VERSION >= 530
304  // Distributed RHS
305  id->MUMPS_ICNTL(20) = 10;
306  // Distributed Sol
307  id->MUMPS_ICNTL(21) = 1;
308 #else
309  // Centralized RHS
310  id->MUMPS_ICNTL(20) = 0;
311  // Centralized Sol
312  id->MUMPS_ICNTL(21) = 0;
313 #endif
314  // Out of core factorization and solve (disabled)
315  id->MUMPS_ICNTL(22) = 0;
316  // Max size of working memory (default = based on estimates)
317  id->MUMPS_ICNTL(23) = 0;
318 }
319 
320 #if MFEM_MUMPS_VERSION >= 530
321 int MUMPSSolver::GetRowRank(int i, const Array<int> &row_starts_) const
322 {
323  if (row_starts_.Size() == 1)
324  {
325  return 0;
326  }
327  auto up = std::upper_bound(row_starts_.begin(), row_starts_.end(), i);
328  return std::distance(row_starts_.begin(), up) - 1;
329 }
330 
331 void MUMPSSolver::RedistributeSol(const int * row_map,
332  const double * x, double * y) const
333 {
334  int size = id->MUMPS_INFO(23);
335  int * send_count = new int[numProcs]();
336  for (int i = 0; i < size; i++)
337  {
338  int j = row_map[i] - 1;
339  int row_rank = GetRowRank(j, row_starts);
340  if (myid == row_rank) { continue; }
341  send_count[row_rank]++;
342  }
343 
344  int * recv_count = new int[numProcs];
345  MPI_Alltoall(send_count, 1, MPI_INT, recv_count, 1, MPI_INT, comm);
346 
347  int * send_displ = new int [numProcs]; send_displ[0] = 0;
348  int * recv_displ = new int [numProcs]; recv_displ[0] = 0;
349  int sbuff_size = send_count[numProcs-1];
350  int rbuff_size = recv_count[numProcs-1];
351  for (int k = 0; k < numProcs - 1; k++)
352  {
353  send_displ[k + 1] = send_displ[k] + send_count[k];
354  recv_displ[k + 1] = recv_displ[k] + recv_count[k];
355  sbuff_size += send_count[k];
356  rbuff_size += recv_count[k];
357  }
358 
359  int * sendbuf_index = new int[sbuff_size];
360  double * sendbuf_values = new double[sbuff_size];
361  int * soffs = new int[numProcs]();
362 
363  for (int i = 0; i < size; i++)
364  {
365  int j = row_map[i] - 1;
366  int row_rank = GetRowRank(j, row_starts);
367  if (myid == row_rank)
368  {
369  int local_index = j - row_start;
370  y[local_index] = x[i];
371  }
372  else
373  {
374  int k = send_displ[row_rank] + soffs[row_rank];
375  sendbuf_index[k] = j;
376  sendbuf_values[k] = x[i];
377  soffs[row_rank]++;
378  }
379  }
380 
381  int * recvbuf_index = new int[rbuff_size];
382  double * recvbuf_values = new double[rbuff_size];
383  MPI_Alltoallv(sendbuf_index,
384  send_count,
385  send_displ,
386  MPI_INT,
387  recvbuf_index,
388  recv_count,
389  recv_displ,
390  MPI_INT,
391  comm);
392  MPI_Alltoallv(sendbuf_values,
393  send_count,
394  send_displ,
395  MPI_DOUBLE,
396  recvbuf_values,
397  recv_count,
398  recv_displ,
399  MPI_DOUBLE,
400  comm);
401 
402  // Unpack recv buffer
403  for (int i = 0; i < rbuff_size; i++)
404  {
405  int local_index = recvbuf_index[i] - row_start;
406  y[local_index] = recvbuf_values[i];
407  }
408 
409  delete [] recvbuf_values;
410  delete [] recvbuf_index;
411  delete [] soffs;
412  delete [] sendbuf_values;
413  delete [] sendbuf_index;
414  delete [] recv_displ;
415  delete [] send_displ;
416  delete [] recv_count;
417  delete [] send_count;
418 }
419 #endif
420 
421 } // namespace mfem
422 
423 #endif // MFEM_USE_MPI
424 #endif // MFEM_USE_MUMPS
void SetOperator(const Operator &op)
Set the Operator and perform factorization.
Definition: mumps.cpp:30
void SetPrintLevel(int print_lvl)
Set the error print level for MUMPS.
Definition: mumps.cpp:247
int Width() const
Get the width (size of input) of the Operator. Synonym with NumCols().
Definition: operator.hpp:72
int Size() const
Returns the size of the vector.
Definition: vector.hpp:199
double * GetData() const
Return a pointer to the beginning of the Vector data.
Definition: vector.hpp:208
void MultTranspose(const Vector &x, Vector &y) const
Transpose Solve y = Op^{-T} x.
Definition: mumps.cpp:237
void Mult(const Vector &x, Vector &y) const
Solve y = Op^{-1} x.
Definition: mumps.cpp:197
void SetMatrixSymType(MatType mtype)
Set the matrix type.
Definition: mumps.cpp:252
int Height() const
Get the height (size of output) of the Operator. Synonym with NumRows().
Definition: operator.hpp:66
virtual const double * HostRead() const
Shortcut for mfem::Read(vec.GetMemory(), vec.Size(), false).
Definition: vector.hpp:442
void SetSize(int nsize)
Change the logical size of the array, keep existing entries.
Definition: array.hpp:679
int height
Dimension of the output / number of rows in the matrix.
Definition: operator.hpp:27
Vector data type.
Definition: vector.hpp:60
RefCoord s[3]
Abstract operator.
Definition: operator.hpp:24
Wrapper for hypre&#39;s ParCSR matrix class.
Definition: hypre.hpp:337
virtual double * HostReadWrite()
Shortcut for mfem::ReadWrite(vec.GetMemory(), vec.Size(), false).
Definition: vector.hpp:458
int width
Dimension of the input / number of columns in the matrix.
Definition: operator.hpp:28