MFEM v4.9.0
Finite element discretization library
Loading...
Searching...
No Matches
util.hpp
Go to the documentation of this file.
1// Copyright (c) 2010-2025, Lawrence Livermore National Security, LLC. Produced
2// at the Lawrence Livermore National Laboratory. All Rights reserved. See files
3// LICENSE and NOTICE for details. LLNL-CODE-806117.
4//
5// This file is part of the MFEM library. For more information and source code
6// availability visit https://mfem.org.
7//
8// MFEM is free software; you can redistribute it and/or modify it under the
9// terms of the BSD-3 license. We welcome feedback and contributions, see file
10// CONTRIBUTING.md for details.
11#pragma once
12
13#include <algorithm>
14#include <array>
15#include <cstdlib>
16#include <iostream>
17#include <unordered_map>
18#include <utility>
19#include <variant>
20#include <vector>
21#include <type_traits>
22#include <numeric>
23#include <iomanip>
24
27#ifdef MFEM_USE_MPI
28#include "../fe/fe_base.hpp"
29#include "../fespace.hpp"
30#include "../pfespace.hpp"
31#include "../../mesh/mesh.hpp"
33
34#include "fieldoperator.hpp"
35#include "parameterspace.hpp"
36#include "tuple.hpp"
37
38namespace mfem::future
39{
40
41template<typename... Ts>
42constexpr auto to_array(const std::tuple<Ts...>& tuple)
43{
44 constexpr auto get_array = [](const Ts&... x) { return std::array<typename std::common_type<Ts...>::type, sizeof...(Ts)> { x... }; };
45 return std::apply(get_array, tuple);
46}
47
48namespace detail
49{
50
51template <typename lambda, std::size_t... i>
52constexpr void for_constexpr(lambda&& f,
53 std::integral_constant<std::size_t, i>... Is)
54{
55 f(Is...);
56}
57
58
59template <std::size_t... n, typename lambda, typename... arg_types>
60constexpr void for_constexpr(lambda&& f,
61 std::integer_sequence<std::size_t, n...>,
62 arg_types... args)
63{
64 (detail::for_constexpr(f, args..., std::integral_constant<std::size_t,n> {}),
65 ...);
66}
67
68} // namespace detail
69
70template <typename lambda, std::size_t... i>
71constexpr void for_constexpr(lambda&& f,
72 std::integer_sequence<std::size_t, i ... >)
73{
74 (f(std::integral_constant<std::size_t, i> {}), ...);
75}
76
77template <typename lambda>
78constexpr void for_constexpr(lambda&& f, std::integer_sequence<std::size_t>) {}
79
80template <int... n, typename lambda>
81constexpr void for_constexpr(lambda&& f)
82{
83 detail::for_constexpr(f, std::make_integer_sequence<std::size_t, n> {}...);
84}
85
86template <typename lambda, typename arg_t>
87constexpr void for_constexpr_with_arg(lambda&& f, arg_t&& arg,
88 std::integer_sequence<std::size_t>)
89{
90 // Base case - do nothing for empty sequence
91}
92
93template <typename lambda, typename arg_t, std::size_t i, std::size_t... Is>
94constexpr void for_constexpr_with_arg(lambda&& f, arg_t&& arg,
95 std::integer_sequence<std::size_t, i, Is...>)
96{
97 f(std::integral_constant<std::size_t, i> {}, get<i>(arg));
98 for_constexpr_with_arg(f, std::forward<arg_t>(arg),
99 std::integer_sequence<std::size_t, Is...> {});
100}
101
102template <typename lambda, typename arg_t>
103constexpr void for_constexpr_with_arg(lambda&& f, arg_t&& arg)
104{
105 using indices =
106 std::make_index_sequence<tuple_size<std::remove_reference_t<arg_t>>::value>;
107 for_constexpr_with_arg(std::forward<lambda>(f), std::forward<arg_t>(arg),
108 indices{});
109}
110
111template <std::size_t I, typename Tuple, std::size_t... Is>
112std::array<bool, sizeof...(Is)>
113make_dependency_array(const Tuple& inputs, std::index_sequence<Is...>)
114{
115 return { (get<I>(inputs).GetFieldId() == get<Is>(inputs).GetFieldId())... };
116}
117
118template <typename... input_ts, std::size_t... Is>
120 std::index_sequence<Is...>)
121{
122 constexpr std::size_t N = sizeof...(input_ts);
123
124 if constexpr (N == 0)
125 return std::unordered_map<int, std::array<bool, 0>> {};
126
127 std::unordered_map<int, std::array<bool, N>> map;
128
129 (void)std::initializer_list<int>
130 {
131 (
132 map[get<Is>(inputs).GetFieldId()] =
133 make_dependency_array<Is>(inputs, std::make_index_sequence<N>{}),
134 0
135 )...
136 };
137
138 return map;
139}
140
141// @brief Create a dependency map from a tuple of inputs.
142//
143// @param inputs a tuple of objects derived from FieldOperator.
144// @returns an unordered_map where the keys are the field IDs and the values
145// are arrays of booleans indicating which inputs depend on each field ID.
146template <typename... input_ts>
148{
149 return make_dependency_map_impl(inputs, std::index_sequence_for<input_ts...> {});
150}
151
152// @brief Get the type name of a template parameter T.
153//
154// Convenient helper function for debugging.
155// Usage example
156// ```c++
157// mfem::out << get_type_name<int>() << std::endl;
158// ```
159// prints "int".
160template <typename T>
161constexpr auto get_type_name() -> std::string_view
162{
163#if defined(__clang__)
164 constexpr auto prefix = std::string_view {"[T = "};
165 constexpr auto suffix = "]";
166 constexpr auto function = std::string_view{__PRETTY_FUNCTION__};
167#elif defined(__GNUC__)
168 constexpr auto prefix = std::string_view {"with T = "};
169 constexpr auto suffix = "; ";
170 constexpr auto function = std::string_view{__PRETTY_FUNCTION__};
171#elif defined(_MSC_VER)
172 constexpr auto prefix = std::string_view {"get_type_name<"};
173 constexpr auto suffix = ">(void)";
174 constexpr auto function = std::string_view{__FUNCSIG__};
175#else
176#error Unsupported compiler
177#endif
178
179 const auto start = function.find(prefix) + prefix.size();
180 const auto end = function.find(suffix);
181 const auto size = end - start;
182
183 return function.substr(start, size);
184}
185
186template <typename Tuple, std::size_t... Is>
187void print_tuple_impl(const Tuple& t, std::index_sequence<Is...>)
188{
189 ((out << (Is == 0 ? "" : ", ") << std::get<Is>(t)), ...);
190}
191
192// @brief Helper function to print a single tuple.
193//
194// @param t The tuple to print.
195template <typename... Args>
196void print_tuple(const std::tuple<Args...>& t)
197{
198 out << "(";
199 print_tuple_impl(t, std::index_sequence_for<Args...> {});
200 out << ")";
201}
202
203/// @brief Pretty print an mfem::DenseMatrix to out
204///
205/// Formatted s.t. the output is
206/// [[v00, v01, ..., v0n],
207/// [v10, v11, ..., v1n],
208/// ..., vmn]]
209/// which is compatible with numpy syntax.
210///
211/// @param out ostream to print to
212/// @param A mfem::DenseMatrix to print
213inline
214void pretty_print(std::ostream &out, const mfem::DenseMatrix &A)
215{
216 // Determine the max width of any entry in scientific notation
217 int max_width = 0;
218 for (int i = 0; i < A.NumRows(); ++i)
219 {
220 for (int j = 0; j < A.NumCols(); ++j)
221 {
222 std::ostringstream oss;
223 oss << std::scientific << std::setprecision(2) << A(i, j);
224 max_width = std::max(max_width, static_cast<int>(oss.str().length()));
225 }
226 }
227
228 out << "[\n";
229 for (int i = 0; i < A.NumRows(); ++i)
230 {
231 out << " [";
232 for (int j = 0; j < A.NumCols(); ++j)
233 {
234 out << std::setw(max_width) << std::scientific << std::setprecision(2) <<
235 A(i, j);
236
237 if (j < A.NumCols() - 1)
238 {
239 out << ", ";
240 }
241 }
242 out << "]";
243 if (i < A.NumRows() - 1)
244 {
245 out << ",\n";
246 }
247 else
248 {
249 out << "\n";
250 }
251 }
252 out << "]\n";
253}
254
255/// @brief Pretty print an mfem::Vector to out
256///
257/// Formatted s.t. the output is [v0, v1, ..., vn] which
258/// is compatible with numpy syntax.
259///
260/// @param v vector of vectors to print
261inline
263{
264 out << "[";
265 for (int i = 0; i < v.Size(); i++)
266 {
267 out << v(i);
268 if (i < v.Size() - 1)
269 {
270 out << ", ";
271 }
272 }
273 out << "]\n";
274}
275
276/// @brief Pretty print an mfem::Array to out
277///
278/// T has to have an overloaded operator<<
279///
280/// Formatted s.t. the output is [v0, v1, ..., vn] which
281/// is compatible with numpy syntax.
282///
283/// @param v vector of vectors to print
284template <typename T>
286{
287 out << "[";
288 for (int i = 0; i < v.Size(); i++)
289 {
290 out << v[i];
291 if (i < v.Size() - 1)
292 {
293 out << ", ";
294 }
295 }
296 out << "]\n";
297}
298
299/// @brief Pretty prints an unordered map of std::array to out
300///
301/// Useful for printing the output of make_dependency_map
302///
303/// @param map unordered map to print
304/// @tparam T type of array elements
305/// @tparam N size of array
306template<typename K, typename T, std::size_t N>
307void pretty_print(const std::unordered_map<K,std::array<T,N>>& map)
308{
309 out << "{";
310 std::size_t count = 0;
311 for (const auto& [key, value] : map)
312 {
313 out << key << ": [";
314 for (std::size_t i = 0; i < N; i++)
315 {
316 out << value[i];
317 if (i < N-1) { out << ", "; }
318 }
319 out << "]";
320 if (count < map.size() - 1)
321 {
322 out << ", ";
323 }
324 count++;
325 }
326 out << "}\n";
327}
328
329inline
330void print_mpi_root(const std::string& msg)
331{
332 auto myrank = Mpi::WorldRank();
333 if (myrank == 0)
334 {
335 out << msg << std::endl;
336 out.flush(); // Ensure output is flushed
337 }
338}
339
340/// @brief print with MPI rank synchronization
341///
342/// @param msg Message to print
343inline
344void print_mpi_sync(const std::string& msg)
345{
346 auto myrank = static_cast<size_t>(Mpi::WorldRank());
347 auto nranks = static_cast<size_t>(Mpi::WorldSize());
348
349 if (nranks == 1)
350 {
351 // Single process case - just print directly
352 out << msg << std::endl;
353 return;
354 }
355
356 // First gather string lengths
357 size_t msg_len = msg.length();
358 std::vector<size_t> lengths(nranks);
359 MPI_Gather(&msg_len, 1, MPITypeMap<size_t>::mpi_type,
360 lengths.data(), 1, MPITypeMap<size_t>::mpi_type,
361 0, MPI_COMM_WORLD);
362
363 if (myrank == 0)
364 {
365 // Rank 0: Allocate receive buffer based on gathered lengths
366 std::vector<std::string> messages(nranks);
367 messages[0] = msg; // Store rank 0's message
368
369 // Receive messages from other ranks
370 for (size_t r = 1; r < nranks; r++)
371 {
372 std::vector<char> buffer(lengths[r] + 1);
373 MPI_Recv(buffer.data(), static_cast<int>(lengths[r]), MPI_CHAR,
374 static_cast<int>(r), 0, MPI_COMM_WORLD, MPI_STATUS_IGNORE);
375 messages[r] = std::string(buffer.data(), static_cast<size_t>(lengths[r]));
376 }
377
378 // Print all messages in rank order
379 for (size_t r = 0; r < nranks; r++)
380 {
381 out << "[Rank " << r << "] " << messages[r] << std::endl;
382 }
383 out.flush();
384 }
385 else
386 {
387 // Other ranks: Send message to rank 0
388 MPI_Send(const_cast<char*>(msg.c_str()), static_cast<int>(msg_len), MPI_CHAR,
389 0, 0, MPI_COMM_WORLD);
390 }
391
392 // Final barrier to ensure completion
393 MPI_Barrier(MPI_COMM_WORLD);
394}
395
396/// @brief Pretty print an mfem::Vector with MPI rank
397///
398/// @param v vector to print
399inline
401{
402 std::stringstream ss;
403 ss << "[";
404 for (int i = 0; i < v.Size(); i++)
405 {
406 ss << v(i);
407 if (i < v.Size() - 1) { ss << ", "; }
408 }
409 ss << "]";
410
411 print_mpi_sync(ss.str());
412}
413
414
415template <typename ... Ts>
416constexpr auto decay_types(tuple<Ts...> const &)
418
419template <typename T>
420using decay_tuple = decltype(decay_types(std::declval<T>()));
421
422template <class F> struct FunctionSignature;
423
424template <typename output_t, typename... input_ts>
425struct FunctionSignature<output_t(input_ts...)>
426{
427 using return_t = output_t;
428 using parameter_ts = tuple<input_ts...>;
429};
430
431template <class T> struct create_function_signature;
432
433// Specialization for member functions (lambdas)
434template <typename output_t, typename T, typename... input_ts>
435struct create_function_signature<output_t (T::*)(input_ts...) const>
436{
437 using type = FunctionSignature<output_t(input_ts...)>;
438};
439
440// Specialization for function pointers
441template <typename output_t, typename... input_ts>
442struct create_function_signature<output_t (*)(input_ts...)>
443{
444 using type = FunctionSignature<output_t(input_ts...)>;
445};
446
447template <typename T>
448constexpr int GetFieldId()
449{
450 return T::GetFieldId();
451}
452
453template <typename Tuple, std::size_t... Is>
454constexpr auto extract_field_ids_impl(Tuple&& t, std::index_sequence<Is...>)
455{
456 return std::array<int, sizeof...(Is)>
457 {
458 std::decay_t<decltype(std::get<Is>(t))>{}.GetFieldId()...
459 };
460}
461
462/// @brief Extracts field IDs from a tuple of objects derived from FieldOperator.
463///
464/// @param t the tuple to extract field IDs from.
465/// @returns an array of field IDs.
466template <typename... Ts>
467constexpr auto extract_field_ids(const std::tuple<Ts...>& t)
468{
469 return extract_field_ids_impl(t, std::index_sequence_for<Ts...> {});
470}
471
472/// @brief Helper function to check if an element is in the array.
473///
474/// @param arr the array to search in.
475/// @param size the size of the array.
476/// @param value the value to search for.
477/// @returns true if the value is found, false otherwise.
478constexpr bool contains(const int* arr, std::size_t size, int value)
479{
480 for (std::size_t i = 0; i < size; ++i)
481 {
482 if (arr[i] == value)
483 {
484 return true;
485 }
486 }
487 return false;
488}
489
490/// @brief Function to count unique field IDs in a tuple.
491///
492/// @param t the tuple to count unique field IDs from.
493/// @returns the number of unique field IDs.
494template <typename... Ts>
495constexpr std::size_t count_unique_field_ids(const std::tuple<Ts...>& t)
496{
497 auto ids = extract_field_ids(t);
498 constexpr std::size_t size = sizeof...(Ts);
499
500 std::array<int, size> unique_ids = {};
501 std::size_t unique_count = 0;
502
503 for (std::size_t i = 0; i < size; ++i)
504 {
505 if (!contains(unique_ids.data(), unique_count, ids[i]))
506 {
507 unique_ids[unique_count] = ids[i];
508 ++unique_count;
509 }
510 }
511
512 return unique_count;
513}
514
515/// @brief Get marked entries from an std::array based on a marker array.
516///
517/// @param a the std::array to get entries from.
518/// @param marker the marker std::array indicating which entries to get.
519/// @returns a std::vector containing the marked entries.
520template <typename T, std::size_t N>
522 const std::array<T, N> &a,
523 const std::array<bool, N> &marker)
524{
525 std::vector<T> r;
526 for (int i = 0; i < N; i++)
527 {
528 if (marker[i])
529 {
530 r.push_back(a[i]);
531 }
532 }
533 return r;
534}
535
536/// @brief Filter fields from a tuple based on their field IDs.
537///
538/// @param t the tuple to filter fields from.
539/// @returns a tuple containing only the fields with field IDs not equal to -1.
540template <typename... Ts>
541constexpr auto filter_fields(const std::tuple<Ts...>& t)
542{
543 return std::tuple_cat(
544 std::conditional_t<Ts::GetFieldId() != -1, std::tuple<Ts>, std::tuple<>> {}...);
545}
546
547/// @brief FieldDescriptor struct
548///
549/// This struct is used to store information about a field.
551{
553 std::variant<const FiniteElementSpace *,
554 const ParFiniteElementSpace *,
555 const ParameterSpace *>;
556
557 /// Field ID
558 std::size_t id;
559
560 /// Field variant
562
563 /// Default constructor
565 id(SIZE_MAX), data(data_variant_t{}) {}
566
567 /// Constructor
568 template <typename T>
569 FieldDescriptor(std::size_t field_id, const T* v) :
570 id(field_id), data(v) {}
571};
572
573namespace dfem
574{
575template <class... T> constexpr bool always_false = false;
576}
577
578/// @brief Entity struct
579///
580/// This struct is used to store information about an entity type.
581namespace Entity
582{
583struct Element;
584struct BoundaryElement;
585struct Face;
586struct BoundaryFace;
587}
588
589/// @brief ThreadBlocks struct
590///
591/// This struct is used to store information about thread blocks
592/// for GPU dispatch.
594{
595 int x = 1;
596 int y = 1;
597 int z = 1;
598};
599
600#if defined(MFEM_USE_CUDA_OR_HIP)
601template <typename func_t>
602__global__ void forall_kernel_shmem(func_t f, int n)
603{
604 int i = blockIdx.x;
605 extern __shared__ real_t shmem[];
606 if (i < n)
607 {
608 f(i, shmem);
609 }
610}
611#endif
612
613template <typename func_t>
614void forall(func_t f,
615 const int &N,
616 const ThreadBlocks &blocks,
617 int num_shmem = 0,
618 real_t *shmem = nullptr)
619{
622 {
623#if defined(MFEM_USE_CUDA_OR_HIP)
624 // int gridsize = (N + Z - 1) / Z;
625 int num_bytes = num_shmem * sizeof(decltype(shmem));
626 dim3 block_size(blocks.x, blocks.y, blocks.z);
628#if defined(MFEM_USE_CUDA)
629 MFEM_GPU_CHECK(cudaGetLastError());
630#elif defined(MFEM_USE_HIP)
631 MFEM_GPU_CHECK(hipGetLastError());
632#endif
633 MFEM_DEVICE_SYNC;
634#endif
635 }
637 {
638 MFEM_ASSERT(!((bool)num_shmem != (bool)shmem),
639 "Backend::CPU needs a pre-allocated shared memory block");
640 for (int i = 0; i < N; i++)
641 {
642 f(i, shmem);
643 }
644 }
645 else
646 {
647 MFEM_ABORT("no compute backend available");
648 }
649}
650
651/// @todo To be removed.
652class FDJacobian : public Operator
653{
654public:
655 FDJacobian(const Operator &op, const Vector &x, real_t fixed_eps = 0.0) :
656 Operator(op.Height(), op.Width()),
657 op(op),
658 x(x),
659 fixed_eps(fixed_eps)
660 {
661 f.UseDevice(x.UseDevice());
662 f.SetSize(Height());
663
664 xpev.UseDevice(x.UseDevice());
665 xpev.SetSize(Width());
666
667 op.Mult(x, f);
668
669 const real_t xnorm_local = x.Norml2();
670 MPI_Allreduce(&xnorm_local, &xnorm, 1, MPITypeMap<real_t>::mpi_type, MPI_SUM,
671 MPI_COMM_WORLD);
672 }
673
674 void Mult(const Vector &v, Vector &y) const override
675 {
676 // See [1] for choice of eps.
677 //
678 // [1] Woodward, C.S., Gardner, D.J. and Evans, K.J., 2015. On the use of
679 // finite difference matrix-vector products in Newton-Krylov solvers for
680 // implicit climate dynamics with spectral elements. Procedia Computer
681 // Science, 51, pp.2036-2045.
682 real_t eps;
683 if (fixed_eps > 0.0)
684 {
685 eps = fixed_eps;
686 }
687 else
688 {
689 const real_t vnorm_local = v.Norml2();
690 real_t vnorm;
691 MPI_Allreduce(&vnorm_local, &vnorm, 1, MPITypeMap<real_t>::mpi_type, MPI_SUM,
692 MPI_COMM_WORLD);
693 eps = lambda * (lambda + xnorm / vnorm);
694 }
695
696 // x + eps * v
697 {
698 const auto d_v = v.Read();
699 const auto d_x = x.Read();
700 auto d_xpev = xpev.Write();
701 mfem::forall(x.Size(), [=] MFEM_HOST_DEVICE (int i)
702 {
703 d_xpev[i] = d_x[i] + eps * d_v[i];
704 });
705 }
706
707 // y = f(x + eps * v)
708 op.Mult(xpev, y);
709
710 // y = (f(x + eps * v) - f(x)) / eps
711 {
712 const auto d_f = f.Read();
713 auto d_y = y.ReadWrite();
714 mfem::forall(f.Size(), [=] MFEM_HOST_DEVICE (int i)
715 {
716 d_y[i] = (d_y[i] - d_f[i]) / eps;
717 });
718 }
719 }
720
721 virtual MemoryClass GetMemoryClass() const override
722 {
724 }
725
726private:
727 const Operator &op;
728 Vector x, f;
729 mutable Vector xpev;
730 real_t lambda = 1.0e-6;
731 real_t fixed_eps;
732 real_t xnorm;
733};
734
735/// @brief Find the index of a field descriptor in a vector of field descriptors.
736///
737/// @param id the field ID to search for.
738/// @param fields the vector of field descriptors.
739/// @returns the index of the field descriptor with the given ID,
740/// or SIZE_MAX if not found.
741inline
742std::size_t FindIdx(const std::size_t& id,
743 const std::vector<FieldDescriptor>& fields)
744{
745 for (std::size_t i = 0; i < fields.size(); i++)
746 {
747 if (fields[i].id == id)
748 {
749 return i;
750 }
751 }
752 return SIZE_MAX;
753}
754
755/// @brief Get the vdof size of a field descriptor.
756///
757/// @param f the field descriptor.
758/// @returns the vdof size of the field descriptor.
759inline
761{
762 return std::visit([](auto arg)
763 {
764 if (arg == nullptr)
765 {
766 MFEM_ABORT("FieldDescriptor data is nullptr");
767 }
768
769 using T = std::decay_t<decltype(arg)>;
770 if constexpr (std::is_same_v<T, const FiniteElementSpace *> ||
771 std::is_same_v<T, const ParFiniteElementSpace *>)
772 {
773 return arg->GetVSize();
774 }
775 else if constexpr (std::is_same_v<T, const ParameterSpace *>)
776 {
777 return arg->GetVSize();
778 }
779 else
780 {
781 static_assert(dfem::always_false<T>, "can't use GetVSize on type");
782 }
783 return 0; // Unreachable, but avoids compiler warning
784 }, f.data);
785}
786
787/// @brief Get the element vdofs of a field descriptor.
788///
789/// @note Can't be used with ParameterSpace.
790///
791/// @param f the field descriptor.
792/// @param el the element index.
793/// @param vdofs the array to store the element vdofs.
794inline
795void GetElementVDofs(const FieldDescriptor &f, int el, Array<int> &vdofs)
796{
797 return std::visit([&](auto arg)
798 {
799 if (arg == nullptr)
800 {
801 MFEM_ABORT("FieldDescriptor data is nullptr");
802 }
803
804 using T = std::decay_t<decltype(arg)>;
805 if constexpr (std::is_same_v<T, const FiniteElementSpace *>)
806 {
807 arg->GetElementVDofs(el, vdofs);
808 }
809 else if constexpr (std::is_same_v<T, const ParFiniteElementSpace *>)
810 {
811 arg->GetElementVDofs(el, vdofs);
812 }
813 else if constexpr (std::is_same_v<T, const ParameterSpace *>)
814 {
815 MFEM_ABORT("internal error");
816 }
817 else
818 {
819 static_assert(dfem::always_false<T>, "can't use GetElementVdofs on type");
820 }
821 }, f.data);
822}
823
824/// @brief Get the true dof size of a field descriptor.
825///
826/// @param f the field descriptor.
827/// @returns the true dof size of the field descriptor.
828inline
830{
831 return std::visit([](auto arg)
832 {
833 if (arg == nullptr)
834 {
835 MFEM_ABORT("FieldDescriptor data is nullptr");
836 }
837
838 using T = std::decay_t<decltype(arg)>;
839 if constexpr (std::is_same_v<T, const FiniteElementSpace *>)
840 {
841 return arg->GetTrueVSize();
842 }
843 else if constexpr (std::is_same_v<T, const ParFiniteElementSpace *>)
844 {
845 return arg->GetTrueVSize();
846 }
847 else if constexpr (std::is_same_v<T, const ParameterSpace *>)
848 {
849 return arg->GetTrueVSize();
850 }
851 else
852 {
853 static_assert(dfem::always_false<T>, "can't use GetTrueVSize on type");
854 }
855 return 0; // Unreachable, but avoids compiler warning
856 }, f.data);
857}
858
859/// @brief Get the vdim of a field descriptor.
860///
861/// @param f the field descriptor.
862/// @returns the vdim of the field descriptor.
863inline
865{
866 return std::visit([](auto && arg)
867 {
868 using T = std::decay_t<decltype(arg)>;
869 if constexpr (std::is_same_v<T, const FiniteElementSpace *>)
870 {
871 return arg->GetVDim();
872 }
873 else if constexpr (std::is_same_v<T, const ParFiniteElementSpace *>)
874 {
875 return arg->GetVDim();
876 }
877 else if constexpr (std::is_same_v<T, const ParameterSpace *>)
878 {
879 return arg->GetVDim();
880 }
881 else
882 {
883 static_assert(dfem::always_false<T>, "can't use GetVDim on type");
884 }
885 return 0; // Unreachable, but avoids compiler warning
886 }, f.data);
887}
888
889/// @brief Get the spatial dimension of a field descriptor.
890///
891/// @param f the field descriptor.
892/// @tparam entity_t the entity type (see Entity).
893/// @returns the spatial dimension of the field descriptor.
894template <typename entity_t>
896{
897 return std::visit([](auto && arg)
898 {
899 using T = std::decay_t<decltype(arg)>;
900 if constexpr (std::is_same_v<T, const FiniteElementSpace *> ||
901 std::is_same_v<T, const ParFiniteElementSpace *>)
902 {
903 if constexpr (std::is_same_v<entity_t, Entity::Element>)
904 {
905 return arg->GetMesh()->Dimension();
906 }
907 else if constexpr (std::is_same_v<entity_t, Entity::BoundaryElement>)
908 {
909 return arg->GetMesh()->Dimension() - 1;
910 }
911 }
912 else if constexpr (std::is_same_v<T, const ParameterSpace *>)
913 {
914 return arg->Dimension();
915 }
916 else
917 {
918 static_assert(dfem::always_false<T>, "can't use GetDimension on type");
919 }
920 return 0; // Unreachable, but avoids compiler warning
921 }, f.data);
922}
923
924
925/// @brief Get the prolongation operator for a field descriptor.
926///
927/// @param f the field descriptor.
928/// @returns the prolongation operator for the field descriptor.
929inline
931{
932 return std::visit([](auto&& arg) -> const Operator*
933 {
934 using T = std::decay_t<decltype(arg)>;
935 if constexpr (std::is_same_v<T, const FiniteElementSpace *> ||
936 std::is_same_v<T, const ParFiniteElementSpace *>)
937 {
938 return arg->GetProlongationMatrix();
939 }
940 else if constexpr (std::is_same_v<T, const ParameterSpace *>)
941 {
942 return arg->GetProlongationMatrix();
943 }
944 else
945 {
946 static_assert(dfem::always_false<T>, "can't use GetProlongation on type");
947 }
948 return nullptr; // Unreachable, but avoids compiler warning
949 }, f.data);
950}
951
952/// @brief Get the element restriction operator for a field descriptor.
953///
954/// @param f the field descriptor.
955/// @param o the element dof ordering.
956/// @returns the element restriction operator for the field descriptor in
957/// specified ordering.
958inline
961{
962 return std::visit([&o](auto&& arg) -> const Operator*
963 {
964 using T = std::decay_t<decltype(arg)>;
965 if constexpr (std::is_same_v<T, const FiniteElementSpace *>
966 || std::is_same_v<T, const ParFiniteElementSpace *>)
967 {
968 return arg->GetElementRestriction(o);
969 }
970 else if constexpr (std::is_same_v<T, const ParameterSpace *>)
971 {
972 return arg->GetElementRestriction(o);
973 }
974 else
975 {
976 static_assert(dfem::always_false<T>,
977 "can't use get_element_restriction on type");
978 }
979 return nullptr; // Unreachable, but avoids compiler warning
980 }, f.data);
981}
982
983/// @brief Get the face restriction operator for a field descriptor.
984///
985/// @param f the field descriptor.
986/// @param o the face dof ordering.
987/// @param ft the face type
988/// @param m indicator if single or double valued
989/// @returns the face restriction operator for the field descriptor in
990/// specified ordering.
991inline
994 FaceType ft,
995 L2FaceValues m)
996{
997 return std::visit([&o, &ft, &m](auto&& arg) -> const Operator*
998 {
999 using T = std::decay_t<decltype(arg)>;
1000 if constexpr (std::is_same_v<T, const FiniteElementSpace *> ||
1001 std::is_same_v<T, const ParFiniteElementSpace *>)
1002 {
1003 return arg->GetFaceRestriction(o, ft, m);
1004 }
1005 else if constexpr (std::is_same_v<T, const ParameterSpace *>)
1006 {
1007 // ParameterSpace does not support face restrictions
1008 MFEM_ABORT("internal error");
1009 }
1010 else
1011 {
1012 static_assert(dfem::always_false<T>,
1013 "can't use get_face_restriction on type");
1014 }
1015 return nullptr; // Unreachable, but avoids compiler warning
1016 }, f.data);
1017}
1018
1019/// @brief Get the restriction operator for a field descriptor.
1020///
1021/// @param f the field descriptor.
1022/// @param o the element dof ordering.
1023/// @returns the restriction operator for the field descriptor in
1024/// specified ordering.
1025template <typename entity_t>
1026inline
1028 const ElementDofOrdering &o)
1029{
1030 if constexpr (std::is_same_v<entity_t, Entity::Element>)
1031 {
1032 return get_element_restriction(f, o);
1033 }
1034 else if constexpr (std::is_same_v<entity_t, Entity::BoundaryElement>)
1035 {
1038 }
1039 MFEM_ABORT("restriction not implemented for Entity");
1040 return nullptr;
1041}
1042
1043/// @brief Get a transpose restriction callback for a field descriptor.
1044///
1045/// @param f the field descriptor.
1046/// @param o the element dof ordering.
1047/// @param fop the field operator.
1048/// @returns a tuple containing a std::function with the transpose
1049/// restriction callback and it's height.
1050template <typename entity_t, typename fop_t>
1051inline std::tuple<std::function<void(const Vector&, Vector&)>, int>
1053 const FieldDescriptor &f,
1054 const ElementDofOrdering &o,
1055 const fop_t &fop)
1056{
1057 if constexpr (is_sum_fop<fop_t>::value)
1058 {
1059 auto RT = [=](const Vector &v_e, Vector &v_l)
1060 {
1061 v_l += v_e;
1062 };
1063 return std::make_tuple(RT, 1);
1064 }
1065 else
1066 {
1067 const Operator *R = get_restriction<entity_t>(f, o);
1068 std::function<void(const Vector&, Vector&)> RT = [=](const Vector &x, Vector &y)
1069 {
1070 R->AddMultTranspose(x, y);
1071 };
1072 return std::make_tuple(RT, R->Height());
1073 }
1074 return std::make_tuple(
1075 std::function<void(const Vector&, Vector&)>([](const Vector&, Vector&)
1076 {
1077 /* no-op */
1078 }), 0); // Never reached, but avoids compiler warning.
1079}
1080
1081/// @brief Apply the prolongation operator to a field.
1082///
1083/// @param field the field descriptor.
1084/// @param x the input vector in tdofs.
1085/// @param field_l the output vector in vdofs.
1086inline
1087void prolongation(const FieldDescriptor field, const Vector &x, Vector &field_l)
1088{
1089 const auto P = get_prolongation(field);
1090 field_l.SetSize(P->Height());
1091 P->Mult(x, field_l);
1092}
1093
1094/// @brief Apply the prolongation operator to a vector of fields.
1095///
1096/// x is a long vector containing the data for all fields on tdofs and
1097/// fields contains the information about each individual field to retrieve
1098/// it's corresponding prolongation.
1099///
1100/// @param fields the array of field descriptors.
1101/// @param x the input vector in tdofs.
1102/// @param fields_l the array of output vectors in vdofs.
1103/// @tparam N the number of fields.
1104/// @tparam M the number of output fields.
1105template <std::size_t N, std::size_t M>
1106void prolongation(const std::array<FieldDescriptor, N> fields,
1107 const Vector &x,
1108 std::array<Vector, M> &fields_l)
1109{
1110 int data_offset = 0;
1111 for (int i = 0; i < N; i++)
1112 {
1113 const auto P = get_prolongation(fields[i]);
1114 const int width = P->Width();
1115 // const Vector x_i(x.GetData() + data_offset, width);
1116 const Vector x_i(const_cast<Vector&>(x), data_offset, width);
1117 fields_l[i].SetSize(P->Height());
1118
1119 P->Mult(x_i, fields_l[i]);
1120 data_offset += width;
1121 }
1122}
1123
1124/// @brief Apply the prolongation operator to a vector of fields.
1125///
1126/// x is a long vector containing the data for all fields on tdofs and
1127/// fields contains the information about each individual field to retrieve
1128/// it's corresponding prolongation.
1129///
1130/// @param fields the array of field descriptors.
1131/// @param x the input vector in tdofs.
1132/// @param fields_l the array of output vectors in vdofs.
1133inline
1134void prolongation(const std::vector<FieldDescriptor> fields,
1135 const Vector &x,
1136 std::vector<Vector> &fields_l)
1137{
1138 int data_offset = 0;
1139 for (std::size_t i = 0; i < fields.size(); i++)
1140 {
1141 const auto P = get_prolongation(fields[i]);
1142 const int width = P->Width();
1143 const Vector x_i(const_cast<Vector&>(x), data_offset, width);
1144 fields_l[i].SetSize(P->Height());
1145 P->Mult(x_i, fields_l[i]);
1146 data_offset += width;
1147 }
1148}
1149
1150inline
1151void get_lvectors(const std::vector<FieldDescriptor> fields,
1152 const Vector &x,
1153 std::vector<Vector> &fields_l)
1154{
1155 int data_offset = 0;
1156 for (std::size_t i = 0; i < fields.size(); i++)
1157 {
1158 const int sz = GetVSize(fields[i]);
1159 fields_l[i].SetSize(sz);
1160
1161 const Vector x_i(const_cast<Vector&>(x), data_offset, sz);
1162 fields_l[i] = x_i;
1163
1164 data_offset += sz;
1165 }
1166}
1167
1168/// @brief Get a transpose prolongation callback for a field descriptor.
1169///
1170/// In the special case of a one field operator, the transpose prolongation
1171/// is a simple sum of the local vector that is reduced to the global vector.
1172///
1173/// @param f the field descriptor.
1174/// @param fop the field operator.
1175/// @param mpi_comm the MPI communicator.
1176/// @tparam fop_t the field operator type.
1177template <typename fop_t>
1178inline
1179std::function<void(const Vector&, Vector&)> get_prolongation_transpose(
1180 const FieldDescriptor &f,
1181 const fop_t &fop,
1182 MPI_Comm mpi_comm)
1183{
1184 if constexpr (is_sum_fop<fop_t>::value)
1185 {
1186 auto PT = [=](const Vector &r_local, Vector &y)
1187 {
1188 MFEM_ASSERT(y.Size() == 1, "output size doesn't match kernel description");
1189 real_t local_sum = r_local.Sum();
1190 MPI_Allreduce(&local_sum, y.GetData(), 1, MPI_DOUBLE, MPI_SUM, mpi_comm);
1191 };
1192 return PT;
1193 }
1194 else if constexpr (is_identity_fop<fop_t>::value)
1195 {
1196 auto PT = [=](const Vector &r_local, Vector &y)
1197 {
1198 y = r_local;
1199 };
1200 return PT;
1201 }
1202 const Operator *P = get_prolongation(f);
1203 auto PT = [=](const Vector &r_local, Vector &y)
1204 {
1205 P->MultTranspose(r_local, y);
1206 };
1207 return PT;
1208}
1209
1210/// @brief Apply the restriction operator to a field.
1211///
1212/// @param u the field descriptor.
1213/// @param u_l the input vector in vdofs.
1214/// @param field_e the output vector in edofs.
1215/// @param ordering the element dof ordering.
1216/// @tparam entity_t the entity type (see Entity).
1217template <typename entity_t>
1219 const Vector &u_l,
1220 Vector &field_e,
1221 ElementDofOrdering ordering)
1222{
1223 const auto R = get_restriction<entity_t>(u, ordering);
1224 MFEM_ASSERT(R->Width() == u_l.Size(),
1225 "restriction not applicable to given data size");
1226 const int height = R->Height();
1227 field_e.SetSize(height);
1228 R->Mult(u_l, field_e);
1229}
1230
1231/// @brief Apply the restriction operator to a vector of fields.
1232///
1233/// @param u the vector of field descriptors.
1234/// @param u_l the vector of input vectors in vdofs.
1235/// @param fields_e the vector of output vectors in edofs.
1236/// @param ordering the element dof ordering.
1237/// @param offset the array index offset to start writing in fields_e.
1238/// @tparam entity_t the entity type (see Entity).
1239template <typename entity_t>
1240void restriction(const std::vector<FieldDescriptor> u,
1241 const std::vector<Vector> &u_l,
1242 std::vector<Vector> &fields_e,
1243 ElementDofOrdering ordering,
1244 const int offset = 0)
1245{
1246 for (std::size_t i = 0; i < u.size(); i++)
1247 {
1248 const auto R = get_restriction<entity_t>(u[i], ordering);
1249 MFEM_ASSERT(R->Width() == u_l[i].Size(),
1250 "restriction not applicable to given data size");
1251 const int height = R->Height();
1252 fields_e[i + offset].SetSize(height);
1253 R->Mult(u_l[i], fields_e[i + offset]);
1254 }
1255}
1256
1257// TODO: keep this temporarily
1258template <std::size_t N, std::size_t M>
1259void element_restriction(const std::array<FieldDescriptor, N> u,
1260 const std::array<Vector, N> &u_l,
1261 std::array<Vector, M> &fields_e,
1262 ElementDofOrdering ordering,
1263 const int offset = 0)
1264{
1265 for (int i = 0; i < N; i++)
1266 {
1267 const auto R = get_element_restriction(u[i], ordering);
1268 MFEM_ASSERT(R->Width() == u_l[i].Size(),
1269 "element restriction not applicable to given data size");
1270 const int height = R->Height();
1271 fields_e[i + offset].SetSize(height);
1272 R->Mult(u_l[i], fields_e[i + offset]);
1273 }
1274}
1275
1276/// @brief Get the number of entities of a given type.
1277///
1278/// @param mesh the mesh.
1279/// @tparam entity_t the entity type (see Entity).
1280/// @returns the number of entities of the given type.
1281template <typename entity_t>
1283{
1284 if constexpr (std::is_same_v<entity_t, Entity::Element>)
1285 {
1286 return mesh.GetNE();
1287 }
1288 else if constexpr (std::is_same_v<entity_t, Entity::BoundaryElement>)
1289 {
1290 return mesh.GetNBE();
1291 }
1292 else
1293 {
1294 static_assert(dfem::always_false<entity_t>, "can't use GetNumEntites on type");
1295 }
1296 return 0; // Unreachable, but avoids compiler warning
1297}
1298
1299/// @brief Get the GetDofToQuad object for a given entity type.
1300///
1301/// This function retrieves the DofToQuad object for a given field descriptor
1302/// and integration rule.
1303///
1304/// @param f the field descriptor.
1305/// @param ir the integration rule.
1306/// @param mode the mode of the DofToQuad object.
1307/// @tparam entity_t the entity type (see Entity).
1308template <typename entity_t>
1309inline
1311 const IntegrationRule &ir,
1312 DofToQuad::Mode mode)
1313{
1314 return std::visit([&ir, &mode](auto&& arg) -> const DofToQuad*
1315 {
1316 using T = std::decay_t<decltype(arg)>;
1317 if constexpr (std::is_same_v<T, const FiniteElementSpace *>
1318 || std::is_same_v<T, const ParFiniteElementSpace *>)
1319 {
1320 if constexpr (std::is_same_v<entity_t, Entity::Element>)
1321 {
1322 return &arg->GetTypicalFE()->GetDofToQuad(ir, mode);
1323 }
1324 else if constexpr (std::is_same_v<entity_t, Entity::BoundaryElement>)
1325 {
1326 return &arg->GetTypicalTraceElement()->GetDofToQuad(ir, mode);
1327 }
1328 }
1329 else if constexpr (std::is_same_v<T, const ParameterSpace *>)
1330 {
1331 return &arg->GetDofToQuad();
1332 }
1333 else
1334 {
1335 static_assert(dfem::always_false<T>, "can't use GetDofToQuad on type");
1336 }
1337 return nullptr; // Unreachable, but avoids compiler warning
1338 }, f.data);
1339}
1340
1341/// @brief Check the compatibility of a field operator type with a
1342/// FieldDescriptor.
1343///
1344/// This function checks if the field operator type is compatible with the
1345/// FieldDescriptor type.
1346///
1347/// @param f the field descriptor.
1348/// @tparam field_operator_t the field operator type.
1349template <typename field_operator_t>
1351{
1352 std::visit([](auto && arg)
1353 {
1354 using T = std::decay_t<decltype(arg)>;
1355 if constexpr (std::is_same_v<T, const FiniteElementSpace *> ||
1356 std::is_same_v<T, const ParFiniteElementSpace *>)
1357 {
1358 if constexpr (std::is_same_v<field_operator_t, Value<>>)
1359 {
1360 // Supported by all FE spaces
1361 }
1362 else if constexpr (std::is_same_v<field_operator_t, Gradient<>>)
1363 {
1364 MFEM_ASSERT(arg->GetTypicalElement()->GetMapType() ==
1366 "Gradient not compatible with FE");
1367 }
1368 else
1369 {
1371 "FieldOperator not compatible with FiniteElementSpace");
1372 }
1373 }
1374 else if constexpr (std::is_same_v<T, const ParameterSpace *>)
1375 {
1376 if constexpr (std::is_same_v<field_operator_t, Identity<>>)
1377 {
1378 // Only supported field operation for ParameterSpace
1379 }
1380 else
1381 {
1383 "FieldOperator not compatible with ParameterSpace");
1384 }
1385 }
1386 else
1387 {
1389 "Operator not compatible with FE");
1390 }
1391 }, f.data);
1392}
1393
1394/// @brief Get the size on quadrature point for a field operator type
1395/// and FieldDescriptor combination.
1396///
1397/// @tparam entity_t the entity type (see Entity).
1398/// @tparam field_operator_t the field operator type.
1399/// @param f the field descriptor.
1400/// @returns the size on quadrature point.
1401template <typename entity_t, typename field_operator_t>
1402int GetSizeOnQP(const field_operator_t &, const FieldDescriptor &f)
1403{
1404 // CheckCompatibility<field_operator_t>(f);
1405
1407 {
1408 return GetVDim(f);
1409 }
1411 {
1412 return GetVDim(f) * GetDimension<entity_t>(f);
1413 }
1415 {
1416 return GetVDim(f);
1417 }
1418 else if constexpr (is_sum_fop<field_operator_t>::value)
1419 {
1420 return 1;
1421 }
1422 else
1423 {
1424 MFEM_ABORT("can't get size on quadrature point for field descriptor");
1425 }
1426 return 0; // Unreachable, but avoids compiler warning
1427}
1428
1429/// @brief Create a map from field operator types to FieldDescriptor indices.
1430///
1431/// @param fields the vector of field descriptors.
1432/// @param fops the field operator types.
1433/// @tparam entity_t the entity type (see Entity).
1434/// @returns an array mapping field operator types to field descriptor indices.
1435template <typename entity_t, typename field_operator_ts>
1436std::array<size_t, tuple_size<field_operator_ts>::value>
1438 const std::vector<FieldDescriptor> &fields,
1439 field_operator_ts &fops)
1440{
1441 std::array<size_t, tuple_size<field_operator_ts>::value> map;
1442
1443 auto find_id = [](const std::vector<FieldDescriptor> &fields, std::size_t i)
1444 {
1445 auto it = std::find_if(begin(fields), end(fields),
1446 [&](const FieldDescriptor &field)
1447 {
1448 return field.id == i;
1449 });
1450
1451 if (it == fields.end())
1452 {
1453 return SIZE_MAX;
1454 }
1455 return static_cast<size_t>(it - fields.begin());
1456 };
1457
1458 auto f = [&](auto &fop, auto &map)
1459 {
1460 if constexpr (std::is_same_v<std::decay_t<decltype(fop)>, Weight>)
1461 {
1462 // TODO-bug: stealing dimension from the first field
1463 fop.dim = GetDimension<entity_t>(fields[0]);
1464 fop.vdim = 1;
1465 fop.size_on_qp = 1;
1466 map = SIZE_MAX;
1467 }
1468 else
1469 {
1470 int i = find_id(fields, fop.GetFieldId());
1471 if (i != -1)
1472 {
1473 fop.dim = GetDimension<entity_t>(fields[i]);
1474 fop.vdim = GetVDim(fields[i]);
1475 fop.size_on_qp = GetSizeOnQP<entity_t>(fop, fields[i]);
1476 map = i;
1477 }
1478 else
1479 {
1480 MFEM_ABORT("can't find field for id: " << fop.GetFieldId());
1481 }
1482 }
1483 };
1484
1485 for_constexpr<tuple_size<field_operator_ts>::value>([&](auto idx)
1486 {
1487 f(get<idx>(fops), map[idx]);
1488 });
1489
1490 return map;
1491}
1492
1493/// @brief Wrap input memory for a given set of inputs.
1494template <typename input_t, std::size_t... i>
1495std::array<DeviceTensor<3>, sizeof...(i)> wrap_input_memory(
1496 std::array<Vector, sizeof...(i)> &input_qp_mem, int num_qp, int num_entities,
1497 const input_t &inputs, std::index_sequence<i...>)
1498{
1499 return {DeviceTensor<3>(input_qp_mem[i].Write(), get<i>(inputs).size_on_qp, num_qp, num_entities) ...};
1500}
1501
1502/// @brief Create input memory for a given set of inputs.
1503template <typename input_t, std::size_t... i>
1504std::array<Vector, sizeof...(i)> create_input_qp_memory(
1505 int num_qp,
1506 int num_entities,
1507 input_t &inputs,
1508 std::index_sequence<i...>)
1509{
1510 return {Vector(get<i>(inputs).size_on_qp * num_qp * num_entities)...};
1511}
1512
1513/// @brief DofToQuadMap struct
1514///
1515/// This struct is used to store the mapping from degrees of freedom to
1516/// quadrature points for a given field operator type.
1518{
1519 /// Enumeration for the indices of the mappings B and G.
1521 {
1524 DOF
1526
1527 /// @brief Basis functions evaluated at quadrature points.
1528 ///
1529 /// This is a 3D tensor with dimensions (num_qp, dim, num_dofs).
1531
1532 /// @brief Gradient of the basis functions evaluated at quadrature points.
1533 ///
1534 /// This is a 3D tensor with dimensions (num_qp, dim, num_dofs).
1536
1537 /// Reverse mapping indicating which input this map belongs to.
1538 int which_input = -1;
1539};
1540
1541/// @brief Get the size on quadrature point for a given set of inputs.
1542///
1543/// @param inputs the inputs tuple.
1544/// @returns a vector containing the size on quadrature point for each input.
1545template <typename input_t, std::size_t... i>
1546std::vector<int> get_input_size_on_qp(
1547 const input_t &inputs,
1548 std::index_sequence<i...>)
1549{
1550 return {get<i>(inputs).size_on_qp...};
1551}
1552
1567
1568template <std::size_t num_fields, std::size_t num_inputs, std::size_t num_outputs>
1570{
1572 std::array<int, 8> offsets;
1573 std::array<std::array<int, 2>, num_inputs> input_dtq_sizes;
1574 std::array<std::array<int, 2>, num_outputs> output_dtq_sizes;
1575 std::array<int, num_fields> field_sizes;
1577 std::array<int, num_inputs> input_sizes;
1578 std::array<int, num_inputs> shadow_sizes;
1580 std::array<int, 6> temp_sizes;
1581};
1582
1583template <typename entity_t, std::size_t num_fields, std::size_t num_inputs, std::size_t num_outputs, typename input_t>
1586 const std::array<DofToQuadMap, num_inputs> &input_dtq_maps,
1587 const std::array<DofToQuadMap, num_outputs> &output_dtq_maps,
1588 const std::vector<FieldDescriptor> &fields,
1589 const int &num_entities,
1590 const input_t &inputs,
1591 const int &num_qp,
1592 const std::vector<int> &input_size_on_qp,
1593 const int &residual_size_on_qp,
1594 const ElementDofOrdering &dof_ordering,
1595 const int &derivative_action_field_idx = -1)
1596{
1597 std::array<int, 8> offsets = {0};
1598 int total_size = 0;
1599
1600 offsets[SharedMemory::Index::INPUT_DTQ] = total_size;
1601 std::array<std::array<int, 2>, num_inputs> input_dtq_sizes;
1602 int max_dtq_qps = 0;
1603 int max_dtq_dofs = 0;
1604 for (std::size_t i = 0; i < num_inputs; i++)
1605 {
1606 auto a = input_dtq_maps[i].B.GetShape();
1607 input_dtq_sizes[i][0] = a[0] * a[1] * a[2];
1608 auto b = input_dtq_maps[i].G.GetShape();
1609 input_dtq_sizes[i][1] = b[0] * b[1] * b[2];
1610
1611 max_dtq_qps = std::max(max_dtq_qps, a[DofToQuadMap::Index::QP]);
1612 max_dtq_dofs = std::max(max_dtq_dofs, a[DofToQuadMap::Index::DOF]);
1613
1614 total_size += std::accumulate(std::begin(input_dtq_sizes[i]),
1615 std::end(input_dtq_sizes[i]),
1616 0);
1617 }
1618
1619 offsets[SharedMemory::Index::OUTPUT_DTQ] = total_size;
1620 std::array<std::array<int, 2>, num_outputs> output_dtq_sizes;
1621 for (std::size_t i = 0; i < num_outputs; i++)
1622 {
1623 auto a = output_dtq_maps[i].B.GetShape();
1624 output_dtq_sizes[i][0] = a[0] * a[1] * a[2];
1625 auto b = output_dtq_maps[i].G.GetShape();
1626 output_dtq_sizes[i][1] = b[0] * b[1] * b[2];
1627
1628 max_dtq_qps = std::max(max_dtq_qps, a[DofToQuadMap::Index::QP]);
1629 max_dtq_dofs = std::max(max_dtq_dofs, a[DofToQuadMap::Index::DOF]);
1630
1631 total_size += std::accumulate(std::begin(output_dtq_sizes[i]),
1632 std::end(output_dtq_sizes[i]),
1633 0);
1634 }
1635
1636 offsets[SharedMemory::Index::FIELD] = total_size;
1637 std::array<int, num_fields> field_sizes;
1638 for (std::size_t i = 0; i < num_fields; i++)
1639 {
1640 field_sizes[i] =
1641 num_entities
1642 ? (get_restriction<entity_t>(fields[i], dof_ordering)->Height()
1643 / num_entities)
1644 : 0;
1645 }
1646 total_size += std::accumulate(
1647 std::begin(field_sizes), std::end(field_sizes), 0);
1648
1649 offsets[SharedMemory::Index::DIRECTION] = total_size;
1650 int direction_size = 0;
1651 if (derivative_action_field_idx != -1)
1652 {
1653 direction_size =
1654 num_entities ? (get_restriction<entity_t>(
1655 fields[derivative_action_field_idx], dof_ordering)
1656 ->Height()
1657 / num_entities)
1658 : 0;
1659 total_size += direction_size;
1660 }
1661
1662 offsets[SharedMemory::Index::INPUT] = total_size;
1663 std::array<int, num_inputs> input_sizes;
1664 for (std::size_t i = 0; i < num_inputs; i++)
1665 {
1666 input_sizes[i] = input_size_on_qp[i] * num_qp;
1667 }
1668 total_size += std::accumulate(
1669 std::begin(input_sizes), std::end(input_sizes), 0);
1670
1671 offsets[SharedMemory::Index::SHADOW] = total_size;
1672 std::array<int, num_inputs> shadow_sizes{0};
1673 if (derivative_action_field_idx != -1)
1674 {
1675 for (std::size_t i = 0; i < num_inputs; i++)
1676 {
1677 shadow_sizes[i] = input_size_on_qp[i] * num_qp;
1678 }
1679 total_size += std::accumulate(
1680 std::begin(shadow_sizes), std::end(shadow_sizes), 0);
1681 }
1682
1683 offsets[SharedMemory::Index::OUTPUT] = total_size;
1684 const int residual_size = residual_size_on_qp;
1685 total_size += residual_size * num_qp;
1686
1687 offsets[SharedMemory::Index::TEMP] = total_size;
1688 constexpr int num_temp = 6;
1689 std::array<int, num_temp> temp_sizes = {0};
1690 // TODO-bug: this assumes q1d >= d1d
1691 const int q1d = max_dtq_qps;
1692 [[maybe_unused]] const int d1d = max_dtq_dofs;
1693
1694 // TODO-bug: this depends on the dimension
1695 constexpr int hardcoded_temp_num = 6;
1696 for (std::size_t i = 0; i < hardcoded_temp_num; i++)
1697 {
1698 // TODO-bug: over-allocates if q1d <= d1d
1699 temp_sizes[i] = q1d * q1d * q1d;
1700 }
1701 total_size += std::accumulate(
1702 std::begin(temp_sizes), std::end(temp_sizes), 0);
1703
1705 {
1706 total_size,
1707 offsets,
1708 input_dtq_sizes,
1709 output_dtq_sizes,
1710 field_sizes,
1711 direction_size,
1712 input_sizes,
1713 shadow_sizes,
1714 residual_size,
1715 temp_sizes
1716 };
1717}
1718
1719template <typename shmem_info_t>
1720void print_shared_memory_info(shmem_info_t &shmem_info)
1721{
1722 out << "Shared Memory Info\n"
1723 << "total size: " << shmem_info.total_size
1724 << " " << "(" << shmem_info.total_size * real_t(sizeof(real_t))/1024.0 << "kb)";
1725 out << "\ninput dtq sizes (B G): ";
1726 for (auto &i : shmem_info.input_dtq_sizes)
1727 {
1728 out << "(";
1729 for (int j = 0; j < 2; j++)
1730 {
1731 out << i[j];
1732 if (j < 1)
1733 {
1734 out << " ";
1735 }
1736 }
1737 out << ") ";
1738 }
1739 out << "\noutput dtq sizes (B G): ";
1740 for (auto &i : shmem_info.output_dtq_sizes)
1741 {
1742 out << "(";
1743 for (int j = 0; j < 2; j++)
1744 {
1745 out << i[j];
1746 if (j < 1)
1747 {
1748 out << " ";
1749 }
1750 }
1751 out << ") ";
1752 }
1753 out << "\nfield sizes: ";
1754 for (auto &i : shmem_info.field_sizes)
1755 {
1756 out << i << " ";
1757 }
1758 out << "\ndirection size: ";
1759 out << shmem_info.direction_size << " ";
1760 out << "\ninput sizes: ";
1761 for (auto &i : shmem_info.input_sizes)
1762 {
1763 out << i << " ";
1764 }
1765 out << "\nshadow sizes: ";
1766 for (auto &i : shmem_info.shadow_sizes)
1767 {
1768 out << i << " ";
1769 }
1770 out << "\ntemp sizes: ";
1771 for (auto &i : shmem_info.temp_sizes)
1772 {
1773 out << i << " ";
1774 }
1775 out << "\noffsets: ";
1776 for (auto &i : shmem_info.offsets)
1777 {
1778 out << i << " ";
1779 }
1780 out << "\n\n";
1781}
1782
1783template <std::size_t N>
1784MFEM_HOST_DEVICE inline
1785std::array<DofToQuadMap, N> load_dtq_mem(
1786 void *mem,
1787 int offset,
1788 const std::array<std::array<int, 2>, N> &sizes,
1789 const std::array<DofToQuadMap, N> &dtq)
1790{
1791 std::array<DofToQuadMap, N> f;
1792 for (std::size_t i = 0; i < N; i++)
1793 {
1794 if (dtq[i].which_input != -1)
1795 {
1796 const auto [nqp_b, dim_b, ndof_b] = dtq[i].B.GetShape();
1797 const auto B = Reshape(&dtq[i].B[0], nqp_b, dim_b, ndof_b);
1798 auto mem_Bi = Reshape(reinterpret_cast<real_t *>(mem) + offset, nqp_b, dim_b,
1799 ndof_b);
1800
1801 MFEM_FOREACH_THREAD(q, x, nqp_b)
1802 {
1803 MFEM_FOREACH_THREAD(d, y, ndof_b)
1804 {
1805 for (int b = 0; b < dim_b; b++)
1806 {
1807 auto v = B(q, b, d);
1808 mem_Bi(q, b, d) = v;
1809 }
1810 }
1811 }
1812
1813 offset += sizes[i][0];
1814
1815 const auto [nqp_g, dim_g, ndof_g] = dtq[i].G.GetShape();
1816 const auto G = Reshape(&dtq[i].G[0], nqp_g, dim_g, ndof_g);
1817 auto mem_Gi = Reshape(reinterpret_cast<real_t *>(mem) + offset, nqp_g, dim_g,
1818 ndof_g);
1819
1820 MFEM_FOREACH_THREAD(q, x, nqp_g)
1821 {
1822 MFEM_FOREACH_THREAD(d, y, ndof_g)
1823 {
1824 for (int b = 0; b < dim_g; b++)
1825 {
1826 mem_Gi(q, b, d) = G(q, b, d);
1827 }
1828 }
1829 }
1830
1831 offset += sizes[i][1];
1832
1833 f[i] = DofToQuadMap{DeviceTensor<3, const real_t>(&mem_Bi[0], nqp_b, dim_b, ndof_b),
1834 DeviceTensor<3, const real_t>(&mem_Gi[0], nqp_g, dim_g, ndof_g),
1835 dtq[i].which_input};
1836 }
1837 else
1838 {
1839 // When which_input is -1, just copy the original DofToQuadMap with empty data.
1840 f[i] = dtq[i];
1841 }
1842 }
1843 return f;
1844}
1845
1846template <std::size_t num_fields>
1847MFEM_HOST_DEVICE inline
1848std::array<DeviceTensor<1>, num_fields>
1850 void *mem,
1851 int offset,
1852 const std::array<int, num_fields> &sizes,
1853 const std::array<DeviceTensor<2>, num_fields> &fields_e,
1854 const int &entity_idx)
1855{
1856 std::array<DeviceTensor<1>, num_fields> f;
1857
1858 for_constexpr<num_fields>([&](auto field_idx)
1859 {
1860 int block_size = MFEM_THREAD_SIZE(x) *
1861 MFEM_THREAD_SIZE(y) *
1862 MFEM_THREAD_SIZE(z);
1863 int tid = MFEM_THREAD_ID(x) +
1864 MFEM_THREAD_SIZE(x) *
1865 (MFEM_THREAD_ID(y) + MFEM_THREAD_SIZE(y) * MFEM_THREAD_ID(z));
1866 for (int k = tid; k < sizes[field_idx]; k += block_size)
1867 {
1868 reinterpret_cast<real_t *>(mem)[offset + k] =
1869 fields_e[field_idx](k, entity_idx);
1870 }
1871
1872 f[field_idx] =
1873 DeviceTensor<1>(&reinterpret_cast<real_t *> (mem)[offset], sizes[field_idx]);
1874
1875 offset += sizes[field_idx];
1876 });
1877
1878 return f;
1879}
1880
1881MFEM_HOST_DEVICE inline
1883 void *mem,
1884 int offset,
1885 const int &size,
1887 const int &entity_idx)
1888{
1889 int block_size = MFEM_THREAD_SIZE(x) *
1890 MFEM_THREAD_SIZE(y) *
1891 MFEM_THREAD_SIZE(z);
1892 int tid = MFEM_THREAD_ID(x) +
1893 MFEM_THREAD_SIZE(x) *
1894 (MFEM_THREAD_ID(y) + MFEM_THREAD_SIZE(y) * MFEM_THREAD_ID(z));
1895 for (int k = tid; k < size; k += block_size)
1896 {
1897 reinterpret_cast<real_t *>(mem)[offset + k] = direction(k, entity_idx);
1898 }
1899 MFEM_SYNC_THREAD;
1900
1901 return DeviceTensor<1>(
1902 &reinterpret_cast<real_t *>(mem)[offset], size);
1903}
1904
1905template <std::size_t N>
1906MFEM_HOST_DEVICE inline
1907std::array<DeviceTensor<2>, N> load_input_mem(
1908 void *mem,
1909 int offset,
1910 const std::array<int, N> &sizes,
1911 const int &num_qp)
1912{
1913 std::array<DeviceTensor<2>, N> f;
1914 for (std::size_t i = 0; i < N; i++)
1915 {
1916 f[i] = DeviceTensor<2>(&reinterpret_cast<real_t *>(mem)[offset],
1917 sizes[i] / num_qp,
1918 num_qp);
1919 offset += sizes[i];
1920 }
1921 return f;
1922}
1923
1924MFEM_HOST_DEVICE inline
1926 void *mem,
1927 int offset,
1928 const int &residual_size,
1929 const int &num_qp)
1930{
1931 return DeviceTensor<2>(reinterpret_cast<real_t *>(mem) + offset, residual_size,
1932 num_qp);
1933}
1934
1935template <std::size_t N>
1936MFEM_HOST_DEVICE inline
1937std::array<DeviceTensor<1>, 6> load_scratch_mem(
1938 void *mem,
1939 int offset,
1940 const std::array<int, N> &sizes)
1941{
1942 std::array<DeviceTensor<1>, N> f;
1943 for (std::size_t i = 0; i < N; i++)
1944 {
1945 f[i] = DeviceTensor<1>(&reinterpret_cast<real_t *>(mem)[offset], sizes[i]);
1946 offset += sizes[i];
1947 }
1948 return f;
1949}
1950
1951template <typename shared_mem_info_t, std::size_t num_inputs, std::size_t num_outputs, std::size_t num_fields>
1952MFEM_HOST_DEVICE inline
1954 void *shmem,
1955 const shared_mem_info_t &shmem_info,
1956 const std::array<DofToQuadMap, num_inputs> &input_dtq_maps,
1957 const std::array<DofToQuadMap, num_outputs> &output_dtq_maps,
1958 const std::array<DeviceTensor<2>, num_fields> &wrapped_fields_e,
1959 const int &num_qp,
1960 const int &e)
1961{
1962 auto input_dtq_shmem =
1964 shmem,
1965 shmem_info.offsets[SharedMemory::Index::INPUT_DTQ],
1966 shmem_info.input_dtq_sizes,
1967 input_dtq_maps);
1968
1969 auto output_dtq_shmem =
1971 shmem,
1972 shmem_info.offsets[SharedMemory::Index::OUTPUT_DTQ],
1973 shmem_info.output_dtq_sizes,
1974 output_dtq_maps);
1975
1976 auto fields_shmem =
1978 shmem,
1979 shmem_info.offsets[SharedMemory::Index::FIELD],
1980 shmem_info.field_sizes,
1981 wrapped_fields_e,
1982 e);
1983
1984 // These functions don't copy, they simply create a `DeviceTensor` object
1985 // that points to correct chunks of the shared memory pool.
1986 auto input_shmem =
1988 shmem,
1989 shmem_info.offsets[SharedMemory::Index::INPUT],
1990 shmem_info.input_sizes,
1991 num_qp);
1992
1993 auto residual_shmem =
1995 shmem,
1996 shmem_info.offsets[SharedMemory::Index::OUTPUT],
1997 shmem_info.residual_size,
1998 num_qp);
1999
2000 auto scratch_mem =
2002 shmem,
2003 shmem_info.offsets[SharedMemory::Index::TEMP],
2004 shmem_info.temp_sizes);
2005
2006 MFEM_SYNC_THREAD;
2007
2008 // nvcc needs make_tuple to be fully qualified
2010 input_dtq_shmem, output_dtq_shmem, fields_shmem,
2011 input_shmem, residual_shmem, scratch_mem);
2012}
2013
2014template <typename shared_mem_info_t, std::size_t num_inputs, std::size_t num_outputs, std::size_t num_fields>
2015MFEM_HOST_DEVICE inline
2017 void *shmem,
2018 const shared_mem_info_t &shmem_info,
2019 const std::array<DofToQuadMap, num_inputs> &input_dtq_maps,
2020 const std::array<DofToQuadMap, num_outputs> &output_dtq_maps,
2021 const std::array<DeviceTensor<2>, num_fields> &wrapped_fields_e,
2022 const DeviceTensor<2> &wrapped_direction_e,
2023 const int &num_qp,
2024 const int &e)
2025{
2026 auto input_dtq_shmem =
2028 shmem,
2029 shmem_info.offsets[SharedMemory::Index::INPUT_DTQ],
2030 shmem_info.input_dtq_sizes,
2031 input_dtq_maps);
2032
2033 auto output_dtq_shmem =
2035 shmem,
2036 shmem_info.offsets[SharedMemory::Index::OUTPUT_DTQ],
2037 shmem_info.output_dtq_sizes,
2038 output_dtq_maps);
2039
2040 auto fields_shmem =
2042 shmem,
2043 shmem_info.offsets[SharedMemory::Index::FIELD],
2044 shmem_info.field_sizes,
2045 wrapped_fields_e,
2046 e);
2047
2048 auto direction_shmem =
2050 shmem,
2051 shmem_info.offsets[SharedMemory::Index::DIRECTION],
2052 shmem_info.direction_size,
2053 wrapped_direction_e,
2054 e);
2055
2056 // These methods don't copy, they simply create a `DeviceTensor` object
2057 // that points to correct chunks of the shared memory pool.
2058 auto input_shmem =
2060 shmem,
2061 shmem_info.offsets[SharedMemory::Index::INPUT],
2062 shmem_info.input_sizes,
2063 num_qp);
2064
2065 auto shadow_shmem =
2067 shmem,
2068 shmem_info.offsets[SharedMemory::Index::SHADOW],
2069 shmem_info.input_sizes,
2070 num_qp);
2071
2072 auto residual_shmem =
2074 shmem,
2075 shmem_info.offsets[SharedMemory::Index::OUTPUT],
2076 shmem_info.residual_size,
2077 num_qp);
2078
2079 auto scratch_mem =
2081 shmem,
2082 shmem_info.offsets[SharedMemory::Index::TEMP],
2083 shmem_info.temp_sizes);
2084
2085 MFEM_SYNC_THREAD;
2086
2087 // nvcc needs make_tuple to be fully qualified
2089 input_dtq_shmem, output_dtq_shmem, fields_shmem,
2090 direction_shmem, input_shmem, shadow_shmem,
2091 residual_shmem, scratch_mem);
2092}
2093
2094template <std::size_t... i>
2095MFEM_HOST_DEVICE inline
2096std::array<DeviceTensor<2>, sizeof...(i)> get_local_input_qp(
2097 const std::array<DeviceTensor<3>, sizeof...(i)> &input_qp_global, int e,
2098 std::index_sequence<i...>)
2099{
2100 return
2101 {
2103 &input_qp_global[i](0, 0, e),
2104 input_qp_global[i].GetShape()[0],
2105 input_qp_global[i].GetShape()[1]) ...
2106 };
2107}
2108
2109template <std::size_t N>
2110MFEM_HOST_DEVICE inline
2111void set_zero(std::array<DeviceTensor<2>, N> &v)
2112{
2113 for (std::size_t i = 0; i < N; i++)
2114 {
2115 int size = v[i].GetShape()[0] * v[i].GetShape()[1];
2116 auto vi = Reshape(&v[i][0], size);
2117 for (int j = 0; j < size; j++)
2118 {
2119 vi[j] = 0.0;
2120 }
2121 }
2122}
2123
2124template <std::size_t n>
2125MFEM_HOST_DEVICE inline
2127{
2128 int s = 1;
2129 for (int i = 0; i < n; i++)
2130 {
2131 s *= u.GetShape()[i];
2132 }
2133 auto ui = Reshape(&u[0], s);
2134 for (int j = 0; j < s; j++)
2135 {
2136 ui[j] = 0.0;
2137 }
2138}
2139
2140/// @brief Copy data from DeviceTensor u to DeviceTensor v
2141///
2142/// @param u source DeviceTensor
2143/// @param v destination DeviceTensor
2144/// @tparam n DeviceTensor rank
2145template <int n>
2146MFEM_HOST_DEVICE inline
2148{
2149 int s = 1;
2150 for (int i = 0; i < n; i++)
2151 {
2152 s *= u.GetShape()[i];
2153 }
2154 auto ui = Reshape(&u[0], s);
2155 auto vi = Reshape(&v[0], s);
2156 for (int j = 0; j < s; j++)
2157 {
2158 vi[j] = ui[j];
2159 }
2160}
2161
2162/// @brief Copy data from array of DeviceTensor u to array of DeviceTensor v
2163///
2164/// @param u source DeviceTensor array
2165/// @param v destination DeviceTensor array
2166/// @tparam n DeviceTensor rank
2167/// @tparam m number of DeviceTensors
2168template <int n, std::size_t m>
2169MFEM_HOST_DEVICE inline
2170void copy(std::array<DeviceTensor<n>, m> &u,
2171 std::array<DeviceTensor<n>, m> &v)
2172{
2173 for (int i = 0; i < m; i++)
2174 {
2175 copy(u[i], v[i]);
2176 }
2177}
2178
2179/// @brief Wraps plain data in DeviceTensors for fields
2180///
2181/// @param fields array of field data
2182/// @param field_sizes for each field, number of values stored for each entity
2183/// @param num_entities number of entities (elements, faces, etc) in mesh
2184/// @tparam num_fields number of fields
2185/// @return array of field data wrapped in DeviceTensors
2186template <std::size_t num_fields>
2187std::array<DeviceTensor<2>, num_fields> wrap_fields(
2188 std::vector<Vector> &fields,
2189 std::array<int, num_fields> &field_sizes,
2190 const int &num_entities)
2191{
2192 std::array<DeviceTensor<2>, num_fields> f;
2193
2194 for_constexpr<num_fields>([&](auto i)
2195 {
2196 f[i] = DeviceTensor<2>(fields[i].ReadWrite(), field_sizes[i], num_entities);
2197 });
2198
2199 return f;
2200}
2201
2202/// @brief Accumulates the sizes of field operators on quadrature points for
2203/// dependent inputs
2204///
2205/// @tparam input_t Type of input field operators tuple
2206/// @tparam num_fields Number of fields
2207/// @tparam i Parameter pack indices for field operators
2208///
2209/// @param inputs Tuple of input field operators
2210/// @param kinput_is_dependent Array indicating which inputs are dependent
2211/// @param input_to_field Array mapping input indices to field indices
2212/// @param fields Array of field descriptors
2213/// @param seq Index sequence for inputs
2214///
2215/// @return Sum of sizes on quadrature points for all dependent inputs
2216///
2217/// @details
2218/// This function accumulates the sizes needed on quadrature points for all
2219/// dependent input field operators. For each dependent input, it calculates the
2220/// size required on quadrature points using GetSizeOnQP() and adds it to the
2221/// total. Non-dependent inputs contribute zero to the total size.
2222template <typename input_t, std::size_t num_fields, std::size_t... i>
2224 const input_t &inputs,
2225 std::array<bool, sizeof...(i)> &kinput_is_dependent,
2226 const std::array<int, sizeof...(i)> &input_to_field,
2227 const std::array<FieldDescriptor, num_fields> &fields,
2228 std::index_sequence<i...> seq)
2229{
2230 MFEM_CONTRACT_VAR(seq); // 'seq' is needed for doxygen
2231 return (... + [](auto &input, auto is_dependent, auto field)
2232 {
2233 if (!is_dependent)
2234 {
2235 return 0;
2236 }
2237 return GetSizeOnQP(input, field);
2238 }
2239 (get<i>(inputs),
2240 get<i>(kinput_is_dependent),
2241 fields[input_to_field[i]]));
2242}
2243
2244template <
2245 typename entity_t,
2246 typename field_operator_ts,
2247 std::size_t N = tuple_size<field_operator_ts>::value,
2248 std::size_t... Is>
2249std::array<DofToQuadMap, N> create_dtq_maps_impl(
2250 field_operator_ts &fops,
2251 std::vector<const DofToQuad*> &dtqs,
2252 const std::array<size_t, N> &field_map,
2253 std::index_sequence<Is...>)
2254{
2255 auto f = [&](auto fop, std::size_t idx)
2256 {
2257 [[maybe_unused]] auto g = [&](int idx)
2258 {
2259 auto dtq = dtqs[field_map[idx]];
2260
2261 int value_dim = 1;
2262 int grad_dim = 1;
2263
2264 if ((dtq->mode != DofToQuad::Mode::TENSOR) &&
2265 (!is_identity_fop<decltype(fop)>::value))
2266 {
2267 value_dim = dtq->FE->GetRangeDim() ? dtq->FE->GetRangeDim() : 1;
2268 grad_dim = dtq->FE->GetDim();
2269 }
2270
2271 return std::tuple{dtq, value_dim, grad_dim};
2272 };
2273
2274 if constexpr (is_value_fop<decltype(fop)>::value ||
2275 is_gradient_fop<decltype(fop)>::value)
2276 {
2277 auto [dtq, value_dim, grad_dim] = g(idx);
2278 return DofToQuadMap
2279 {
2280 DeviceTensor<3, const real_t>(dtq->B.Read(), dtq->nqpt, value_dim, dtq->ndof),
2281 DeviceTensor<3, const real_t>(dtq->G.Read(), dtq->nqpt, grad_dim, dtq->ndof),
2282 static_cast<int>(idx)
2283 };
2284 }
2285 else if constexpr (std::is_same_v<decltype(fop), Weight>)
2286 {
2287 return DofToQuadMap
2288 {
2289 DeviceTensor<3, const real_t>(nullptr, 1, 1, 1),
2290 DeviceTensor<3, const real_t>(nullptr, 1, 1, 1),
2291 -1
2292 };
2293 }
2294 else if constexpr (is_identity_fop<decltype(fop)>::value ||
2295 is_sum_fop<decltype(fop)>::value)
2296 {
2297 auto [dtq, value_dim, grad_dim] = g(idx);
2298 return DofToQuadMap
2299 {
2300 DeviceTensor<3, const real_t>(nullptr, dtq->nqpt, value_dim, dtq->ndof),
2301 DeviceTensor<3, const real_t>(nullptr, dtq->nqpt, grad_dim, dtq->ndof),
2302 -1
2303 };
2304 }
2305 else
2306 {
2307 static_assert(dfem::always_false<decltype(fop)>,
2308 "field operator type is not implemented");
2309 }
2310 return DofToQuadMap
2311 {
2312 DeviceTensor<3, const real_t>(nullptr, 0, 0, 0),
2313 DeviceTensor<3, const real_t>(nullptr, 0, 0, 0),
2314 -1
2315 }; // Unreachable, but avoids compiler warning
2316 };
2317 return std::array<DofToQuadMap, N>
2318 {
2319 f(get<Is>(fops), Is)...
2320 };
2321}
2322
2323/// @brief Create DofToQuad maps for a given set of field operators.
2324///
2325/// @param fops field operators
2326/// @param dtqmaps DofToQuad maps
2327/// @param to_field_map mapping from input indices to field indices
2328/// @tparam entity_t type of the entity
2329/// @return array of DofToQuad maps
2330template <
2331 typename entity_t,
2332 typename field_operator_ts,
2333 std::size_t num_fields>
2334std::array<DofToQuadMap, num_fields> create_dtq_maps(
2335 field_operator_ts &fops,
2336 std::vector<const DofToQuad*> &dtqmaps,
2337 const std::array<size_t, num_fields> &to_field_map)
2338{
2340 fops, dtqmaps,
2341 to_field_map,
2342 std::make_index_sequence<num_fields> {});
2343}
2344
2345} // namespace mfem::future
2346#endif
int Size() const
Return the logical size of the array.
Definition array.hpp:166
Data type dense matrix using column-major storage.
Definition densemat.hpp:24
A basic generic Tensor class, appropriate for use on the GPU.
Definition dtensor.hpp:84
MFEM_HOST_DEVICE auto & GetShape() const
Returns the shape of the tensor.
Definition dtensor.hpp:131
static bool Allows(unsigned long b_mask)
Return true if any of the backends in the backend mask, b_mask, are allowed.
Definition device.hpp:262
static MemoryClass GetDeviceMemoryClass()
Get the current Device MemoryClass. This is the MemoryClass used by most MFEM device kernels to acces...
Definition device.hpp:285
Structure representing the matrices/tensors needed to evaluate (in reference space) the values,...
Definition fe_base.hpp:141
Mode
Type of data stored in the arrays B, Bt, G, and Gt.
Definition fe_base.hpp:154
@ TENSOR
Tensor product representation using 1D matrices/tensors with dimensions using 1D number of quadrature...
Definition fe_base.hpp:165
Abstract data type element.
Definition element.hpp:29
Class FiniteElementSpace - responsible for providing FEM view of the mesh, mainly managing the set of...
Definition fespace.hpp:208
Class for an integration rule - an Array of IntegrationPoint.
Definition intrules.hpp:100
Mesh data type.
Definition mesh.hpp:65
int GetNE() const
Returns number of elements.
Definition mesh.hpp:1377
int GetNBE() const
Returns number of boundary elements.
Definition mesh.hpp:1380
static int WorldRank()
Return the MPI rank in MPI_COMM_WORLD.
static int WorldSize()
Return the size of MPI_COMM_WORLD.
Abstract operator.
Definition operator.hpp:25
int Height() const
Get the height (size of output) of the Operator. Synonym with NumRows().
Definition operator.hpp:66
virtual void Mult(const Vector &x, Vector &y) const =0
Operator application: y=A(x).
int NumCols() const
Get the number of columns (size of input) of the Operator. Synonym with Width().
Definition operator.hpp:75
int Width() const
Get the width (size of input) of the Operator. Synonym with NumCols().
Definition operator.hpp:72
int NumRows() const
Get the number of rows (size of output) of the Operator. Synonym with Height().
Definition operator.hpp:69
virtual void AddMultTranspose(const Vector &x, Vector &y, const real_t a=1.0) const
Operator transpose application: y+=A^t(x) (default) or y+=a*A^t(x).
Definition operator.cpp:58
virtual void MultTranspose(const Vector &x, Vector &y) const
Action of the transpose operator: y=A^t(x). The default behavior in class Operator is to generate an ...
Definition operator.hpp:100
Abstract parallel finite element space.
Definition pfespace.hpp:31
Vector data type.
Definition vector.hpp:82
virtual const real_t * Read(bool on_dev=true) const
Shortcut for mfem::Read(vec.GetMemory(), vec.Size(), on_dev).
Definition vector.hpp:520
virtual real_t * ReadWrite(bool on_dev=true)
Shortcut for mfem::ReadWrite(vec.GetMemory(), vec.Size(), on_dev).
Definition vector.hpp:536
real_t Norml2() const
Returns the l2 norm of the vector.
Definition vector.cpp:968
int Size() const
Returns the size of the vector.
Definition vector.hpp:234
virtual void UseDevice(bool use_dev) const
Enable execution of Vector operations using the mfem::Device.
Definition vector.hpp:145
real_t Sum() const
Return the sum of the vector entries.
Definition vector.cpp:1246
void SetSize(int s)
Resize the vector to size s.
Definition vector.hpp:584
virtual real_t * Write(bool on_dev=true)
Shortcut for mfem::Write(vec.GetMemory(), vec.Size(), on_dev).
Definition vector.hpp:528
virtual MemoryClass GetMemoryClass() const override
Return the MemoryClass preferred by the Operator.
Definition util.hpp:721
void Mult(const Vector &v, Vector &y) const override
Operator application: y=A(x).
Definition util.hpp:674
FDJacobian(const Operator &op, const Vector &x, real_t fixed_eps=0.0)
Definition util.hpp:655
Base class for parametric spaces.
Weight FieldOperator.
real_t b
Definition lissajous.cpp:42
real_t a
Definition lissajous.cpp:41
string direction
constexpr void for_constexpr(lambda &&f, std::integral_constant< std::size_t, i >... Is)
Definition util.hpp:52
constexpr bool always_false
Definition util.hpp:575
constexpr auto decay_types(tuple< Ts... > const &) -> tuple< std::remove_cv_t< std::remove_reference_t< Ts > >... >
constexpr auto filter_fields(const std::tuple< Ts... > &t)
Filter fields from a tuple based on their field IDs.
Definition util.hpp:541
const Operator * get_element_restriction(const FieldDescriptor &f, ElementDofOrdering o)
Get the element restriction operator for a field descriptor.
Definition util.hpp:959
const Operator * get_face_restriction(const FieldDescriptor &f, ElementDofOrdering o, FaceType ft, L2FaceValues m)
Get the face restriction operator for a field descriptor.
Definition util.hpp:992
void prolongation(const FieldDescriptor field, const Vector &x, Vector &field_l)
Apply the prolongation operator to a field.
Definition util.hpp:1087
constexpr bool contains(const int *arr, std::size_t size, int value)
Helper function to check if an element is in the array.
Definition util.hpp:478
MFEM_HOST_DEVICE auto unpack_shmem(void *shmem, const shared_mem_info_t &shmem_info, const std::array< DofToQuadMap, num_inputs > &input_dtq_maps, const std::array< DofToQuadMap, num_outputs > &output_dtq_maps, const std::array< DeviceTensor< 2 >, num_fields > &wrapped_fields_e, const int &num_qp, const int &e)
Definition util.hpp:1953
void restriction(const FieldDescriptor u, const Vector &u_l, Vector &field_e, ElementDofOrdering ordering)
Apply the restriction operator to a field.
Definition util.hpp:1218
void GetElementVDofs(const FieldDescriptor &f, int el, Array< int > &vdofs)
Get the element vdofs of a field descriptor.
Definition util.hpp:795
MFEM_HOST_DEVICE tuple< T... > make_tuple(const T &... args)
helper function for combining a list of values into a tuple
Definition tuple.hpp:224
void print_mpi_root(const std::string &msg)
Definition util.hpp:330
const Operator * get_prolongation(const FieldDescriptor &f)
Get the prolongation operator for a field descriptor.
Definition util.hpp:930
void get_lvectors(const std::vector< FieldDescriptor > fields, const Vector &x, std::vector< Vector > &fields_l)
Definition util.hpp:1151
std::array< Vector, sizeof...(i)> create_input_qp_memory(int num_qp, int num_entities, input_t &inputs, std::index_sequence< i... >)
Create input memory for a given set of inputs.
Definition util.hpp:1504
MFEM_HOST_DEVICE DeviceTensor< 1 > load_direction_mem(void *mem, int offset, const int &size, const DeviceTensor< 2 > &direction, const int &entity_idx)
Definition util.hpp:1882
MFEM_HOST_DEVICE std::array< DeviceTensor< 2 >, N > load_input_mem(void *mem, int offset, const std::array< int, N > &sizes, const int &num_qp)
Definition util.hpp:1907
decltype(decay_types(std::declval< T >())) decay_tuple
Definition util.hpp:420
int GetNumEntities(const mfem::Mesh &mesh)
Get the number of entities of a given type.
Definition util.hpp:1282
constexpr auto get_type_name() -> std::string_view
Definition util.hpp:161
void pretty_print(std::ostream &out, const mfem::DenseMatrix &A)
Pretty print an mfem::DenseMatrix to out.
Definition util.hpp:214
void print_tuple(const std::tuple< Args... > &t)
Definition util.hpp:196
std::array< bool, sizeof...(Is)> make_dependency_array(const Tuple &inputs, std::index_sequence< Is... >)
Definition util.hpp:113
MFEM_HOST_DEVICE void copy(DeviceTensor< n > &u, DeviceTensor< n > &v)
Copy data from DeviceTensor u to DeviceTensor v.
Definition util.hpp:2147
constexpr auto extract_field_ids(const std::tuple< Ts... > &t)
Extracts field IDs from a tuple of objects derived from FieldOperator.
Definition util.hpp:467
void print_mpi_sync(const std::string &msg)
print with MPI rank synchronization
Definition util.hpp:344
std::array< DeviceTensor< 3 >, sizeof...(i)> wrap_input_memory(std::array< Vector, sizeof...(i)> &input_qp_mem, int num_qp, int num_entities, const input_t &inputs, std::index_sequence< i... >)
Wrap input memory for a given set of inputs.
Definition util.hpp:1495
std::tuple< std::function< void(const Vector &, Vector &)>, int > get_restriction_transpose(const FieldDescriptor &f, const ElementDofOrdering &o, const fop_t &fop)
Get a transpose restriction callback for a field descriptor.
Definition util.hpp:1052
std::array< DeviceTensor< 2 >, num_fields > wrap_fields(std::vector< Vector > &fields, std::array< int, num_fields > &field_sizes, const int &num_entities)
Wraps plain data in DeviceTensors for fields.
Definition util.hpp:2187
std::vector< int > get_input_size_on_qp(const input_t &inputs, std::index_sequence< i... >)
Get the size on quadrature point for a given set of inputs.
Definition util.hpp:1546
MFEM_HOST_DEVICE std::array< DofToQuadMap, N > load_dtq_mem(void *mem, int offset, const std::array< std::array< int, 2 >, N > &sizes, const std::array< DofToQuadMap, N > &dtq)
Definition util.hpp:1785
SharedMemoryInfo< num_fields, num_inputs, num_outputs > get_shmem_info(const std::array< DofToQuadMap, num_inputs > &input_dtq_maps, const std::array< DofToQuadMap, num_outputs > &output_dtq_maps, const std::vector< FieldDescriptor > &fields, const int &num_entities, const input_t &inputs, const int &num_qp, const std::vector< int > &input_size_on_qp, const int &residual_size_on_qp, const ElementDofOrdering &dof_ordering, const int &derivative_action_field_idx=-1)
Definition util.hpp:1585
const Operator * get_restriction(const FieldDescriptor &f, const ElementDofOrdering &o)
Get the restriction operator for a field descriptor.
Definition util.hpp:1027
std::array< size_t, tuple_size< field_operator_ts >::value > create_descriptors_to_fields_map(const std::vector< FieldDescriptor > &fields, field_operator_ts &fops)
Create a map from field operator types to FieldDescriptor indices.
Definition util.hpp:1437
void element_restriction(const std::array< FieldDescriptor, N > u, const std::array< Vector, N > &u_l, std::array< Vector, M > &fields_e, ElementDofOrdering ordering, const int offset=0)
Definition util.hpp:1259
constexpr std::size_t count_unique_field_ids(const std::tuple< Ts... > &t)
Function to count unique field IDs in a tuple.
Definition util.hpp:495
constexpr void for_constexpr_with_arg(lambda &&f, arg_t &&arg, std::integer_sequence< std::size_t >)
Definition util.hpp:87
MFEM_HOST_DEVICE constexpr auto type(const tuple< T... > &values)
a function intended to be used for extracting the ith type from a tuple.
Definition tuple.hpp:347
std::array< DofToQuadMap, num_fields > create_dtq_maps(field_operator_ts &fops, std::vector< const DofToQuad * > &dtqmaps, const std::array< size_t, num_fields > &to_field_map)
Create DofToQuad maps for a given set of field operators.
Definition util.hpp:2334
auto make_dependency_map_impl(tuple< input_ts... > inputs, std::index_sequence< Is... >)
Definition util.hpp:119
constexpr auto to_array(const std::tuple< Ts... > &tuple)
Definition util.hpp:42
int accumulate_sizes_on_qp(const input_t &inputs, std::array< bool, sizeof...(i)> &kinput_is_dependent, const std::array< int, sizeof...(i)> &input_to_field, const std::array< FieldDescriptor, num_fields > &fields, std::index_sequence< i... > seq)
Accumulates the sizes of field operators on quadrature points for dependent inputs.
Definition util.hpp:2223
const DofToQuad * GetDofToQuad(const FieldDescriptor &f, const IntegrationRule &ir, DofToQuad::Mode mode)
Get the GetDofToQuad object for a given entity type.
Definition util.hpp:1310
std::array< DofToQuadMap, N > create_dtq_maps_impl(field_operator_ts &fops, std::vector< const DofToQuad * > &dtqs, const std::array< size_t, N > &field_map, std::index_sequence< Is... >)
Definition util.hpp:2249
constexpr void for_constexpr(lambda &&f, std::integer_sequence< std::size_t, i ... >)
Definition util.hpp:71
MFEM_HOST_DEVICE std::array< DeviceTensor< 2 >, sizeof...(i)> get_local_input_qp(const std::array< DeviceTensor< 3 >, sizeof...(i)> &input_qp_global, int e, std::index_sequence< i... >)
Definition util.hpp:2096
void CheckCompatibility(const FieldDescriptor &f)
Check the compatibility of a field operator type with a FieldDescriptor.
Definition util.hpp:1350
std::function< void(const Vector &, Vector &)> get_prolongation_transpose(const FieldDescriptor &f, const fop_t &fop, MPI_Comm mpi_comm)
Get a transpose prolongation callback for a field descriptor.
Definition util.hpp:1179
std::size_t FindIdx(const std::size_t &id, const std::vector< FieldDescriptor > &fields)
Find the index of a field descriptor in a vector of field descriptors.
Definition util.hpp:742
MFEM_HOST_DEVICE std::array< DeviceTensor< 1 >, num_fields > load_field_mem(void *mem, int offset, const std::array< int, num_fields > &sizes, const std::array< DeviceTensor< 2 >, num_fields > &fields_e, const int &entity_idx)
Definition util.hpp:1849
void forall(func_t f, const int &N, const ThreadBlocks &blocks, int num_shmem=0, real_t *shmem=nullptr)
Definition util.hpp:614
int GetVDim(const FieldDescriptor &f)
Get the vdim of a field descriptor.
Definition util.hpp:864
int GetSizeOnQP(const field_operator_t &, const FieldDescriptor &f)
Get the size on quadrature point for a field operator type and FieldDescriptor combination.
Definition util.hpp:1402
MFEM_HOST_DEVICE DeviceTensor< 2 > load_residual_mem(void *mem, int offset, const int &residual_size, const int &num_qp)
Definition util.hpp:1925
auto make_dependency_map(tuple< input_ts... > inputs)
Definition util.hpp:147
constexpr int GetFieldId()
Definition util.hpp:448
int GetVSize(const FieldDescriptor &f)
Get the vdof size of a field descriptor.
Definition util.hpp:760
auto get_marked_entries(const std::array< T, N > &a, const std::array< bool, N > &marker)
Get marked entries from an std::array based on a marker array.
Definition util.hpp:521
void pretty_print_mpi(const mfem::Vector &v)
Pretty print an mfem::Vector with MPI rank.
Definition util.hpp:400
MFEM_HOST_DEVICE std::array< DeviceTensor< 1 >, 6 > load_scratch_mem(void *mem, int offset, const std::array< int, N > &sizes)
Definition util.hpp:1937
int GetDimension(const FieldDescriptor &f)
Get the spatial dimension of a field descriptor.
Definition util.hpp:895
MFEM_HOST_DEVICE void set_zero(std::array< DeviceTensor< 2 >, N > &v)
Definition util.hpp:2111
int GetTrueVSize(const FieldDescriptor &f)
Get the true dof size of a field descriptor.
Definition util.hpp:829
constexpr auto extract_field_ids_impl(Tuple &&t, std::index_sequence< Is... >)
Definition util.hpp:454
void print_tuple_impl(const Tuple &t, std::index_sequence< Is... >)
Definition util.hpp:187
void print_shared_memory_info(shmem_info_t &shmem_info)
Definition util.hpp:1720
MFEM_HOST_DEVICE zero & get(zero &x)
let zero be accessed like a tuple
Definition tensor.hpp:281
__global__ void forall_kernel_shmem(func_t f, int n)
Definition util.hpp:602
real_t u(const Vector &xvec)
Definition lor_mms.hpp:22
T * Write(Memory< T > &mem, int size, bool on_dev=true)
Get a pointer for write access to mem with the mfem::Device's DeviceMemoryClass, if on_dev = true,...
Definition device.hpp:365
OutStream out(std::cout)
Global stream used by the library for standard output. Initially it uses the same std::streambuf as s...
Definition globals.hpp:66
MemoryClass
Memory classes identify sets of memory types.
T * ReadWrite(Memory< T > &mem, int size, bool on_dev=true)
Get a pointer for read+write access to mem with the mfem::Device's DeviceMemoryClass,...
Definition device.hpp:382
MFEM_HOST_DEVICE DeviceTensor< sizeof...(Dims), T > Reshape(T *ptr, Dims... dims)
Wrap a pointer as a DeviceTensor with automatically deduced template parameters.
Definition dtensor.hpp:138
float real_t
Definition config.hpp:46
ElementDofOrdering
Constants describing the possible orderings of the DOFs in one element.
Definition fespace.hpp:47
std::function< real_t(const Vector &)> f(real_t mass_coeff)
Definition lor_mms.hpp:30
void forall(int N, lambda &&body)
Definition forall.hpp:839
FaceType
Definition mesh.hpp:49
@ HIP_MASK
Biwise-OR of all HIP backends.
Definition device.hpp:93
@ CPU_MASK
Biwise-OR of all CPU backends.
Definition device.hpp:89
@ CUDA_MASK
Biwise-OR of all CUDA backends.
Definition device.hpp:91
Helper struct to convert a C++ type to an MPI type.
DofToQuadMap struct.
Definition util.hpp:1518
DeviceTensor< 3, const real_t > G
Gradient of the basis functions evaluated at quadrature points.
Definition util.hpp:1535
Index
Enumeration for the indices of the mappings B and G.
Definition util.hpp:1521
int which_input
Reverse mapping indicating which input this map belongs to.
Definition util.hpp:1538
DeviceTensor< 3, const real_t > B
Basis functions evaluated at quadrature points.
Definition util.hpp:1530
FieldDescriptor struct.
Definition util.hpp:551
std::size_t id
Field ID.
Definition util.hpp:558
FieldDescriptor(std::size_t field_id, const T *v)
Constructor.
Definition util.hpp:569
data_variant_t data
Field variant.
Definition util.hpp:561
FieldDescriptor()
Default constructor.
Definition util.hpp:564
std::variant< const FiniteElementSpace *, const ParFiniteElementSpace *, const ParameterSpace * > data_variant_t
Definition util.hpp:552
std::array< int, num_fields > field_sizes
Definition util.hpp:1575
std::array< std::array< int, 2 >, num_inputs > input_dtq_sizes
Definition util.hpp:1573
std::array< std::array< int, 2 >, num_outputs > output_dtq_sizes
Definition util.hpp:1574
std::array< int, num_inputs > shadow_sizes
Definition util.hpp:1578
std::array< int, num_inputs > input_sizes
Definition util.hpp:1577
std::array< int, 6 > temp_sizes
Definition util.hpp:1580
std::array< int, 8 > offsets
Definition util.hpp:1572
ThreadBlocks struct.
Definition util.hpp:594
This is a class that mimics most of std::tuple's interface, except that it is usable in CUDA kernels ...
Definition tuple.hpp:49