15#if defined(MFEM_USE_CUDA)
16#define MFEM_cu_or_hip(stub) cu##stub
17#define MFEM_CU_or_HIP(stub) CU##stub
18#elif defined(MFEM_USE_HIP)
19#define MFEM_cu_or_hip(stub) hip##stub
20#define MFEM_CU_or_HIP(stub) HIP##stub
23#define MFEM_CONCAT(x, y, z) MFEM_CONCAT_(x, y, z)
24#define MFEM_CONCAT_(x, y, z) x ## y ## z
27#define MFEM_GPUBLAS_PREFIX(stub) MFEM_CONCAT(MFEM_cu_or_hip(blas), S, stub)
28#elif defined(MFEM_USE_DOUBLE)
29#define MFEM_GPUBLAS_PREFIX(stub) MFEM_CONCAT(MFEM_cu_or_hip(blas), D, stub)
32#define MFEM_BLAS_SUCCESS MFEM_CU_or_HIP(BLAS_STATUS_SUCCESS)
37GPUBlas &GPUBlas::Instance()
39 static GPUBlas instance;
45 return Instance().handle;
48#ifndef MFEM_USE_CUDA_OR_HIP
51GPUBlas::~GPUBlas() { }
61 blasStatus_t status = MFEM_cu_or_hip(blasCreate)(&handle);
62 MFEM_VERIFY(status == MFEM_BLAS_SUCCESS,
"Cannot initialize GPU BLAS.");
67 MFEM_cu_or_hip(blasDestroy)(handle);
72 const blasStatus_t status = MFEM_cu_or_hip(blasSetAtomicsMode)(
73 Handle(), MFEM_CU_or_HIP(BLAS_ATOMICS_ALLOWED));
74 MFEM_VERIFY(status == MFEM_BLAS_SUCCESS,
"GPU BLAS error.");
79 const blasStatus_t status = MFEM_cu_or_hip(blasSetAtomicsMode)(
80 Handle(), MFEM_CU_or_HIP(BLAS_ATOMICS_NOT_ALLOWED));
81 MFEM_VERIFY(status == MFEM_BLAS_SUCCESS,
"GPU BLAS error.");
88 const bool tr = (op == Op::T);
92 const int n_mat = A.
SizeK();
93 const int k = x.
Size() / n / n_mat;
99 const auto op_A = tr ? MFEM_CU_or_HIP(BLAS_OP_T) : MFEM_CU_or_HIP(BLAS_OP_N);
100 const auto op_B = MFEM_CU_or_HIP(BLAS_OP_N);
102 const blasStatus_t status = MFEM_GPUBLAS_PREFIX(gemmStridedBatched)(
104 &
alpha, d_A, m, m*n, d_x, n, n*k, &
beta, d_y,
106 MFEM_VERIFY(status == MFEM_BLAS_SUCCESS,
"GPU BLAS error.");
111 const int n = A.
SizeI();
112 const int n_mat = A.
SizeK();
123 d_A_ptrs[i] = A_base + i*n*n;
126 const blasStatus_t status = MFEM_GPUBLAS_PREFIX(getrfBatched)(
128 info_array.
Write(), n_mat);
129 MFEM_VERIFY(status == MFEM_BLAS_SUCCESS,
"");
135 const int n = LU.
SizeI();
136 const int n_mat = LU.
SizeK();
137 const int n_rhs = x.
Size() / n / n_mat;
149 d_A_ptrs[i] = A_base + i*n*n;
150 d_B_ptrs[i] = B_base + i*n*n_rhs;
155 const blasStatus_t status = MFEM_GPUBLAS_PREFIX(getrsBatched)(
157 n, n_rhs, d_A_ptrs, n, P.
Read(), d_B_ptrs, n,
159 MFEM_VERIFY(status == MFEM_BLAS_SUCCESS,
"");
164 const int n = A.
SizeI();
165 const int n_mat = A.
SizeK();
180 d_A_ptrs[i] = A_base + i*n*n;
181 d_LU_ptrs[i] = LU_base + i*n*n;
189 status = MFEM_GPUBLAS_PREFIX(getrfBatched)(
191 info_array.
Write(), n_mat);
192 MFEM_VERIFY(status == MFEM_BLAS_SUCCESS,
"");
194 status = MFEM_GPUBLAS_PREFIX(getriBatched)(
196 info_array.
Write(), n_mat);
197 MFEM_VERIFY(status == MFEM_BLAS_SUCCESS,
"");
T * ReadWrite(bool on_dev=true)
Shortcut for mfem::ReadWrite(a.GetMemory(), a.Size(), on_dev).
void SetSize(int nsize)
Change the logical size of the array, keep existing entries.
T * Write(bool on_dev=true)
Shortcut for mfem::Write(a.GetMemory(), a.Size(), on_dev).
const T * Read(bool on_dev=true) const
Shortcut for mfem::Read(a.GetMemory(), a.Size(), on_dev).
Op
Operation type (transposed or not transposed)
Rank 3 tensor (array of matrices)
Memory< real_t > & GetMemory()
const real_t * Read(bool on_dev=true) const
Shortcut for mfem::Read( GetMemory(), TotalSize(), on_dev).
real_t * Write(bool on_dev=true)
Shortcut for mfem::Write(GetMemory(), TotalSize(), on_dev).
real_t * ReadWrite(bool on_dev=true)
Shortcut for mfem::ReadWrite(GetMemory(), TotalSize(), on_dev).
void AddMult(const DenseTensor &A, const Vector &x, Vector &y, real_t alpha=1.0, real_t beta=1.0, Op op=Op::N) const override
See BatchedLinAlg::AddMult.
void LUSolve(const DenseTensor &LU, const Array< int > &P, Vector &x) const override
See BatchedLinAlg::LUSolve.
void LUFactor(DenseTensor &A, Array< int > &P) const override
See BatchedLinAlg::LUFactor.
void Invert(DenseTensor &A) const override
See BatchedLinAlg::Invert.
static void EnableAtomics()
Enable atomic operations.
static void DisableAtomics()
Disable atomic operations.
static HandleType Handle()
Return the handle, creating it if needed.
void CopyFrom(const Memory &src, int size)
Copy size entries from src to *this.
virtual const real_t * Read(bool on_dev=true) const
Shortcut for mfem::Read(vec.GetMemory(), vec.Size(), on_dev).
virtual real_t * ReadWrite(bool on_dev=true)
Shortcut for mfem::ReadWrite(vec.GetMemory(), vec.Size(), on_dev).
int Size() const
Returns the size of the vector.
virtual real_t * Write(bool on_dev=true)
Shortcut for mfem::Write(vec.GetMemory(), vec.Size(), on_dev).
MFEM_cu_or_hip(blasStatus_t) blasStatus_t
void forall(int N, lambda &&body)