12 #ifndef MFEM_TMOP_PA_HPP
13 #define MFEM_TMOP_PA_HPP
15 #include "../../config/config.hpp"
16 #include "../../linalg/dtensor.hpp"
18 #include "../kernels.hpp"
20 #include <unordered_map>
35 map.template Emplace<N-1>();
50 using Key_t =
typename K::Key_t;
51 using Return_t =
typename K::Return_t;
52 using Kernel_t =
typename K::Kernel_t;
53 using map_t = std::unordered_map<Key_t, Kernel_t>;
60 bool Find(
const Key_t
id) {
return map.find(
id) != map.end(); }
62 Kernel_t
At(
const Key_t
id) {
return map.at(
id); }
66 constexpr Key_t key = K::template GetKey<N>();
67 constexpr Kernel_t ker = K::template GetKer<key>();
68 map.emplace(key, ker);
156 #define MFEM_REGISTER_TMOP_KERNELS(return_t, kernel, ...) \
157 template<int T_D1D = 0, int T_Q1D = 0, int T_MAX = 0> \
158 return_t kernel(__VA_ARGS__);\
159 typedef return_t (*kernel##_p)(__VA_ARGS__);\
160 struct K##kernel##_T {\
161 static const int N = 14;\
163 using Return_t = return_t;\
164 using Kernel_t = kernel##_p;\
165 template<Key_t I> static constexpr Key_t GetKey() noexcept { return \
166 I==0 ? 0x22 : I==1 ? 0x23 : I==2 ? 0x24 : I==3 ? 0x25 : I==4 ? 0x26 :\
167 I==5 ? 0x33 : I==6 ? 0x34 : I==7 ? 0x35 : I==8 ? 0x36 :\
168 I==9 ? 0x44 : I==10 ? 0x45 : I==11 ? 0x46 :\
169 I==12 ? 0x55 : I==13 ? 0x56 : 0; }\
170 template<Key_t K> static constexpr Kernel_t GetKer() noexcept\
171 { return &kernel<(K>>4)&0xF, K&0xF>; }\
173 static kernels::KernelMap<K##kernel##_T> K##kernel;\
174 template<int T_D1D, int T_Q1D, int T_MAX> return_t kernel(__VA_ARGS__)
180 #define MFEM_LAUNCH_TMOP_KERNEL(kernel, id, ...)\
181 if (K##kernel.Find(id)) { return K##kernel.At(id)(__VA_ARGS__,0,0); }\
183 constexpr int T_MAX = 4;\
184 const int d1d = (id>>4)&0xF, q1d = id&0xF;\
185 MFEM_VERIFY(d1d <= MAX_D1D && q1d <= MAX_Q1D, "Max size error!");\
186 return kernel<0,0,T_MAX>(__VA_ARGS__,d1d,q1d); }
192 #endif // MFEM_TMOP_PA_HPP
static void Fill(KernelMap< K > &map)
KernelMap class which creates an unordered_map of the Keys/Kernels.
static void Fill(KernelMap< K > &map)
Kernel_t At(const Key_t id)
bool Find(const Key_t id)