MFEM v4.8.0
Finite element discretization library
Loading...
Searching...
No Matches
gpu_blas.cpp
Go to the documentation of this file.
1// Copyright (c) 2010-2025, Lawrence Livermore National Security, LLC. Produced
2// at the Lawrence Livermore National Laboratory. All Rights reserved. See files
3// LICENSE and NOTICE for details. LLNL-CODE-806117.
4//
5// This file is part of the MFEM library. For more information and source code
6// availability visit https://mfem.org.
7//
8// MFEM is free software; you can redistribute it and/or modify it under the
9// terms of the BSD-3 license. We welcome feedback and contributions, see file
10// CONTRIBUTING.md for details.
11
12#include "gpu_blas.hpp"
14
15#if defined(MFEM_USE_CUDA)
16#define MFEM_cu_or_hip(stub) cu##stub
17#define MFEM_CU_or_HIP(stub) CU##stub
18#elif defined(MFEM_USE_HIP)
19#define MFEM_cu_or_hip(stub) hip##stub
20#define MFEM_CU_or_HIP(stub) HIP##stub
21#endif
22
23#define MFEM_CONCAT(x, y, z) MFEM_CONCAT_(x, y, z)
24#define MFEM_CONCAT_(x, y, z) x ## y ## z
25
26#ifdef MFEM_USE_SINGLE
27#define MFEM_GPUBLAS_PREFIX(stub) MFEM_CONCAT(MFEM_cu_or_hip(blas), S, stub)
28#elif defined(MFEM_USE_DOUBLE)
29#define MFEM_GPUBLAS_PREFIX(stub) MFEM_CONCAT(MFEM_cu_or_hip(blas), D, stub)
30#endif
31
32#define MFEM_BLAS_SUCCESS MFEM_CU_or_HIP(BLAS_STATUS_SUCCESS)
33
34namespace mfem
35{
36
37GPUBlas &GPUBlas::Instance()
38{
39 static GPUBlas instance;
40 return instance;
41}
42
43GPUBlas::HandleType GPUBlas::Handle()
44{
45 return Instance().handle;
46}
47
48#ifndef MFEM_USE_CUDA_OR_HIP
49
50GPUBlas::GPUBlas() { }
51GPUBlas::~GPUBlas() { }
54
55#else
56
57using blasStatus_t = MFEM_cu_or_hip(blasStatus_t);
58
59GPUBlas::GPUBlas()
60{
61 blasStatus_t status = MFEM_cu_or_hip(blasCreate)(&handle);
62 MFEM_VERIFY(status == MFEM_BLAS_SUCCESS, "Cannot initialize GPU BLAS.");
63}
64
65GPUBlas::~GPUBlas()
66{
67 MFEM_cu_or_hip(blasDestroy)(handle);
68}
69
71{
72 const blasStatus_t status = MFEM_cu_or_hip(blasSetAtomicsMode)(
73 Handle(), MFEM_CU_or_HIP(BLAS_ATOMICS_ALLOWED));
74 MFEM_VERIFY(status == MFEM_BLAS_SUCCESS, "GPU BLAS error.");
75}
76
78{
79 const blasStatus_t status = MFEM_cu_or_hip(blasSetAtomicsMode)(
80 Handle(), MFEM_CU_or_HIP(BLAS_ATOMICS_NOT_ALLOWED));
81 MFEM_VERIFY(status == MFEM_BLAS_SUCCESS, "GPU BLAS error.");
82}
83
86 Op op) const
87{
88 const bool tr = (op == Op::T);
89
90 const int m = tr ? A.SizeJ() : A.SizeI();
91 const int n = tr ? A.SizeI() : A.SizeJ();
92 const int n_mat = A.SizeK();
93 const int k = x.Size() / n / n_mat;
94
95 auto d_A = A.Read();
96 auto d_x = x.Read(); // Shape: (n, k, n_mat)
97 auto d_y = beta == 0.0 ? y.Write() : y.ReadWrite(); // Shape (m, k, n_mat)
98
99 const auto op_A = tr ? MFEM_CU_or_HIP(BLAS_OP_T) : MFEM_CU_or_HIP(BLAS_OP_N);
100 const auto op_B = MFEM_CU_or_HIP(BLAS_OP_N);
101
102 const blasStatus_t status = MFEM_GPUBLAS_PREFIX(gemmStridedBatched)(
103 GPUBlas::Handle(), op_A, op_B, m, k, n,
104 &alpha, d_A, m, m*n, d_x, n, n*k, &beta, d_y,
105 m, m*k, n_mat);
106 MFEM_VERIFY(status == MFEM_BLAS_SUCCESS, "GPU BLAS error.");
107}
108
110{
111 const int n = A.SizeI();
112 const int n_mat = A.SizeK();
113
114 P.SetSize(n*n_mat);
115
116 Array<int> info_array(n_mat);
117
118 real_t *A_base = A.ReadWrite();
119 Array<real_t*> A_ptrs(n_mat);
120 real_t **d_A_ptrs = A_ptrs.Write();
121 mfem::forall(n_mat, [=] MFEM_HOST_DEVICE (int i)
122 {
123 d_A_ptrs[i] = A_base + i*n*n;
124 });
125
126 const blasStatus_t status = MFEM_GPUBLAS_PREFIX(getrfBatched)(
127 GPUBlas::Handle(), n, d_A_ptrs, n, P.Write(),
128 info_array.Write(), n_mat);
129 MFEM_VERIFY(status == MFEM_BLAS_SUCCESS, "");
130}
131
133 const DenseTensor &LU, const Array<int> &P, Vector &x) const
134{
135 const int n = LU.SizeI();
136 const int n_mat = LU.SizeK();
137 const int n_rhs = x.Size() / n / n_mat;
138
139 Array<real_t*> A_ptrs(n_mat);
140 real_t **d_A_ptrs = A_ptrs.Write();
141 Array<real_t*> B_ptrs(n_mat);
142 real_t **d_B_ptrs = B_ptrs.Write();
143
144 {
145 real_t *A_base = const_cast<real_t*>(LU.Read());
146 real_t *B_base = x.ReadWrite();
147 mfem::forall(n_mat, [=] MFEM_HOST_DEVICE (int i)
148 {
149 d_A_ptrs[i] = A_base + i*n*n;
150 d_B_ptrs[i] = B_base + i*n*n_rhs;
151 });
152 }
153
154 int info = 0;
155 const blasStatus_t status = MFEM_GPUBLAS_PREFIX(getrsBatched)(
156 GPUBlas::Handle(), MFEM_CU_or_HIP(BLAS_OP_N),
157 n, n_rhs, d_A_ptrs, n, P.Read(), d_B_ptrs, n,
158 &info, n_mat);
159 MFEM_VERIFY(status == MFEM_BLAS_SUCCESS, "");
160}
161
163{
164 const int n = A.SizeI();
165 const int n_mat = A.SizeK();
166
167 DenseTensor LU(A.SizeI(), A.SizeJ(), A.SizeK());
168 LU.Write();
169 LU.GetMemory().CopyFrom(A.GetMemory(), A.TotalSize());
170
171 Array<real_t*> LU_ptrs(n_mat);
172 Array<real_t*> A_ptrs(n_mat);
173 real_t **d_A_ptrs = A_ptrs.Write();
174 real_t **d_LU_ptrs = LU_ptrs.Write();
175 {
176 real_t *A_base = A.ReadWrite();
177 real_t *LU_base = LU.Write();
178 mfem::forall(n_mat, [=] MFEM_HOST_DEVICE (int i)
179 {
180 d_A_ptrs[i] = A_base + i*n*n;
181 d_LU_ptrs[i] = LU_base + i*n*n;
182 });
183 }
184
185 Array<int> P(n*n_mat);
186 Array<int> info_array(n_mat);
187 blasStatus_t status;
188
189 status = MFEM_GPUBLAS_PREFIX(getrfBatched)(
190 GPUBlas::Handle(), n, d_LU_ptrs, n, P.Write(),
191 info_array.Write(), n_mat);
192 MFEM_VERIFY(status == MFEM_BLAS_SUCCESS, "");
193
194 status = MFEM_GPUBLAS_PREFIX(getriBatched)(
195 GPUBlas::Handle(), n, d_LU_ptrs, n, P.ReadWrite(), d_A_ptrs, n,
196 info_array.Write(), n_mat);
197 MFEM_VERIFY(status == MFEM_BLAS_SUCCESS, "");
198}
199
200#endif
201
202} // namespace mfem
T * ReadWrite(bool on_dev=true)
Shortcut for mfem::ReadWrite(a.GetMemory(), a.Size(), on_dev).
Definition array.hpp:353
void SetSize(int nsize)
Change the logical size of the array, keep existing entries.
Definition array.hpp:758
T * Write(bool on_dev=true)
Shortcut for mfem::Write(a.GetMemory(), a.Size(), on_dev).
Definition array.hpp:345
const T * Read(bool on_dev=true) const
Shortcut for mfem::Read(a.GetMemory(), a.Size(), on_dev).
Definition array.hpp:337
Op
Operation type (transposed or not transposed)
Definition batched.hpp:54
Rank 3 tensor (array of matrices)
Memory< real_t > & GetMemory()
int SizeJ() const
int TotalSize() const
const real_t * Read(bool on_dev=true) const
Shortcut for mfem::Read( GetMemory(), TotalSize(), on_dev).
real_t * Write(bool on_dev=true)
Shortcut for mfem::Write(GetMemory(), TotalSize(), on_dev).
real_t * ReadWrite(bool on_dev=true)
Shortcut for mfem::ReadWrite(GetMemory(), TotalSize(), on_dev).
int SizeI() const
int SizeK() const
void AddMult(const DenseTensor &A, const Vector &x, Vector &y, real_t alpha=1.0, real_t beta=1.0, Op op=Op::N) const override
See BatchedLinAlg::AddMult.
Definition gpu_blas.cpp:84
void LUSolve(const DenseTensor &LU, const Array< int > &P, Vector &x) const override
See BatchedLinAlg::LUSolve.
Definition gpu_blas.cpp:132
void LUFactor(DenseTensor &A, Array< int > &P) const override
See BatchedLinAlg::LUFactor.
Definition gpu_blas.cpp:109
void Invert(DenseTensor &A) const override
See BatchedLinAlg::Invert.
Definition gpu_blas.cpp:162
static void EnableAtomics()
Enable atomic operations.
Definition gpu_blas.cpp:52
static void DisableAtomics()
Disable atomic operations.
Definition gpu_blas.cpp:53
static HandleType Handle()
Return the handle, creating it if needed.
Definition gpu_blas.cpp:43
void CopyFrom(const Memory &src, int size)
Copy size entries from src to *this.
Vector data type.
Definition vector.hpp:82
virtual const real_t * Read(bool on_dev=true) const
Shortcut for mfem::Read(vec.GetMemory(), vec.Size(), on_dev).
Definition vector.hpp:494
virtual real_t * ReadWrite(bool on_dev=true)
Shortcut for mfem::ReadWrite(vec.GetMemory(), vec.Size(), on_dev).
Definition vector.hpp:510
int Size() const
Returns the size of the vector.
Definition vector.hpp:226
virtual real_t * Write(bool on_dev=true)
Shortcut for mfem::Write(vec.GetMemory(), vec.Size(), on_dev).
Definition vector.hpp:502
Vector beta
const real_t alpha
Definition ex15.cpp:369
MFEM_cu_or_hip(blasStatus_t) blasStatus_t
Definition gpu_blas.cpp:57
float real_t
Definition config.hpp:43
void forall(int N, lambda &&body)
Definition forall.hpp:753