18 #ifndef MFEM_INTERNAL_DUAL_HPP
19 #define MFEM_INTERNAL_DUAL_HPP
21 #include <type_traits>
23 #include "../general/backends.hpp"
35 template <
typename value_type,
typename gradient_type>
41 gradient_type gradient;
45 auto operator=(
double a) -> dual<value_type, gradient_type>&
58 static constexpr
bool value =
false;
62 template <
typename value_type,
typename gradient_type>
63 struct is_dual_number<dual<value_type, gradient_type> >
65 static constexpr
bool value =
true;
69 template <
typename other_type,
typename value_type,
typename gradient_type,
70 typename =
typename std::enable_if<
71 std::is_arithmetic<other_type>::value ||
72 is_dual_number<other_type>::value>::type>
74 constexpr
auto operator+(dual<value_type, gradient_type>
a,
75 other_type
b) -> dual<value_type, gradient_type>
77 return {
a.value +
b,
a.gradient};
89 template <
typename other_type,
typename value_type,
typename gradient_type,
90 typename =
typename std::enable_if<
91 std::is_arithmetic<other_type>::value ||
92 is_dual_number<other_type>::value>::type>
95 dual<value_type, gradient_type>
b) -> dual<value_type, gradient_type>
97 return {
a +
b.value,
b.gradient};
101 template <
typename value_type_a,
typename gradient_type_a,
typename value_type_b,
typename gradient_type_b>
103 constexpr
auto operator+(dual<value_type_a, gradient_type_a>
a,
104 dual<value_type_b, gradient_type_b>
b) -> dual<decltype(
a.value +
b.value),
105 decltype(
a.gradient +
b.gradient)>
107 return {
a.value +
b.value,
a.gradient +
b.gradient};
111 template <
typename value_type,
typename gradient_type> MFEM_HOST_DEVICE
112 constexpr
auto operator-(dual<value_type, gradient_type> x) ->
113 dual<value_type, gradient_type>
115 return {-x.value, -x.gradient};
119 template <
typename value_type,
typename gradient_type> MFEM_HOST_DEVICE
120 constexpr
auto operator-(dual<value_type, gradient_type>
a,
121 double b) -> dual<value_type, gradient_type>
123 return {
a.value -
b,
a.gradient};
127 template <
typename value_type,
typename gradient_type> MFEM_HOST_DEVICE
129 dual<value_type, gradient_type>
b) -> dual<value_type, gradient_type>
131 return {
a -
b.value, -
b.gradient};
135 template <
typename value_type_a,
typename gradient_type_a,
typename value_type_b,
typename gradient_type_b>
137 constexpr
auto operator-(dual<value_type_a, gradient_type_a>
a,
138 dual<value_type_b, gradient_type_b>
b) -> dual<decltype(
a.value -
b.value),
139 decltype(
a.gradient -
b.gradient)>
141 return {
a.value -
b.value,
a.gradient -
b.gradient};
145 template <
typename value_type,
typename gradient_type> MFEM_HOST_DEVICE
146 constexpr
auto operator*(
const dual<value_type, gradient_type>&
a,
147 double b) -> dual<decltype(
a.value *
b), decltype(
a.gradient *
b)>
149 return {
a.value *
b,
a.gradient *
b};
153 template <
typename value_type,
typename gradient_type> MFEM_HOST_DEVICE
155 const dual<value_type, gradient_type>&
b) ->
156 dual<decltype(
a *
b.value), decltype(
a *
b.gradient)>
158 return {
a *
b.value,
a *
b.gradient};
162 template <
typename value_type_a,
typename gradient_type_a,
typename value_type_b,
typename gradient_type_b>
164 constexpr
auto operator*(dual<value_type_a, gradient_type_a>
a,
165 dual<value_type_b, gradient_type_b>
b) -> dual<decltype(
a.value *
b.value),
166 decltype(
b.value *
a.gradient +
a.value *
b.gradient)>
168 return {
a.value *
b.value,
b.value *
a.gradient +
a.value *
b.gradient};
172 template <
typename value_type,
typename gradient_type> MFEM_HOST_DEVICE
173 constexpr
auto operator/(
const dual<value_type, gradient_type>&
a,
174 double b) -> dual<decltype(
a.value /
b), decltype(
a.gradient /
b)>
176 return {
a.value /
b,
a.gradient /
b};
180 template <
typename value_type,
typename gradient_type> MFEM_HOST_DEVICE
182 const dual<value_type, gradient_type>&
b) -> dual<decltype(
a /
b.value),
183 decltype(-(
a / (
b.value *
b.value)) *
b.gradient)>
185 return {
a /
b.value, -(
a / (
b.value *
b.value)) *
b.gradient};
189 template <
typename value_type_a,
typename gradient_type_a,
typename value_type_b,
typename gradient_type_b>
191 constexpr
auto operator/(dual<value_type_a, gradient_type_a>
a,
192 dual<value_type_b, gradient_type_b>
b) -> dual<decltype(
a.value /
b.value),
193 decltype((
a.gradient /
b.value) -
194 (
a.value *
b.gradient) /
195 (
b.value *
b.value))>
197 return {
a.value /
b.value, (
a.gradient /
b.value) - (
a.value *
b.gradient) / (
b.value *
b.value)};
205 #define mfem_binary_comparator_overload(x) \
206 template <typename value_type, typename gradient_type> \
207 MFEM_HOST_DEVICE constexpr bool operator x( \
208 const dual<value_type, gradient_type>& a, \
211 return a.value x b; \
214 template <typename value_type, typename gradient_type> \
215 MFEM_HOST_DEVICE constexpr bool operator x( \
217 const dual<value_type, gradient_type>& b) \
219 return a x b.value; \
222 template <typename value_type_a, \
223 typename gradient_type_a, \
224 typename value_type_b, \
225 typename gradient_type_b> MFEM_HOST_DEVICE \
226 constexpr bool operator x( \
227 const dual<value_type_a, gradient_type_a>& a, \
228 const dual<value_type_b, gradient_type_b>& b) \
230 return a.value x b.value; \
233 mfem_binary_comparator_overload(<)
234 mfem_binary_comparator_overload(<=)
235 mfem_binary_comparator_overload(==)
236 mfem_binary_comparator_overload(>=)
237 mfem_binary_comparator_overload(>)
239 #undef mfem_binary_comparator_overload
242 template <
typename value_type,
typename gradient_type> MFEM_HOST_DEVICE
243 dual<value_type, gradient_type>& operator+=(dual<value_type, gradient_type>&
a,
244 const dual<value_type, gradient_type>&
b)
247 a.gradient += b.gradient;
252 template <
typename value_type,
typename gradient_type> MFEM_HOST_DEVICE
253 dual<value_type, gradient_type>& operator-=(dual<value_type, gradient_type>&
a,
254 const dual<value_type, gradient_type>&
b)
257 a.gradient -= b.gradient;
262 template <
typename value_type,
typename gradient_type> MFEM_HOST_DEVICE
263 dual<value_type, gradient_type>& operator+=(dual<value_type, gradient_type>&
a,
271 template <
typename value_type,
typename gradient_type> MFEM_HOST_DEVICE
272 dual<value_type, gradient_type>& operator-=(dual<value_type, gradient_type>&
a,
280 template <
typename value_type,
typename gradient_type> MFEM_HOST_DEVICE
281 dual<value_type, gradient_type> abs(dual<value_type, gradient_type> x)
283 return (x.value >= 0) ? x : -x;
287 template <
typename value_type,
typename gradient_type> MFEM_HOST_DEVICE
288 dual<value_type, gradient_type> sqrt(dual<value_type, gradient_type> x)
290 return {std::sqrt(x.value), x.gradient / (2.0 * std::sqrt(x.value))};
294 template <
typename value_type,
typename gradient_type> MFEM_HOST_DEVICE
295 dual<value_type, gradient_type> cos(dual<value_type, gradient_type>
a)
297 return {std::cos(a.value), -a.gradient * std::sin(a.value)};
301 template <
typename value_type,
typename gradient_type> MFEM_HOST_DEVICE
302 dual<value_type, gradient_type> sin(dual<value_type, gradient_type>
a)
304 return {std::sin(a.value), a.gradient * std::cos(a.value)};
308 template <
typename value_type,
typename gradient_type> MFEM_HOST_DEVICE
309 dual<value_type, gradient_type> sinh(dual<value_type, gradient_type>
a)
311 return {std::sinh(a.value), a.gradient * std::cosh(a.value)};
315 template <
typename value_type,
typename gradient_type> MFEM_HOST_DEVICE
316 dual<value_type, gradient_type> acos(dual<value_type, gradient_type>
a)
320 return {acos(a.value), -a.gradient / sqrt(value_type{1} - a.value * a.value)};
324 template <
typename value_type,
typename gradient_type> MFEM_HOST_DEVICE
325 dual<value_type, gradient_type> asin(dual<value_type, gradient_type> a)
329 return {asin(a.value), a.gradient / sqrt(value_type{1} - a.value * a.value)};
333 template <
typename value_type,
typename gradient_type> MFEM_HOST_DEVICE
334 dual<value_type, gradient_type> tan(dual<value_type, gradient_type> a)
337 value_type
f = tan(a.value);
338 return {
f, a.gradient * (value_type{1} + f *
f)};
342 template <
typename value_type,
typename gradient_type> MFEM_HOST_DEVICE
343 dual<value_type, gradient_type> atan(dual<value_type, gradient_type> a)
345 return {atan(a.value), a.gradient / (value_type{1} + a.value * a.value)};
349 template <
typename value_type,
typename gradient_type> MFEM_HOST_DEVICE
350 dual<value_type, gradient_type> exp(dual<value_type, gradient_type> a)
352 return {std::exp(a.value), std::exp(a.value) * a.gradient};
356 template <
typename value_type,
typename gradient_type> MFEM_HOST_DEVICE
357 dual<value_type, gradient_type> log(dual<value_type, gradient_type> a)
359 return {std::log(a.value), a.gradient / a.value};
363 template <
typename value_type,
typename gradient_type> MFEM_HOST_DEVICE
364 dual<value_type, gradient_type> pow(dual<value_type, gradient_type> a,
365 dual<value_type, gradient_type>
b)
367 value_type value = pow(a.value, b.value);
368 return {value, value * (a.gradient * (b.value / a.value) + b.gradient * std::log(a.value))};
372 template <
typename value_type,
typename gradient_type> MFEM_HOST_DEVICE
373 dual<value_type, gradient_type> pow(
double a, dual<value_type, gradient_type> b)
375 value_type value = pow(a, b.value);
376 return {value, value * b.gradient * std::log(a)};
380 template <
typename value_type > MFEM_HOST_DEVICE
381 value_type pow(value_type a, value_type b) {
return std::pow(a, b); }
384 template <
typename value_type,
typename gradient_type> MFEM_HOST_DEVICE
385 dual<value_type, gradient_type> pow(dual<value_type, gradient_type> a,
double b)
387 value_type value = pow(a.value, b);
388 return {value, value * a.gradient * b / a.value};
392 template <
typename value_type,
typename gradient_type,
int... n>
393 std::ostream& operator<<(std::ostream& os, dual<value_type, gradient_type> A)
395 os <<
'(' << A.value <<
' ' << A.gradient <<
')';
400 MFEM_HOST_DEVICE constexpr dual<double, double> make_dual(
double x) {
return {x, 1.0}; }
403 template <
typename T> MFEM_HOST_DEVICE T get_value(
const T& arg) {
return arg; }
406 template <
typename value_type,
typename gradient_type>
407 MFEM_HOST_DEVICE gradient_type get_value(dual<value_type, gradient_type> arg)
413 template <
typename value_type,
typename gradient_type>
414 MFEM_HOST_DEVICE gradient_type get_gradient(dual<value_type, gradient_type> arg)
MFEM_ALWAYS_INLINE AutoSIMD< scalar_t, S, A > operator/(const scalar_t &e, const AutoSIMD< scalar_t, S, A > &v)
MFEM_ALWAYS_INLINE AutoSIMD< scalar_t, S, A > operator+(const scalar_t &e, const AutoSIMD< scalar_t, S, A > &v)
double f(const Vector &xvec)
MemoryClass operator*(MemoryClass mc1, MemoryClass mc2)
Return a suitable MemoryClass from a pair of MemoryClasses.
MFEM_ALWAYS_INLINE AutoSIMD< scalar_t, S, A > operator-(const scalar_t &e, const AutoSIMD< scalar_t, S, A > &v)