24 const bool tr = (op == Op::T);
26 const int m = A.
SizeI();
27 const int n = A.
SizeJ();
28 const int n_mat = A.
SizeK();
29 const int k = x.
Size() / (tr ? m : n) / n_mat;
32 auto d_x =
Reshape(x.
Read(), (tr ? m : n), k, n_mat);
34 (tr ? n : m), k, n_mat);
40 kernels::AddMultAtB(m, n, k, &d_A(0,0,i), &d_x(0,0,i), &d_y(0,0,i),
48 kernels::AddMult(m, k, n, &d_A(0,0,i), &d_x(0,0,i), &d_y(0,0,i),
66 const int m = A.
SizeI();
67 const int NE = A.
SizeK();
81 real_t *X = &inv_all(0, 0, e);
83 const real_t *data = &data_all(0, 0, e);
84 const int *ipiv = &piv_all(0, e);
86 for (
int k = 0; k < m; k++)
88 const real_t minus_x_k = -(x[k] = 1.0 / data[k + k * m]);
89 for (
int i = 0; i < k; i++)
91 x[i] = data[i + k * m] * minus_x_k;
93 for (
int j = k - 1; j >= 0; j--)
95 const real_t x_j = (x[j] /= data[j + j * m]);
96 for (
int i = 0; i < j; i++)
98 x[i] -= data[i + j * m] * x_j;
107 for (
int j = 0; j < k; j++)
109 const real_t minus_L_kj = -data[k + j * m];
110 for (
int i = 0; i <= j; i++)
112 X[i + j * m] += X[i + k * m] * minus_L_kj;
114 for (
int i = j + 1; i < m; i++)
116 X[i + j * m] = X[i + k * m] * minus_L_kj;
120 for (
int k = m - 2; k >= 0; k--)
122 for (
int j = 0; j < k; j++)
124 const real_t L_kj = data[k + j * m];
125 for (
int i = 0; i < m; i++)
127 X[i + j * m] -= X[i + k * m] * L_kj;
133 for (
int k = m - 1; k >= 0; k--)
135 const int piv_k = ipiv[k];
138 for (
int i = 0; i < m; i++)
140 kernels::internal::Swap(X[i + k * m], X[i + piv_k * m]);
149 const int m = A.
SizeI();
150 const int NE = A.
SizeK();
156 pivot_flag[0] =
true;
157 bool *d_pivot_flag = pivot_flag.
ReadWrite();
162 if (!flag) { d_pivot_flag[0] =
false; }
165 MFEM_VERIFY(pivot_flag.
HostRead()[0],
"Batch LU factorization failed");
171 const int m = LU.
SizeI();
172 const int n_mat = LU.
SizeK();
173 const int n_rhs = x.
Size() / m / n_mat;
179 mfem::forall(n_mat * n_rhs, [=] MFEM_HOST_DEVICE (
int idx)
181 const int i_rhs = idx % n_rhs;
182 const int i_mat = idx / n_rhs;
const T * HostRead() const
Shortcut for mfem::Read(a.GetMemory(), a.Size(), false).
T * ReadWrite(bool on_dev=true)
Shortcut for mfem::ReadWrite(a.GetMemory(), a.Size(), on_dev).
void SetSize(int nsize)
Change the logical size of the array, keep existing entries.
T * Write(bool on_dev=true)
Shortcut for mfem::Write(a.GetMemory(), a.Size(), on_dev).
const T * Read(bool on_dev=true) const
Shortcut for mfem::Read(a.GetMemory(), a.Size(), on_dev).
Op
Operation type (transposed or not transposed)
Rank 3 tensor (array of matrices)
const real_t * Read(bool on_dev=true) const
Shortcut for mfem::Read( GetMemory(), TotalSize(), on_dev).
real_t * Write(bool on_dev=true)
Shortcut for mfem::Write(GetMemory(), TotalSize(), on_dev).
real_t * ReadWrite(bool on_dev=true)
Shortcut for mfem::ReadWrite(GetMemory(), TotalSize(), on_dev).
void LUSolve(const DenseTensor &LU, const Array< int > &P, Vector &x) const override
See BatchedLinAlg::LUSolve.
void AddMult(const DenseTensor &A, const Vector &x, Vector &y, real_t alpha, real_t beta, Op op) const override
See BatchedLinAlg::AddMult.
void Invert(DenseTensor &A) const override
See BatchedLinAlg::Invert.
void LUFactor(DenseTensor &A, Array< int > &P) const override
See BatchedLinAlg::LUFactor.
virtual const real_t * Read(bool on_dev=true) const
Shortcut for mfem::Read(vec.GetMemory(), vec.Size(), on_dev).
virtual real_t * ReadWrite(bool on_dev=true)
Shortcut for mfem::ReadWrite(vec.GetMemory(), vec.Size(), on_dev).
int Size() const
Returns the size of the vector.
virtual real_t * Write(bool on_dev=true)
Shortcut for mfem::Write(vec.GetMemory(), vec.Size(), on_dev).
MFEM_HOST_DEVICE void AddMultAtB(const int Aheight, const int Awidth, const int Bwidth, const TA *Adata, const TB *Bdata, TC *Cdata, const TB alpha, const TA beta)
Compute C = alpha*At*B + beta*C.
MFEM_HOST_DEVICE void AddMult(const int Aheight, const int Awidth, const int Bwidth, const TB *Bdata, const TC *Cdata, TA *Adata, const TB alpha, const TA beta)
Matrix-matrix multiplication: A = alpha * B * C + beta * A, where the matrices A, B and C are of size...
MFEM_HOST_DEVICE bool LUFactor(real_t *A, const int m, int *ipiv, const real_t tol=0.0)
Compute the LU factorization of the m x m matrix A.
MFEM_HOST_DEVICE void LUSolve(const real_t *data, const int m, const int *ipiv, real_t *x)
Assuming L.U = P.A for a factored matrix (m x m),.
MFEM_HOST_DEVICE DeviceTensor< sizeof...(Dims), T > Reshape(T *ptr, Dims... dims)
Wrap a pointer as a DeviceTensor with automatically deduced template parameters.
void forall(int N, lambda &&body)