MFEM v4.9.0
Finite element discretization library
Loading...
Searching...
No Matches
kernel_dispatch.hpp
Go to the documentation of this file.
1// Copyright (c) 2010-2025, Lawrence Livermore National Security, LLC. Produced
2// at the Lawrence Livermore National Laboratory. All Rights reserved. See files
3// LICENSE and NOTICE for details. LLNL-CODE-806117.
4//
5// This file is part of the MFEM library. For more information and source code
6// availability visit https://mfem.org.
7//
8// MFEM is free software; you can redistribute it and/or modify it under the
9// terms of the BSD-3 license. We welcome feedback and contributions, see file
10// CONTRIBUTING.md for details.
11
12#ifndef MFEM_KERNEL_DISPATCH_HPP
13#define MFEM_KERNEL_DISPATCH_HPP
14
15#include "../config/config.hpp"
16#include "kernel_reporter.hpp"
17#include <unordered_map>
18#include <tuple>
19#include <type_traits>
20#include <cstddef>
21
22namespace mfem
23{
24
25// The MFEM_REGISTER_KERNELS macro registers kernels for runtime dispatch using
26// a dispatch map.
27//
28// This creates a dispatch table (a static member variable) named @a KernelName
29// containing function points of type @a KernelType. These are followed by one
30// or two sets of parenthesized argument types.
31//
32// The first set of argument types contains the types that are used to dispatch
33// to either specialized or fallback kernels. The second set of argument types
34// can be used to further specialize the kernel without participating in
35// dispatch (a canonical example is NBZ, determining the size of the thread
36// blocks; this is required to specialize kernels for optimal performance, but
37// is not relevant for dispatch).
38//
39// After calling this macro, the user must implement the Kernel and Fallback
40// static member functions, which return pointers to the appropriate kernel
41// functions depending on the parameters.
42//
43// Specialized functions can be registered using the static AddSpecialization
44// member function.
45
46#define MFEM_EXPAND(X) X // Workaround needed for MSVC compiler
47
48#define MFEM_REGISTER_KERNELS(KernelName, KernelType, ...) \
49 MFEM_EXPAND(MFEM_EXPAND(MFEM_REGISTER_KERNELS_N(__VA_ARGS__,2,1,)) \
50 (KernelName,KernelType,__VA_ARGS__))
51
52#define MFEM_REGISTER_KERNELS_N(_1, _2, N, ...) MFEM_REGISTER_KERNELS_##N
53
54// Expands a variable length macro parameter so that multiple variable length
55// parameters can be passed to the same macro.
56#define MFEM_PARAM_LIST(...) __VA_ARGS__
57
58// Version of MFEM_REGISTER_KERNELS without any "optional" (non-dispatch)
59// parameters.
60#define MFEM_REGISTER_KERNELS_1(KernelName, KernelType, Params) \
61 MFEM_REGISTER_KERNELS_(KernelName, KernelType, Params, (), Params)
62
63// Version of MFEM_REGISTER_KERNELS without any optional (non-dispatch)
64// parameters (e.g. NBZ).
65#define MFEM_REGISTER_KERNELS_2(KernelName, KernelType, Params, OptParams) \
66 MFEM_REGISTER_KERNELS_(KernelName, KernelType, Params, OptParams, \
67 (MFEM_PARAM_LIST Params, MFEM_PARAM_LIST OptParams))
68
69// P1 are the parameters, P2 are the optional (non-dispatch parameters), and P3
70// is the concatenation of P1 and P2. We need to pass it as a separate argument
71// to avoid a trailing comma in the case that P2 is empty.
72#define MFEM_REGISTER_KERNELS_(KernelName, KernelType, P1, P2, P3) \
73 class KernelName \
74 : public ::mfem::KernelDispatchTable< \
75 KernelName, KernelType, \
76 ::mfem::internal::KernelTypeList<MFEM_PARAM_LIST P1>, \
77 ::mfem::internal::KernelTypeList<MFEM_PARAM_LIST P2>> { \
78 public: \
79 const char *kernel_name = MFEM_KERNEL_NAME(KernelName); \
80 using KernelSignature = KernelType; \
81 template <MFEM_PARAM_LIST P3> static KernelSignature Kernel(); \
82 static MFEM_EXPORT KernelSignature Fallback(MFEM_PARAM_LIST P1); \
83 static MFEM_EXPORT KernelName &Get() { \
84 static KernelName table; \
85 return table; \
86 } \
87 }
88
89/// @brief Hashes variadic packs for which each type contained in the variadic
90/// pack has a specialization of `std::hash` available.
91///
92/// For example, packs containing int, bool, enum values, etc.
93template<typename ...KernelParameters>
95{
96private:
97 template<int N>
98 size_t operator()(std::tuple<KernelParameters...> value) const { return 0; }
99
100 // The hashing formula here is taken directly from the Boost library, with
101 // the magic number 0x9e3779b9 chosen to minimize hashing collisions.
102 template<std::size_t N, typename THead, typename... TTail>
103 size_t operator()(std::tuple<KernelParameters...> value) const
104 {
105 constexpr int Index = N - sizeof...(TTail) - 1;
106 auto lhs_hash = std::hash<THead>()(std::get<Index>(value));
107 auto rhs_hash = operator()<N, TTail...>(value);
108 return lhs_hash^(rhs_hash + 0x9e3779b9 + (lhs_hash<<6) + (lhs_hash>>2));
109 }
110public:
111 /// Returns the hash of the given @a value.
112 size_t operator()(std::tuple<KernelParameters...> value) const
113 {
114 return operator()<sizeof...(KernelParameters),KernelParameters...>(value);
115 }
116};
117
118namespace internal { template<typename... Types> struct KernelTypeList { }; }
119
120template<typename... T> class KernelDispatchTable { };
121
122template <typename Kernels,
123 typename Signature,
124 typename... Params,
125 typename... OptParams>
127 Signature,
128 internal::KernelTypeList<Params...>,
129 internal::KernelTypeList<OptParams...>>
130{
131 using TableType = std::unordered_map<std::tuple<Params...>,
132 Signature, KernelDispatchKeyHash<Params...>>;
133 TableType table;
134
135 /// @brief Call function @a f with arguments @a args (perfect forwaring).
136 ///
137 /// Only valid when the function @a f is not a member function.
138 template <typename F, typename... Args,
139 typename std::enable_if<std::is_pointer<F>::value,bool>::type=true>
140 static void Invoke(F f, Args&&... args)
141 {
142 f(std::forward<Args>(args)...);
143 }
144
145 /// @brief Calls member function @a f on object @a t with arguments @a args
146 /// (perfect forwarding).
147 ///
148 /// Only valid when @a f is a member function of class @a T.
149 template <typename F, typename T, typename... Args,
150 typename std::enable_if<
151 std::is_member_function_pointer<F>::value,bool>::type=true>
152 static void Invoke(F f, T&& t, Args&&... args)
153 {
154 (t.*f)(std::forward<Args>(args)...);
155 }
156
157public:
158 /// @brief Run the kernel with the given dispatch parameters and arguments.
159 ///
160 /// If a compile-time specialized version of the kernel with the given
161 /// parameters has been registered, it will be called. Otherwise, the
162 /// fallback kernel will be called.
163 ///
164 /// If the kernel is a member function, then the first argument after @a
165 /// params should be the object on which it is called.
166 template<typename... Args>
167 static void Run(Params... params, Args&&... args)
168 {
169 const auto &table = Kernels::Get().table;
170 const std::tuple<Params...> key = std::make_tuple(params...);
171 const auto it = table.find(key);
172 if (it != table.end())
173 {
174 Invoke(it->second, std::forward<Args>(args)...);
175 }
176 else
177 {
178 KernelReporter::ReportFallback(Kernels::Get().kernel_name, params...);
179 Invoke(Kernels::Fallback(params...), std::forward<Args>(args)...);
180 }
181 }
182
183 /// Register a specialized kernel for dispatch.
184 template <Params... PARAMS>
186 {
187 // Version without optional parameters
188 static void Add()
189 {
190 std::tuple<Params...> param_tuple(PARAMS...);
191 Kernels::Get().table[param_tuple] =
192 Kernels:: template Kernel<PARAMS..., OptParams{}...>();
193 };
194 // Version with optional parameters
195 template <OptParams... OPT_PARAMS>
196 struct Opt
197 {
198 static void Add()
199 {
200 std::tuple<Params...> param_tuple(PARAMS...);
201 Kernels::Get().table[param_tuple] =
202 Kernels:: template Kernel<PARAMS..., OPT_PARAMS...>();
203 }
204 };
205 };
206
207 /// Return the dispatch map table
208 static const TableType &GetDispatchTable()
209 {
210 return Kernels::Get().table;
211 }
212};
213
214}
215
216#endif
static void Run(Params... params, Args &&... args)
Run the kernel with the given dispatch parameters and arguments.
static void ReportFallback(const std::string &kernel_name, Params &&... params)
Report the fallback kernel with given parameters.
SchrodingerBaseKernels< ParMesh, ParFiniteElementSpace, ParComplexGridFunction, ParGridFunction, ParBilinearForm, ParMixedBilinearForm, ParLinearForm > Kernels
std::function< real_t(const Vector &)> f(real_t mass_coeff)
Definition lor_mms.hpp:30
Hashes variadic packs for which each type contained in the variadic pack has a specialization of std:...
size_t operator()(std::tuple< KernelParameters... > value) const
Returns the hash of the given value.
Base class for Schrodinger solver kernels.