MFEM v4.8.0
Finite element discretization library
Loading...
Searching...
No Matches
magma.cpp
Go to the documentation of this file.
1// Copyright (c) 2010-2025, Lawrence Livermore National Security, LLC. Produced
2// at the Lawrence Livermore National Laboratory. All Rights reserved. See files
3// LICENSE and NOTICE for details. LLNL-CODE-806117.
4//
5// This file is part of the MFEM library. For more information and source code
6// availability visit https://mfem.org.
7//
8// MFEM is free software; you can redistribute it and/or modify it under the
9// terms of the BSD-3 license. We welcome feedback and contributions, see file
10// CONTRIBUTING.md for details.
11
12#include "magma.hpp"
13#include "../lapack.hpp"
15
16#ifdef MFEM_USE_MAGMA
17
18#ifdef MFEM_USE_SINGLE
19#define MFEM_MAGMA_PREFIX(stub) magma_s ## stub
20#define MFEM_MAGMABLAS_PREFIX(stub) magmablas_s ## stub
21#elif defined(MFEM_USE_DOUBLE)
22#define MFEM_MAGMA_PREFIX(stub) magma_d ## stub
23#define MFEM_MAGMABLAS_PREFIX(stub) magmablas_d ## stub
24#endif
25
26namespace mfem
27{
28
29Magma::Magma()
30{
31 const magma_int_t status = magma_init();
32 MFEM_VERIFY(status == MAGMA_SUCCESS, "Error initializing MAGMA.");
33 magma_device_t dev;
34 magma_getdevice(&dev);
35 magma_queue_create(dev, &queue);
36}
37
38Magma::~Magma()
39{
40 magma_queue_destroy(queue);
41 const magma_int_t status = magma_finalize();
42 MFEM_VERIFY(status == MAGMA_SUCCESS, "Error finalizing MAGMA.");
43}
44
45Magma &Magma::Instance()
46{
47 static Magma magma;
48 return magma;
49}
50
51magma_queue_t Magma::Queue()
52{
53 return Instance().queue;
54}
55
58 Op op) const
59{
60 const bool tr = (op == Op::T);
61
62 const int m = tr ? A.SizeJ() : A.SizeI();
63 const int n = tr ? A.SizeI() : A.SizeJ();
64 const int n_mat = A.SizeK();
65 const int k = x.Size() / n / n_mat;
66
67 auto d_A = A.Read();
68 auto d_x = x.Read(); // Shape (n, k, n_mat);
69 auto d_y = beta == 0.0 ? y.Write() : y.ReadWrite(); // Shape (m, k, n_mat);
70
71 magma_trans_t magma_op = tr ? MagmaNoTrans : MagmaTrans;
72
73 MFEM_MAGMABLAS_PREFIX(gemm_batched_strided)(
74 magma_op, MagmaNoTrans, m, k, n, alpha, d_A, m, m*n, d_x, n, n*k,
75 beta, d_y, m, m*k, n_mat, Magma::Queue());
76}
77
79{
80 const int n = A.SizeI();
81 const int n_mat = A.SizeK();
82
83 P.SetSize(n*n_mat);
84
85 real_t *A_base = A.ReadWrite();
86 int *P_base = P.ReadWrite();
87
88 Array<real_t*> A_ptrs(n_mat);
89 Array<int*> P_ptrs(n_mat);
90 real_t **d_A_ptrs = A_ptrs.Write();
91 int **d_P_ptrs = P_ptrs.Write();
92 mfem::forall(n_mat, [=] MFEM_HOST_DEVICE (int i)
93 {
94 d_A_ptrs[i] = A_base + i*n*n;
95 d_P_ptrs[i] = P_base + i*n;
96 });
97
98 Array<int> info_array(n_mat);
99 const magma_int_t status = MFEM_MAGMA_PREFIX(getrf_batched)(
100 n, n, d_A_ptrs, n, d_P_ptrs,
101 info_array.Write(), n_mat, Magma::Queue());
102 MFEM_VERIFY(status == MAGMA_SUCCESS, "");
103}
104
106 const DenseTensor &LU, const Array<int> &P, Vector &x) const
107{
108 const int n = LU.SizeI();
109 const int n_mat = LU.SizeK();
110 const int n_rhs = x.Size() / n / n_mat;
111
112 Array<real_t*> A_ptrs(n_mat);
113 Array<real_t*> B_ptrs(n_mat);
114 Array<int*> P_ptrs(n_mat);
115 real_t **d_A_ptrs = A_ptrs.Write();
116 real_t **d_B_ptrs = B_ptrs.Write();
117 int **d_P_ptrs = P_ptrs.Write();
118
119 {
120 real_t *A_base = const_cast<real_t*>(LU.Read());
121 real_t *B_base = x.ReadWrite();
122 int *P_base = const_cast<int*>(P.Read());
123 mfem::forall(n_mat, [=] MFEM_HOST_DEVICE (int i)
124 {
125 d_A_ptrs[i] = A_base + i*n*n;
126 d_B_ptrs[i] = B_base + i*n*n_rhs;
127 d_P_ptrs[i] = P_base + i*n;
128 });
129 }
130
131 const magma_int_t status = MFEM_MAGMA_PREFIX(getrs_batched)(
132 MagmaNoTrans, n, n_rhs, d_A_ptrs, n, d_P_ptrs,
133 d_B_ptrs, n, n_mat, Magma::Queue());
134 MFEM_VERIFY(status == MAGMA_SUCCESS, "");
135}
136
138{
139 const int n = A.SizeI();
140 const int n_mat = A.SizeK();
141
142 DenseTensor LU(A.SizeI(), A.SizeJ(), A.SizeK());
143 LU.Write();
144 LU.GetMemory().CopyFrom(A.GetMemory(), A.TotalSize());
145
146 Array<int> P(n*n_mat);
147
148 Array<real_t*> LU_ptrs(n_mat);
149 Array<real_t*> A_ptrs(n_mat);
150 Array<int*> P_ptrs(n_mat);
151 real_t **d_A_ptrs = A_ptrs.Write();
152 real_t **d_LU_ptrs = LU_ptrs.Write();
153 int **d_P_ptrs = P_ptrs.Write();
154 {
155 real_t *A_base = A.ReadWrite();
156 real_t *LU_base = LU.Write();
157 int *P_base = P.Write();
158 mfem::forall(n_mat, [=] MFEM_HOST_DEVICE (int i)
159 {
160 d_A_ptrs[i] = A_base + i*n*n;
161 d_LU_ptrs[i] = LU_base + i*n*n;
162 d_P_ptrs[i] = P_base + i*n;
163 });
164 }
165
166 Array<int> info_array(n_mat);
167 magma_int_t status;
168
169 status = MFEM_MAGMA_PREFIX(getrf_batched)(
170 n, n, d_A_ptrs, n, d_P_ptrs, info_array.Write(), n_mat,
171 Magma::Queue());
172 MFEM_VERIFY(status == MAGMA_SUCCESS, "");
173
174 status = MFEM_MAGMA_PREFIX(getri_outofplace_batched)(
175 n, d_LU_ptrs, n, d_P_ptrs, d_A_ptrs, n, info_array.Write(),
176 n_mat, Magma::Queue());
177 MFEM_VERIFY(status == MAGMA_SUCCESS, "");
178}
179
180} // namespace mfem
181
182#endif
T * ReadWrite(bool on_dev=true)
Shortcut for mfem::ReadWrite(a.GetMemory(), a.Size(), on_dev).
Definition array.hpp:353
void SetSize(int nsize)
Change the logical size of the array, keep existing entries.
Definition array.hpp:758
T * Write(bool on_dev=true)
Shortcut for mfem::Write(a.GetMemory(), a.Size(), on_dev).
Definition array.hpp:345
const T * Read(bool on_dev=true) const
Shortcut for mfem::Read(a.GetMemory(), a.Size(), on_dev).
Definition array.hpp:337
Op
Operation type (transposed or not transposed)
Definition batched.hpp:54
Rank 3 tensor (array of matrices)
Memory< real_t > & GetMemory()
int SizeJ() const
int TotalSize() const
const real_t * Read(bool on_dev=true) const
Shortcut for mfem::Read( GetMemory(), TotalSize(), on_dev).
real_t * Write(bool on_dev=true)
Shortcut for mfem::Write(GetMemory(), TotalSize(), on_dev).
real_t * ReadWrite(bool on_dev=true)
Shortcut for mfem::ReadWrite(GetMemory(), TotalSize(), on_dev).
int SizeI() const
int SizeK() const
void LUSolve(const DenseTensor &A, const Array< int > &P, Vector &x) const override
See BatchedLinAlg::LUSolve.
Definition magma.cpp:105
void LUFactor(DenseTensor &A, Array< int > &P) const override
See BatchedLinAlg::LUFactor.
Definition magma.cpp:78
void AddMult(const DenseTensor &A, const Vector &x, Vector &y, real_t alpha=1.0, real_t beta=1.0, Op op=Op::N) const override
See BatchedLinAlg::AddMult.
Definition magma.cpp:56
void Invert(DenseTensor &A) const override
See BatchedLinAlg::Invert.
Definition magma.cpp:137
static magma_queue_t Queue()
Return the queue, creating it if needed.
Definition magma.cpp:51
void CopyFrom(const Memory &src, int size)
Copy size entries from src to *this.
Vector data type.
Definition vector.hpp:82
virtual const real_t * Read(bool on_dev=true) const
Shortcut for mfem::Read(vec.GetMemory(), vec.Size(), on_dev).
Definition vector.hpp:494
virtual real_t * ReadWrite(bool on_dev=true)
Shortcut for mfem::ReadWrite(vec.GetMemory(), vec.Size(), on_dev).
Definition vector.hpp:510
int Size() const
Returns the size of the vector.
Definition vector.hpp:226
virtual real_t * Write(bool on_dev=true)
Shortcut for mfem::Write(vec.GetMemory(), vec.Size(), on_dev).
Definition vector.hpp:502
Vector beta
const real_t alpha
Definition ex15.cpp:369
float real_t
Definition config.hpp:43
void forall(int N, lambda &&body)
Definition forall.hpp:753