MFEM v4.7.0
Finite element discretization library
Loading...
Searching...
No Matches
dual.hpp
Go to the documentation of this file.
1// Copyright (c) 2010-2024, Lawrence Livermore National Security, LLC. Produced
2// at the Lawrence Livermore National Laboratory. All Rights reserved. See files
3// LICENSE and NOTICE for details. LLNL-CODE-806117.
4//
5// This file is part of the MFEM library. For more information and source code
6// availability visit https://mfem.org.
7//
8// MFEM is free software; you can redistribute it and/or modify it under the
9// terms of the BSD-3 license. We welcome feedback and contributions, see file
10// CONTRIBUTING.md for details.
11
12/**
13 * @file dual.hpp
14 *
15 * @brief This file contains the declaration of a dual number class
16 */
17
18#ifndef MFEM_INTERNAL_DUAL_HPP
19#define MFEM_INTERNAL_DUAL_HPP
20
21#include <type_traits> // for is_arithmetic
22#include <cmath>
24
25namespace mfem
26{
27namespace internal
28{
29
30/**
31 * @brief Dual number struct (value plus gradient)
32 * @tparam gradient_type The type of the gradient (should support addition,
33 * scalar multiplication/division, and unary negation operators)
34 */
35template <typename value_type, typename gradient_type>
36struct dual
37{
38 /// the actual numerical value
39 value_type value;
40 /// the partial derivatives of value w.r.t. some other quantity
41 gradient_type gradient;
42
43 /** @brief assignment of a double to a value of a dual. Promotes a double to
44 * a dual with a zero gradient value. */
45 auto operator=(real_t a) -> dual<value_type, gradient_type>&
46 {
47 value = a;
48 gradient = {};
49 return *this;
50 }
51};
52
53/** @brief class for checking if a type is a dual number or not */
54template <typename T>
55struct is_dual_number
56{
57 /// whether or not type T is a dual number
58 static constexpr bool value = false;
59};
60
61/** @brief class for checking if a type is a dual number or not */
62template <typename value_type, typename gradient_type>
63struct is_dual_number<dual<value_type, gradient_type> >
64{
65 static constexpr bool value = true; ///< whether or not type T is a dual number
66};
67
68/** @brief addition of a dual number and a non-dual number */
69template <typename other_type, typename value_type, typename gradient_type,
70 typename = typename std::enable_if<
71 std::is_arithmetic<other_type>::value ||
72 is_dual_number<other_type>::value>::type>
73MFEM_HOST_DEVICE
74constexpr auto operator+(dual<value_type, gradient_type> a,
75 other_type b) -> dual<value_type, gradient_type>
76{
77 return {a.value + b, a.gradient};
78}
79
80// C++17 version of the above
81//
82// template <typename value_type, typename gradient_type>
83// constexpr auto operator+(dual<value_type, gradient_type> a, value_type b)
84// {
85// return dual{a.value + b, a.gradient};
86// }
87
88/** @brief addition of a dual number and a non-dual number */
89template <typename other_type, typename value_type, typename gradient_type,
90 typename = typename std::enable_if<
91 std::is_arithmetic<other_type>::value ||
92 is_dual_number<other_type>::value>::type>
93MFEM_HOST_DEVICE
94constexpr auto operator+(other_type a,
95 dual<value_type, gradient_type> b) -> dual<value_type, gradient_type>
96{
97 return {a + b.value, b.gradient};
98}
99
100/** @brief addition of two dual numbers */
101template <typename value_type_a, typename gradient_type_a, typename value_type_b, typename gradient_type_b>
102MFEM_HOST_DEVICE
103constexpr auto operator+(dual<value_type_a, gradient_type_a> a,
104 dual<value_type_b, gradient_type_b> b) -> dual<decltype(a.value + b.value),
105 decltype(a.gradient + b.gradient)>
106{
107 return {a.value + b.value, a.gradient + b.gradient};
108}
109
110/** @brief unary negation of a dual number */
111template <typename value_type, typename gradient_type> MFEM_HOST_DEVICE
112constexpr auto operator-(dual<value_type, gradient_type> x) ->
113dual<value_type, gradient_type>
114{
115 return {-x.value, -x.gradient};
116}
117
118/** @brief subtraction of a non-dual number from a dual number */
119template <typename value_type, typename gradient_type> MFEM_HOST_DEVICE
120constexpr auto operator-(dual<value_type, gradient_type> a,
121 real_t b) -> dual<value_type, gradient_type>
122{
123 return {a.value - b, a.gradient};
124}
125
126/** @brief subtraction of a dual number from a non-dual number */
127template <typename value_type, typename gradient_type> MFEM_HOST_DEVICE
128constexpr auto operator-(real_t a,
129 dual<value_type, gradient_type> b) -> dual<value_type, gradient_type>
130{
131 return {a - b.value, -b.gradient};
132}
133
134/** @brief subtraction of two dual numbers */
135template <typename value_type_a, typename gradient_type_a, typename value_type_b, typename gradient_type_b>
136MFEM_HOST_DEVICE
137constexpr auto operator-(dual<value_type_a, gradient_type_a> a,
138 dual<value_type_b, gradient_type_b> b) -> dual<decltype(a.value - b.value),
139 decltype(a.gradient - b.gradient)>
140{
141 return {a.value - b.value, a.gradient - b.gradient};
142}
143
144/** @brief multiplication of a dual number and a non-dual number */
145template <typename value_type, typename gradient_type> MFEM_HOST_DEVICE
146constexpr auto operator*(const dual<value_type, gradient_type>& a,
147 real_t b) -> dual<decltype(a.value * b), decltype(a.gradient * b)>
148{
149 return {a.value * b, a.gradient * b};
150}
151
152/** @brief multiplication of a dual number and a non-dual number */
153template <typename value_type, typename gradient_type> MFEM_HOST_DEVICE
154constexpr auto operator*(real_t a,
155 const dual<value_type, gradient_type>& b) ->
156dual<decltype(a * b.value), decltype(a * b.gradient)>
157{
158 return {a * b.value, a * b.gradient};
159}
160
161/** @brief multiplication of two dual numbers */
162template <typename value_type_a, typename gradient_type_a, typename value_type_b, typename gradient_type_b>
163MFEM_HOST_DEVICE
164constexpr auto operator*(dual<value_type_a, gradient_type_a> a,
165 dual<value_type_b, gradient_type_b> b) -> dual<decltype(a.value * b.value),
166 decltype(b.value * a.gradient + a.value * b.gradient)>
167{
168 return {a.value * b.value, b.value * a.gradient + a.value * b.gradient};
169}
170
171/** @brief division of a dual number by a non-dual number */
172template <typename value_type, typename gradient_type> MFEM_HOST_DEVICE
173constexpr auto operator/(const dual<value_type, gradient_type>& a,
174 real_t b) -> dual<decltype(a.value / b), decltype(a.gradient / b)>
175{
176 return {a.value / b, a.gradient / b};
177}
178
179/** @brief division of a non-dual number by a dual number */
180template <typename value_type, typename gradient_type> MFEM_HOST_DEVICE
181constexpr auto operator/(real_t a,
182 const dual<value_type, gradient_type>& b) -> dual<decltype(a / b.value),
183 decltype(-(a / (b.value * b.value)) * b.gradient)>
184{
185 return {a / b.value, -(a / (b.value * b.value)) * b.gradient};
186}
187
188/** @brief division of two dual numbers */
189template <typename value_type_a, typename gradient_type_a, typename value_type_b, typename gradient_type_b>
190MFEM_HOST_DEVICE
191constexpr auto operator/(dual<value_type_a, gradient_type_a> a,
192 dual<value_type_b, gradient_type_b> b) -> dual<decltype(a.value / b.value),
193 decltype((a.gradient / b.value) -
194 (a.value * b.gradient) /
195 (b.value * b.value))>
196{
197 return {a.value / b.value, (a.gradient / b.value) - (a.value * b.gradient) / (b.value * b.value)};
198}
199
200/**
201 * @brief Generates const + non-const overloads for a binary comparison operator
202 * Comparisons are conducted against the "value" part of the dual number
203 * @param[in] x The comparison operator to overload
204 */
205#define mfem_binary_comparator_overload(x) \
206 template <typename value_type, typename gradient_type> \
207 MFEM_HOST_DEVICE constexpr bool operator x( \
208 const dual<value_type, gradient_type>& a, \
209 real_t b) \
210 { \
211 return a.value x b; \
212 } \
213 \
214 template <typename value_type, typename gradient_type> \
215 MFEM_HOST_DEVICE constexpr bool operator x( \
216 real_t a, \
217 const dual<value_type, gradient_type>& b) \
218 { \
219 return a x b.value; \
220 } \
221 \
222 template <typename value_type_a, \
223 typename gradient_type_a, \
224 typename value_type_b, \
225 typename gradient_type_b> MFEM_HOST_DEVICE \
226 constexpr bool operator x( \
227 const dual<value_type_a, gradient_type_a>& a, \
228 const dual<value_type_b, gradient_type_b>& b) \
229 { \
230 return a.value x b.value; \
231 }
232
233mfem_binary_comparator_overload(<) ///< implement operator< for dual numbers
234mfem_binary_comparator_overload(<=) ///< implement operator<= for dual numbers
235mfem_binary_comparator_overload(==) ///< implement operator== for dual numbers
236mfem_binary_comparator_overload(>=) ///< implement operator>= for dual numbers
237mfem_binary_comparator_overload(>) ///< implement operator> for dual numbers
238
239#undef mfem_binary_comparator_overload
240
241/** @brief compound assignment (+) for dual numbers */
242template <typename value_type, typename gradient_type> MFEM_HOST_DEVICE
243dual<value_type, gradient_type>& operator+=(dual<value_type, gradient_type>& a,
244 const dual<value_type, gradient_type>& b)
245{
246 a.value += b.value;
247 a.gradient += b.gradient;
248 return a;
249}
250
251/** @brief compound assignment (-) for dual numbers */
252template <typename value_type, typename gradient_type> MFEM_HOST_DEVICE
253dual<value_type, gradient_type>& operator-=(dual<value_type, gradient_type>& a,
254 const dual<value_type, gradient_type>& b)
255{
256 a.value -= b.value;
257 a.gradient -= b.gradient;
258 return a;
259}
260
261/** @brief compound assignment (+) for dual numbers with `double` righthand side */
262template <typename value_type, typename gradient_type> MFEM_HOST_DEVICE
263dual<value_type, gradient_type>& operator+=(dual<value_type, gradient_type>& a,
264 real_t b)
265{
266 a.value += b;
267 return a;
268}
269
270/** @brief compound assignment (-) for dual numbers with `double` righthand side */
271template <typename value_type, typename gradient_type> MFEM_HOST_DEVICE
272dual<value_type, gradient_type>& operator-=(dual<value_type, gradient_type>& a,
273 real_t b)
274{
275 a.value -= b;
276 return a;
277}
278
279/** @brief implementation of absolute value function for dual numbers */
280template <typename value_type, typename gradient_type> MFEM_HOST_DEVICE
281dual<value_type, gradient_type> abs(dual<value_type, gradient_type> x)
282{
283 return (x.value >= 0) ? x : -x;
284}
285
286/** @brief implementation of square root for dual numbers */
287template <typename value_type, typename gradient_type> MFEM_HOST_DEVICE
288dual<value_type, gradient_type> sqrt(dual<value_type, gradient_type> x)
289{
290 using std::sqrt;
291 return {sqrt(x.value), x.gradient / (2.0 * sqrt(x.value))};
292}
293
294/** @brief implementation of cosine for dual numbers */
295template <typename value_type, typename gradient_type> MFEM_HOST_DEVICE
296dual<value_type, gradient_type> cos(dual<value_type, gradient_type> a)
297{
298 using std::cos;
299 using std::sin;
300 return {cos(a.value), -a.gradient * sin(a.value)};
301}
302
303/** @brief implementation of sine for dual numbers */
304template <typename value_type, typename gradient_type> MFEM_HOST_DEVICE
305dual<value_type, gradient_type> sin(dual<value_type, gradient_type> a)
306{
307 using std::sin;
308 using std::cos;
309 return {sin(a.value), a.gradient * cos(a.value)};
310}
311
312/** @brief implementation of sinh for dual numbers */
313template <typename value_type, typename gradient_type> MFEM_HOST_DEVICE
314dual<value_type, gradient_type> sinh(dual<value_type, gradient_type> a)
315{
316 using std::sinh;
317 using std::cosh;
318 return {sinh(a.value), a.gradient * cosh(a.value)};
319}
320
321/** @brief implementation of acos for dual numbers */
322template <typename value_type, typename gradient_type> MFEM_HOST_DEVICE
323dual<value_type, gradient_type> acos(dual<value_type, gradient_type> a)
324{
325 using std::sqrt;
326 using std::acos;
327 return {acos(a.value), -a.gradient / sqrt(value_type{1} - a.value * a.value)};
328}
329
330/** @brief implementation of asin for dual numbers */
331template <typename value_type, typename gradient_type> MFEM_HOST_DEVICE
332dual<value_type, gradient_type> asin(dual<value_type, gradient_type> a)
333{
334 using std::sqrt;
335 using std::asin;
336 return {asin(a.value), a.gradient / sqrt(value_type{1} - a.value * a.value)};
337}
338
339/** @brief implementation of tan for dual numbers */
340template <typename value_type, typename gradient_type> MFEM_HOST_DEVICE
341dual<value_type, gradient_type> tan(dual<value_type, gradient_type> a)
342{
343 using std::tan;
344 value_type f = tan(a.value);
345 return {f, a.gradient * (value_type{1} + f * f)};
346}
347
348/** @brief implementation of atan for dual numbers */
349template <typename value_type, typename gradient_type> MFEM_HOST_DEVICE
350dual<value_type, gradient_type> atan(dual<value_type, gradient_type> a)
351{
352 using std::atan;
353 return {atan(a.value), a.gradient / (value_type{1} + a.value * a.value)};
354}
355
356/** @brief implementation of exponential function for dual numbers */
357template <typename value_type, typename gradient_type> MFEM_HOST_DEVICE
358dual<value_type, gradient_type> exp(dual<value_type, gradient_type> a)
359{
360 using std::exp;
361 return {exp(a.value), exp(a.value) * a.gradient};
362}
363
364/** @brief implementation of the natural logarithm function for dual numbers */
365template <typename value_type, typename gradient_type> MFEM_HOST_DEVICE
366dual<value_type, gradient_type> log(dual<value_type, gradient_type> a)
367{
368 using std::log;
369 return {log(a.value), a.gradient / a.value};
370}
371
372/** @brief implementation of `a` (dual) raised to the `b` (dual) power */
373template <typename value_type, typename gradient_type> MFEM_HOST_DEVICE
374dual<value_type, gradient_type> pow(dual<value_type, gradient_type> a,
375 dual<value_type, gradient_type> b)
376{
377 using std::log;
378 using std::pow;
379 value_type value = pow(a.value, b.value);
380 return {value, value * (a.gradient * (b.value / a.value) + b.gradient * log(a.value))};
381}
382
383/** @brief implementation of `a` (non-dual) raised to the `b` (dual) power */
384template <typename value_type, typename gradient_type> MFEM_HOST_DEVICE
385dual<value_type, gradient_type> pow(real_t a, dual<value_type, gradient_type> b)
386{
387 using std::pow;
388 using std::log;
389 value_type value = pow(a, b.value);
390 return {value, value * b.gradient * log(a)};
391}
392
393/** @brief implementation of `a` (non-dual) raised to the `b` (non-dual) power */
394template <typename value_type > MFEM_HOST_DEVICE
395value_type pow(value_type a, value_type b)
396{
397 using std::pow;
398 return pow(a, b);
399}
400
401/** @brief implementation of `a` (dual) raised to the `b` (non-dual) power */
402template <typename value_type, typename gradient_type> MFEM_HOST_DEVICE
403dual<value_type, gradient_type> pow(dual<value_type, gradient_type> a, real_t b)
404{
405 using std::pow;
406 value_type value = pow(a.value, b);
407 return {value, value * a.gradient * b / a.value};
408}
409
410/** @brief overload of operator<< for `dual` to work with work with standard output streams */
411template <typename value_type, typename gradient_type, int... n>
412std::ostream& operator<<(std::ostream& os, dual<value_type, gradient_type> A)
413{
414 os << '(' << A.value << ' ' << A.gradient << ')';
415 return os;
416}
417
418/** @brief promote a value to a dual number of the appropriate type */
419MFEM_HOST_DEVICE constexpr dual<real_t, real_t> make_dual(real_t x) { return {x, 1.0}; }
420
421/** @brief return the "value" part from a given type. For non-dual types, this is just the identity function */
422template <typename T> MFEM_HOST_DEVICE T get_value(const T& arg) { return arg; }
423
424/** @brief return the "value" part from a dual number type */
425template <typename value_type, typename gradient_type>
426MFEM_HOST_DEVICE gradient_type get_value(dual<value_type, gradient_type> arg)
427{
428 return arg.value;
429}
430
431/** @brief return the "gradient" part from a dual number type */
432template <typename value_type, typename gradient_type>
433MFEM_HOST_DEVICE gradient_type get_gradient(dual<value_type, gradient_type> arg)
434{
435 return arg.gradient;
436}
437
438} // namespace internal
439} // namespace mfem
440
441#endif
real_t b
Definition lissajous.cpp:42
real_t a
Definition lissajous.cpp:41
MemoryClass operator*(MemoryClass mc1, MemoryClass mc2)
Return a suitable MemoryClass from a pair of MemoryClasses.
float real_t
Definition config.hpp:43
std::function< real_t(const Vector &)> f(real_t mass_coeff)
Definition lor_mms.hpp:30