MFEM v4.8.0
Finite element discretization library
Loading...
Searching...
No Matches
kernel_dispatch.hpp
Go to the documentation of this file.
1// Copyright (c) 2010-2025, Lawrence Livermore National Security, LLC. Produced
2// at the Lawrence Livermore National Laboratory. All Rights reserved. See files
3// LICENSE and NOTICE for details. LLNL-CODE-806117.
4//
5// This file is part of the MFEM library. For more information and source code
6// availability visit https://mfem.org.
7//
8// MFEM is free software; you can redistribute it and/or modify it under the
9// terms of the BSD-3 license. We welcome feedback and contributions, see file
10// CONTRIBUTING.md for details.
11
12#ifndef MFEM_KERNEL_DISPATCH_HPP
13#define MFEM_KERNEL_DISPATCH_HPP
14
15#include "../config/config.hpp"
16#include "kernel_reporter.hpp"
17#include <unordered_map>
18#include <tuple>
19#include <cstddef>
20
21namespace mfem
22{
23
24// The MFEM_REGISTER_KERNELS macro registers kernels for runtime dispatch using
25// a dispatch map.
26//
27// This creates a dispatch table (a static member variable) named @a KernelName
28// containing function points of type @a KernelType. These are followed by one
29// or two sets of parenthesized argument types.
30//
31// The first set of argument types contains the types that are used to dispatch
32// to either specialized or fallback kernels. The second set of argument types
33// can be used to further specialize the kernel without participating in
34// dispatch (a canonical example is NBZ, determining the size of the thread
35// blocks; this is required to specialize kernels for optimal performance, but
36// is not relevant for dispatch).
37//
38// After calling this macro, the user must implement the Kernel and Fallback
39// static member functions, which return pointers to the appropriate kernel
40// functions depending on the parameters.
41//
42// Specialized functions can be registered using the static AddSpecialization
43// member function.
44
45#define MFEM_EXPAND(X) X // Workaround needed for MSVC compiler
46
47#define MFEM_REGISTER_KERNELS(KernelName, KernelType, ...) \
48 MFEM_EXPAND(MFEM_EXPAND(MFEM_REGISTER_KERNELS_N(__VA_ARGS__,2,1,)) \
49 (KernelName,KernelType,__VA_ARGS__))
50
51#define MFEM_REGISTER_KERNELS_N(_1, _2, N, ...) MFEM_REGISTER_KERNELS_##N
52
53// Expands a variable length macro parameter so that multiple variable length
54// parameters can be passed to the same macro.
55#define MFEM_PARAM_LIST(...) __VA_ARGS__
56
57// Version of MFEM_REGISTER_KERNELS without any "optional" (non-dispatch)
58// parameters.
59#define MFEM_REGISTER_KERNELS_1(KernelName, KernelType, Params) \
60 MFEM_REGISTER_KERNELS_(KernelName, KernelType, Params, (), Params)
61
62// Version of MFEM_REGISTER_KERNELS without any optional (non-dispatch)
63// parameters (e.g. NBZ).
64#define MFEM_REGISTER_KERNELS_2(KernelName, KernelType, Params, OptParams) \
65 MFEM_REGISTER_KERNELS_(KernelName, KernelType, Params, OptParams, \
66 (MFEM_PARAM_LIST Params, MFEM_PARAM_LIST OptParams))
67
68// P1 are the parameters, P2 are the optional (non-dispatch parameters), and P3
69// is the concatenation of P1 and P2. We need to pass it as a separate argument
70// to avoid a trailing comma in the case that P2 is empty.
71#define MFEM_REGISTER_KERNELS_(KernelName, KernelType, P1, P2, P3) \
72 class KernelName \
73 : public ::mfem::KernelDispatchTable< \
74 KernelName, KernelType, \
75 ::mfem::internal::KernelTypeList<MFEM_PARAM_LIST P1>, \
76 ::mfem::internal::KernelTypeList<MFEM_PARAM_LIST P2>> { \
77 public: \
78 const char *kernel_name = MFEM_KERNEL_NAME(KernelName); \
79 using KernelSignature = KernelType; \
80 template <MFEM_PARAM_LIST P3> static MFEM_EXPORT KernelSignature Kernel(); \
81 static MFEM_EXPORT KernelSignature Fallback(MFEM_PARAM_LIST P1); \
82 static MFEM_EXPORT KernelName &Get() { \
83 static KernelName table; \
84 return table; \
85 } \
86 }
87
88/// @brief Hashes variadic packs for which each type contained in the variadic
89/// pack has a specialization of `std::hash` available.
90///
91/// For example, packs containing int, bool, enum values, etc.
92template<typename ...KernelParameters>
94{
95private:
96 template<int N>
97 size_t operator()(std::tuple<KernelParameters...> value) const { return 0; }
98
99 // The hashing formula here is taken directly from the Boost library, with
100 // the magic number 0x9e3779b9 chosen to minimize hashing collisions.
101 template<std::size_t N, typename THead, typename... TTail>
102 size_t operator()(std::tuple<KernelParameters...> value) const
103 {
104 constexpr int Index = N - sizeof...(TTail) - 1;
105 auto lhs_hash = std::hash<THead>()(std::get<Index>(value));
106 auto rhs_hash = operator()<N, TTail...>(value);
107 return lhs_hash^(rhs_hash + 0x9e3779b9 + (lhs_hash<<6) + (lhs_hash>>2));
108 }
109public:
110 /// Returns the hash of the given @a value.
111 size_t operator()(std::tuple<KernelParameters...> value) const
112 {
113 return operator()<sizeof...(KernelParameters),KernelParameters...>(value);
114 }
115};
116
117namespace internal { template<typename... Types> struct KernelTypeList { }; }
118
119template<typename... T> class KernelDispatchTable { };
120
121template <typename Kernels,
122 typename Signature,
123 typename... Params,
124 typename... OptParams>
126 Signature,
127 internal::KernelTypeList<Params...>,
128 internal::KernelTypeList<OptParams...>>
129{
130 using TableType = std::unordered_map<std::tuple<Params...>,
131 Signature, KernelDispatchKeyHash<Params...>>;
132 TableType table;
133
134public:
135 /// @brief Run the kernel with the given dispatch parameters and arguments.
136 ///
137 /// If a compile-time specialized version of the kernel with the given
138 /// parameters has been registered, it will be called. Otherwise, the
139 /// fallback kernel will be called.
140 template<typename... Args>
141 static void Run(Params... params, Args&&... args)
142 {
143 const auto &table = Kernels::Get().table;
144 const std::tuple<Params...> key = std::make_tuple(params...);
145 const auto it = table.find(key);
146 if (it != table.end())
147 {
148 it->second(std::forward<Args>(args)...);
149 }
150 else
151 {
152 KernelReporter::ReportFallback(Kernels::Get().kernel_name, params...);
153 Kernels::Fallback(params...)(std::forward<Args>(args)...);
154 }
155 }
156
157 /// Register a specialized kernel for dispatch.
158 template <Params... PARAMS>
160 {
161 // Version without optional parameters
162 static void Add()
163 {
164 std::tuple<Params...> param_tuple(PARAMS...);
165 Kernels::Get().table[param_tuple] =
166 Kernels:: template Kernel<PARAMS..., OptParams{}...>();
167 };
168 // Version with optional parameters
169 template <OptParams... OPT_PARAMS>
170 struct Opt
171 {
172 static void Add()
173 {
174 std::tuple<Params...> param_tuple(PARAMS...);
175 Kernels::Get().table[param_tuple] =
176 Kernels:: template Kernel<PARAMS..., OPT_PARAMS...>();
177 }
178 };
179 };
180
181 /// Return the dispatch map table
182 static const TableType &GetDispatchTable()
183 {
184 return Kernels::Get().table;
185 }
186};
187
188}
189
190#endif
static void Run(Params... params, Args &&... args)
Run the kernel with the given dispatch parameters and arguments.
static void ReportFallback(const std::string &kernel_name, Params &&... params)
Report the fallback kernel with given parameters.
Hashes variadic packs for which each type contained in the variadic pack has a specialization of std:...
size_t operator()(std::tuple< KernelParameters... > value) const
Returns the hash of the given value.