MFEM v4.9.0
Finite element discretization library
Loading...
Searching...
No Matches
dual.hpp
Go to the documentation of this file.
1// Copyright (c) 2010-2025, Lawrence Livermore National Security, LLC. Produced
2// at the Lawrence Livermore National Laboratory. All Rights reserved. See files
3// LICENSE and NOTICE for details. LLNL-CODE-806117.
4//
5// This file is part of the MFEM library. For more information and source code
6// availability visit https://mfem.org.
7//
8// MFEM is free software; you can redistribute it and/or modify it under the
9// terms of the BSD-3 license. We welcome feedback and contributions, see file
10// CONTRIBUTING.md for details.
11
12/**
13 * @file dual.hpp
14 *
15 * @brief This file contains the declaration of a dual number class
16 */
17
18#pragma once
19
20#include <type_traits> // for is_arithmetic
21#include <cmath>
23
24namespace mfem
25{
26namespace future
27{
28
29/**
30 * @brief Dual number struct (value plus gradient)
31 * @tparam gradient_type The type of the gradient (should support addition,
32 * scalar multiplication/division, and unary negation operators)
33 */
34template <typename value_type, typename gradient_type>
35struct dual
36{
37 /// the actual numerical value
38 value_type value;
39 /// the partial derivatives of value w.r.t. some other quantity
40 gradient_type gradient;
41
42 /** @brief assignment of a real_t to a value of a dual. Promotes a real_t to
43 * a dual with a zero gradient value. */
44 MFEM_HOST_DEVICE
46 {
47 value = a;
48 gradient = {};
49 return *this;
50 }
51};
52
53/** @brief class for checking if a type is a dual number or not */
54template <typename T>
56{
57 /// whether or not type T is a dual number
58 static constexpr bool value = false;
59};
60
61/** @brief class for checking if a type is a dual number or not */
62template <typename value_type, typename gradient_type>
63struct is_dual_number<dual<value_type, gradient_type> >
64{
65 static constexpr bool value = true; ///< whether or not type T is a dual number
66};
67
68/** @brief addition of a dual number and a non-dual number */
69template <typename other_type, typename value_type, typename gradient_type,
70 typename = typename std::enable_if<
71 std::is_arithmetic<other_type>::value ||
73MFEM_HOST_DEVICE
76{
77 return {a.value + b, a.gradient};
78}
79
80// C++17 version of the above
81//
82// template <typename value_type, typename gradient_type>
83// constexpr auto operator+(dual<value_type, gradient_type> a, value_type b)
84// {
85// return dual{a.value + b, a.gradient};
86// }
87
88/** @brief addition of a dual number and a non-dual number */
89template <typename other_type, typename value_type, typename gradient_type,
90 typename = typename std::enable_if<
91 std::is_arithmetic<other_type>::value ||
93MFEM_HOST_DEVICE
94constexpr auto operator+(other_type a,
96{
97 return {a + b.value, b.gradient};
98}
99
100/** @brief addition of two dual numbers */
101template <typename value_type_a, typename gradient_type_a, typename value_type_b, typename gradient_type_b>
102MFEM_HOST_DEVICE
104 dual<value_type_b, gradient_type_b> b) -> dual<decltype(a.value + b.value),
105 decltype(a.gradient + b.gradient)>
106{
107 return {a.value + b.value, a.gradient + b.gradient};
108}
109
110/** @brief unary negation of a dual number */
111template <typename value_type, typename gradient_type> MFEM_HOST_DEVICE
113dual<value_type, gradient_type>
114{
115 return {-x.value, -x.gradient};
116}
117
118/** @brief subtraction of a non-dual number from a dual number */
119template <typename value_type, typename gradient_type> MFEM_HOST_DEVICE
122{
123 return {a.value - b, a.gradient};
124}
125
126/** @brief subtraction of a dual number from a non-dual number */
127template <typename value_type, typename gradient_type> MFEM_HOST_DEVICE
128constexpr auto operator-(real_t a,
130{
131 return {a - b.value, -b.gradient};
132}
133
134/** @brief subtraction of two dual numbers */
135template <typename value_type_a, typename gradient_type_a, typename value_type_b, typename gradient_type_b>
136MFEM_HOST_DEVICE
138 dual<value_type_b, gradient_type_b> b) -> dual<decltype(a.value - b.value),
139 decltype(a.gradient - b.gradient)>
140{
141 return {a.value - b.value, a.gradient - b.gradient};
142}
143
144/** @brief multiplication of a dual number and a non-dual number */
145template <typename value_type, typename gradient_type> MFEM_HOST_DEVICE
147 real_t b) -> dual<decltype(a.value * b), decltype(a.gradient * b)>
148{
149 return {a.value * b, a.gradient * b};
150}
151
152/** @brief multiplication of a dual number and a non-dual number */
153template <typename value_type, typename gradient_type> MFEM_HOST_DEVICE
154constexpr auto operator*(real_t a,
156dual<decltype(a * b.value), decltype(a * b.gradient)>
157{
158 return {a * b.value, a * b.gradient};
159}
160
161/** @brief multiplication of two dual numbers */
162template <typename value_type_a, typename gradient_type_a, typename value_type_b, typename gradient_type_b>
163MFEM_HOST_DEVICE
165 dual<value_type_b, gradient_type_b> b) -> dual<decltype(a.value * b.value),
166 decltype(b.value * a.gradient + a.value * b.gradient)>
167{
168 return {a.value * b.value, b.value * a.gradient + a.value * b.gradient};
169}
170
171/** @brief division of a dual number by a non-dual number */
172template <typename value_type, typename gradient_type> MFEM_HOST_DEVICE
174 real_t b) -> dual<decltype(a.value / b), decltype(a.gradient / b)>
175{
176 return {a.value / b, a.gradient / b};
177}
178
179/** @brief division of a non-dual number by a dual number */
180template <typename value_type, typename gradient_type> MFEM_HOST_DEVICE
181constexpr auto operator/(real_t a,
182 const dual<value_type, gradient_type>& b) -> dual<decltype(a / b.value),
183 decltype(-(a / (b.value * b.value)) * b.gradient)>
184{
185 return {a / b.value, -(a / (b.value * b.value)) * b.gradient};
186}
187
188/** @brief division of two dual numbers */
189template <typename value_type_a, typename gradient_type_a, typename value_type_b, typename gradient_type_b>
190MFEM_HOST_DEVICE
192 dual<value_type_b, gradient_type_b> b) -> dual<decltype(a.value / b.value),
193 decltype((a.gradient / b.value) -
194 (a.value * b.gradient) /
195 (b.value * b.value))>
196{
197 return {a.value / b.value, (a.gradient / b.value) - (a.value * b.gradient) / (b.value * b.value)};
198}
199
200/**
201 * @brief Generates const + non-const overloads for a binary comparison operator
202 * Comparisons are conducted against the "value" part of the dual number
203 * @param[in] x The comparison operator to overload
204 */
205#define mfem_binary_comparator_overload(x) \
206 template <typename value_type, typename gradient_type> \
207 MFEM_HOST_DEVICE constexpr bool operator x( \
208 const dual<value_type, gradient_type>& a, \
209 real_t b) \
210 { \
211 return a.value x b; \
212 } \
213 \
214 template <typename value_type, typename gradient_type> \
215 MFEM_HOST_DEVICE constexpr bool operator x( \
216 real_t a, \
217 const dual<value_type, gradient_type>& b) \
218 { \
219 return a x b.value; \
220 } \
221 \
222 template <typename value_type_a, \
223 typename gradient_type_a, \
224 typename value_type_b, \
225 typename gradient_type_b> MFEM_HOST_DEVICE \
226 constexpr bool operator x( \
227 const dual<value_type_a, gradient_type_a>& a, \
228 const dual<value_type_b, gradient_type_b>& b) \
229 { \
230 return a.value x b.value; \
231 }
232
233mfem_binary_comparator_overload(<) ///< implement operator< for dual numbers
234mfem_binary_comparator_overload(<=) ///< implement operator<= for dual numbers
235mfem_binary_comparator_overload(==) ///< implement operator== for dual numbers
236mfem_binary_comparator_overload(>=) ///< implement operator>= for dual numbers
237mfem_binary_comparator_overload(>) ///< implement operator> for dual numbers
238
239#undef mfem_binary_comparator_overload
240
241/** @brief compound assignment (+) for dual numbers */
242template <typename value_type, typename gradient_type> MFEM_HOST_DEVICE
245{
246 a.value += b.value;
247 a.gradient += b.gradient;
248 return a;
249}
250
251/** @brief compound assignment (-) for dual numbers */
252template <typename value_type, typename gradient_type> MFEM_HOST_DEVICE
255{
256 a.value -= b.value;
257 a.gradient -= b.gradient;
258 return a;
259}
260
261/** @brief compound assignment (+) for dual numbers with `real_t` righthand side */
262template <typename value_type, typename gradient_type> MFEM_HOST_DEVICE
269
270/** @brief compound assignment (-) for dual numbers with `real_t` righthand side */
271template <typename value_type, typename gradient_type> MFEM_HOST_DEVICE
278
279/** @brief implementation of absolute value function for dual numbers */
280template <typename value_type, typename gradient_type> MFEM_HOST_DEVICE
285
286/** @brief implementation of square root for dual numbers */
287template <typename value_type, typename gradient_type> MFEM_HOST_DEVICE
289{
290 using std::sqrt;
291 return {sqrt(x.value), x.gradient / (2 * sqrt(x.value))};
292}
293
294/** @brief implementation of cosine for dual numbers */
295template <typename value_type, typename gradient_type> MFEM_HOST_DEVICE
297{
298 using std::cos;
299 using std::sin;
300 return {cos(a.value), -a.gradient * sin(a.value)};
301}
302
303/** @brief implementation of sine for dual numbers */
304template <typename value_type, typename gradient_type> MFEM_HOST_DEVICE
306{
307 using std::sin;
308 using std::cos;
309 return {sin(a.value), a.gradient * cos(a.value)};
310}
311
312/** @brief implementation of sinh for dual numbers */
313template <typename value_type, typename gradient_type> MFEM_HOST_DEVICE
315{
316 using std::sinh;
317 using std::cosh;
318 return {sinh(a.value), a.gradient * cosh(a.value)};
319}
320
321/** @brief implementation of acos for dual numbers */
322template <typename value_type, typename gradient_type> MFEM_HOST_DEVICE
324{
325 using std::sqrt;
326 using std::acos;
327 return {acos(a.value), -a.gradient / sqrt(value_type{1} - a.value * a.value)};
328}
329
330/** @brief implementation of asin for dual numbers */
331template <typename value_type, typename gradient_type> MFEM_HOST_DEVICE
333{
334 using std::sqrt;
335 using std::asin;
336 return {asin(a.value), a.gradient / sqrt(value_type{1} - a.value * a.value)};
337}
338
339/** @brief implementation of tan for dual numbers */
340template <typename value_type, typename gradient_type> MFEM_HOST_DEVICE
342{
343 using std::tan;
344 value_type f = tan(a.value);
345 return {f, a.gradient * (value_type{1} + f * f)};
346}
347
348/** @brief implementation of atan for dual numbers */
349template <typename value_type, typename gradient_type> MFEM_HOST_DEVICE
351{
352 using std::atan;
353 return {atan(a.value), a.gradient / (value_type{1} + a.value * a.value)};
354}
355
356/** @brief implementation of exponential function for dual numbers */
357template <typename value_type, typename gradient_type> MFEM_HOST_DEVICE
359{
360 using std::exp;
361 return {exp(a.value), exp(a.value) * a.gradient};
362}
363
364/** @brief implementation of the natural logarithm function for dual numbers */
365template <typename value_type, typename gradient_type> MFEM_HOST_DEVICE
367{
368 using std::log;
369 return {log(a.value), a.gradient / a.value};
370}
371
372/** @brief implementation of `a` (dual) raised to the `b` (dual) power */
373template <typename value_type, typename gradient_type> MFEM_HOST_DEVICE
376{
377 using std::log;
378 using std::pow;
379 value_type value = pow(a.value, b.value);
380 return {value, value * (a.gradient * (b.value / a.value) + b.gradient * log(a.value))};
381}
382
383/** @brief implementation of `a` (non-dual) raised to the `b` (dual) power */
384template <typename value_type, typename gradient_type> MFEM_HOST_DEVICE
386{
387 using std::pow;
388 using std::log;
389 value_type value = pow(a, b.value);
390 return {value, value * b.gradient * log(a)};
391}
392
393/** @brief implementation of `a` (non-dual) raised to the `b` (non-dual) power */
394template <typename value_type > MFEM_HOST_DEVICE
395value_type pow(value_type a, value_type b)
396{
397 using std::pow;
398 return pow(a, b);
399}
400
401/** @brief implementation of `a` (dual) raised to the `b` (non-dual) power */
402template <typename value_type, typename gradient_type> MFEM_HOST_DEVICE
404{
405 using std::pow;
406 value_type value = pow(a.value, b);
407 return {value, value * a.gradient * b / a.value};
408}
409
410/** @brief overload of operator<< for `dual` to work with work with standard output streams */
411template <typename value_type, typename gradient_type, int... n>
412std::ostream& operator<<(std::ostream& os, dual<value_type, gradient_type> A)
413{
414 os << '(' << A.value << ' ' << A.gradient << ')';
415 return os;
416}
417
418/** @brief promote a value to a dual number of the appropriate type */
419MFEM_HOST_DEVICE constexpr dual<real_t, real_t> make_dual(real_t x) { return {x, 1.0}; }
420
421/** @brief return the "value" part from a given type. For non-dual types, this is just the identity function */
422template <typename T> MFEM_HOST_DEVICE T get_value(const T& arg) { return arg; }
423
424/** @brief return the "value" part from a dual number type */
425template <typename value_type, typename gradient_type>
426MFEM_HOST_DEVICE gradient_type get_value(dual<value_type, gradient_type> arg)
427{
428 return arg.value;
429}
430
431/** @brief return the "gradient" part from a dual number type */
432template <typename value_type, typename gradient_type>
433MFEM_HOST_DEVICE gradient_type get_gradient(dual<value_type, gradient_type> arg)
434{
435 return arg.gradient;
436}
437
438} // namespace future
439} // namespace mfem
real_t b
Definition lissajous.cpp:42
real_t a
Definition lissajous.cpp:41
MFEM_HOST_DEVICE constexpr auto operator-(dual< value_type, gradient_type > x) -> dual< value_type, gradient_type >
unary negation of a dual number
Definition dual.hpp:112
MFEM_HOST_DEVICE constexpr auto operator+(dual< value_type, gradient_type > a, other_type b) -> dual< value_type, gradient_type >
addition of a dual number and a non-dual number
Definition dual.hpp:74
MFEM_HOST_DEVICE dual< value_type, gradient_type > tan(dual< value_type, gradient_type > a)
implementation of tan for dual numbers
Definition dual.hpp:341
MFEM_HOST_DEVICE T get_value(const T &arg)
return the "value" part from a given type. For non-dual types, this is just the identity function
Definition dual.hpp:422
gradient_type MFEM_HOST_DEVICE dual< value_type, gradient_type > & operator+=(dual< value_type, gradient_type > &a, const dual< value_type, gradient_type > &b)
Definition dual.hpp:243
MFEM_HOST_DEVICE dual< value_type, gradient_type > cos(dual< value_type, gradient_type > a)
implementation of cosine for dual numbers
Definition dual.hpp:296
MFEM_HOST_DEVICE dual< value_type, gradient_type > acos(dual< value_type, gradient_type > a)
implementation of acos for dual numbers
Definition dual.hpp:323
MFEM_HOST_DEVICE dual< value_type, gradient_type > abs(dual< value_type, gradient_type > x)
implementation of absolute value function for dual numbers
Definition dual.hpp:281
MFEM_HOST_DEVICE dual< value_type, gradient_type > sinh(dual< value_type, gradient_type > a)
implementation of sinh for dual numbers
Definition dual.hpp:314
MFEM_HOST_DEVICE dual< value_type, gradient_type > sin(dual< value_type, gradient_type > a)
implementation of sine for dual numbers
Definition dual.hpp:305
MFEM_HOST_DEVICE dual< value_type, gradient_type > & operator-=(dual< value_type, gradient_type > &a, const dual< value_type, gradient_type > &b)
compound assignment (-) for dual numbers
Definition dual.hpp:253
std::ostream & operator<<(std::ostream &os, dual< value_type, gradient_type > A)
overload of operator<< for dual to work with work with standard output streams
Definition dual.hpp:412
MFEM_HOST_DEVICE constexpr auto type(const tuple< T... > &values)
a function intended to be used for extracting the ith type from a tuple.
Definition tuple.hpp:347
MFEM_HOST_DEVICE dual< value_type, gradient_type > log(dual< value_type, gradient_type > a)
implementation of the natural logarithm function for dual numbers
Definition dual.hpp:366
MFEM_HOST_DEVICE dual< value_type, gradient_type > exp(dual< value_type, gradient_type > a)
implementation of exponential function for dual numbers
Definition dual.hpp:358
MFEM_HOST_DEVICE dual< value_type, gradient_type > sqrt(dual< value_type, gradient_type > x)
implementation of square root for dual numbers
Definition dual.hpp:288
MFEM_HOST_DEVICE dual< value_type, gradient_type > asin(dual< value_type, gradient_type > a)
implementation of asin for dual numbers
Definition dual.hpp:332
MFEM_HOST_DEVICE constexpr dual< real_t, real_t > make_dual(real_t x)
promote a value to a dual number of the appropriate type
Definition dual.hpp:419
mfem_binary_comparator_overload(<) mfem_binary_comparator_overload(<
implement operator<= for dual numbers
MFEM_HOST_DEVICE constexpr auto operator*(const dual< value_type, gradient_type > &a, real_t b) -> dual< decltype(a.value *b), decltype(a.gradient *b)>
multiplication of a dual number and a non-dual number
Definition dual.hpp:146
MFEM_HOST_DEVICE constexpr auto operator/(const dual< value_type, gradient_type > &a, real_t b) -> dual< decltype(a.value/b), decltype(a.gradient/b)>
division of a dual number by a non-dual number
Definition dual.hpp:173
MFEM_HOST_DEVICE gradient_type get_gradient(dual< value_type, gradient_type > arg)
return the "gradient" part from a dual number type
Definition dual.hpp:433
MFEM_HOST_DEVICE dual< value_type, gradient_type > atan(dual< value_type, gradient_type > a)
implementation of atan for dual numbers
Definition dual.hpp:350
MFEM_HOST_DEVICE dual< value_type, gradient_type > pow(dual< value_type, gradient_type > a, dual< value_type, gradient_type > b)
implementation of a (dual) raised to the b (dual) power
Definition dual.hpp:374
float real_t
Definition config.hpp:46
std::function< real_t(const Vector &)> f(real_t mass_coeff)
Definition lor_mms.hpp:30
Dual number struct (value plus gradient)
Definition dual.hpp:36
MFEM_HOST_DEVICE auto operator=(real_t a) -> dual< value_type, gradient_type > &
assignment of a real_t to a value of a dual. Promotes a real_t to a dual with a zero gradient value.
Definition dual.hpp:45
gradient_type gradient
the partial derivatives of value w.r.t. some other quantity
Definition dual.hpp:40
value_type value
the actual numerical value
Definition dual.hpp:38
class for checking if a type is a dual number or not
Definition dual.hpp:56
static constexpr bool value
whether or not type T is a dual number
Definition dual.hpp:58