12#ifndef MFEM_KERNEL_DISPATCH_HPP
13#define MFEM_KERNEL_DISPATCH_HPP
17#include <unordered_map>
45#define MFEM_EXPAND(X) X
47#define MFEM_REGISTER_KERNELS(KernelName, KernelType, ...) \
48 MFEM_EXPAND(MFEM_EXPAND(MFEM_REGISTER_KERNELS_N(__VA_ARGS__,2,1,)) \
49 (KernelName,KernelType,__VA_ARGS__))
51#define MFEM_REGISTER_KERNELS_N(_1, _2, N, ...) MFEM_REGISTER_KERNELS_##N
55#define MFEM_PARAM_LIST(...) __VA_ARGS__
59#define MFEM_REGISTER_KERNELS_1(KernelName, KernelType, Params) \
60 MFEM_REGISTER_KERNELS_(KernelName, KernelType, Params, (), Params)
64#define MFEM_REGISTER_KERNELS_2(KernelName, KernelType, Params, OptParams) \
65 MFEM_REGISTER_KERNELS_(KernelName, KernelType, Params, OptParams, \
66 (MFEM_PARAM_LIST Params, MFEM_PARAM_LIST OptParams))
71#define MFEM_REGISTER_KERNELS_(KernelName, KernelType, P1, P2, P3) \
73 : public ::mfem::KernelDispatchTable< \
74 KernelName, KernelType, \
75 ::mfem::internal::KernelTypeList<MFEM_PARAM_LIST P1>, \
76 ::mfem::internal::KernelTypeList<MFEM_PARAM_LIST P2>> { \
78 const char *kernel_name = MFEM_KERNEL_NAME(KernelName); \
79 using KernelSignature = KernelType; \
80 template <MFEM_PARAM_LIST P3> static MFEM_EXPORT KernelSignature Kernel(); \
81 static MFEM_EXPORT KernelSignature Fallback(MFEM_PARAM_LIST P1); \
82 static MFEM_EXPORT KernelName &Get() { \
83 static KernelName table; \
92template<
typename ...KernelParameters>
97 size_t operator()(std::tuple<KernelParameters...> value)
const {
return 0; }
101 template<std::size_t N,
typename THead,
typename... TTail>
102 size_t operator()(std::tuple<KernelParameters...> value)
const
104 constexpr int Index = N -
sizeof...(TTail) - 1;
105 auto lhs_hash = std::hash<THead>()(std::get<Index>(value));
106 auto rhs_hash = operator()<N, TTail...>(value);
107 return lhs_hash^(rhs_hash + 0x9e3779b9 + (lhs_hash<<6) + (lhs_hash>>2));
111 size_t operator()(std::tuple<KernelParameters...> value)
const
113 return operator()<
sizeof...(KernelParameters),KernelParameters...>(value);
117namespace internal {
template<
typename... Types>
struct KernelTypeList { }; }
121template <
typename Kernels,
124 typename... OptParams>
127 internal::KernelTypeList<Params...>,
128 internal::KernelTypeList<OptParams...>>
130 using TableType = std::unordered_map<std::tuple<Params...>,
140 template<
typename...
Args>
141 static void Run(Params... params,
Args&&... args)
143 const auto &table = Kernels::Get().table;
144 const std::tuple<Params...> key = std::make_tuple(params...);
145 const auto it = table.find(key);
146 if (it != table.end())
148 it->second(std::forward<Args>(args)...);
153 Kernels::Fallback(params...)(std::forward<Args>(args)...);
158 template <Params... PARAMS>
164 std::tuple<Params...> param_tuple(PARAMS...);
165 Kernels::Get().table[param_tuple] =
166 Kernels:: template Kernel<PARAMS..., OptParams{}...>();
169 template <OptParams... OPT_PARAMS>
174 std::tuple<Params...> param_tuple(PARAMS...);
175 Kernels::Get().table[param_tuple] =
176 Kernels:: template Kernel<PARAMS..., OPT_PARAMS...>();
184 return Kernels::Get().table;
static void Run(Params... params, Args &&... args)
Run the kernel with the given dispatch parameters and arguments.
static const TableType & GetDispatchTable()
Return the dispatch map table.
static void ReportFallback(const std::string &kernel_name, Params &&... params)
Report the fallback kernel with given parameters.
Hashes variadic packs for which each type contained in the variadic pack has a specialization of std:...
size_t operator()(std::tuple< KernelParameters... > value) const
Returns the hash of the given value.
Register a specialized kernel for dispatch.