12 #ifndef MFEM_TMOP_PA_HPP 13 #define MFEM_TMOP_PA_HPP 15 #include "../../config/config.hpp" 16 #include "../../linalg/dtensor.hpp" 18 #include "../kernels.hpp" 20 #include <unordered_map> 35 map.template Emplace<N-1>();
50 using Key_t =
typename K::Key_t;
51 using Return_t =
typename K::Return_t;
52 using Kernel_t =
typename K::Kernel_t;
53 using map_t = std::unordered_map<Key_t, Kernel_t>;
60 bool Find(
const Key_t
id) {
return map.find(
id) != map.end(); }
62 Kernel_t
At(
const Key_t
id) {
return map.at(
id); }
66 constexpr Key_t key = K::template GetKey<N>();
67 constexpr Kernel_t ker = K::template GetKer<key>();
68 map.emplace(key, ker);
156 #define MFEM_REGISTER_TMOP_KERNELS(return_t, kernel, ...) \ 157 template<int T_D1D = 0, int T_Q1D = 0, int T_MAX = 0> \ 158 return_t kernel(__VA_ARGS__);\ 159 typedef return_t (*kernel##_p)(__VA_ARGS__);\ 160 struct K##kernel##_T {\ 161 static const int N = 14;\ 163 using Return_t = return_t;\ 164 using Kernel_t = kernel##_p;\ 165 template<Key_t I> static constexpr Key_t GetKey() noexcept { return \ 166 I==0 ? 0x22 : I==1 ? 0x23 : I==2 ? 0x24 : I==3 ? 0x25 : I==4 ? 0x26 :\ 167 I==5 ? 0x33 : I==6 ? 0x34 : I==7 ? 0x35 : I==8 ? 0x36 :\ 168 I==9 ? 0x44 : I==10 ? 0x45 : I==11 ? 0x46 :\ 169 I==12 ? 0x55 : I==13 ? 0x56 : 0; }\ 170 template<Key_t K> static constexpr Kernel_t GetKer() noexcept\ 171 { return &kernel<(K>>4)&0xF, K&0xF>; }\ 173 static kernels::KernelMap<K##kernel##_T> K##kernel;\ 174 template<int T_D1D, int T_Q1D, int T_MAX> return_t kernel(__VA_ARGS__) 180 #define MFEM_LAUNCH_TMOP_KERNEL(kernel, id, ...)\ 181 if (K##kernel.Find(id)) { return K##kernel.At(id)(__VA_ARGS__,0,0); }\ 183 constexpr int T_MAX = 4;\ 184 const int d1d = (id>>4)&0xF, q1d = id&0xF;\ 185 MFEM_VERIFY(d1d <= MAX_D1D && q1d <= MAX_Q1D, "Max size error!");\ 186 return kernel<0,0,T_MAX>(__VA_ARGS__,d1d,q1d); } 192 #endif // MFEM_TMOP_PA_HPP static void Fill(KernelMap< K > &map)
KernelMap class which creates an unordered_map of the Keys/Kernels.
static void Fill(KernelMap< K > &map)
Kernel_t At(const Key_t id)
bool Find(const Key_t id)