33 std::size_t num_fields>
34MFEM_HOST_DEVICE
inline
43 const bool &use_sum_factorization)
45 if (use_sum_factorization)
49 MFEM_FOREACH_THREAD_DIRECT(q, x, q1d)
52 auto r =
Reshape(&residual_shmem(0, q), rs_qp);
58 MFEM_FOREACH_THREAD_DIRECT(qx, x, q1d)
60 MFEM_FOREACH_THREAD_DIRECT(qy, y, q1d)
62 const int q = qx + q1d * qy;
64 auto r =
Reshape(&residual_shmem(0, q), rs_qp);
71 MFEM_FOREACH_THREAD_DIRECT(qx, x, q1d)
73 MFEM_FOREACH_THREAD_DIRECT(qy, y, q1d)
75 MFEM_FOREACH_THREAD_DIRECT(qz, z, q1d)
77 const int q = qx + q1d * (qy + q1d * qz);
79 auto r =
Reshape(&residual_shmem(0, q), rs_qp);
87#if !(defined(MFEM_USE_CUDA) || defined(MFEM_USE_HIP))
88 MFEM_ABORT(
"unsupported dimension for sum factorization");
95 MFEM_FOREACH_THREAD_DIRECT(q, x, num_qp)
98 auto r =
Reshape(&residual_shmem(0, q), rs_qp);
118 typename qf_param_ts,
120 std::size_t num_fields>
121MFEM_HOST_DEVICE
inline
131 const bool &use_sum_factorization)
133 if (use_sum_factorization)
137 MFEM_FOREACH_THREAD_DIRECT(q, x, q1d)
139 auto r =
Reshape(&residual_shmem(0, q), das_qp);
141#ifdef MFEM_USE_ENZYME
152 MFEM_FOREACH_THREAD_DIRECT(qx, x, q1d)
154 MFEM_FOREACH_THREAD_DIRECT(qy, y, q1d)
156 const int q = qx + q1d * qy;
157 auto r =
Reshape(&residual_shmem(0, q), das_qp);
159#ifdef MFEM_USE_ENZYME
171 MFEM_FOREACH_THREAD_DIRECT(qx, x, q1d)
173 MFEM_FOREACH_THREAD_DIRECT(qy, y, q1d)
175 MFEM_FOREACH_THREAD_DIRECT(qz, z, q1d)
177 const int q = qx + q1d * (qy + q1d * qz);
178 auto r =
Reshape(&residual_shmem(0, q), das_qp);
180#ifdef MFEM_USE_ENZYME
193 MFEM_ABORT_KERNEL(
"unsupported dimension");
198 MFEM_FOREACH_THREAD_DIRECT(q, x, num_qp)
200 auto r =
Reshape(&residual_shmem(0, q), das_qp);
202#ifdef MFEM_USE_ENZYME
217 typename qf_param_ts,
219 std::size_t num_fields>
220MFEM_HOST_DEVICE
inline
231 const int test_vdim = qpdc.
GetShape()[0];
232 const int test_op_dim = qpdc.
GetShape()[1];
233 const int trial_vdim = qpdc.
GetShape()[2];
234 const int num_qp = qpdc.
GetShape()[4];
235 const size_t num_inputs = itod.
GetShape()[0];
237 for (
int j = 0; j < trial_vdim; j++)
240 for (
size_t s = 0; s < num_inputs; s++)
242 const int trial_op_dim =
static_cast<int>(itod(s));
243 if (trial_op_dim == 0)
248 auto d_qp =
Reshape(&(shadow_shmem[s])[0], trial_vdim, trial_op_dim, num_qp);
249 for (
int m = 0; m < trial_op_dim; m++)
253 auto r =
Reshape(&residual_shmem(0, q), das_qp);
255#ifdef MFEM_USE_ENZYME
264 auto f =
Reshape(&r(0), test_vdim, test_op_dim);
265 for (
int i = 0; i < test_vdim; i++)
267 for (
int k = 0; k < test_op_dim; k++)
269 qpdc(i, k, j, m + m_offset, q) =
f(i, k);
273 m_offset += trial_op_dim;
298 typename qf_param_ts,
300 std::size_t num_fields>
301MFEM_HOST_DEVICE
inline
312 const bool &use_sum_factorization)
314 if (use_sum_factorization)
318 MFEM_FOREACH_THREAD_DIRECT(q, x, q1d)
321 qfunc, input_shmem, shadow_shmem, residual_shmem, qpdc, itod, das_qp, q);
326 MFEM_FOREACH_THREAD_DIRECT(qx, x, q1d)
328 MFEM_FOREACH_THREAD_DIRECT(qy, y, q1d)
330 const int q = qx + q1d * qy;
332 qfunc, input_shmem, shadow_shmem, residual_shmem, qpdc, itod, das_qp, q);
338 MFEM_FOREACH_THREAD_DIRECT(qx, x, q1d)
340 MFEM_FOREACH_THREAD_DIRECT(qy, y, q1d)
342 MFEM_FOREACH_THREAD_DIRECT(qz, z, q1d)
344 const int q = qx + q1d * (qy + q1d * qz);
346 qfunc, input_shmem, shadow_shmem, residual_shmem, qpdc, itod, das_qp, q);
353 MFEM_ABORT_KERNEL(
"unsupported dimension");
358 const int num_qp = qpdc.
GetShape()[4];
359 MFEM_FOREACH_THREAD_DIRECT(q, x, num_qp)
362 qfunc, input_shmem, shadow_shmem, residual_shmem, qpdc, itod, das_qp, q);
386template <
size_t num_fields>
387MFEM_HOST_DEVICE
inline
395 const int test_vdim = qpdc.
GetShape()[0];
396 const int test_op_dim = qpdc.
GetShape()[1];
397 const int trial_vdim = qpdc.
GetShape()[2];
398 const int num_qp = qpdc.
GetShape()[4];
399 const size_t num_inputs = itod.
GetShape()[0];
401 for (
int i = 0; i < test_vdim; i++)
403 for (
int k = 0; k < test_op_dim; k++)
407 for (
size_t s = 0; s < num_inputs; s++)
409 const int trial_op_dim =
static_cast<int>(itod(s));
410 if (trial_op_dim == 0)
415 Reshape(&(shadow_shmem[s])[0], trial_vdim, trial_op_dim, num_qp);
416 for (
int j = 0; j < trial_vdim; j++)
418 for (
int m = 0; m < trial_op_dim; m++)
420 sum += qpdc(i, k, j, m + m_offset, q) * d_qp(j, m, q);
423 m_offset += trial_op_dim;
448template <
size_t num_fields>
449MFEM_HOST_DEVICE
inline
457 const bool &use_sum_factorization)
459 if (use_sum_factorization)
463 MFEM_FOREACH_THREAD_DIRECT(q, x, q1d)
470 MFEM_FOREACH_THREAD_DIRECT(qx, x, q1d)
472 MFEM_FOREACH_THREAD_DIRECT(qy, y, q1d)
474 const int q = qx + q1d * qy;
481 MFEM_FOREACH_THREAD_DIRECT(qx, x, q1d)
483 MFEM_FOREACH_THREAD_DIRECT(qy, y, q1d)
485 MFEM_FOREACH_THREAD_DIRECT(qz, z, q1d)
487 const int q = qx + q1d * (qy + q1d * qz);
495 MFEM_ABORT_KERNEL(
"unsupported dimension");
500 const int num_qp = qpdc.
GetShape()[4];
501 MFEM_FOREACH_THREAD_DIRECT(q, x, num_qp)
508template <
typename qfunc_t,
typename args_ts,
size_t num_args>
509MFEM_HOST_DEVICE
inline
512 const qfunc_t &qfunc,
521template <
typename qfunc_t,
typename arg_ts,
size_t num_args>
522MFEM_HOST_DEVICE
inline
525 const qfunc_t &qfunc,
536#ifdef MFEM_USE_ENZYME
538template <
typename func_t,
typename... arg_ts>
539MFEM_HOST_DEVICE
inline
548template <
typename qfunc_t,
typename arg_ts, std::size_t... Is,
549 typename inactive_arg_ts>
550MFEM_HOST_DEVICE
inline
552 arg_ts &&shadow_args,
553 std::index_sequence<Is...>,
554 inactive_arg_ts &&inactive_args,
555 std::index_sequence<>)
558 decltype(&qfunc_t::operator())>::type::return_t;
566template <
typename qfunc_t,
typename arg_ts, std::size_t... Is,
567 typename inactive_arg_ts, std::size_t... Js>
568MFEM_HOST_DEVICE
inline
570 arg_ts &&shadow_args,
571 std::index_sequence<Is...>,
572 inactive_arg_ts &&inactive_args,
573 std::index_sequence<Js...>)
576 decltype(&qfunc_t::operator())>::type::return_t;
579 decltype(
get<Js>(inactive_args))...>,
585template <
typename qfunc_t,
typename arg_ts,
typename inactive_arg_ts>
586MFEM_HOST_DEVICE
inline
588 arg_ts &&shadow_args,
589 inactive_arg_ts &&inactive_args)
591 auto arg_indices = std::make_index_sequence<
594 auto inactive_arg_indices = std::make_index_sequence<
598 inactive_args, inactive_arg_indices);
601template <
typename qfunc_t,
typename arg_ts,
size_t num_args>
602MFEM_HOST_DEVICE
inline
A basic generic Tensor class, appropriate for use on the GPU.
MFEM_HOST_DEVICE auto & GetShape() const
Returns the shape of the tensor.
MFEM_HOST_DEVICE return_type __enzyme_fwddiff(Args...)
constexpr int dimension
This example only works in 3D. Kernels for 2D are not implemented.
MFEM_HOST_DEVICE void apply_qpdc(DeviceTensor< 3 > &fhat, const std::array< DeviceTensor< 2 >, num_fields > &shadow_shmem, const DeviceTensor< 5, const real_t > &qpdc, const DeviceTensor< 1, const real_t > &itod, const int &q)
Apply the quadrature point data cache (qpdc) to a vector (usually a direction) on quadrature point q.
MFEM_HOST_DEVICE void call_qfunction_derivative(qfunc_t &qfunc, const std::array< DeviceTensor< 2 >, num_fields > &input_shmem, const std::array< DeviceTensor< 2 >, num_fields > &shadow_shmem, DeviceTensor< 2 > &residual_shmem, DeviceTensor< 5 > &qpdc, const DeviceTensor< 1, const real_t > &itod, const int &das_qp, const int &q)
MFEM_HOST_DEVICE auto qfunction_wrapper(const func_t &f, arg_ts &&...args)
MFEM_HOST_DEVICE auto fwddiff_apply_enzyme_indexed(qfunc_t &qfunc, arg_ts &&args, arg_ts &&shadow_args, std::index_sequence< Is... >, inactive_arg_ts &&inactive_args, std::index_sequence<>)
MFEM_HOST_DEVICE void process_derivative_from_native_dual(DeviceTensor< 1, T > &r, const tensor< dual< T, T >, n, m > &x)
MFEM_HOST_DEVICE void process_qf_result(DeviceTensor< 1, T > &r, const tensor< dual< T, T >, n > &x)
decltype(decay_types(std::declval< T >())) decay_tuple
MFEM_HOST_DEVICE void call_qfunction_derivative_action(qfunc_t &qfunc, const std::array< DeviceTensor< 2 >, num_fields > &input_shmem, const std::array< DeviceTensor< 2 >, num_fields > &shadow_shmem, DeviceTensor< 2 > &residual_shmem, const int &das_qp, const int &num_qp, const int &q1d, const int &dimension, const bool &use_sum_factorization)
Call a qfunction with the given parameters and compute it's derivative action.
MFEM_HOST_DEVICE auto fwddiff_apply_enzyme(qfunc_t &qfunc, arg_ts &&args, arg_ts &&shadow_args, inactive_arg_ts &&inactive_args)
MFEM_HOST_DEVICE void apply_qpdc(DeviceTensor< 3 > &fhat, const std::array< DeviceTensor< 2 >, num_fields > &shadow_shmem, const DeviceTensor< 5, const real_t > &qpdc, const DeviceTensor< 1, const real_t > &itod, const int &q1d, const int &dimension, const bool &use_sum_factorization)
Apply the quadrature point data cache (qpdc) to a vector (usually a direction).
MFEM_HOST_DEVICE auto apply(lambda f, tuple< T... > &args)
a way of passing an n-tuple to a function that expects n separate arguments
MFEM_HOST_DEVICE void call_qfunction(qfunc_t &qfunc, const std::array< DeviceTensor< 2 >, num_fields > &input_shmem, DeviceTensor< 2 > &residual_shmem, const int &rs_qp, const int &num_qp, const int &q1d, const int &dimension, const bool &use_sum_factorization)
Call a qfunction with the given parameters.
MFEM_HOST_DEVICE void call_qfunction_derivative(qfunc_t &qfunc, const std::array< DeviceTensor< 2 >, num_fields > &input_shmem, const std::array< DeviceTensor< 2 >, num_fields > &shadow_shmem, DeviceTensor< 2 > &residual_shmem, DeviceTensor< 5 > &qpdc, const DeviceTensor< 1, const real_t > &itod, const int &das_qp, const int &q1d, const int &dimension, const bool &use_sum_factorization)
Call a qfunction with the given parameters and compute it's derivative represented by the Jacobian on...
MFEM_HOST_DEVICE void apply_kernel(DeviceTensor< 1, real_t > &f_qp, const qfunc_t &qfunc, args_ts &args, const std::array< DeviceTensor< 2 >, num_args > &u, int qp)
MFEM_HOST_DEVICE void apply_kernel_fwddiff_enzyme(DeviceTensor< 1, real_t > &f_qp, qfunc_t &qfunc, arg_ts &args, arg_ts &shadow_args, const std::array< DeviceTensor< 2 >, num_args > &u, const std::array< DeviceTensor< 2 >, num_args > &v, int qp_idx)
MFEM_HOST_DEVICE void apply_kernel_native_dual(DeviceTensor< 1, real_t > &f_qp, const qfunc_t &qfunc, arg_ts &args, const std::array< DeviceTensor< 2 >, num_args > &u, const std::array< DeviceTensor< 2 >, num_args > &v, const int &qp_idx)
MFEM_HOST_DEVICE zero & get(zero &x)
let zero be accessed like a tuple
MFEM_HOST_DEVICE void process_qf_args(const std::array< DeviceTensor< 2 >, num_fields > &u, const std::array< DeviceTensor< 2 >, num_fields > &v, qf_args &args, const int &qp)
real_t u(const Vector &xvec)
MFEM_HOST_DEVICE DeviceTensor< sizeof...(Dims), T > Reshape(T *ptr, Dims... dims)
Wrap a pointer as a DeviceTensor with automatically deduced template parameters.
std::function< real_t(const Vector &)> f(real_t mass_coeff)
This is a class that mimics most of std::tuple's interface, except that it is usable in CUDA kernels ...