13#include "../lapack.hpp"
19#define MFEM_MAGMA_PREFIX(stub) magma_s ## stub
20#define MFEM_MAGMABLAS_PREFIX(stub) magmablas_s ## stub
21#elif defined(MFEM_USE_DOUBLE)
22#define MFEM_MAGMA_PREFIX(stub) magma_d ## stub
23#define MFEM_MAGMABLAS_PREFIX(stub) magmablas_d ## stub
31 const magma_int_t status = magma_init();
32 MFEM_VERIFY(status == MAGMA_SUCCESS,
"Error initializing MAGMA.");
34 magma_getdevice(&dev);
35 magma_queue_create(dev, &queue);
40 magma_queue_destroy(queue);
41 const magma_int_t status = magma_finalize();
42 MFEM_VERIFY(status == MAGMA_SUCCESS,
"Error finalizing MAGMA.");
45Magma &Magma::Instance()
53 return Instance().queue;
60 const bool tr = (op == Op::T);
64 const int n_mat = A.
SizeK();
65 const int k = x.
Size() / n / n_mat;
71 magma_trans_t magma_op = tr ? MagmaNoTrans : MagmaTrans;
73 MFEM_MAGMABLAS_PREFIX(gemm_batched_strided)(
74 magma_op, MagmaNoTrans, m, k, n,
alpha, d_A, m, m*n, d_x, n, n*k,
80 const int n = A.
SizeI();
81 const int n_mat = A.
SizeK();
91 int **d_P_ptrs = P_ptrs.
Write();
94 d_A_ptrs[i] = A_base + i*n*n;
95 d_P_ptrs[i] = P_base + i*n;
99 const magma_int_t status = MFEM_MAGMA_PREFIX(getrf_batched)(
100 n, n, d_A_ptrs, n, d_P_ptrs,
102 MFEM_VERIFY(status == MAGMA_SUCCESS,
"");
108 const int n = LU.
SizeI();
109 const int n_mat = LU.
SizeK();
110 const int n_rhs = x.
Size() / n / n_mat;
117 int **d_P_ptrs = P_ptrs.
Write();
122 int *P_base =
const_cast<int*
>(P.
Read());
125 d_A_ptrs[i] = A_base + i*n*n;
126 d_B_ptrs[i] = B_base + i*n*n_rhs;
127 d_P_ptrs[i] = P_base + i*n;
131 const magma_int_t status = MFEM_MAGMA_PREFIX(getrs_batched)(
132 MagmaNoTrans, n, n_rhs, d_A_ptrs, n, d_P_ptrs,
134 MFEM_VERIFY(status == MAGMA_SUCCESS,
"");
139 const int n = A.
SizeI();
140 const int n_mat = A.
SizeK();
153 int **d_P_ptrs = P_ptrs.
Write();
157 int *P_base = P.
Write();
160 d_A_ptrs[i] = A_base + i*n*n;
161 d_LU_ptrs[i] = LU_base + i*n*n;
162 d_P_ptrs[i] = P_base + i*n;
169 status = MFEM_MAGMA_PREFIX(getrf_batched)(
170 n, n, d_A_ptrs, n, d_P_ptrs, info_array.
Write(), n_mat,
172 MFEM_VERIFY(status == MAGMA_SUCCESS,
"");
174 status = MFEM_MAGMA_PREFIX(getri_outofplace_batched)(
175 n, d_LU_ptrs, n, d_P_ptrs, d_A_ptrs, n, info_array.
Write(),
177 MFEM_VERIFY(status == MAGMA_SUCCESS,
"");
T * ReadWrite(bool on_dev=true)
Shortcut for mfem::ReadWrite(a.GetMemory(), a.Size(), on_dev).
void SetSize(int nsize)
Change the logical size of the array, keep existing entries.
T * Write(bool on_dev=true)
Shortcut for mfem::Write(a.GetMemory(), a.Size(), on_dev).
const T * Read(bool on_dev=true) const
Shortcut for mfem::Read(a.GetMemory(), a.Size(), on_dev).
Op
Operation type (transposed or not transposed)
Rank 3 tensor (array of matrices)
Memory< real_t > & GetMemory()
const real_t * Read(bool on_dev=true) const
Shortcut for mfem::Read( GetMemory(), TotalSize(), on_dev).
real_t * Write(bool on_dev=true)
Shortcut for mfem::Write(GetMemory(), TotalSize(), on_dev).
real_t * ReadWrite(bool on_dev=true)
Shortcut for mfem::ReadWrite(GetMemory(), TotalSize(), on_dev).
void LUSolve(const DenseTensor &A, const Array< int > &P, Vector &x) const override
See BatchedLinAlg::LUSolve.
void LUFactor(DenseTensor &A, Array< int > &P) const override
See BatchedLinAlg::LUFactor.
void AddMult(const DenseTensor &A, const Vector &x, Vector &y, real_t alpha=1.0, real_t beta=1.0, Op op=Op::N) const override
See BatchedLinAlg::AddMult.
void Invert(DenseTensor &A) const override
See BatchedLinAlg::Invert.
static magma_queue_t Queue()
Return the queue, creating it if needed.
void CopyFrom(const Memory &src, int size)
Copy size entries from src to *this.
virtual const real_t * Read(bool on_dev=true) const
Shortcut for mfem::Read(vec.GetMemory(), vec.Size(), on_dev).
virtual real_t * ReadWrite(bool on_dev=true)
Shortcut for mfem::ReadWrite(vec.GetMemory(), vec.Size(), on_dev).
int Size() const
Returns the size of the vector.
virtual real_t * Write(bool on_dev=true)
Shortcut for mfem::Write(vec.GetMemory(), vec.Size(), on_dev).
void forall(int N, lambda &&body)