15#if defined(MFEM_USE_CUDA) 
   16#define MFEM_cu_or_hip(stub) cu##stub 
   17#define MFEM_CU_or_HIP(stub) CU##stub 
   18#elif defined(MFEM_USE_HIP) 
   19#define MFEM_cu_or_hip(stub) hip##stub 
   20#define MFEM_CU_or_HIP(stub) HIP##stub 
   23#define MFEM_CONCAT(x, y, z) MFEM_CONCAT_(x, y, z) 
   24#define MFEM_CONCAT_(x, y, z) x ## y ## z 
   27#define MFEM_GPUBLAS_PREFIX(stub) MFEM_CONCAT(MFEM_cu_or_hip(blas), S, stub) 
   28#elif defined(MFEM_USE_DOUBLE) 
   29#define MFEM_GPUBLAS_PREFIX(stub) MFEM_CONCAT(MFEM_cu_or_hip(blas), D, stub) 
   32#define MFEM_BLAS_SUCCESS MFEM_CU_or_HIP(BLAS_STATUS_SUCCESS) 
   37GPUBlas &GPUBlas::Instance()
 
   39   static GPUBlas instance;
 
   45   return Instance().handle;
 
 
   48#ifndef MFEM_USE_CUDA_OR_HIP 
   51GPUBlas::~GPUBlas() { }
 
   61   blasStatus_t status = MFEM_cu_or_hip(blasCreate)(&handle);
 
   62   MFEM_VERIFY(status == MFEM_BLAS_SUCCESS, 
"Cannot initialize GPU BLAS.");
 
   67   MFEM_cu_or_hip(blasDestroy)(handle);
 
   72   const blasStatus_t status = MFEM_cu_or_hip(blasSetAtomicsMode)(
 
   73                                  Handle(), MFEM_CU_or_HIP(BLAS_ATOMICS_ALLOWED));
 
   74   MFEM_VERIFY(status == MFEM_BLAS_SUCCESS, 
"GPU BLAS error.");
 
   79   const blasStatus_t status = MFEM_cu_or_hip(blasSetAtomicsMode)(
 
   80                                  Handle(), MFEM_CU_or_HIP(BLAS_ATOMICS_NOT_ALLOWED));
 
   81   MFEM_VERIFY(status == MFEM_BLAS_SUCCESS, 
"GPU BLAS error.");
 
   88   const bool tr = (op == Op::T);
 
   92   const int n_mat = A.
SizeK();
 
   93   const int k = x.
Size() / n / n_mat;
 
   99   const auto op_A = tr ? MFEM_CU_or_HIP(BLAS_OP_T) : MFEM_CU_or_HIP(BLAS_OP_N);
 
  100   const auto op_B = MFEM_CU_or_HIP(BLAS_OP_N);
 
  102   const blasStatus_t status = MFEM_GPUBLAS_PREFIX(gemmStridedBatched)(
 
  104                                  &
alpha, d_A, m, m*n, d_x, n, n*k, &
beta, d_y,
 
  106   MFEM_VERIFY(status == MFEM_BLAS_SUCCESS, 
"GPU BLAS error.");
 
 
  111   const int n = A.
SizeI();
 
  112   const int n_mat = A.
SizeK();
 
  123      d_A_ptrs[i] = A_base + i*n*n;
 
  126   const blasStatus_t status = MFEM_GPUBLAS_PREFIX(getrfBatched)(
 
  128                                  info_array.
Write(), n_mat);
 
  129   MFEM_VERIFY(status == MFEM_BLAS_SUCCESS, 
"");
 
 
  135   const int n = LU.
SizeI();
 
  136   const int n_mat = LU.
SizeK();
 
  137   const int n_rhs = x.
Size() / n / n_mat;
 
  149         d_A_ptrs[i] = A_base + i*n*n;
 
  150         d_B_ptrs[i] = B_base + i*n*n_rhs;
 
  155   const blasStatus_t status = MFEM_GPUBLAS_PREFIX(getrsBatched)(
 
  157                                  n, n_rhs, d_A_ptrs, n, P.
Read(), d_B_ptrs, n,
 
  159   MFEM_VERIFY(status == MFEM_BLAS_SUCCESS, 
"");
 
 
  164   const int n = A.
SizeI();
 
  165   const int n_mat = A.
SizeK();
 
  180         d_A_ptrs[i] = A_base + i*n*n;
 
  181         d_LU_ptrs[i] = LU_base + i*n*n;
 
  189   status = MFEM_GPUBLAS_PREFIX(getrfBatched)(
 
  191               info_array.
Write(), n_mat);
 
  192   MFEM_VERIFY(status == MFEM_BLAS_SUCCESS, 
"");
 
  194   status = MFEM_GPUBLAS_PREFIX(getriBatched)(
 
  196               info_array.
Write(), n_mat);
 
  197   MFEM_VERIFY(status == MFEM_BLAS_SUCCESS, 
"");
 
 
T * ReadWrite(bool on_dev=true)
Shortcut for mfem::ReadWrite(a.GetMemory(), a.Size(), on_dev).
void SetSize(int nsize)
Change the logical size of the array, keep existing entries.
T * Write(bool on_dev=true)
Shortcut for mfem::Write(a.GetMemory(), a.Size(), on_dev).
const T * Read(bool on_dev=true) const
Shortcut for mfem::Read(a.GetMemory(), a.Size(), on_dev).
Op
Operation type (transposed or not transposed)
Rank 3 tensor (array of matrices)
Memory< real_t > & GetMemory()
const real_t * Read(bool on_dev=true) const
Shortcut for mfem::Read( GetMemory(), TotalSize(), on_dev).
real_t * Write(bool on_dev=true)
Shortcut for mfem::Write(GetMemory(), TotalSize(), on_dev).
real_t * ReadWrite(bool on_dev=true)
Shortcut for mfem::ReadWrite(GetMemory(), TotalSize(), on_dev).
void AddMult(const DenseTensor &A, const Vector &x, Vector &y, real_t alpha=1.0, real_t beta=1.0, Op op=Op::N) const override
See BatchedLinAlg::AddMult.
void LUSolve(const DenseTensor &LU, const Array< int > &P, Vector &x) const override
See BatchedLinAlg::LUSolve.
void LUFactor(DenseTensor &A, Array< int > &P) const override
See BatchedLinAlg::LUFactor.
void Invert(DenseTensor &A) const override
See BatchedLinAlg::Invert.
static void EnableAtomics()
Enable atomic operations.
static void DisableAtomics()
Disable atomic operations.
static HandleType Handle()
Return the handle, creating it if needed.
void CopyFrom(const Memory &src, int size)
Copy size entries from src to *this.
virtual const real_t * Read(bool on_dev=true) const
Shortcut for mfem::Read(vec.GetMemory(), vec.Size(), on_dev).
virtual real_t * ReadWrite(bool on_dev=true)
Shortcut for mfem::ReadWrite(vec.GetMemory(), vec.Size(), on_dev).
int Size() const
Returns the size of the vector.
virtual real_t * Write(bool on_dev=true)
Shortcut for mfem::Write(vec.GetMemory(), vec.Size(), on_dev).
MFEM_cu_or_hip(blasStatus_t) blasStatus_t
void forall(int N, lambda &&body)