MFEM v4.9.0
Finite element discretization library
Loading...
Searching...
No Matches
doperator.hpp
Go to the documentation of this file.
1// Copyright (c) 2010-2025, Lawrence Livermore National Security, LLC. Produced
2// at the Lawrence Livermore National Laboratory. All Rights reserved. See files
3// LICENSE and NOTICE for details. LLNL-CODE-806117.
4//
5// This file is part of the MFEM library. For more information and source code
6// availability visit https://mfem.org.
7//
8// MFEM is free software; you can redistribute it and/or modify it under the
9// terms of the BSD-3 license. We welcome feedback and contributions, see file
10// CONTRIBUTING.md for details.
11#pragma once
12
13#include <type_traits>
14#include <utility>
15
17
18#ifdef MFEM_USE_MPI
19#include "../fespace.hpp"
20
21#include "util.hpp"
22#include "interpolate.hpp"
23#include "integrate.hpp"
24#include "qfunction_apply.hpp"
25#include "assemble.hpp"
26
27namespace mfem::future
28{
29
30/// @brief Type alias for a function that computes the action of an operator
31using action_t =
32 std::function<void(std::vector<Vector> &, const std::vector<Vector> &, Vector &)>;
33
34/// @brief Type alias for a function that computes the cache for the action of a derivative
36 std::function<void(std::vector<Vector> &, const Vector &)>;
37
38/// @brief Type alias for a function that computes the action of a derivative
40 std::function<void(std::vector<Vector> &, const Vector &, Vector &)>;
41
42/// @brief Type alias for a function that assembles the SparseMatrix of a
43/// derivative operator
45 std::function<void(std::vector<Vector> &, SparseMatrix *&)>;
46
47/// @brief Type alias for a function that assembles the HypreParMatrix of a
48/// derivative operator
50 std::function<void(std::vector<Vector> &, HypreParMatrix *&)>;
51
52/// @brief Type alias for a function that applies the appropriate restriction to
53/// the solution and parameters
55 std::function<void(std::vector<Vector> &,
56 const std::vector<Vector> &,
57 std::vector<Vector> &)>;
58
59/// Class representing the derivative (Jacobian) operator of a
60/// DifferentiableOperator.
61///
62/// This class implements a derivative operator that computes directional
63/// derivatives for a given set of solution and parameter fields. It supports
64/// both forward and transpose operations, as well as assembly into sparse
65/// matrices.
66///
67/// @note The derivative operator uses only forward mode differentiation in Mult
68/// and MultTranspose. It does not support reverse mode differentiation. The
69/// MultTranspose operation is achieved by using the transpose of the derivative
70/// actions on each quadrature point.
71///
72/// @see DifferentiableOperator
74{
75public:
76 /// Constructor for the DerivativeOperator class.
77 ///
78 /// This is usually not called directly from a user. A DifferentiableOperator
79 /// calls this constructor when using
80 /// DifferentiableOperator::GetDerivative().
82 const int &height,
83 const int &width,
84 const std::vector<derivative_action_t> &derivative_actions,
85 const FieldDescriptor &direction,
86 const int &daction_l_size,
87 const std::vector<derivative_action_t> &derivative_actions_transpose,
88 const FieldDescriptor &transpose_direction,
89 const int &daction_transpose_l_size,
90 const std::vector<Vector *> &solutions_l,
91 const std::vector<Vector *> &parameters_l,
92 const restriction_callback_t &restriction_callback,
93 const std::function<void(Vector &, Vector &)> &prolongation_transpose,
94 const std::vector<assemble_derivative_sparsematrix_callback_t>
95 &assemble_derivative_sparsematrix_callbacks,
96 const std::vector<assemble_derivative_hypreparmatrix_callback_t>
97 &assemble_derivative_hypreparmatrix_callbacks) :
99 derivative_actions(derivative_actions),
101 daction_l(daction_l_size),
102 daction_l_size(daction_l_size),
103 derivative_actions_transpose(derivative_actions_transpose),
104 transpose_direction(transpose_direction),
105 prolongation_transpose(prolongation_transpose),
106 assemble_derivative_sparsematrix_callbacks(
107 assemble_derivative_sparsematrix_callbacks),
108 assemble_derivative_hypreparmatrix_callbacks(
109 assemble_derivative_hypreparmatrix_callbacks)
110 {
111 std::vector<Vector> s_l(solutions_l.size());
112 for (size_t i = 0; i < s_l.size(); i++)
113 {
114 s_l[i] = *solutions_l[i];
115 }
116
117 std::vector<Vector> p_l(parameters_l.size());
118 for (size_t i = 0; i < p_l.size(); i++)
119 {
120 p_l[i] = *parameters_l[i];
121 }
122
123 fields_e.resize(solutions_l.size() + parameters_l.size());
124 restriction_callback(s_l, p_l, fields_e);
125 }
126
127 /// @brief Compute the action of the derivative operator on a given vector.
128 ///
129 /// @param direction_t The direction vector in which to compute the
130 /// derivative. This has to be a T-dof vector.
131 /// @param result_t Result vector of the action of the derivative on
132 /// direction_t on T-dofs.
133 void Mult(const Vector &direction_t, Vector &result_t) const override
134 {
135 daction_l.SetSize(daction_l_size);
136 daction_l = 0.0;
137
138 prolongation(direction, direction_t, direction_l);
139 for (const auto &f : derivative_actions)
140 {
141 f(fields_e, direction_l, daction_l);
142 }
143 prolongation_transpose(daction_l, result_t);
144 };
145
146 /// @brief Compute the transpose of the derivative operator on a given
147 /// vector.
148 ///
149 /// This function computes the transpose of the derivative operator on a
150 /// given vector by transposing the quadrature point local forward derivative
151 /// action. It does not use reverse mode automatic differentiation.
152 ///
153 /// @param direction_t The direction vector in which to compute the
154 /// derivative. This has to be a T-dof vector.
155 /// @param result_t Result vector of the transpose action of the derivative on
156 /// direction_t on T-dofs.
157 void MultTranspose(const Vector &direction_t, Vector &result_t) const override
158 {
159 MFEM_ASSERT(!derivative_actions_transpose.empty(),
160 "derivative can't be used to be multiplied in transpose mode");
161
162 daction_l.SetSize(width);
163 daction_l = 0.0;
164
165 prolongation(transpose_direction, direction_t, direction_l);
166 for (const auto &f : derivative_actions_transpose)
167 {
168 f(fields_e, direction_l, daction_l);
169 }
170 prolongation_transpose(daction_l, result_t);
171 };
172
173 /// @brief Assemble the derivative operator into a SparseMatrix.
174 ///
175 /// @param A The SparseMatrix to assemble the derivative operator into. Can
176 /// be an uninitialized object.
178 {
179 MFEM_ASSERT(!assemble_derivative_sparsematrix_callbacks.empty(),
180 "derivative can't be assembled into a SparseMatrix");
181
182 for (const auto &f : assemble_derivative_sparsematrix_callbacks)
183 {
184 f(fields_e, A);
185 }
186 }
187
188 /// @brief Assemble the derivative operator into a HypreParMatrix.
189 ///
190 /// @param A The HypreParMatrix to assemble the derivative operator into. Can
191 /// be an uninitialized object.
193 {
194 MFEM_ASSERT(!assemble_derivative_hypreparmatrix_callbacks.empty(),
195 "derivative can't be assembled into a HypreParMatrix");
196
197 for (const auto &f : assemble_derivative_hypreparmatrix_callbacks)
198 {
199 f(fields_e, A);
200 }
201 }
202
203private:
204 /// Derivative action callbacks. Depending on the requested derivatives in
205 /// DifferentiableOperator the callbacks represent certain combinations of
206 /// actions of derivatives of the forward operator.
207 std::vector<derivative_action_t> derivative_actions;
208
210
211 mutable Vector daction_l;
212
213 const int daction_l_size;
214
215 /// Transpose Derivative action callbacks. Depending on the requested
216 /// derivatives in DifferentiableOperator the callbacks represent certain
217 /// combinations of actions of derivatives of the forward operator.
218 std::vector<derivative_action_t> derivative_actions_transpose;
219
220 FieldDescriptor transpose_direction;
221
222 mutable std::vector<Vector> fields_e;
223
224 mutable Vector direction_l;
225
226 std::function<void(Vector &, Vector &)> prolongation_transpose;
227
228 /// Callbacks that assemble derivatives into a SparseMatrix.
229 std::vector<assemble_derivative_sparsematrix_callback_t>
230 assemble_derivative_sparsematrix_callbacks;
231
232 /// Callbacks that assemble derivatives into a HypreParMatrix.
233 std::vector<assemble_derivative_hypreparmatrix_callback_t>
234 assemble_derivative_hypreparmatrix_callbacks;
235};
236
237/// Class representing a differentiable operator which acts on solution and
238/// parameter fields to compute residuals.
239///
240/// This class provides functionality to define differentiable operators by
241/// composing functions that compute values at quadrature points. It supports
242/// automatic differentiation to compute derivatives with respect to solutions
243/// (Jacobians) and parameter fields (general derivative operators).
244///
245/// The operator is constructed with solution fields that it will act on and
246/// parameter fields that define coefficients. Quadrature functions are added by
247/// e.g. using AddDomainIntegrator() which specify how the operator evaluates
248/// those functions and parameters at quadrature points.
249///
250/// Derivatives can be computed by obtaining a DerivativeOperator using
251/// GetDerivative().
252///
253/// @see DerivativeOperator
255{
256public:
257 /// Constructor for the DifferentiableOperator class.
258 ///
259 /// @param solutions The solution fields that the operator will act on.
260 /// @param parameters The parameter fields that define coefficients.
261 /// @param mesh The mesh on which the operator is defined.
263 const std::vector<FieldDescriptor> &solutions,
264 const std::vector<FieldDescriptor> &parameters,
265 const ParMesh &mesh);
266
267 /// MultLevel enum to indicate if the T->L Operators are used in the
268 /// Mult method.
274
275 /// @brief Set the MultLevel mode for the DifferentiableOperator.
276 /// The default is TVECTOR, which means that the Operator will use
277 /// T->L before Mult and L->T Operators after.
279 {
280 mult_level = level;
281 }
282
283 /// @brief Compute the action of the operator on a given vector.
284 ///
285 /// @param solutions_in The solution vector in which to compute the action.
286 /// This has to be a T-dof vector if MultLevel is set to TVECTOR, or L-dof
287 /// Vector if MultLevel is set to LVECTOR.
288 /// @param result_in Result vector of the action of the operator on
289 /// solutions. The result is a T-dof vector or L-dof vector depending on
290 /// the MultLevel.
291 void Mult(const Vector &solutions_in, Vector &result_in) const override
292 {
293 MFEM_ASSERT(!action_callbacks.empty(), "no integrators have been set");
294
295 if (mult_level == MultLevel::LVECTOR)
296 {
297 get_lvectors(solutions, solutions_in, solutions_l);
298 result_in = 0.0;
299 for (auto &action : action_callbacks)
300 {
301 action(solutions_l, parameters_l, result_in);
302 }
303 }
304 else
305 {
306 prolongation(solutions, solutions_in, solutions_l);
307 residual_l = 0.0;
308 for (auto &action : action_callbacks)
309 {
310 action(solutions_l, parameters_l, residual_l);
311 }
312 prolongation_transpose(residual_l, result_in);
313 }
314 }
315
316 /// @brief Add an integrator to the operator.
317 /// Called only from AddDomainIntegrator() and AddBoundaryIntegrator().
318 template <
319 typename entity_t,
320 typename qfunc_t,
321 typename input_t,
322 typename output_t,
323 typename derivative_ids_t>
324 void AddIntegrator(
325 qfunc_t &qfunc,
326 input_t inputs,
327 output_t outputs,
328 const IntegrationRule &integration_rule,
329 const Array<int> &attributes,
330 derivative_ids_t derivative_ids);
331
332 /// @brief Add a domain integrator to the operator.
333 ///
334 /// @param qfunc The quadrature function to be added.
335 /// @param inputs Tuple of FieldOperators for the inputs of the quadrature
336 /// function.
337 /// @param outputs Tuple of FieldOperators for the outputs of the quadrature
338 /// function.
339 /// @param integration_rule IntegrationRule to use with this integrator.
340 /// @param domain_attributes Domain attributes marker array indicating over
341 /// which attributes this integrator will integrate over.
342 /// @param derivative_ids Derivatives to be made available for this
343 /// integrator.
344 template <
345 typename qfunc_t,
346 typename input_t,
347 typename output_t,
348 typename derivative_ids_t = decltype(std::make_index_sequence<0> {})>
350 qfunc_t &qfunc,
351 input_t inputs,
352 output_t outputs,
353 const IntegrationRule &integration_rule,
354 const Array<int> &domain_attributes,
355 derivative_ids_t derivative_ids = std::make_index_sequence<0> {});
356
357 /// @brief Add a boundary integrator to the operator.
358 ///
359 /// @param qfunc The quadrature function to be added.
360 /// @param inputs Tuple of FieldOperators for the inputs of the quadrature
361 /// function.
362 /// @param outputs Tuple of FieldOperators for the outputs of the quadrature
363 /// function.
364 /// @param integration_rule IntegrationRule to use with this integrator.
365 /// @param boundary_attributes Boundary attributes marker array indicating over
366 /// which attributes this integrator will integrate over.
367 /// @param derivative_ids Derivatives to be made available for this
368 /// integrator.
369 template <
370 typename qfunc_t,
371 typename input_t,
372 typename output_t,
373 typename derivative_ids_t = decltype(std::make_index_sequence<0> {})>
375 qfunc_t &qfunc,
376 input_t inputs,
377 output_t outputs,
378 const IntegrationRule &integration_rule,
379 const Array<int> &boundary_attributes,
380 derivative_ids_t derivative_ids = std::make_index_sequence<0> {});
381
382 /// @brief Set the parameters for the operator.
383 ///
384 /// This has to be called before using Mult() or MultTranspose().
385 ///
386 /// @param p The parameters to be set. This should be a vector of pointers to
387 /// the parameter vectors. The vectors have to be L-vectors (e.g.
388 /// GridFunctions).
389 void SetParameters(std::vector<Vector *> p) const;
390
391 /// @brief Disable the use of tensor product structure.
392 ///
393 /// This function disables the use of tensor product structure for the
394 /// operator. Usually, DifferentiableOperator creates callbacks based on
395 /// heuristics that achieve good performance for each element type. Some
396 /// functionality is not implemented for these performant algorithms but only
397 /// for generic assembly. Therefore the user can decide to use fallback
398 /// methods.
399 void DisableTensorProductStructure(bool disable = true)
400 {
401 use_tensor_product_structure = !disable;
402 }
403
404 /// @brief Get the derivative operator for a given derivative ID.
405 ///
406 /// This function returns a shared pointer to a DerivativeOperator that
407 /// computes the derivative of the operator with respect to the given
408 /// derivative ID. The derivative ID is used to identify the specific
409 /// derivative action to be performed.
410 ///
411 /// @param derivative_id The ID of the derivative to be computed.
412 /// @param sol_l The solution vectors to be used for the derivative
413 /// computation. This should be a vector of pointers to the solution
414 /// vectors. The vectors have to be L-vectors (e.g. GridFunctions).
415 /// @param par_l The parameter vectors to be used for the derivative
416 /// computation. This should be a vector of pointers to the parameter
417 /// vectors. The vectors have to be L-vectors (e.g. GridFunctions).
418 /// @return A shared pointer to the DerivativeOperator.
419 std::shared_ptr<DerivativeOperator> GetDerivative(
420 size_t derivative_id, std::vector<Vector *> sol_l, std::vector<Vector *> par_l)
421 {
422 MFEM_ASSERT(derivative_action_callbacks.find(derivative_id) !=
423 derivative_action_callbacks.end(),
424 "no derivative action has been found for ID " << derivative_id);
425
426 MFEM_ASSERT(sol_l.size() == solutions.size(),
427 "wrong number of solutions");
428
429 MFEM_ASSERT(par_l.size() == parameters.size(),
430 "wrong number of parameters");
431
432 const size_t derivative_idx = FindIdx(derivative_id, fields);
433
434 std::vector<Vector> s_l(solutions_l.size());
435 for (size_t i = 0; i < s_l.size(); i++)
436 {
437 s_l[i] = *sol_l[i];
438 }
439
440 std::vector<Vector> p_l(parameters_l.size());
441 for (size_t i = 0; i < p_l.size(); i++)
442 {
443 p_l[i] = *par_l[i];
444 }
445
446 fields_e.resize(solutions_l.size() + parameters_l.size());
447 restriction_callback(s_l, p_l, fields_e);
448
449 // Dummy
450 Vector dir_l;
451 if (derivative_idx > s_l.size())
452 {
453 dir_l = p_l[derivative_idx - s_l.size()];
454 }
455 else
456 {
457 dir_l = s_l[derivative_idx];
458 }
459
460 derivative_setup_callbacks[derivative_id][0](fields_e, dir_l);
461
462 return std::make_shared<DerivativeOperator>(
463 height,
464 GetTrueVSize(fields[derivative_idx]),
465 derivative_action_callbacks[derivative_id],
466 fields[derivative_idx],
467 residual_l.Size(),
468 daction_transpose_callbacks[derivative_id],
469 fields[test_space_field_idx],
470 GetVSize(fields[test_space_field_idx]),
471 sol_l,
472 par_l,
473 restriction_callback,
474 prolongation_transpose,
475 assemble_derivative_sparsematrix_callbacks[derivative_id],
476 assemble_derivative_hypreparmatrix_callbacks[derivative_id]);
477 }
478
479private:
480 const ParMesh &mesh;
481
482 MultLevel mult_level = TVECTOR;
483
484 std::vector<action_t> action_callbacks;
485 std::map<size_t, std::vector<derivative_setup_t>> derivative_setup_callbacks;
486 std::map<size_t,
487 std::vector<derivative_action_t>> derivative_action_callbacks;
488 std::map<size_t,
489 std::vector<derivative_action_t>> daction_transpose_callbacks;
490 std::map<size_t,
491 std::vector<assemble_derivative_sparsematrix_callback_t>>
492 assemble_derivative_sparsematrix_callbacks;
493 std::map<size_t,
494 std::vector<assemble_derivative_hypreparmatrix_callback_t>>
495 assemble_derivative_hypreparmatrix_callbacks;
496
497 std::vector<FieldDescriptor> solutions;
498 std::vector<FieldDescriptor> parameters;
499 // solutions and parameters
500 std::vector<FieldDescriptor> fields;
501
502 mutable std::vector<Vector> solutions_l;
503 mutable std::vector<Vector> parameters_l;
504 mutable Vector residual_l;
505
506 mutable std::vector<Vector> fields_e;
507 mutable Vector residual_e;
508
509 std::function<void(Vector &, Vector &)> prolongation_transpose;
510 std::function<void(Vector &, Vector &)> output_restriction_transpose;
511 restriction_callback_t restriction_callback;
512
513 std::map<size_t, Vector> derivative_qp_caches;
514
515 std::map<size_t, size_t> assembled_vector_sizes;
516
517 bool use_tensor_product_structure = true;
518
519 size_t test_space_field_idx = SIZE_MAX;
520};
521
522template <
523 typename qfunc_t,
524 typename input_t,
525 typename output_t,
526 typename derivative_ids_t>
528 qfunc_t &qfunc,
529 input_t inputs,
530 output_t outputs,
531 const IntegrationRule &integration_rule,
532 const Array<int> &domain_attributes,
533 derivative_ids_t derivative_ids)
534{
536 qfunc, inputs, outputs, integration_rule, domain_attributes, derivative_ids);
537}
538
539template <
540 typename qfunc_t,
541 typename input_t,
542 typename output_t,
543 typename derivative_ids_t>
545 qfunc_t &qfunc,
546 input_t inputs,
547 output_t outputs,
548 const IntegrationRule &integration_rule,
549 const Array<int> &boundary_attributes,
550 derivative_ids_t derivative_ids)
551{
552
553 if (mesh.GetNFbyType(FaceType::Boundary) != mesh.GetNBE())
554 {
555 MFEM_ABORT("AddBoundaryIntegrator on meshes with interior boundaries is not supported.");
556 }
558 qfunc, inputs, outputs, integration_rule, boundary_attributes, derivative_ids);
559}
560
561template <
562 typename entity_t,
563 typename qfunc_t,
564 typename input_t,
565 typename output_t,
566 typename derivative_ids_t>
568 qfunc_t &qfunc,
569 input_t inputs,
570 output_t outputs,
571 const IntegrationRule &integration_rule,
572 const Array<int> &attributes,
573 derivative_ids_t derivative_ids)
574{
575 if constexpr (!(std::is_same_v<entity_t, Entity::Element> ||
576 std::is_same_v<entity_t, Entity::BoundaryElement>))
577 {
578 static_assert(dfem::always_false<entity_t>,
579 "entity type not supported in AddIntegrator");
580 }
581
582 static constexpr size_t num_inputs =
583 tuple_size<decltype(inputs)>::value;
584
585 static constexpr size_t num_outputs =
586 tuple_size<decltype(outputs)>::value;
587
588 using qf_signature =
589 typename create_function_signature<decltype(&qfunc_t::operator())>::type;
590 using qf_param_ts = typename qf_signature::parameter_ts;
591 using qf_output_t = typename qf_signature::return_t;
592
593 // Consistency checks
594 if constexpr (num_outputs > 1)
595 {
596 static_assert(dfem::always_false<qfunc_t>,
597 "more than one output per quadrature functions is not supported right now");
598 }
599
600 if constexpr (std::is_same_v<qf_output_t, void>)
601 {
602 static_assert(dfem::always_false<qfunc_t>,
603 "quadrature function has no return value");
604 }
605
606 constexpr size_t num_qfinputs = tuple_size<qf_param_ts>::value;
607 static_assert(num_qfinputs == num_inputs,
608 "quadrature function inputs and descriptor inputs have to match");
609
610 constexpr size_t num_qf_outputs = tuple_size<qf_output_t>::value;
611 static_assert(num_qf_outputs == num_outputs,
612 "quadrature function outputs and descriptor outputs have to match");
613
614 constexpr auto inout_tuple =
616 constexpr auto filtered_inout_tuple = filter_fields(inout_tuple);
617 static constexpr size_t num_fields =
618 count_unique_field_ids(filtered_inout_tuple);
619
620 MFEM_ASSERT(num_fields == solutions.size() + parameters.size(),
621 "Total number of fields doesn't match sum of solutions and parameters."
622 " This indicates that some fields are not used in the integrator,"
623 " which currently is not supported.");
624
625 auto dependency_map = make_dependency_map(inputs);
626
627 // pretty_print(dependency_map);
628
629 auto input_to_field =
631 auto output_to_field =
633
634 // TODO: factor out
635 std::vector<int> inputs_vdim(num_inputs);
636 for_constexpr<num_inputs>([&](auto i)
637 {
638 inputs_vdim[i] = get<i>(inputs).vdim;
639 });
640
641 const Array<int> *elem_attributes = nullptr;
642 if constexpr (std::is_same_v<entity_t, Entity::Element>)
643 {
644 elem_attributes = &mesh.GetElementAttributes();
645 }
646 else if constexpr (std::is_same_v<entity_t, Entity::BoundaryElement>)
647 {
648 elem_attributes = &mesh.GetBdrFaceAttributes();
649 }
650
651 const auto output_fop = get<0>(outputs);
652 test_space_field_idx = FindIdx(output_fop.GetFieldId(), fields);
653
654 bool use_sum_factorization = false;
655 Element::Type entity_element_type;
656 if constexpr (std::is_same_v<entity_t, Entity::Element>)
657 {
658 entity_element_type =
660
661 if ((entity_element_type == Element::QUADRILATERAL ||
662 entity_element_type == Element::HEXAHEDRON) &&
663 use_tensor_product_structure == true)
664 {
665 use_sum_factorization = true;
666 }
667 }
668 else if constexpr (std::is_same_v<entity_t, Entity::BoundaryElement>)
669 {
670 entity_element_type =
672
673 if ((entity_element_type == Element::SEGMENT ||
674 entity_element_type == Element::QUADRILATERAL) &&
675 use_tensor_product_structure == true)
676 {
677 use_sum_factorization = true;
678 }
679 }
680
681 ElementDofOrdering element_dof_ordering = ElementDofOrdering::NATIVE;
682 DofToQuad::Mode doftoquad_mode = DofToQuad::Mode::FULL;
683 if (use_sum_factorization)
684 {
685 element_dof_ordering = ElementDofOrdering::LEXICOGRAPHIC;
686 doftoquad_mode = DofToQuad::Mode::TENSOR;
687 }
688
689 auto [output_rt,
691 (fields[test_space_field_idx],
692 element_dof_ordering, output_fop);
693 auto &output_e_size = output_e_sz;
694
695 output_restriction_transpose = output_rt;
696 residual_e.SetSize(output_e_size);
697
698 // The explicit captures are necessary to avoid dependency on
699 // the specific instance of this class (this pointer).
700 restriction_callback =
701 [=, solutions = this->solutions, parameters = this->parameters]
702 (std::vector<Vector> &sol,
703 const std::vector<Vector> &par,
704 std::vector<Vector> &f)
705 {
706 restriction<entity_t>(solutions, sol, f,
707 element_dof_ordering);
708 restriction<entity_t>(parameters, par, f,
709 element_dof_ordering,
710 solutions.size());
711 };
712
713 prolongation_transpose = get_prolongation_transpose(
714 fields[test_space_field_idx], output_fop, mesh.GetComm());
715
716 int dimension;
717 if constexpr (std::is_same_v<entity_t, Entity::Element>)
718 {
719 dimension = mesh.Dimension();
720 }
721 else if constexpr (std::is_same_v<entity_t, Entity::BoundaryElement>)
722 {
723 dimension = mesh.Dimension() - 1;
724 }
725
726 [[maybe_unused]] const int num_elements = GetNumEntities<entity_t>(mesh);
727 const int num_entities = GetNumEntities<entity_t>(mesh);
728 const int num_qp = integration_rule.GetNPoints();
729
730 if constexpr (is_sum_fop<decltype(output_fop)>::value)
731 {
732 residual_l.SetSize(1);
733 height = 1;
734 }
735 else
736 {
737 const int residual_lsize = GetVSize(fields[test_space_field_idx]);
738 residual_l.SetSize(residual_lsize);
739 height = GetTrueVSize(fields[test_space_field_idx]);
740 }
741
742 // TODO: Is this a hack?
743 width = GetTrueVSize(fields[0]);
744
745 std::vector<const DofToQuad*> dtq;
746 for (const auto &field : fields)
747 {
748 dtq.emplace_back(GetDofToQuad<entity_t>(
749 field,
750 integration_rule,
751 doftoquad_mode));
752 }
753 const int q1d = (int)floor(std::pow(num_qp, 1.0/dimension) + 0.5);
754
755 const int residual_size_on_qp =
756 GetSizeOnQP<entity_t>(output_fop,
757 fields[test_space_field_idx]);
758
759 auto input_dtq_maps = create_dtq_maps<entity_t>(inputs, dtq, input_to_field);
760 auto output_dtq_maps = create_dtq_maps<entity_t>(outputs, dtq, output_to_field);
761
762 const int test_vdim = output_fop.vdim;
763 const int test_op_dim = output_fop.size_on_qp / output_fop.vdim;
764 const int num_test_dof =
765 num_entities ? (output_e_size / output_fop.vdim / num_entities) : 0;
766
767 auto ir_weights = Reshape(integration_rule.GetWeights().Read(), num_qp);
768
769 auto input_size_on_qp =
770 get_input_size_on_qp(inputs, std::make_index_sequence<num_inputs> {});
771
772 auto action_shmem_info =
774 (input_dtq_maps, output_dtq_maps, fields, num_entities, inputs, num_qp,
775 input_size_on_qp, residual_size_on_qp, element_dof_ordering);
776
777 Vector shmem_cache(action_shmem_info.total_size);
778
779 // print_shared_memory_info(action_shmem_info);
780
781 ThreadBlocks thread_blocks;
782 if (dimension == 3)
783 {
784 if (use_sum_factorization)
785 {
786 thread_blocks.x = q1d;
787 thread_blocks.y = q1d;
788 thread_blocks.z = q1d;
789 }
790 }
791 else if (dimension == 2)
792 {
793 if (use_sum_factorization)
794 {
795 thread_blocks.x = q1d;
796 thread_blocks.y = q1d;
797 thread_blocks.z = 1;
798 }
799 }
800 else if (dimension == 1)
801 {
802 thread_blocks.x = q1d;
803 thread_blocks.y = 1;
804 thread_blocks.z = 1;
805 }
806
807 action_callbacks.push_back(
808 // Explicitly capture everything we need, so we can make explicit choice
809 // how to capture every variable, by copy or by ref.
810 [
811 // capture by copy:
812 dimension, // int
813 num_entities, // int
814 num_test_dof, // int
815 num_qp, // int
816 q1d, // int
817 residual_size_on_qp, // int
818 test_vdim, // int (= output_fop.vdim)
819 test_op_dim, // int (derived from output_fop)
820 inputs, // mfem::future::tuple
821 attributes, // Array<int>
822 ir_weights, // DeviceTensor
823 use_sum_factorization, // bool
824 input_dtq_maps, // std::array<DofToQuadMap, num_fields>
825 output_dtq_maps, // std::array<DofToQuadMap, num_fields>
826 input_to_field, // std::array<int, s>
827 output_fop, // class derived from FieldOperator
828 qfunc, // qfunc_t
829 thread_blocks, // ThreadBlocks
830 shmem_cache, // Vector (local)
831 action_shmem_info, // SharedMemoryInfo
832 // TODO: make this Array<int> a member of the DifferentiableOperator
833 // and capture it by ref.
834 elem_attributes, // Array<int>
835
836 // capture by ref:
837 &restriction_cb = this->restriction_callback,
838 &fields_e = this->fields_e,
839 &residual_e = this->residual_e,
840 &output_restriction_transpose = this->output_restriction_transpose
841 ]
842 (std::vector<Vector> &sol, const std::vector<Vector> &par, Vector &res)
843 mutable // mutable: needed to modify 'shmem_cache'
844 {
845 restriction_cb(sol, par, fields_e);
846
847 residual_e = 0.0;
848 auto ye = Reshape(residual_e.ReadWrite(), test_vdim, num_test_dof, num_entities);
849
850 auto wrapped_fields_e = wrap_fields(fields_e,
851 action_shmem_info.field_sizes,
852 num_entities);
853
854 const bool has_attr = attributes.Size() > 0;
855 const auto d_attr = attributes.Read();
856 const auto d_elem_attr = elem_attributes->Read();
857
858 forall([=] MFEM_HOST_DEVICE (int e, void *shmem)
859 {
860 if (has_attr && !d_attr[d_elem_attr[e] - 1]) { return; }
861
862 auto [input_dtq_shmem, output_dtq_shmem, fields_shmem, input_shmem,
863 residual_shmem, scratch_shmem] =
864 unpack_shmem(shmem, action_shmem_info, input_dtq_maps, output_dtq_maps,
865 wrapped_fields_e, num_qp, e);
866
868 input_shmem, fields_shmem, input_dtq_shmem, input_to_field, inputs, ir_weights,
869 scratch_shmem, dimension, use_sum_factorization);
870
872 qfunc, input_shmem, residual_shmem,
873 residual_size_on_qp, num_qp, q1d, dimension, use_sum_factorization);
874
875 auto fhat = Reshape(&residual_shmem(0, 0), test_vdim, test_op_dim, num_qp);
876 auto y = Reshape(&ye(0, 0, e), num_test_dof, test_vdim);
878 y, fhat, output_fop, output_dtq_shmem[0],
879 scratch_shmem, dimension, use_sum_factorization);
880 }, num_entities, thread_blocks, action_shmem_info.total_size, shmem_cache.ReadWrite());
881 output_restriction_transpose(residual_e, res);
882 });
883
884 // Without this compile-time check, some valid instantiations of this method
885 // will fail.
886 if constexpr (derivative_ids_t::size() != 0)
887 {
888 // Create the action of the derivatives
889 for_constexpr([&, &or_transpose =
890 this->output_restriction_transpose](const std::size_t derivative_id)
891 {
892 const size_t d_field_idx = FindIdx(derivative_id, fields);
893 const auto direction = fields[d_field_idx];
894 const int da_size_on_qp =
895 GetSizeOnQP<entity_t>(output_fop, fields[test_space_field_idx]);
896
897 auto shmem_info =
899 input_dtq_maps, output_dtq_maps, fields, num_entities, inputs,
900 num_qp, input_size_on_qp, residual_size_on_qp,
901 element_dof_ordering, d_field_idx);
902
903 Vector shmem_cache(shmem_info.total_size);
904
905 // print_shared_memory_info(shmem_info);
906
907 Vector direction_e(get_restriction<entity_t>(fields[d_field_idx],
908 element_dof_ordering)->Height());
909 Vector derivative_action_e(output_e_size);
910 derivative_action_e = 0.0;
911
912 // Lookup the derivative_id key in the dependency map
913 auto it = dependency_map.find(derivative_id);
914 if (it == dependency_map.end())
915 {
916 MFEM_ABORT("Derivative ID not found in dependency map");
917 }
918 const auto input_is_dependent = it->second;
919
920 // Trial operator dimension for each input.
921 // The trial operator dimension is set for each input that is
922 // dependent and if it is independent the dimension is 0.
923 Vector inputs_trial_op_dim(num_inputs);
924 int total_trial_op_dim = 0;
925 {
926 auto itod = Reshape(inputs_trial_op_dim.HostReadWrite(), num_inputs);
927 int idx = 0;
928 for_constexpr<num_inputs>([&](auto s)
929 {
930 if (!input_is_dependent[s])
931 {
932 itod(idx) = 0;
933 }
934 else
935 {
936 // TODO: BUG! Make this a general function that works for all kinds of inputs.
937 itod(idx) = input_size_on_qp[s] / get<s>(inputs).vdim;
938 }
939 total_trial_op_dim += static_cast<int>(itod(idx));
940 idx++;
941 });
942 }
943
944 // First Input index of the derivative
945 const size_t d_input_idx = [d_field_idx, &input_to_field]
946 {
947 for (size_t i = 0; i < input_to_field.size(); i++)
948 {
949 if (input_to_field[i] == d_field_idx)
950 {
951 return i;
952 }
953 }
954 return size_t(SIZE_MAX);
955 }();
956
957 const int trial_vdim = GetVDim(fields[d_field_idx]);
958 const int num_trial_dof =
959 get_restriction<entity_t>(fields[d_field_idx], element_dof_ordering)->Height() /
960 inputs_vdim[d_input_idx] / num_entities;
961 const int num_trial_dof_1d =
962 input_dtq_maps[d_input_idx].B.GetShape()[DofToQuadMap::Index::DOF];
963
964 Vector Ae_mem(num_test_dof * test_vdim * num_trial_dof * trial_vdim *
965 num_entities);
966 Ae_mem = 0.0;
967
968 // Quadrature point local derivative cache for each element, with data
969 // layout:
970 // [test_vdim, test_op_dim, trial_vdim, trial_op_dim, qp, num_entities].
971 derivative_qp_caches[derivative_id] = Vector(test_vdim * test_op_dim *
972 trial_vdim *
973 total_trial_op_dim * num_qp * num_entities);
974 // Create local references for MSVC lambda capture compatibility
975 auto& fields_ref = this->fields;
976 auto& derivative_qp_caches_ref = this->derivative_qp_caches[derivative_id];
977
978 // In each of the callbacks we're saving the derivatives in the quadrature point
979 // caches. This trades memory with computational effort but also minimizes
980 // data movement on each multiplication of the gradient with a directional
981 // vector.
982 derivative_setup_callbacks[derivative_id].push_back(
983 [
984 // capture by copy:
985 dimension, // int
986 num_entities, // int
987 num_qp, // int
988 q1d, // int
989 test_vdim, // int (= output_fop.vdim)
990 test_op_dim, // int (derived from output_fop)
991 inputs, // mfem::future::tuple
992 attributes, // Array<int>
993 ir_weights, // DeviceTensor
994 use_sum_factorization, // bool
995 input_dtq_maps, // std::array<DofToQuadMap, num_fields>
996 output_dtq_maps, // std::array<DofToQuadMap, num_fields>
997 input_to_field, // std::array<int, s>
998 qfunc, // qfunc_t
999 thread_blocks, // ThreadBlocks
1000 shmem_cache, // Vector (local)
1001 shmem_info, // SharedMemoryInfo
1002 // TODO: make this Array<int> a member of the DifferentiableOperator
1003 // and capture it by ref.
1004 elem_attributes, // Array<int>
1005 element_dof_ordering, // ElementDofOrdering
1006
1007 direction, // FieldDescriptor
1008 direction_e, // Vector
1009 da_size_on_qp, // int
1010
1011 total_trial_op_dim,
1012 trial_vdim,
1013 inputs_trial_op_dim,
1014
1015 // capture by ref:
1016 &qpdc_mem = derivative_qp_caches_ref
1017 ](std::vector<Vector> &f_e, const Vector &dir_l) mutable
1018 {
1019 restriction<entity_t>(direction, dir_l, direction_e,
1020 element_dof_ordering);
1021 auto wrapped_fields_e = wrap_fields(f_e, shmem_info.field_sizes,
1022 num_entities);
1023 auto wrapped_direction_e = Reshape(direction_e.ReadWrite(),
1024 shmem_info.direction_size,
1025 num_entities);
1026
1027 auto qpdc = Reshape(qpdc_mem.ReadWrite(), test_vdim, test_op_dim,
1028 trial_vdim, total_trial_op_dim, num_qp, num_entities);
1029
1030 auto itod = Reshape(inputs_trial_op_dim.Read(), num_inputs);
1031
1032 const auto d_elem_attr = elem_attributes->Read();
1033 const bool has_attr = attributes.Size() > 0;
1034 const auto d_domain_attr = attributes.Read();
1035
1036 forall([=] MFEM_HOST_DEVICE (int e, real_t *shmem)
1037 {
1038 if (has_attr && !d_domain_attr[d_elem_attr[e] - 1]) { return; }
1039
1040 auto [input_dtq_shmem, output_dtq_shmem, fields_shmem,
1041 direction_shmem, input_shmem,
1042 shadow_shmem_, residual_shmem,
1043 scratch_shmem] =
1044 unpack_shmem(shmem, shmem_info, input_dtq_maps, output_dtq_maps,
1045 wrapped_fields_e, wrapped_direction_e, num_qp, e);
1046 auto &shadow_shmem = shadow_shmem_;
1047
1049 input_shmem, fields_shmem, input_dtq_shmem, input_to_field,
1050 inputs, ir_weights, scratch_shmem, dimension,
1051 use_sum_factorization);
1052
1053 set_zero(shadow_shmem);
1054
1055 auto qpdc_e = Reshape(&qpdc(0, 0, 0, 0, 0, e), test_vdim, test_op_dim,
1056 trial_vdim, total_trial_op_dim, num_qp);
1058 qfunc, input_shmem, shadow_shmem, residual_shmem, qpdc_e, itod, da_size_on_qp,
1059 q1d, dimension, use_sum_factorization);
1060 }, num_entities, thread_blocks, shmem_info.total_size,
1061 shmem_cache.ReadWrite());
1062 });
1063
1064 // The derivative action only uses the quadrature point caches and applies
1065 // them to an input vector before integrating with the desired trial operator.
1066 derivative_action_callbacks[derivative_id].push_back(
1067 [
1068 // capture by copy:
1069 dimension, // int
1070 num_entities, // int
1071 num_test_dof, // int
1072 num_qp, // int
1073 q1d, // int
1074 test_vdim, // int (= output_fop.vdim)
1075 test_op_dim, // int (derived from output_fop)
1076 inputs, // mfem::future::tuple
1077 attributes, // Array<int>
1078 ir_weights, // DeviceTensor
1079 use_sum_factorization, // bool
1080 input_dtq_maps, // std::array<DofToQuadMap, num_fields>
1081 output_dtq_maps, // std::array<DofToQuadMap, num_fields>
1082 output_fop, // class derived from FieldOperator
1083 thread_blocks, // ThreadBlocks
1084 shmem_cache, // Vector (local)
1085 shmem_info, // SharedMemoryInfo
1086 // TODO: make this Array<int> a member of the DifferentiableOperator
1087 // and capture it by ref.
1088 elem_attributes, // Array<int>
1089
1090 input_is_dependent, // std::array<bool, num_inputs>
1091 direction, // FieldDescriptor
1092 direction_e, // Vector
1093 derivative_action_e, // Vector
1094 element_dof_ordering, // ElementDofOrdering
1095 inputs_trial_op_dim,
1096 total_trial_op_dim,
1097 trial_vdim,
1098 // capture by ref:
1099 &qpdc_mem = derivative_qp_caches_ref,
1100 &or_transpose
1101 ](
1102 std::vector<Vector> &f_e, const Vector &dir_l,
1103 Vector &der_action_l) mutable
1104 {
1105 restriction<entity_t>(direction, dir_l, direction_e,
1106 element_dof_ordering);
1107 auto ye = Reshape(derivative_action_e.ReadWrite(), num_test_dof,
1108 test_vdim, num_entities);
1109 auto wrapped_fields_e = wrap_fields(f_e, shmem_info.field_sizes,
1110 num_entities);
1111 auto wrapped_direction_e = Reshape(direction_e.ReadWrite(),
1112 shmem_info.direction_size,
1113 num_entities);
1114
1115 auto qpdc = Reshape(qpdc_mem.Read(), test_vdim, test_op_dim,
1116 trial_vdim, total_trial_op_dim, num_qp, num_entities);
1117
1118 auto itod = Reshape(inputs_trial_op_dim.Read(), num_inputs);
1119
1120 const bool has_attr = attributes.Size() > 0;
1121 const auto d_attr = attributes.Read();
1122 const auto d_elem_attr = elem_attributes->Read();
1123
1124 derivative_action_e = 0.0;
1125 forall([=] MFEM_HOST_DEVICE (int e, real_t *shmem)
1126 {
1127 if (has_attr && !d_attr[d_elem_attr[e] - 1]) { return; }
1128
1129 auto [input_dtq_shmem, output_dtq_shmem, fields_shmem,
1130 direction_shmem, input_shmem,
1131 shadow_shmem_, residual_shmem,
1132 scratch_shmem] =
1133 unpack_shmem(shmem, shmem_info, input_dtq_maps, output_dtq_maps,
1134 wrapped_fields_e, wrapped_direction_e, num_qp, e);
1135 auto &shadow_shmem = shadow_shmem_;
1136
1138 shadow_shmem, direction_shmem, input_dtq_shmem, inputs,
1139 ir_weights, scratch_shmem, input_is_dependent, dimension,
1140 use_sum_factorization);
1141
1142 auto fhat = Reshape(&residual_shmem(0, 0), test_vdim,
1143 test_op_dim, num_qp);
1144
1145 auto qpdce = Reshape(&qpdc(0, 0, 0, 0, 0, e), test_vdim, test_op_dim,
1146 trial_vdim, total_trial_op_dim, num_qp);
1147
1148 apply_qpdc(fhat, shadow_shmem, qpdce, itod, q1d, dimension,
1149 use_sum_factorization);
1150
1151 auto y = Reshape(&ye(0, 0, e), num_test_dof, test_vdim);
1153 y, fhat, output_fop, output_dtq_shmem[0],
1154 scratch_shmem, dimension, use_sum_factorization);
1155 }, num_entities, thread_blocks, shmem_info.total_size,
1156 shmem_cache.ReadWrite());
1157 or_transpose(derivative_action_e, der_action_l);
1158 });
1159
1160 assemble_derivative_sparsematrix_callbacks[derivative_id].push_back(
1161 [
1162 // capture by copy:
1163 dimension, // int
1164 num_entities, // int
1165 num_test_dof, // int
1166 num_qp, // int
1167 q1d, // int
1168 test_vdim, // int (= output_fop.vdim)
1169 test_op_dim, // int (derived from output_fop)
1170 inputs, // mfem::future::tuple
1171 attributes, // Array<int>
1172 use_sum_factorization, // bool
1173 input_dtq_maps, // std::array<DofToQuadMap, num_fields>
1174 output_dtq_maps, // std::array<DofToQuadMap, num_fields>
1175 input_to_field, // std::array<int, s>
1176 output_fop, // class derived from FieldOperator
1177 thread_blocks, // ThreadBlocks
1178 shmem_cache, // Vector (local)
1179 shmem_info, // SharedMemoryInfo
1180 // TODO: make this Array<int> a member of the DifferentiableOperator
1181 // and capture it by ref.
1182 elem_attributes, // Array<int>
1183
1184 input_is_dependent, // std::array<bool, num_inputs>
1185 direction_e, // Vector
1186 total_trial_op_dim,
1187 trial_vdim,
1188 num_trial_dof,
1189 num_trial_dof_1d,
1190 inputs_trial_op_dim,
1191 Ae_mem,
1192 output_to_field,
1193
1194 // capture by ref:
1195 &qpdc_mem = derivative_qp_caches_ref,
1196 &fields = fields_ref
1197 ](std::vector<Vector> &f_e, SparseMatrix *&A) mutable
1198 {
1199 auto wrapped_fields_e = wrap_fields(f_e, shmem_info.field_sizes,
1200 num_entities);
1201 auto wrapped_direction_e = Reshape(direction_e.ReadWrite(),
1202 shmem_info.direction_size,
1203 num_entities);
1204
1205 auto qpdc = Reshape(qpdc_mem.Read(), test_vdim, test_op_dim,
1206 trial_vdim, total_trial_op_dim, num_qp, num_entities);
1207
1208 auto itod = Reshape(inputs_trial_op_dim.Read(), num_inputs);
1209
1210 auto Ae = Reshape(Ae_mem.ReadWrite(), num_test_dof, test_vdim, num_trial_dof,
1211 trial_vdim, num_entities);
1212
1213 const auto d_elem_attr = elem_attributes->Read();
1214 const bool has_attr = attributes.Size() > 0;
1215 const auto d_domain_attr = attributes.Read();
1216
1217 forall([=] MFEM_HOST_DEVICE (int e, real_t *shmem)
1218 {
1219 if (has_attr && !d_domain_attr[d_elem_attr[e] - 1]) { return; }
1220
1221 auto [input_dtq_shmem, output_dtq_shmem, fields_shmem,
1222 direction_shmem, input_shmem,
1223 shadow_shmem_, residual_shmem,
1224 scratch_shmem] =
1225 unpack_shmem(shmem, shmem_info, input_dtq_maps, output_dtq_maps,
1226 wrapped_fields_e, wrapped_direction_e, num_qp, e);
1227
1228 auto fhat = Reshape(&residual_shmem(0, 0), test_vdim, test_op_dim, num_qp);
1229 auto Aee = Reshape(&Ae(0, 0, 0, 0, e), num_test_dof, test_vdim, num_trial_dof,
1230 trial_vdim);
1231 auto qpdce = Reshape(&qpdc(0, 0, 0, 0, 0, e), test_vdim, test_op_dim,
1232 trial_vdim, total_trial_op_dim, num_qp);
1233 assemble_element_mat_naive(Aee, fhat, qpdce, itod, inputs, output_fop,
1234 input_dtq_shmem, output_dtq_shmem[0], scratch_shmem, dimension, q1d,
1235 num_trial_dof_1d, use_sum_factorization);
1236 }, num_entities, thread_blocks, shmem_info.total_size,
1237 shmem_cache.ReadWrite());
1238
1239 FieldDescriptor *trial_field = nullptr;
1240 for (size_t s = 0; s < num_inputs; s++)
1241 {
1242 if (input_is_dependent[s])
1243 {
1244 trial_field = &fields[input_to_field[s]];
1245 }
1246 }
1247
1248 auto trial_fes = *std::get_if<const ParFiniteElementSpace *>
1249 (&trial_field->data);
1250 auto test_fes = *std::get_if<const ParFiniteElementSpace *>
1251 (&fields[output_to_field[0]].data);
1252
1253 A = new SparseMatrix(test_fes->GetVSize(), trial_fes->GetVSize());
1254
1255 auto tmp = Reshape(Ae_mem.HostReadWrite(), num_test_dof * test_vdim,
1256 num_trial_dof * trial_vdim, num_entities);
1257 for (int e = 0; e < num_entities; e++)
1258 {
1259 DenseMatrix Aee(&tmp(0, 0, e), num_test_dof * test_vdim,
1260 num_trial_dof * trial_vdim);
1261
1262 Array<int> test_vdofs, trial_vdofs;
1263 test_fes->GetElementVDofs(e, test_vdofs);
1264 trial_fes->GetElementVDofs(e, trial_vdofs);
1265
1266 if (use_sum_factorization)
1267 {
1268 Array<int> test_vdofs_mapped(test_vdofs.Size());
1269
1270 const Array<int> &test_dofmap =
1271 dynamic_cast<const TensorBasisElement&>(*test_fes->GetFE(0)).GetDofMap();
1272
1273 if (test_dofmap.Size() == 0)
1274 {
1275 test_vdofs_mapped = test_vdofs;
1276 }
1277 else
1278 {
1279 MFEM_ASSERT(test_dofmap.Size() == num_test_dof,
1280 "internal error: dof map of the test space does not "
1281 "match previously determined number of test space dofs");
1282
1283 for (int vd = 0; vd < test_vdim; vd++)
1284 {
1285 for (int i = 0; i < num_test_dof; i++)
1286 {
1287 test_vdofs_mapped[i + vd * num_test_dof] =
1288 test_vdofs[test_dofmap[i] + vd * num_test_dof];
1289 }
1290 }
1291 }
1292
1293 Array<int> trial_vdofs_mapped(trial_vdofs.Size());
1294 const Array<int> &trial_dofmap =
1295 dynamic_cast<const TensorBasisElement&>(*trial_fes->GetFE(0)).GetDofMap();
1296
1297 if (trial_dofmap.Size() == 0)
1298 {
1299 trial_vdofs_mapped = trial_vdofs;
1300 }
1301 else
1302 {
1303 MFEM_ASSERT(trial_dofmap.Size() == num_trial_dof,
1304 "internal error: dof map of the test space does not "
1305 "match previously determined number of test space dofs");
1306
1307 for (int vd = 0; vd < trial_vdim; vd++)
1308 {
1309 for (int i = 0; i < num_trial_dof; i++)
1310 {
1311 trial_vdofs_mapped[i + vd * num_trial_dof] =
1312 trial_vdofs[trial_dofmap[i] + vd * num_trial_dof];
1313 }
1314 }
1315 }
1316
1317 A->AddSubMatrix(test_vdofs_mapped, trial_vdofs_mapped, Aee, 1);
1318 }
1319 else
1320 {
1321 A->AddSubMatrix(test_vdofs, trial_vdofs, Aee, 1);
1322 }
1323 }
1324 A->Finalize();
1325 });
1326
1327 // Create local references for MSVC lambda capture compatibility
1328 auto& assemble_derivative_sparsematrix_callbacks_ref =
1329 this->assemble_derivative_sparsematrix_callbacks[derivative_id];
1330
1331 assemble_derivative_hypreparmatrix_callbacks[derivative_id].push_back(
1332 [
1333 input_is_dependent,
1334 input_to_field,
1335 output_to_field,
1336 &spmatcb = assemble_derivative_sparsematrix_callbacks_ref,
1337 &fields = fields_ref
1338 ](std::vector<Vector> &f_e, HypreParMatrix *&A) mutable
1339 {
1340 SparseMatrix *spmat = nullptr;
1341 for (const auto &f : spmatcb)
1342 {
1343 f(f_e, spmat);
1344 }
1345
1346 if (spmat == nullptr)
1347 {
1348 MFEM_ABORT("internal error");
1349 }
1350
1351 bool same_test_and_trial = false;
1352 for (size_t s = 0; s < num_inputs; s++)
1353 {
1354 if (input_is_dependent[s])
1355 {
1356 if (output_to_field[0] == input_to_field[s])
1357 {
1358 same_test_and_trial = true;
1359 break;
1360 }
1361 }
1362 }
1363
1364 FieldDescriptor *trial_field = nullptr;
1365 for (size_t s = 0; s < num_inputs; s++)
1366 {
1367 if (input_is_dependent[s])
1368 {
1369 trial_field = &fields[input_to_field[s]];
1370 }
1371 }
1372
1373 auto trial_fes = *std::get_if<const ParFiniteElementSpace *>
1374 (&trial_field->data);
1375 auto test_fes = *std::get_if<const ParFiniteElementSpace *>
1376 (&fields[output_to_field[0]].data);
1377
1378 if (same_test_and_trial)
1379 {
1380 HypreParMatrix tmp(test_fes->GetComm(),
1381 test_fes->GlobalVSize(),
1382 test_fes->GetDofOffsets(),
1383 spmat);
1384 A = RAP(&tmp, test_fes->Dof_TrueDof_Matrix());
1385 }
1386 else
1387 {
1388 HypreParMatrix tmp(test_fes->GetComm(),
1389 test_fes->GlobalVSize(),
1390 trial_fes->GlobalVSize(),
1391 test_fes->GetDofOffsets(),
1392 trial_fes->GetDofOffsets(),
1393 spmat);
1394 A = RAP(test_fes->Dof_TrueDof_Matrix(), &tmp,
1395 trial_fes->Dof_TrueDof_Matrix());
1396 }
1397 delete spmat;
1398 });
1399 }, derivative_ids);
1400 }
1401}
1402
1403} // namespace mfem::future
1404#endif
int Size() const
Return the logical size of the array.
Definition array.hpp:166
const T * Read(bool on_dev=true) const
Shortcut for mfem::Read(a.GetMemory(), a.Size(), on_dev).
Definition array.hpp:381
Data type dense matrix using column-major storage.
Definition densemat.hpp:24
Mode
Type of data stored in the arrays B, Bt, G, and Gt.
Definition fe_base.hpp:154
@ FULL
Full multidimensional representation which does not use tensor product structure. The ordering of the...
Definition fe_base.hpp:158
@ TENSOR
Tensor product representation using 1D matrices/tensors with dimensions using 1D number of quadrature...
Definition fe_base.hpp:165
Type
Constants for the classes derived from Element.
Definition element.hpp:41
static Type TypeFromGeometry(const Geometry::Type geom)
Return the Element::Type associated with the given Geometry::Type.
Definition element.cpp:17
Wrapper for hypre's ParCSR matrix class.
Definition hypre.hpp:419
Class for an integration rule - an Array of IntegrationPoint.
Definition intrules.hpp:100
int GetNPoints() const
Returns the number of the points in the integration rule.
Definition intrules.hpp:256
const Array< real_t > & GetWeights() const
Return the quadrature weights in a contiguous array.
Definition intrules.cpp:86
Geometry::Type GetTypicalElementGeometry() const
If the local mesh is not empty, return GetElementGeometry(0); otherwise, return a typical Geometry pr...
Definition mesh.cpp:1628
const Array< int > & GetElementAttributes() const
Returns the attributes for all elements in this mesh. The i'th entry of the array is the attribute of...
Definition mesh.cpp:965
int Dimension() const
Dimension of the reference space used within the elements.
Definition mesh.hpp:1306
const Array< int > & GetBdrFaceAttributes() const
Returns the attributes for all boundary elements in this mesh.
Definition mesh.cpp:925
int GetNBE() const
Returns number of boundary elements.
Definition mesh.hpp:1380
Geometry::Type GetTypicalFaceGeometry() const
If the local mesh is not empty, return GetFaceGeometry(0); otherwise return a typical face geometry p...
Definition mesh.cpp:1596
Abstract operator.
Definition operator.hpp:25
int width
Dimension of the input / number of columns in the matrix.
Definition operator.hpp:28
int Height() const
Get the height (size of output) of the Operator. Synonym with NumRows().
Definition operator.hpp:66
int height
Dimension of the output / number of rows in the matrix.
Definition operator.hpp:27
Class for parallel meshes.
Definition pmesh.hpp:34
MPI_Comm GetComm() const
Definition pmesh.hpp:405
int GetNFbyType(FaceType type) const override
Returns the number of local faces according to the requested type, does not count master non-conformi...
Definition pmesh.cpp:3205
Data type sparse matrix.
Definition sparsemat.hpp:51
void AddSubMatrix(const Array< int > &rows, const Array< int > &cols, const DenseMatrix &subm, int skip_zeros=1)
void Finalize(int skip_zeros=1) override
Finalize the matrix initialization, switching the storage format from LIL to CSR.
Vector data type.
Definition vector.hpp:82
virtual const real_t * Read(bool on_dev=true) const
Shortcut for mfem::Read(vec.GetMemory(), vec.Size(), on_dev).
Definition vector.hpp:520
virtual real_t * ReadWrite(bool on_dev=true)
Shortcut for mfem::ReadWrite(vec.GetMemory(), vec.Size(), on_dev).
Definition vector.hpp:536
int Size() const
Returns the size of the vector.
Definition vector.hpp:234
void SetSize(int s)
Resize the vector to size s.
Definition vector.hpp:584
virtual real_t * HostReadWrite()
Shortcut for mfem::ReadWrite(vec.GetMemory(), vec.Size(), false).
Definition vector.hpp:540
void MultTranspose(const Vector &direction_t, Vector &result_t) const override
Compute the transpose of the derivative operator on a given vector.
void Assemble(SparseMatrix *&A)
Assemble the derivative operator into a SparseMatrix.
DerivativeOperator(const int &height, const int &width, const std::vector< derivative_action_t > &derivative_actions, const FieldDescriptor &direction, const int &daction_l_size, const std::vector< derivative_action_t > &derivative_actions_transpose, const FieldDescriptor &transpose_direction, const int &daction_transpose_l_size, const std::vector< Vector * > &solutions_l, const std::vector< Vector * > &parameters_l, const restriction_callback_t &restriction_callback, const std::function< void(Vector &, Vector &)> &prolongation_transpose, const std::vector< assemble_derivative_sparsematrix_callback_t > &assemble_derivative_sparsematrix_callbacks, const std::vector< assemble_derivative_hypreparmatrix_callback_t > &assemble_derivative_hypreparmatrix_callbacks)
Definition doperator.hpp:81
void Assemble(HypreParMatrix *&A)
Assemble the derivative operator into a HypreParMatrix.
void Mult(const Vector &direction_t, Vector &result_t) const override
Compute the action of the derivative operator on a given vector.
std::shared_ptr< DerivativeOperator > GetDerivative(size_t derivative_id, std::vector< Vector * > sol_l, std::vector< Vector * > par_l)
Get the derivative operator for a given derivative ID.
void AddIntegrator(qfunc_t &qfunc, input_t inputs, output_t outputs, const IntegrationRule &integration_rule, const Array< int > &attributes, derivative_ids_t derivative_ids)
Add an integrator to the operator. Called only from AddDomainIntegrator() and AddBoundaryIntegrator()...
DifferentiableOperator(const std::vector< FieldDescriptor > &solutions, const std::vector< FieldDescriptor > &parameters, const ParMesh &mesh)
Definition doperator.cpp:30
void SetMultLevel(MultLevel level)
Set the MultLevel mode for the DifferentiableOperator. The default is TVECTOR, which means that the O...
void SetParameters(std::vector< Vector * > p) const
Set the parameters for the operator.
Definition doperator.cpp:19
void Mult(const Vector &solutions_in, Vector &result_in) const override
Compute the action of the operator on a given vector.
void DisableTensorProductStructure(bool disable=true)
Disable the use of tensor product structure.
void AddDomainIntegrator(qfunc_t &qfunc, input_t inputs, output_t outputs, const IntegrationRule &integration_rule, const Array< int > &domain_attributes, derivative_ids_t derivative_ids=std::make_index_sequence< 0 > {})
Add a domain integrator to the operator.
void AddBoundaryIntegrator(qfunc_t &qfunc, input_t inputs, output_t outputs, const IntegrationRule &integration_rule, const Array< int > &boundary_attributes, derivative_ids_t derivative_ids=std::make_index_sequence< 0 > {})
Add a boundary integrator to the operator.
constexpr int dimension
This example only works in 3D. Kernels for 2D are not implemented.
Definition hooke.cpp:45
string direction
constexpr bool always_false
Definition util.hpp:575
constexpr auto filter_fields(const std::tuple< Ts... > &t)
Filter fields from a tuple based on their field IDs.
Definition util.hpp:541
void prolongation(const FieldDescriptor field, const Vector &x, Vector &field_l)
Apply the prolongation operator to a field.
Definition util.hpp:1087
constexpr auto merge_mfem_tuples_as_empty_std_tuple(const mfem::future::tuple< T1s... > &, const mfem::future::tuple< T2s... > &)
Auxiliary template function that merges (concatenates) two mfem::future::tuple types into a single st...
Definition tuple.hpp:878
MFEM_HOST_DEVICE auto unpack_shmem(void *shmem, const shared_mem_info_t &shmem_info, const std::array< DofToQuadMap, num_inputs > &input_dtq_maps, const std::array< DofToQuadMap, num_outputs > &output_dtq_maps, const std::array< DeviceTensor< 2 >, num_fields > &wrapped_fields_e, const int &num_qp, const int &e)
Definition util.hpp:1953
void restriction(const FieldDescriptor u, const Vector &u_l, Vector &field_e, ElementDofOrdering ordering)
Apply the restriction operator to a field.
Definition util.hpp:1218
void get_lvectors(const std::vector< FieldDescriptor > fields, const Vector &x, std::vector< Vector > &fields_l)
Definition util.hpp:1151
std::function< void(std::vector< Vector > &, const Vector &)> derivative_setup_t
Type alias for a function that computes the cache for the action of a derivative.
Definition doperator.hpp:35
int GetNumEntities(const mfem::Mesh &mesh)
Get the number of entities of a given type.
Definition util.hpp:1282
MFEM_HOST_DEVICE void map_fields_to_quadrature_data(std::array< DeviceTensor< 2 >, num_inputs > &fields_qp, const std::array< DeviceTensor< 1 >, num_fields > &fields_e, const std::array< DofToQuadMap, num_inputs > &dtqmaps, const std::array< size_t, num_inputs > &input_to_field, const field_operator_ts &fops, const DeviceTensor< 1, const real_t > &integration_weights, const std::array< DeviceTensor< 1 >, 6 > &scratch_mem, const int &dimension, const bool &use_sum_factorization=false)
std::tuple< std::function< void(const Vector &, Vector &)>, int > get_restriction_transpose(const FieldDescriptor &f, const ElementDofOrdering &o, const fop_t &fop)
Get a transpose restriction callback for a field descriptor.
Definition util.hpp:1052
std::function< void(std::vector< Vector > &, const Vector &, Vector &)> derivative_action_t
Type alias for a function that computes the action of a derivative.
Definition doperator.hpp:39
std::array< DeviceTensor< 2 >, num_fields > wrap_fields(std::vector< Vector > &fields, std::array< int, num_fields > &field_sizes, const int &num_entities)
Wraps plain data in DeviceTensors for fields.
Definition util.hpp:2187
std::vector< int > get_input_size_on_qp(const input_t &inputs, std::index_sequence< i... >)
Get the size on quadrature point for a given set of inputs.
Definition util.hpp:1546
MFEM_HOST_DEVICE void apply_qpdc(DeviceTensor< 3 > &fhat, const std::array< DeviceTensor< 2 >, num_fields > &shadow_shmem, const DeviceTensor< 5, const real_t > &qpdc, const DeviceTensor< 1, const real_t > &itod, const int &q1d, const int &dimension, const bool &use_sum_factorization)
Apply the quadrature point data cache (qpdc) to a vector (usually a direction).
std::function< void(std::vector< Vector > &, SparseMatrix *&)> assemble_derivative_sparsematrix_callback_t
Type alias for a function that assembles the SparseMatrix of a derivative operator.
Definition doperator.hpp:44
SharedMemoryInfo< num_fields, num_inputs, num_outputs > get_shmem_info(const std::array< DofToQuadMap, num_inputs > &input_dtq_maps, const std::array< DofToQuadMap, num_outputs > &output_dtq_maps, const std::vector< FieldDescriptor > &fields, const int &num_entities, const input_t &inputs, const int &num_qp, const std::vector< int > &input_size_on_qp, const int &residual_size_on_qp, const ElementDofOrdering &dof_ordering, const int &derivative_action_field_idx=-1)
Definition util.hpp:1585
const Operator * get_restriction(const FieldDescriptor &f, const ElementDofOrdering &o)
Get the restriction operator for a field descriptor.
Definition util.hpp:1027
std::array< size_t, tuple_size< field_operator_ts >::value > create_descriptors_to_fields_map(const std::vector< FieldDescriptor > &fields, field_operator_ts &fops)
Create a map from field operator types to FieldDescriptor indices.
Definition util.hpp:1437
constexpr std::size_t count_unique_field_ids(const std::tuple< Ts... > &t)
Function to count unique field IDs in a tuple.
Definition util.hpp:495
MFEM_HOST_DEVICE constexpr auto type(const tuple< T... > &values)
a function intended to be used for extracting the ith type from a tuple.
Definition tuple.hpp:347
std::array< DofToQuadMap, num_fields > create_dtq_maps(field_operator_ts &fops, std::vector< const DofToQuad * > &dtqmaps, const std::array< size_t, num_fields > &to_field_map)
Create DofToQuad maps for a given set of field operators.
Definition util.hpp:2334
const DofToQuad * GetDofToQuad(const FieldDescriptor &f, const IntegrationRule &ir, DofToQuad::Mode mode)
Get the GetDofToQuad object for a given entity type.
Definition util.hpp:1310
MFEM_HOST_DEVICE void call_qfunction(qfunc_t &qfunc, const std::array< DeviceTensor< 2 >, num_fields > &input_shmem, DeviceTensor< 2 > &residual_shmem, const int &rs_qp, const int &num_qp, const int &q1d, const int &dimension, const bool &use_sum_factorization)
Call a qfunction with the given parameters.
MFEM_HOST_DEVICE void map_direction_to_quadrature_data_conditional(std::array< DeviceTensor< 2 >, num_inputs > &directions_qp, const DeviceTensor< 1 > &direction_e, const std::array< DofToQuadMap, num_inputs > &dtqmaps, field_operator_ts fops, const DeviceTensor< 1, const real_t > &integration_weights, const std::array< DeviceTensor< 1 >, 6 > &scratch_mem, const std::array< bool, num_inputs > &conditions, const int &dimension, const bool &use_sum_factorization)
MFEM_HOST_DEVICE void assemble_element_mat_naive(const DeviceTensor< 4, real_t > &A, const DeviceTensor< 3, real_t > &fhat, const DeviceTensor< 5, const real_t > &qpdc, const DeviceTensor< 1, const real_t > &itod, const input_fop_ts &inputs, const output_fop_t &output, const std::array< DofToQuadMap, num_inputs > &input_dtqmaps, const DofToQuadMap &output_dtqmap, std::array< DeviceTensor< 1 >, 6 > &scratch_shmem, const int &dimension, const int &q1d, const int &td1d, const bool &use_sum_factorization)
Assemble element matrix for two or three dimensional data.
Definition assemble.hpp:366
constexpr void for_constexpr(lambda &&f, std::integer_sequence< std::size_t, i ... >)
Definition util.hpp:71
MFEM_HOST_DEVICE void call_qfunction_derivative(qfunc_t &qfunc, const std::array< DeviceTensor< 2 >, num_fields > &input_shmem, const std::array< DeviceTensor< 2 >, num_fields > &shadow_shmem, DeviceTensor< 2 > &residual_shmem, DeviceTensor< 5 > &qpdc, const DeviceTensor< 1, const real_t > &itod, const int &das_qp, const int &q1d, const int &dimension, const bool &use_sum_factorization)
Call a qfunction with the given parameters and compute it's derivative represented by the Jacobian on...
std::function< void(std::vector< Vector > &, const std::vector< Vector > &, std::vector< Vector > &)> restriction_callback_t
Type alias for a function that applies the appropriate restriction to the solution and parameters.
Definition doperator.hpp:54
MFEM_HOST_DEVICE void map_quadrature_data_to_fields(DeviceTensor< 2, real_t > &y, const DeviceTensor< 3, real_t > &f, const output_t &output, const DofToQuadMap &dtq, std::array< DeviceTensor< 1 >, 6 > &scratch_mem, const int &dimension, const bool &use_sum_factorization)
std::function< void(const Vector &, Vector &)> get_prolongation_transpose(const FieldDescriptor &f, const fop_t &fop, MPI_Comm mpi_comm)
Get a transpose prolongation callback for a field descriptor.
Definition util.hpp:1179
std::size_t FindIdx(const std::size_t &id, const std::vector< FieldDescriptor > &fields)
Find the index of a field descriptor in a vector of field descriptors.
Definition util.hpp:742
void forall(func_t f, const int &N, const ThreadBlocks &blocks, int num_shmem=0, real_t *shmem=nullptr)
Definition util.hpp:614
int GetVDim(const FieldDescriptor &f)
Get the vdim of a field descriptor.
Definition util.hpp:864
int GetSizeOnQP(const field_operator_t &, const FieldDescriptor &f)
Get the size on quadrature point for a field operator type and FieldDescriptor combination.
Definition util.hpp:1402
auto make_dependency_map(tuple< input_ts... > inputs)
Definition util.hpp:147
int GetVSize(const FieldDescriptor &f)
Get the vdof size of a field descriptor.
Definition util.hpp:760
std::function< void(std::vector< Vector > &, HypreParMatrix *&)> assemble_derivative_hypreparmatrix_callback_t
Type alias for a function that assembles the HypreParMatrix of a derivative operator.
Definition doperator.hpp:49
MFEM_HOST_DEVICE void set_zero(std::array< DeviceTensor< 2 >, N > &v)
Definition util.hpp:2111
int GetTrueVSize(const FieldDescriptor &f)
Get the true dof size of a field descriptor.
Definition util.hpp:829
MFEM_HOST_DEVICE zero & get(zero &x)
let zero be accessed like a tuple
Definition tensor.hpp:281
std::function< void(std::vector< Vector > &, const std::vector< Vector > &, Vector &)> action_t
Type alias for a function that computes the action of an operator.
Definition doperator.hpp:31
MFEM_HOST_DEVICE DeviceTensor< sizeof...(Dims), T > Reshape(T *ptr, Dims... dims)
Wrap a pointer as a DeviceTensor with automatically deduced template parameters.
Definition dtensor.hpp:138
void RAP(const DenseMatrix &A, const DenseMatrix &P, DenseMatrix &RAP)
float real_t
Definition config.hpp:46
ElementDofOrdering
Constants describing the possible orderings of the DOFs in one element.
Definition fespace.hpp:47
@ NATIVE
Native ordering as defined by the FiniteElement.
std::function< real_t(const Vector &)> f(real_t mass_coeff)
Definition lor_mms.hpp:30
real_t p(const Vector &x, real_t t)
FieldDescriptor struct.
Definition util.hpp:551
data_variant_t data
Field variant.
Definition util.hpp:561
ThreadBlocks struct.
Definition util.hpp:594