18#ifndef MFEM_INTERNAL_DUAL_HPP
19#define MFEM_INTERNAL_DUAL_HPP
35template <
typename value_type,
typename gradient_type>
41 gradient_type gradient;
45 auto operator=(
real_t a) -> dual<value_type, gradient_type>&
58 static constexpr bool value =
false;
62template <
typename value_type,
typename gradient_type>
63struct is_dual_number<dual<value_type, gradient_type> >
65 static constexpr bool value =
true;
69template <
typename other_type,
typename value_type,
typename gradient_type,
70 typename =
typename std::enable_if<
71 std::is_arithmetic<other_type>::value ||
72 is_dual_number<other_type>::value>::type>
74constexpr auto operator+(dual<value_type, gradient_type>
a,
75 other_type
b) -> dual<value_type, gradient_type>
77 return {
a.value +
b,
a.gradient};
89template <
typename other_type,
typename value_type,
typename gradient_type,
90 typename =
typename std::enable_if<
91 std::is_arithmetic<other_type>::value ||
92 is_dual_number<other_type>::value>::type>
94constexpr auto operator+(other_type
a,
95 dual<value_type, gradient_type>
b) -> dual<value_type, gradient_type>
97 return {
a +
b.value,
b.gradient};
101template <
typename value_type_a,
typename gradient_type_a,
typename value_type_b,
typename gradient_type_b>
103constexpr auto operator+(dual<value_type_a, gradient_type_a>
a,
104 dual<value_type_b, gradient_type_b>
b) -> dual<
decltype(
a.value +
b.value),
105 decltype(
a.gradient +
b.gradient)>
107 return {
a.value +
b.value,
a.gradient +
b.gradient};
111template <
typename value_type,
typename gradient_type> MFEM_HOST_DEVICE
112constexpr auto operator-(dual<value_type, gradient_type> x) ->
113dual<value_type, gradient_type>
115 return {-x.value, -x.gradient};
119template <
typename value_type,
typename gradient_type> MFEM_HOST_DEVICE
120constexpr auto operator-(dual<value_type, gradient_type>
a,
121 real_t b) -> dual<value_type, gradient_type>
123 return {
a.value -
b,
a.gradient};
127template <
typename value_type,
typename gradient_type> MFEM_HOST_DEVICE
128constexpr auto operator-(
real_t a,
129 dual<value_type, gradient_type>
b) -> dual<value_type, gradient_type>
131 return {
a -
b.value, -
b.gradient};
135template <
typename value_type_a,
typename gradient_type_a,
typename value_type_b,
typename gradient_type_b>
137constexpr auto operator-(dual<value_type_a, gradient_type_a>
a,
138 dual<value_type_b, gradient_type_b>
b) -> dual<
decltype(
a.value -
b.value),
139 decltype(
a.gradient -
b.gradient)>
141 return {
a.value -
b.value,
a.gradient -
b.gradient};
145template <
typename value_type,
typename gradient_type> MFEM_HOST_DEVICE
146constexpr auto operator*(
const dual<value_type, gradient_type>&
a,
147 real_t b) -> dual<
decltype(
a.value *
b),
decltype(
a.gradient *
b)>
149 return {
a.value *
b,
a.gradient *
b};
153template <
typename value_type,
typename gradient_type> MFEM_HOST_DEVICE
155 const dual<value_type, gradient_type>&
b) ->
156dual<
decltype(
a *
b.value),
decltype(
a *
b.gradient)>
158 return {
a *
b.value,
a *
b.gradient};
162template <
typename value_type_a,
typename gradient_type_a,
typename value_type_b,
typename gradient_type_b>
164constexpr auto operator*(dual<value_type_a, gradient_type_a>
a,
165 dual<value_type_b, gradient_type_b>
b) -> dual<
decltype(
a.value *
b.value),
166 decltype(
b.value *
a.gradient +
a.value *
b.gradient)>
168 return {
a.value *
b.value,
b.value *
a.gradient +
a.value *
b.gradient};
172template <
typename value_type,
typename gradient_type> MFEM_HOST_DEVICE
173constexpr auto operator/(
const dual<value_type, gradient_type>&
a,
174 real_t b) -> dual<
decltype(
a.value /
b),
decltype(
a.gradient /
b)>
176 return {
a.value /
b,
a.gradient /
b};
180template <
typename value_type,
typename gradient_type> MFEM_HOST_DEVICE
181constexpr auto operator/(
real_t a,
182 const dual<value_type, gradient_type>&
b) -> dual<
decltype(
a /
b.value),
183 decltype(-(
a / (
b.value *
b.value)) *
b.gradient)>
185 return {
a /
b.value, -(
a / (
b.value *
b.value)) *
b.gradient};
189template <
typename value_type_a,
typename gradient_type_a,
typename value_type_b,
typename gradient_type_b>
191constexpr auto operator/(dual<value_type_a, gradient_type_a>
a,
192 dual<value_type_b, gradient_type_b>
b) -> dual<
decltype(
a.value /
b.value),
193 decltype((
a.gradient /
b.value) -
194 (
a.value *
b.gradient) /
195 (
b.value *
b.value))>
197 return {
a.value /
b.value, (
a.gradient /
b.value) - (
a.value *
b.gradient) / (
b.value *
b.value)};
205#define mfem_binary_comparator_overload(x) \
206 template <typename value_type, typename gradient_type> \
207 MFEM_HOST_DEVICE constexpr bool operator x( \
208 const dual<value_type, gradient_type>& a, \
211 return a.value x b; \
214 template <typename value_type, typename gradient_type> \
215 MFEM_HOST_DEVICE constexpr bool operator x( \
217 const dual<value_type, gradient_type>& b) \
219 return a x b.value; \
222 template <typename value_type_a, \
223 typename gradient_type_a, \
224 typename value_type_b, \
225 typename gradient_type_b> MFEM_HOST_DEVICE \
226 constexpr bool operator x( \
227 const dual<value_type_a, gradient_type_a>& a, \
228 const dual<value_type_b, gradient_type_b>& b) \
230 return a.value x b.value; \
233mfem_binary_comparator_overload(<)
234mfem_binary_comparator_overload(<=)
235mfem_binary_comparator_overload(==)
236mfem_binary_comparator_overload(>=)
237mfem_binary_comparator_overload(>)
239#undef mfem_binary_comparator_overload
242template <
typename value_type,
typename gradient_type> MFEM_HOST_DEVICE
243dual<value_type, gradient_type>& operator+=(dual<value_type, gradient_type>&
a,
244 const dual<value_type, gradient_type>&
b)
247 a.gradient +=
b.gradient;
252template <
typename value_type,
typename gradient_type> MFEM_HOST_DEVICE
253dual<value_type, gradient_type>& operator-=(dual<value_type, gradient_type>&
a,
254 const dual<value_type, gradient_type>&
b)
257 a.gradient -=
b.gradient;
262template <
typename value_type,
typename gradient_type> MFEM_HOST_DEVICE
263dual<value_type, gradient_type>& operator+=(dual<value_type, gradient_type>&
a,
271template <
typename value_type,
typename gradient_type> MFEM_HOST_DEVICE
272dual<value_type, gradient_type>& operator-=(dual<value_type, gradient_type>&
a,
280template <
typename value_type,
typename gradient_type> MFEM_HOST_DEVICE
281dual<value_type, gradient_type> abs(dual<value_type, gradient_type> x)
283 return (x.value >= 0) ? x : -x;
287template <
typename value_type,
typename gradient_type> MFEM_HOST_DEVICE
288dual<value_type, gradient_type> sqrt(dual<value_type, gradient_type> x)
291 return {sqrt(x.value), x.gradient / (2.0 * sqrt(x.value))};
295template <
typename value_type,
typename gradient_type> MFEM_HOST_DEVICE
296dual<value_type, gradient_type> cos(dual<value_type, gradient_type>
a)
300 return {cos(
a.value), -
a.gradient * sin(
a.value)};
304template <
typename value_type,
typename gradient_type> MFEM_HOST_DEVICE
305dual<value_type, gradient_type> sin(dual<value_type, gradient_type>
a)
309 return {sin(
a.value),
a.gradient * cos(
a.value)};
313template <
typename value_type,
typename gradient_type> MFEM_HOST_DEVICE
314dual<value_type, gradient_type> sinh(dual<value_type, gradient_type>
a)
318 return {sinh(
a.value),
a.gradient * cosh(
a.value)};
322template <
typename value_type,
typename gradient_type> MFEM_HOST_DEVICE
323dual<value_type, gradient_type> acos(dual<value_type, gradient_type>
a)
327 return {acos(
a.value), -
a.gradient / sqrt(value_type{1} -
a.value *
a.value)};
331template <
typename value_type,
typename gradient_type> MFEM_HOST_DEVICE
332dual<value_type, gradient_type> asin(dual<value_type, gradient_type>
a)
336 return {asin(
a.value),
a.gradient / sqrt(value_type{1} -
a.value *
a.value)};
340template <
typename value_type,
typename gradient_type> MFEM_HOST_DEVICE
341dual<value_type, gradient_type> tan(dual<value_type, gradient_type>
a)
344 value_type
f = tan(
a.value);
345 return {
f,
a.gradient * (value_type{1} +
f *
f)};
349template <
typename value_type,
typename gradient_type> MFEM_HOST_DEVICE
350dual<value_type, gradient_type> atan(dual<value_type, gradient_type>
a)
353 return {atan(
a.value),
a.gradient / (value_type{1} +
a.value *
a.value)};
357template <
typename value_type,
typename gradient_type> MFEM_HOST_DEVICE
358dual<value_type, gradient_type> exp(dual<value_type, gradient_type>
a)
361 return {exp(
a.value), exp(
a.value) *
a.gradient};
365template <
typename value_type,
typename gradient_type> MFEM_HOST_DEVICE
366dual<value_type, gradient_type> log(dual<value_type, gradient_type>
a)
369 return {log(
a.value),
a.gradient /
a.value};
373template <
typename value_type,
typename gradient_type> MFEM_HOST_DEVICE
374dual<value_type, gradient_type> pow(dual<value_type, gradient_type>
a,
375 dual<value_type, gradient_type>
b)
379 value_type value = pow(
a.value,
b.value);
380 return {value, value * (
a.gradient * (
b.value /
a.value) +
b.gradient * log(
a.value))};
384template <
typename value_type,
typename gradient_type> MFEM_HOST_DEVICE
385dual<value_type, gradient_type> pow(
real_t a, dual<value_type, gradient_type>
b)
389 value_type value = pow(
a,
b.value);
390 return {value, value *
b.gradient * log(
a)};
394template <
typename value_type > MFEM_HOST_DEVICE
395value_type pow(value_type
a, value_type
b)
402template <
typename value_type,
typename gradient_type> MFEM_HOST_DEVICE
403dual<value_type, gradient_type> pow(dual<value_type, gradient_type>
a,
real_t b)
406 value_type value = pow(
a.value,
b);
407 return {value, value *
a.gradient *
b /
a.value};
411template <
typename value_type,
typename gradient_type,
int... n>
412std::ostream& operator<<(std::ostream& os, dual<value_type, gradient_type> A)
414 os <<
'(' << A.value <<
' ' << A.gradient <<
')';
419MFEM_HOST_DEVICE
constexpr dual<real_t, real_t> make_dual(
real_t x) {
return {x, 1.0}; }
422template <
typename T> MFEM_HOST_DEVICE T get_value(
const T& arg) {
return arg; }
425template <
typename value_type,
typename gradient_type>
426MFEM_HOST_DEVICE gradient_type get_value(dual<value_type, gradient_type> arg)
432template <
typename value_type,
typename gradient_type>
433MFEM_HOST_DEVICE gradient_type get_gradient(dual<value_type, gradient_type> arg)
MemoryClass operator*(MemoryClass mc1, MemoryClass mc2)
Return a suitable MemoryClass from a pair of MemoryClasses.
std::function< real_t(const Vector &)> f(real_t mass_coeff)