12#ifndef MFEM_TEMPLATE_LAYOUT
13#define MFEM_TEMPLATE_LAYOUT
23template <
int N1,
int S1>
24struct OffsetStridedLayout1D;
25template <
int N1,
int S1,
int N2,
int S2>
26struct StridedLayout2D;
28template <
int N1,
int S1>
33 static const int size = N1;
35 MFEM_HOST_DEVICE
static inline int ind(
int i1)
48 template <
int N1_1,
int N1_2>
52 MFEM_STATIC_ASSERT(N1_1*N1_2 == N1,
"invalid dimensions");
57template <
int N1,
int S1,
int N2,
int S2>
58struct OffsetStridedLayout2D;
60template <
int N1,
int S1>
65 static const int size = N1;
71 MFEM_HOST_DEVICE
inline int ind(
int i1)
const
84 template <
int N1_1,
int N1_2>
88 MFEM_STATIC_ASSERT(N1_1*N1_2 == N1,
"invalid dimensions");
93template <
int N1,
int S1,
int N2,
int S2,
int N3,
int S3>
94struct StridedLayout3D;
95template <
int N1,
int S1,
int N2,
int S2,
int N3,
int S3,
int N4,
int S4>
96struct StridedLayout4D;
98template <
int N1,
int S1,
int N2,
int S2>
106 MFEM_HOST_DEVICE
static inline int ind(
int i1,
int i2)
108 return (S1*i1+S2*i2);
119 template <
int M1,
int M2>
127 template <
int N1_1,
int N1_2>
131 MFEM_STATIC_ASSERT(N1_1*N1_2 == N1,
"invalid dimensions");
134 template <
int N2_1,
int N2_2>
138 MFEM_STATIC_ASSERT(N2_1*N2_2 == N2,
"invalid dimensions");
141 template <
int N1_1,
int N1_2,
int N2_1,
int N2_2>
145 MFEM_STATIC_ASSERT(N1_1*N1_2 == N1 && N2_1*N2_2 == N2,
146 "invalid dimensions");
154 MFEM_STATIC_ASSERT(S2 == S1*N1 || S1 == S2*N2,
"invalid reshape");
163template <
int N1,
int S1,
int N2,
int S2,
int N3,
int S3>
164struct OffsetStridedLayout3D;
165template <
int N1,
int S1,
int N2,
int S2,
int N3,
int S3,
int N4,
int S4>
166struct OffsetStridedLayout4D;
168template <
int N1,
int S1,
int N2,
int S2>
180 MFEM_HOST_DEVICE
inline int ind(
int i1,
int i2)
const
182 return offset+S1*i1+S2*i2;
193 template <
int M1,
int M2>
201 template <
int N1_1,
int N1_2>
205 MFEM_STATIC_ASSERT(N1_1*N1_2 == N1,
"invalid dimensions");
208 template <
int N2_1,
int N2_2>
212 MFEM_STATIC_ASSERT(N2_1*N2_2 == N2,
"invalid dimensions");
215 template <
int N1_1,
int N1_2,
int N2_1,
int N2_2>
220 MFEM_STATIC_ASSERT(N1_1*N1_2 == N1 && N2_1*N2_2 == N2,
221 "invalid dimensions");
223 N1_1,S1,N1_2,S1*N1_1,N2_1,S2,N2_2,S2*N2_1>(
offset);
230 MFEM_STATIC_ASSERT(S2 == S1*N1 || S1 == S2*N2,
"invalid reshape");
239template <
int N1,
int S1,
int N2,
int S2,
int N3,
int S3>
246 static const int size = N1*N2*N3;
248 static inline int ind(
int i1,
int i2,
int i3)
250 return S1*i1+S2*i2+S3*i3;
271 MFEM_STATIC_ASSERT(S2 == S1*N1,
"invalid reshape");
282 MFEM_STATIC_ASSERT(S3 == S2*N2,
"invalid reshape");
286 template <
int N1_1,
int N1_2>
290 MFEM_STATIC_ASSERT(N1_1*N1_2 == N1,
"invalid dimensions");
293 template <
int N2_1,
int N2_2>
297 MFEM_STATIC_ASSERT(N2_1*N2_2 == N2,
"invalid dimensions");
300 template <
int N3_1,
int N3_2>
304 MFEM_STATIC_ASSERT(N3_1*N3_2 == N3,
"invalid dimensions");
322template <
int N1,
int S1,
int N2,
int S2,
int N3,
int S3>
329 static const int size = N1*N2*N3;
335 inline int ind(
int i1,
int i2,
int i3)
const
337 return offset+S1*i1+S2*i2+S3*i3;
358 MFEM_STATIC_ASSERT(S2 == S1*N1,
"invalid reshape");
365 MFEM_STATIC_ASSERT(S3 == S2*N2,
"invalid reshape");
369 template <
int N1_1,
int N1_2>
373 MFEM_STATIC_ASSERT(N1_1*N1_2 == N1,
"invalid dimensions");
376 template <
int N2_1,
int N2_2>
380 MFEM_STATIC_ASSERT(N2_1*N2_2 == N2,
"invalid dimensions");
385template <
int N1,
int S1,
int N2,
int S2,
int N3,
int S3,
int N4,
int S4>
393 static const int size = N1*N2*N3*N4;
395 static inline int ind(
int i1,
int i2,
int i3,
int i4)
397 return S1*i1+S2*i2+S3*i3+S4*i4;
416 MFEM_STATIC_ASSERT(S2 == S1*N1,
"invalid reshape");
423 MFEM_STATIC_ASSERT(S4 == S3*N3,
"invalid reshape");
428template <
int N1,
int S1,
int N2,
int S2,
int N3,
int S3,
int N4,
int S4>
436 static const int size = N1*N2*N3*N4;
442 inline int ind(
int i1,
int i2,
int i3,
int i4)
const
444 return offset+S1*i1+S2*i2+S3*i3+S4*i4;
448template <
int N1,
int N2>
452template <
int N1,
int N2,
int N3>
456template <
int N1,
int N2,
int N3,
int N4>
490 Init(ordering, scalar_size, num_comp);
500 int ind(
int scalar_idx,
int comp_idx)
const
513template <Ordering::Type Ord,
int NumComp = 0>
528 "invalid number of components");
535 MFEM_ASSERT(fes.
GetOrdering() == Ord,
"ordering mismatch");
537 "invalid number of components");
543 int ind(
int scalar_idx,
int comp_idx)
const
551 return comp_idx + (NumComp ? NumComp :
num_components) * scalar_idx;
558 (NumComp == 0 || NumComp == fes.
GetVDim()));
571 MFEM_ASSERT(fes.
GetVDim() == 1,
"invalid number of components");
576 int ind(
int scalar_idx,
int comp_idx)
const {
return scalar_idx; }
DynamicVectorLayout(Ordering::Type ordering, int scalar_size, int num_comp)
static bool Matches(const FiniteElementSpace &fes)
int ind(int scalar_idx, int comp_idx) const
int NumComponents() const
DynamicVectorLayout(const FiniteElementSpace &fes)
void Init(Ordering::Type ordering, int scalar_size, int num_comp)
Class FiniteElementSpace - responsible for providing FEM view of the mesh, mainly managing the set of...
int GetNDofs() const
Returns number of degrees of freedom. This is the number of Local Degrees of Freedom.
Ordering::Type GetOrdering() const
Return the ordering method.
int GetVDim() const
Returns the vector dimension of the finite element space.
ScalarLayout(const FiniteElementSpace &fes)
int NumComponents() const
static bool Matches(const FiniteElementSpace &fes)
int ind(int scalar_idx, int comp_idx) const
static bool Matches(const FiniteElementSpace &fes)
VectorLayout(int scalar_size_, int num_comp_=NumComp)
int NumComponents() const
int ind(int scalar_idx, int comp_idx) const
VectorLayout(const FiniteElementSpace &fes)
OffsetStridedLayout1D< M1, S1 > sub(int o1) const
OffsetStridedLayout2D< N1_1, S1, N1_2, S1 *N1_1 > split_1() const
MFEM_HOST_DEVICE int ind(int i1) const
OffsetStridedLayout1D(int offset_)
OffsetStridedLayout3D< N1_1, S1, N1_2, S1 *N1_1, N2, S2 > split_1() const
OffsetStridedLayout2D< M1, S1, M2, S2 > sub(int o1, int o2) const
OffsetStridedLayout2D(int offset_)
OffsetStridedLayout4D< N1_1, S1, N1_2, S1 *N1_1, N2_1, S2, N2_2, S2 *N2_1 > split_12() const
OffsetStridedLayout1D< N1 *N2,(S1< S2)?S1:S2 > merge_12() const
MFEM_HOST_DEVICE int ind(int i1, int i2) const
OffsetStridedLayout2D< N2, S2, N1, S1 > transpose_12() const
OffsetStridedLayout1D< N1, S1 > ind2(int i2) const
OffsetStridedLayout1D< N2, S2 > ind1(int i1) const
OffsetStridedLayout3D< N1, S1, N2_1, S2, N2_2, S2 *N2_1 > split_2() const
OffsetStridedLayout4D< N1, S1, N2_1, S2, N2_2, S2 *N2_1, N3, S3 > split_2() const
OffsetStridedLayout2D< N1 *N2, S1, N3, S3 > merge_12() const
int ind(int i1, int i2, int i3) const
OffsetStridedLayout2D< N1, S1, N2, S2 > ind3(int i3) const
OffsetStridedLayout2D< N2, S2, N3, S3 > ind1(int i1) const
OffsetStridedLayout3D(int offset_)
OffsetStridedLayout2D< N1, S1, N3, S3 > ind2(int i2) const
OffsetStridedLayout2D< N1, S1, N2 *N3, S2 > merge_23() const
OffsetStridedLayout4D< N1_1, S1, N1_2, S1 *N1_1, N2, S2, N3, S3 > split_1() const
int ind(int i1, int i2, int i3, int i4) const
OffsetStridedLayout4D(int offset_)
static MFEM_HOST_DEVICE int ind(int i1)
static StridedLayout2D< N1_1, S1, N1_2, S1 *N1_1 > split_1()
static OffsetStridedLayout1D< M1, S1 > sub(int o1)
static OffsetStridedLayout1D< N1, S1 > ind2(int i2)
static StridedLayout1D< N1 *N2,(S1< S2)?S1:S2 > merge_12()
static MFEM_HOST_DEVICE int ind(int i1, int i2)
static StridedLayout2D< N2, S2, N1, S1 > transpose_12()
static OffsetStridedLayout1D< N2, S2 > ind1(int i1)
static StridedLayout4D< N1_1, S1, N1_2, S1 *N1_1, N2_1, S2, N2_2, S2 *N2_1 > split_12()
static StridedLayout3D< N1_1, S1, N1_2, S1 *N1_1, N2, S2 > split_1()
static OffsetStridedLayout2D< M1, S1, M2, S2 > sub(int o1, int o2)
static StridedLayout3D< N1, S1, N2_1, S2, N2_2, S2 *N2_1 > split_2()
static StridedLayout3D< N3, S3, N2, S2, N1, S1 > transpose_13()
static StridedLayout3D< N2, S2, N1, S1, N3, S3 > transpose_12()
static OffsetStridedLayout2D< N1, S1, N2, S2 > ind3(int i3)
static StridedLayout4D< N1, S1, N2_1, S2, N2_2, S2 *N2_1, N3, S3 > split_2()
static StridedLayout2D< N1, S1, N2 *N3, S2 > merge_23()
static StridedLayout2D< N1 *N2, S1, N3, S3 > merge_12()
static int ind(int i1, int i2, int i3)
static StridedLayout4D< N1, S1, N2, S2, N3_1, S3, N3_2, S3 *N3_1 > split_3()
static StridedLayout3D< N1, S1, N3, S3, N2, S2 > transpose_23()
static StridedLayout4D< N1_1, S1, N1_2, S1 *N1_1, N2, S2, N3, S3 > split_1()
static OffsetStridedLayout2D< N2, S2, N3, S3 > ind1(int i1)
static OffsetStridedLayout2D< N1, S1, N3, S3 > ind2(int i2)
static OffsetStridedLayout2D< N2, S2, N3, S3 > ind14(int i1, int i4)
static int ind(int i1, int i2, int i3, int i4)
static StridedLayout3D< N1, S1, N2, S2, N3 *N4, S3 > merge_34()
static OffsetStridedLayout2D< N1, S1, N4, S4 > ind23(int i2, int i3)
static OffsetStridedLayout3D< N1, S1, N2, S2, N3, S3 > ind4(int i4)
static StridedLayout3D< N1 *N2, S1, N3, S3, N4, S4 > merge_12()