MFEM  v4.5.2
Finite element discretization library
operator.cpp
Go to the documentation of this file.
1 // Copyright (c) 2010-2023, Lawrence Livermore National Security, LLC. Produced
2 // at the Lawrence Livermore National Laboratory. All Rights reserved. See files
3 // LICENSE and NOTICE for details. LLNL-CODE-806117.
4 //
5 // This file is part of the MFEM library. For more information and source code
6 // availability visit https://mfem.org.
7 //
8 // MFEM is free software; you can redistribute it and/or modify it under the
9 // terms of the BSD-3 license. We welcome feedback and contributions, see file
10 // CONTRIBUTING.md for details.
11 
12 #include "operator.hpp"
13 
14 #include "../../../config/config.hpp"
15 #include "../../../linalg/vector.hpp"
16 #include "../../fespace.hpp"
17 #include "util.hpp"
18 #include "ceed.hpp"
19 
20 namespace mfem
21 {
22 
23 namespace ceed
24 {
25 
26 #ifdef MFEM_USE_CEED
27 Operator::Operator(CeedOperator op)
28 {
29  oper = op;
30  CeedSize in_len, out_len;
31  int ierr = CeedOperatorGetActiveVectorLengths(oper, &in_len, &out_len);
32  PCeedChk(ierr);
33  height = out_len;
34  width = in_len;
35  MFEM_VERIFY(height == out_len, "height overflow");
36  MFEM_VERIFY(width == in_len, "width overflow");
37  CeedVectorCreate(internal::ceed, height, &v);
38  CeedVectorCreate(internal::ceed, width, &u);
39 }
40 #endif
41 
42 void Operator::Mult(const mfem::Vector &x, mfem::Vector &y) const
43 {
44 #ifdef MFEM_USE_CEED
45  const CeedScalar *x_ptr;
46  CeedScalar *y_ptr;
47  CeedMemType mem;
48  CeedGetPreferredMemType(mfem::internal::ceed, &mem);
49  if ( Device::Allows(Backend::DEVICE_MASK) && mem==CEED_MEM_DEVICE )
50  {
51  x_ptr = x.Read();
52  y_ptr = y.Write();
53  }
54  else
55  {
56  x_ptr = x.HostRead();
57  y_ptr = y.HostWrite();
58  mem = CEED_MEM_HOST;
59  }
60  CeedVectorSetArray(u, mem, CEED_USE_POINTER, const_cast<CeedScalar*>(x_ptr));
61  CeedVectorSetArray(v, mem, CEED_USE_POINTER, y_ptr);
62 
63  CeedOperatorApply(oper, u, v, CEED_REQUEST_IMMEDIATE);
64 
65  CeedVectorTakeArray(u, mem, const_cast<CeedScalar**>(&x_ptr));
66  CeedVectorTakeArray(v, mem, &y_ptr);
67 #else
68  MFEM_ABORT("MFEM must be built with MFEM_USE_CEED=YES to use libCEED.");
69 #endif
70 }
71 
73  const double a) const
74 {
75 #ifdef MFEM_USE_CEED
76  MFEM_VERIFY(a == 1.0, "General coefficient case is not yet supported!");
77  const CeedScalar *x_ptr;
78  CeedScalar *y_ptr;
79  CeedMemType mem;
80  CeedGetPreferredMemType(mfem::internal::ceed, &mem);
81  if ( Device::Allows(Backend::DEVICE_MASK) && mem==CEED_MEM_DEVICE )
82  {
83  x_ptr = x.Read();
84  y_ptr = y.ReadWrite();
85  }
86  else
87  {
88  x_ptr = x.HostRead();
89  y_ptr = y.HostReadWrite();
90  mem = CEED_MEM_HOST;
91  }
92  CeedVectorSetArray(u, mem, CEED_USE_POINTER, const_cast<CeedScalar*>(x_ptr));
93  CeedVectorSetArray(v, mem, CEED_USE_POINTER, y_ptr);
94 
95  CeedOperatorApplyAdd(oper, u, v, CEED_REQUEST_IMMEDIATE);
96 
97  CeedVectorTakeArray(u, mem, const_cast<CeedScalar**>(&x_ptr));
98  CeedVectorTakeArray(v, mem, &y_ptr);
99 #else
100  MFEM_ABORT("MFEM must be built with MFEM_USE_CEED=YES to use libCEED.");
101 #endif
102 }
103 
105 {
106 #ifdef MFEM_USE_CEED
107  CeedScalar *d_ptr;
108  CeedMemType mem;
109  CeedGetPreferredMemType(mfem::internal::ceed, &mem);
110  if ( Device::Allows(Backend::DEVICE_MASK) && mem==CEED_MEM_DEVICE )
111  {
112  d_ptr = diag.ReadWrite();
113  }
114  else
115  {
116  d_ptr = diag.HostReadWrite();
117  mem = CEED_MEM_HOST;
118  }
119  CeedVectorSetArray(v, mem, CEED_USE_POINTER, d_ptr);
120 
121  CeedOperatorLinearAssembleAddDiagonal(oper, v, CEED_REQUEST_IMMEDIATE);
122 
123  CeedVectorTakeArray(v, mem, &d_ptr);
124 #else
125  MFEM_ABORT("MFEM must be built with MFEM_USE_CEED=YES to use libCEED.");
126 #endif
127 }
128 
129 } // namespace ceed
130 
131 } // namespace mfem
virtual const double * HostRead() const
Shortcut for mfem::Read(vec.GetMemory(), vec.Size(), false).
Definition: vector.hpp:452
virtual double * HostWrite()
Shortcut for mfem::Write(vec.GetMemory(), vec.Size(), false).
Definition: vector.hpp:460
virtual const double * Read(bool on_dev=true) const
Shortcut for mfem::Read(vec.GetMemory(), vec.Size(), on_dev).
Definition: vector.hpp:448
void AddMult(const mfem::Vector &x, mfem::Vector &y, const double a=1.0) const override
Operator application: y+=A(x) (default) or y+=a*A(x).
Definition: operator.cpp:72
CeedOperator oper
Definition: operator.hpp:29
virtual double * Write(bool on_dev=true)
Shortcut for mfem::Write(vec.GetMemory(), vec.Size(), on_dev).
Definition: vector.hpp:456
static bool Allows(unsigned long b_mask)
Return true if any of the backends in the backend mask, b_mask, are allowed.
Definition: device.hpp:258
int height
Dimension of the output / number of rows in the matrix.
Definition: operator.hpp:27
double a
Definition: lissajous.cpp:41
virtual double * ReadWrite(bool on_dev=true)
Shortcut for mfem::ReadWrite(vec.GetMemory(), vec.Size(), on_dev).
Definition: vector.hpp:464
void GetDiagonal(mfem::Vector &diag) const
Definition: operator.cpp:104
void Mult(const mfem::Vector &x, mfem::Vector &y) const override
Operator application: y=A(x).
Definition: operator.cpp:42
Vector data type.
Definition: vector.hpp:60
Biwise-OR of all device backends.
Definition: device.hpp:96
virtual double * HostReadWrite()
Shortcut for mfem::ReadWrite(vec.GetMemory(), vec.Size(), false).
Definition: vector.hpp:468
int width
Dimension of the input / number of columns in the matrix.
Definition: operator.hpp:28