MFEM v4.9.0
Finite element discretization library
Loading...
Searching...
No Matches
schrodinger_flow.hpp
Go to the documentation of this file.
1// Copyright (c) 2010-2025, Lawrence Livermore National Security, LLC. Produced
2// at the Lawrence Livermore National Laboratory. All Rights reserved. See files
3// LICENSE and NOTICE for details. LLNL-CODE-806117.
4//
5// This file is part of the MFEM library. For more information and source code
6// availability visit https://mfem.org.
7//
8// MFEM is free software; you can redistribute it and/or modify it under the
9// terms of the BSD-3 license. We welcome feedback and contributions, see file
10// CONTRIBUTING.md for details.
11//
12// ---------------------------------------------
13// Incompressible Schrödinger Flow (ISF) Miniapp
14// ---------------------------------------------
15//
16// This miniapp introduces the Incompressible Schrödinger Flow (ISF) method,
17// an approach for simulating inviscid fluid dynamics by solving the linear
18// Schrödinger equation, leveraging the hydrodynamical analogy to quantum
19// mechanics proposed by Madelung in 1926. ISF offers a simple and efficient
20// framework, particularly effective for capturing vortex dynamics.
21// See README for more details.
22
23#pragma once
24
25#include "mfem.hpp"
26
27using namespace mfem;
28
29namespace mfem
30{
31
32/// @brief Options for the Incompressible Schrödinger Flow solver.
33struct Options: public OptionsParser
34{
35 const char *device = "cpu";
36 int order = 1;
37 // Simulation setup
38 real_t dt = 0.0;
39 real_t hbar = 1e-1;
40 int max_steps = 256;
41 // Mesh setup
42 int dim = 2;
43 int nx = 64, ny = 64, nz = 64;
44 real_t sx = 4.0, sy = 4.0, sz = 4.0;
45 bool periodic = true, set_bc = false;
46 // Leapfrog setup
47 bool leapfrog = false;
49 leapfrog_r1 = 0.4, leapfrog_r2 = 0.26;
50 // Jet setup
51 bool jet = false;
53 enum class JetGeom : int { Band = 0, Disc = 1, Rect = 2 };
54 int jet_geom = 1;
55 // Solvers setup
56 real_t rtol = 1e-6, atol = 0.0, ftz = 1e-15;
57 int max_iters = 1000, print_level = -1;
58 // Visualization setup
59 enum class VisData : int
60 {
61 Velocity, // Velocity norm: jet, leapfrog
62 Vorticity, // Vorticity norm: leapfrog only
63 X, Y, Z, Jet, // (debug: Coordinates and Jet geometry)
65 };
66 bool visualization = true, paraview = false;
67 int vis_steps = 1, vis_width = 1024, vis_height = 1024;
68 int vis_data = static_cast<int>(VisData::Vorticity);
69 const char *vis_keys = "cgjR";
70
71 Options(int argc, char *argv[]): OptionsParser(argc, argv)
72 {
73 AddOption(&device, "-d", "--device",
74 "Device configuration string, see Device::Configure().");
75 AddOption(&order, "-o", "--order", "Finite element order");
76 AddOption(&dt, "-dt", "--dt", "Timestep size");
77 AddOption(&hbar, "-hbar", "--hbar", "Planck constant");
78 AddOption(&max_steps, "-ms", "--max-steps", "Maximum steps");
79 AddOption(&dim, "-dim", "--dim", "Dimension of the problem (2 or 3)");
80 AddOption(&nx, "-nx", "--nx", "Number of elements in x direction");
81 AddOption(&ny, "-ny", "--ny", "Number of elements in y direction");
82 AddOption(&nz, "-nz", "--nz", "Number of elements in z direction");
83 AddOption(&sx, "-sx", "--sx", "Size of the domain in x direction");
84 AddOption(&sy, "-sy", "--sy", "Size of the domain in y direction");
85 AddOption(&sz, "-sz", "--sz", "Size of the domain in z direction");
86 AddOption(&periodic, "-per", "--periodic", "-no-per",
87 "--no-periodic", "Use a periodic mesh.");
88 AddOption(&set_bc, "-bc", "--impose-bc", "-no-bc", "--dont-impose-bc",
89 "Impose or not essential boundary conditions.");
90 AddOption(&leapfrog, "-lf", "--leapfrog", "-no-lf", "--no-leapfrog",
91 "Enable or disable leapfrog.");
92 AddOption(&leapfrog_vx, "-lvx", "--leapfrog-vx", "Leapfrog X velocity");
93 AddOption(&leapfrog_r1, "-lr1", "--leapfrog-r1", "Leapfrog ring 1 radius");
94 AddOption(&leapfrog_r2, "-lr2", "--leapfrog-r2", "Leapfrog ring 2 radius");
95 AddOption(&leapfrog_sw, "-lsw", "--leapfrog-sw", "Leapfrog swirling strength");
96 AddOption(&jet, "-jet", "--jet", "-no-jet", "--no-jet",
97 "Enable or disable jet.");
98 AddOption(&jet_vx, "-jvx", "--jet-vx", "Jet X velocity");
99 AddOption(&jet_geom, "-jg", "--jet-geom", "0: strip, 1: disc, 2: rect");
100 AddOption(&rtol, "-rtol", "--rtol", "Solvers relative tolerance");
101 AddOption(&atol, "-atol", "--atol", "Solvers absolute tolerance");
102 AddOption(&ftz, "-ftz", "--ftz", "Flush to zero threshold");
103 AddOption(&max_iters, "-mi", "--max-iterations", "Solvers max iterations");
104 AddOption(&print_level, "-pl", "--print-level", "Solvers print level");
105 AddOption(&visualization, "-vis", "--visualization", "-no-vis",
106 "--no-visualization", "Enable or not GLVis visualization");
107 AddOption(&paraview, "-pv", "--paraview", "-no-pv",
108 "--no-paraview", "Enable or not Paraview visualization");
109 AddOption(&vis_steps, "-vs", "--vis-steps", "Visualization steps");
110 AddOption(&vis_width, "-vw", "--vis-width", "vis width");
111 AddOption(&vis_height, "-vh", "--vis-height", "vis height");
112 AddOption(&vis_data, "-vd", "--vis-data",
113 "Velocity: 0: Vorticity: 1 (leapfrog only)");
114 AddOption(&vis_keys, "-vk", "--vis-keys", "Visualization keys, default: cgjR");
115 ParseCheck();
116 MFEM_VERIFY(jet ^ leapfrog, "'jet' or 'leapfrog' option must be set");
117 MFEM_VERIFY(vis_data < static_cast<int>(VisData::Unknown),
118 "Invalid visualization data option.");
119 if (dt == 0.0)
120 {
121 const auto dx = sx / static_cast<real_t>(order * nx);
122 dt = (dx*dx) / hbar;
123 }
124 }
125};
126
127} // namespace mfem
128
129/// @brief Complex number type for device.
130#if !(defined(MFEM_USE_CUDA) || defined(MFEM_USE_HIP))
131#include <cmath>
132#include <complex>
133#include <utility>
134using complex_t = std::complex<real_t>;
135#else // CUDA or HIP
136
137#ifdef MFEM_USE_CUDA
138#include <cuComplex.h>
139#ifdef MFEM_USE_SINGLE
140using RealComplex_t = cuFloatComplex;
141#else
142using RealComplex_t = cuDoubleComplex;
143#endif // MFEM_USE_SINGLE
144#endif // MFEM_USE_CUDA
145
146#ifdef MFEM_USE_HIP
147#include <hip/hip_complex.h>
148#ifdef MFEM_USE_SINGLE
149using RealComplex_t = hipFloatComplex;
150#else
151using RealComplex_t = hipDoubleComplex;
152#endif // MFEM_USE_SINGLE
153#endif // MFEM_USE_HIP
154
155struct Complex : public RealComplex_t
156{
157 MFEM_HOST_DEVICE Complex() = default;
158 MFEM_HOST_DEVICE Complex(real_t r) { x = r, y = 0.0; }
159 MFEM_HOST_DEVICE Complex(real_t r, real_t i) { x = r, y = i; }
160 MFEM_HOST_DEVICE real_t real() const { return x; }
161 MFEM_HOST_DEVICE void real(real_t r) { x = r; }
162 MFEM_HOST_DEVICE real_t imag() const { return y; }
163 MFEM_HOST_DEVICE void imag(real_t i) { y = i; }
164
165 template <typename U>
166 MFEM_HOST_DEVICE inline Complex &operator*=(const U &z)
167 {
168 return (*this = *this * z, *this);
169 }
170
171 template <typename U>
172 MFEM_HOST_DEVICE inline Complex &operator/=(const U &z)
173 {
174 return (*this = *this / z, *this);
175 }
176};
177
178MFEM_HOST_DEVICE inline Complex operator*(const Complex &x, const real_t &y)
179{
180 return Complex(x.real() * y, x.imag() * y);
181}
182
183MFEM_HOST_DEVICE inline Complex operator+(const Complex &a, const Complex &b)
184{
185 return Complex(a.real() + b.real(), a.imag() + b.imag());
186}
187
188MFEM_HOST_DEVICE inline Complex operator*(const real_t d, const Complex &z)
189{
190 return Complex(z.real() * d, z.imag() * d);
191}
192
193MFEM_HOST_DEVICE inline Complex operator*(const Complex &a, const Complex &b)
194{
195 return Complex(a.real() * b.real() - a.imag() * b.imag(),
196 a.real() * b.imag() + a.imag() * b.real());
197}
198
199MFEM_HOST_DEVICE inline Complex operator/(const Complex &z, const real_t &d)
200{
201 return Complex(z.real() / d, z.imag() / d);
202}
203
204MFEM_HOST_DEVICE inline real_t abs(const Complex &z)
205{
206 return std::hypot(z.real(), z.imag());
207}
208
209MFEM_HOST_DEVICE inline Complex exp(const Complex &q)
210{
211 real_t s, c, e = std::exp(q.real());
212#ifdef MFEM_USE_SINGLE
213 sincosf(q.imag(), &s, &c);
214#else
215 sincos(q.imag(), &s, &c);
216#endif
217 return Complex(c * e, s * e);
218}
219
220MFEM_HOST_DEVICE inline real_t norm(const Complex &z)
221{
222 return z.real() * z.real() + z.imag() * z.imag();
223}
224
225using complex_t = Complex;
226#endif // MFEM_USE_CUDA || MFEM_USE_HIP
227
228namespace mfem
229{
230
231using real3_t = std::array<real_t, 3>;
232
233/// @brief Base class for Schrodinger solver kernels.
234template <typename TMesh,
235 typename TFiniteElementSpace,
236 typename TComplexGridFunction,
237 typename TGridFunction,
238 typename TBilinearForm,
239 typename TMixedBilinearForm,
240 typename TLinearForm>
242{
243 std::function<Mesh()> CreateMesh2D, CreateMesh3D;
248 TFiniteElementSpace h1_fes, nd_fes, nodal_fes;
249 TGridFunction nodes;
250 const int ne, ndofs;
256 std::function<void()> SetEssentialTrueDofs;
261 TComplexGridFunction psi1, psi2;
262 TComplexGridFunction delta_psi1, delta_psi2, gpsi1_nd, gpsi2_nd;
263 TComplexGridFunction gpsi1_x, gpsi2_x, gpsi1_y, gpsi2_y, gpsi1_z, gpsi2_z;
264 TGridFunction div_u, q, h1_gf, nd_gf;
265 TLinearForm rhs;
269
271 std::function<TMesh(Mesh&)> CreateMesh,
272 std::function<OrthoSolver()> CreateOrthoSolver,
273 std::function<CGSolver()> CreateCGSolver):
274 Options(config),
275 CreateMesh2D([&]()
276 {
277 const auto type = Element::QUADRILATERAL;
278 Mesh xy = Mesh::MakeCartesian2D(nx, ny, type, false, sx, sy, false);
280 if (!periodic) { return xy; }
281 std::vector<Vector> Tr2 = { Vector({ sx, 0.0_r }),
282 Vector({ 0.0_r, sy })
283 };
285 }),
286 CreateMesh3D([&]()
287 {
288 const auto type = Element::HEXAHEDRON;
289 Mesh xyz = Mesh::MakeCartesian3D(nx, ny, nz, type, sx, sy, sz, false);
290 xyz.SetCurvature(order);
291 if (!periodic) { return xyz; }
292 std::vector<Vector> Tr3 = { Vector({ sx, 0.0_r, 0.0_r }),
293 Vector({ 0.0_r, sy, 0.0_r }),
294 Vector({ 0.0_r, 0.0_r, sz })
295 };
297 }),
299 mesh(CreateMesh(serial_mesh)),
300 h1_fec(order, dim),
301 nd_fec(order, dim),
302 h1_fes(&mesh, &h1_fec),
303 nd_fes(&mesh, &nd_fec),
306 ne(mesh.GetNE()),
307 ndofs(h1_fes.GetNDofs()),
308 one(1.0),
309 Vx(dim, [&](const Vector &, Vector &v)
310 {
311 v.SetSize(dim), v[0] = 1.0, v[1] = 0.0;
312 if (dim == 3) { v[2] = 0.0; }
313 }),
314 Vy(dim, [&](const Vector &, Vector &v)
315 {
316 v.SetSize(dim), v[0] = 0.0, v[1] = 1.0;
317 if (dim == 3) { v[2] = 0.0; }
318 }),
319 Vz(dim, [&](const Vector &, Vector &v)
320 {
321 v.SetSize(dim), v[0] = 0.0, v[1] = 0.0;
322 if (dim == 3) { v[2] = 1.0; }
323 }),
324 mass_h1(&h1_fes),
325 mass_nd(&nd_fes),
326 diff_h1(&h1_fes),
333 {
334 if (periodic || mesh.bdr_attributes.Size() == 0) { return; }
335 Array<int> ess_bdr(mesh.bdr_attributes.Max());
336 ess_bdr = set_bc ? 1 : 0;
337 h1_fes.GetEssentialTrueDofs(ess_bdr, ess_tdof_list);
338 }),
340 diff_h1_setup((diff_h1.AddDomainIntegrator(new DiffusionIntegrator(one)),
341 diff_h1.SetAssemblyLevel(AssemblyLevel::PARTIAL),
343 true)),
345 diff_h1_ortho(CreateOrthoSolver()),
346 mass_h1_cgs(CreateCGSolver()),
347 mass_nd_cgs(CreateCGSolver()),
348 diff_h1_cgs(CreateCGSolver()),
349 psi1(&h1_fes),
350 psi2(&h1_fes),
356 div_u(&h1_fes),
357 q(&h1_fes),
358 h1_gf(&h1_fes),
359 nd_gf(&nd_fes),
360 rhs(&h1_fes),
362 {
363 mesh.GetNodes(nodes);
364
365 mass_h1.AddDomainIntegrator(new MassIntegrator(one));
366 mass_nd.AddDomainIntegrator(new VectorFEMassIntegrator(one));
367 grad_nd.AddDomainIntegrator(new MixedVectorGradientIntegrator());
368
369 nd_dot_x_h1.AddDomainIntegrator(new MixedDotProductIntegrator(Vx));
370 nd_dot_y_h1.AddDomainIntegrator(new MixedDotProductIntegrator(Vy));
371 nd_dot_z_h1.AddDomainIntegrator(new MixedDotProductIntegrator(Vz));
372
373 mass_h1.SetAssemblyLevel(AssemblyLevel::PARTIAL);
374 mass_nd.SetAssemblyLevel(AssemblyLevel::PARTIAL);
375 grad_nd.SetAssemblyLevel(AssemblyLevel::PARTIAL);
376
377 nd_dot_x_h1.SetAssemblyLevel(AssemblyLevel::LEGACY);
378 nd_dot_y_h1.SetAssemblyLevel(AssemblyLevel::LEGACY);
379 nd_dot_z_h1.SetAssemblyLevel(AssemblyLevel::LEGACY);
380
383 grad_nd.Assemble();
384
385 nd_dot_x_h1.Assemble();
386 nd_dot_y_h1.Assemble();
387 if (dim == 3) { nd_dot_z_h1.Assemble(); }
388
389 nd_dot_x_h1.Finalize();
390 nd_dot_y_h1.Finalize();
391 if (dim == 3) { nd_dot_z_h1.Finalize(); }
392
393 mass_h1.FormSystemMatrix(ess_tdof_list, mass_h1_op);
394 mass_nd.FormSystemMatrix(ess_tdof_list, mass_nd_op);
395 diff_h1.FormSystemMatrix(ess_tdof_list, diff_h1_op);
396
397 // Only used for velocity computation
398 if (visualization && static_cast<VisData>(vis_data) == VisData::Velocity)
399 {
400 grad_nd.FormRectangularSystemMatrix(ess_tdof_list, ess_tdof_list, grad_nd_op);
401 nd_dot_x_h1.FormRectangularSystemMatrix(ess_tdof_list, ess_tdof_list,
403 nd_dot_y_h1.FormRectangularSystemMatrix(ess_tdof_list, ess_tdof_list,
405 if (dim == 3)
406 {
407 nd_dot_z_h1.FormRectangularSystemMatrix(ess_tdof_list, ess_tdof_list,
409 }
410 }
411
418
425
434
435 rhs.AddDomainIntegrator(new DomainLFIntegrator(div_u_coeff));
436 rhs.UseFastAssembly(true);
437 }
438
439 /// @brief Initialize the wavefunctions psi1 and psi2.
440 void Initialize(Vector &phase_r)
441 {
442 psi1 = 0.0, psi2 = 0.0;
443 if (leapfrog && phase_r.Size() > 0)
444 {
445 const auto phase = phase_r.Read();
446 auto psi1_r = Reshape(psi1.real().ReadWrite(), ndofs);
447 auto psi1_i = Reshape(psi1.imag().ReadWrite(), ndofs);
448 auto psi2_r = Reshape(psi2.real().ReadWrite(), ndofs);
449 auto psi2_i = Reshape(psi2.imag().ReadWrite(), ndofs);
450 mfem::forall(ndofs, [=] MFEM_HOST_DEVICE(int n)
451 {
452 const complex_t eps = 0.01, i_phase(0, phase[n]);
453 const complex_t z1 = exp(i_phase);
454 const complex_t z2 = eps * exp(i_phase);
455 psi1_r(n) = z1.real(), psi1_i(n) = z1.imag();
456 psi2_r(n) = z2.real(), psi2_i(n) = z2.imag();
457 });
458 }
459 if (jet) { psi1.real() = 1.0, psi2.real() = 0.0; }
460 }
461
462 /// @brief Normalize the wavefunctions psi1 and psi2.
464 {
465 auto psi1_r = Reshape(psi1.real().ReadWrite(), ndofs);
466 auto psi1_i = Reshape(psi1.imag().ReadWrite(), ndofs);
467 auto psi2_r = Reshape(psi2.real().ReadWrite(), ndofs);
468 auto psi2_i = Reshape(psi2.imag().ReadWrite(), ndofs);
469 mfem::forall(ndofs, [=] MFEM_HOST_DEVICE(int n)
470 {
471 complex_t psi1(psi1_r(n), psi1_i(n)), psi2(psi2_r(n), psi2_i(n));
472 const real_t psi_norm = std::sqrt(norm(psi1) + norm(psi2));
473 if (fabs(psi_norm) < 1e-16) { return; }
474 psi1_r(n) /= psi_norm, psi1_i(n) /= psi_norm;
475 psi2_r(n) /= psi_norm, psi2_i(n) /= psi_norm;
476 });
477 }
478
479 /// @brief Restrict the wavefunctions psi1 and psi2.
480 void Restrict(const real_t t, const TGridFunction &isJet_in,
481 const real_t omega, const TGridFunction &phase_in)
482 {
483 MFEM_VERIFY(jet, "Jet must be enabled use restrict.");
484 const auto isJet = isJet_in.Read();
485 const auto phase = phase_in.Read();
486 auto psi1_r = Reshape(psi1.real().ReadWrite(), ndofs);
487 auto psi1_i = Reshape(psi1.imag().ReadWrite(), ndofs);
488 auto psi2_r = Reshape(psi2.real().ReadWrite(), ndofs);
489 auto psi2_i = Reshape(psi2.imag().ReadWrite(), ndofs);
490 mfem::forall(ndofs, [=] MFEM_HOST_DEVICE(int n)
491 {
492 if (isJet[n] == 0) { return; }
493 const complex_t i_pn_omega_t(0.0, phase[n] - omega * t);
494 complex_t psi1(psi1_r(n), psi1_i(n)), psi2(psi2_r(n), psi2_i(n));
495 const real_t amp1 = abs(psi1), amp2 = abs(psi2);
496 psi1 = amp1 * exp(i_pn_omega_t);
497 psi2 = amp2 * exp(i_pn_omega_t);
498 psi1_r(n) = psi1.real(), psi1_i(n) = psi1.imag();
499 psi2_r(n) = psi2.real(), psi2_i(n) = psi2.imag();
500 });
501 }
502
503 template<typename Gfn, typename Xfn, typename Yfn, typename Zfn>
504 void GradPsi(Gfn &Grad_nd, Xfn &x_dot_Mm1, Yfn &y_dot_Mm1, Zfn &z_dot_Mm1)
505 {
506 Grad_nd(psi1.real(), gpsi1_nd.real());
507 Grad_nd(psi1.imag(), gpsi1_nd.imag());
508 Grad_nd(psi2.real(), gpsi2_nd.real());
509 Grad_nd(psi2.imag(), gpsi2_nd.imag());
510
511 x_dot_Mm1(gpsi1_nd.real(), gpsi1_x.real());
512 x_dot_Mm1(gpsi1_nd.imag(), gpsi1_x.imag());
513 x_dot_Mm1(gpsi2_nd.real(), gpsi2_x.real());
514 x_dot_Mm1(gpsi2_nd.imag(), gpsi2_x.imag());
515
516 y_dot_Mm1(gpsi1_nd.real(), gpsi1_y.real());
517 y_dot_Mm1(gpsi1_nd.imag(), gpsi1_y.imag());
518 y_dot_Mm1(gpsi2_nd.real(), gpsi2_y.real());
519 y_dot_Mm1(gpsi2_nd.imag(), gpsi2_y.imag());
520
521 if (dim == 3)
522 {
523 z_dot_Mm1(gpsi1_nd.real(), gpsi1_z.real());
524 z_dot_Mm1(gpsi1_nd.imag(), gpsi1_z.imag());
525 z_dot_Mm1(gpsi2_nd.real(), gpsi2_z.real());
526 z_dot_Mm1(gpsi2_nd.imag(), gpsi2_z.imag());
527 }
528 }
529
530 // u = ℏRe{−𝑖𝝭ᵀ·∇𝝭} = ħ[𝝭1r.∇𝝭1i - 𝝭1i.∇𝝭1r + 𝝭2r.∇𝝭2i - 𝝭2i.∇𝝭2r]
531 void GradPsiVelocity(const real_t hbar, TGridFunction &ux,
532 TGridFunction &uy, TGridFunction &uz)
533 {
534 const auto psi1r = Reshape(psi1.real().Read(), ndofs);
535 const auto psi1i = Reshape(psi1.imag().Read(), ndofs);
536 const auto psi2r = Reshape(psi2.real().Read(), ndofs);
537 const auto psi2i = Reshape(psi2.imag().Read(), ndofs);
538
539 const auto Gpsi1rx = Reshape(gpsi1_x.real().Read(), ndofs);
540 const auto Gpsi1ix = Reshape(gpsi1_x.imag().Read(), ndofs);
541 const auto Gpsi1ry = Reshape(gpsi1_y.real().Read(), ndofs);
542 const auto Gpsi1iy = Reshape(gpsi1_y.imag().Read(), ndofs);
543 const auto Gpsi1rz = Reshape(gpsi1_z.real().Read(), ndofs);
544 const auto Gpsi1iz = Reshape(gpsi1_z.imag().Read(), ndofs);
545
546 const auto Gpsi2rx = Reshape(gpsi2_x.real().Read(), ndofs);
547 const auto Gpsi2ix = Reshape(gpsi2_x.imag().Read(), ndofs);
548 const auto Gpsi2ry = Reshape(gpsi2_y.real().Read(), ndofs);
549 const auto Gpsi2iy = Reshape(gpsi2_y.imag().Read(), ndofs);
550 const auto Gpsi2rz = Reshape(gpsi2_z.real().Read(), ndofs);
551 const auto Gpsi2iz = Reshape(gpsi2_z.imag().Read(), ndofs);
552
553 auto vx = Reshape(ux.Write(), ndofs);
554 auto vy = Reshape(uy.Write(), ndofs);
555 auto vz = Reshape(uz.Write(), ndofs);
556 const real_t FTZ = ftz;
557 const int DIM = dim;
558
559 mfem::forall(ndofs, [=] MFEM_HOST_DEVICE(int n)
560 {
561 vx(n) = psi1r(n) * Gpsi1ix(n) - psi1i(n) * Gpsi1rx(n);
562 vx(n) += psi2r(n) * Gpsi2ix(n) - psi2i(n) * Gpsi2rx(n);
563 vx(n) *= (fabs(vx(n)) < FTZ) ? 0.0 : hbar;
564 vy(n) = psi1r(n) * Gpsi1iy(n) - psi1i(n) * Gpsi1ry(n);
565 vy(n) += psi2r(n) * Gpsi2iy(n) - psi2i(n) * Gpsi2ry(n);
566 vy(n) *= (fabs(vy(n)) < FTZ) ? 0.0 : hbar;
567 if (DIM == 2) { return; }
568 vz(n) = psi1r(n) * Gpsi1iz(n) - psi1i(n) * Gpsi1rz(n);
569 vz(n) += psi2r(n) * Gpsi2iz(n) - psi2i(n) * Gpsi2rz(n);
570 vz(n) *= (fabs(vz(n)) < FTZ) ? 0.0 : hbar;
571 });
572 }
573
574 // ∇∙u = -ℏ.Re{𝝭ᵀ·𝑖∆𝝭} = -ℏ[𝝭1i.∆𝝭1r - 𝝭1r.∆𝝭1i + 𝝭2i.∆𝝭2r - 𝝭2r.∆𝝭2i]
576 {
577 const auto psi1r = Reshape(psi1.real().Read(), ndofs);
578 const auto psi1i = Reshape(psi1.imag().Read(), ndofs);
579 const auto psi2r = Reshape(psi2.real().Read(), ndofs);
580 const auto psi2i = Reshape(psi2.imag().Read(), ndofs);
581 const auto Dpsi1r = Reshape(delta_psi1.real().Read(), ndofs);
582 const auto Dpsi1i = Reshape(delta_psi1.imag().Read(), ndofs);
583 const auto Dpsi2r = Reshape(delta_psi2.real().Read(), ndofs);
584 const auto Dpsi2i = Reshape(delta_psi2.imag().Read(), ndofs);
585 auto div_u_w = Reshape(div_u.Write(), ndofs);
586 mfem::forall(ndofs, [=] MFEM_HOST_DEVICE(int n)
587 {
588 div_u_w(n) = psi1i(n) * Dpsi1r(n) - psi1r(n) * Dpsi1i(n);
589 div_u_w(n) += psi2i(n) * Dpsi2r(n) - psi2r(n) * Dpsi2i(n);
590 div_u_w(n) *= -1.0;
591 });
592 }
593
594 // 𝝭ⁿ⁺¹ = exp(−i.q/ħ).𝝭ⁿ
596 {
597 const auto q_r = Reshape(q.Read(), ndofs);
598 auto psi1_r = Reshape(psi1.real().ReadWrite(), ndofs);
599 auto psi1_i = Reshape(psi1.imag().ReadWrite(), ndofs);
600 auto psi2_r = Reshape(psi2.real().ReadWrite(), ndofs);
601 auto psi2_i = Reshape(psi2.imag().ReadWrite(), ndofs);
602 mfem::forall(ndofs, [=] MFEM_HOST_DEVICE(int n)
603 {
604 const complex_t minus_i(0, -1.0);
605 const complex_t eiq = exp(minus_i * q_r(n));
606 complex_t psi1(psi1_r(n), psi1_i(n)), psi2(psi2_r(n), psi2_i(n));
607 psi1 *= eiq, psi2 *= eiq;
608 psi1_r(n) = psi1.real(), psi1_i(n) = psi1.imag();
609 psi2_r(n) = psi2.real(), psi2_i(n) = psi2.imag();
610 });
611 }
612
613 /// @brief Add a circular vortex to the wavefunctions psi1 and psi2.
614 void AddCircularVortex(const real3_t center, const real3_t normal,
615 const real_t radius, const real_t swirling)
616 {
617 MFEM_VERIFY(swirling > 0.0, "Swirling strength must be positive");
618 const auto DIM = dim;
619 const real3_t o = { center[0], center[1], center[2] };
620 const real_t norm2 = std::sqrt(normal[0] * normal[0] +
621 normal[1] * normal[1] +
622 normal[2] * normal[2]);
623 const real_t n0 = normal[0] / norm2, n1 = normal[1] / norm2,
624 n2 = normal[2] / norm2;
625 const auto &X = Reshape(nodes.Read(), dim, ndofs);
626 auto psi1_r = Reshape(psi1.real().ReadWrite(), ndofs);
627 auto psi1_i = Reshape(psi1.imag().ReadWrite(), ndofs);
628 mfem::forall(ndofs, [=] MFEM_HOST_DEVICE(int n)
629 {
630 const real_t px = X(0, n), py = X(1, n),
631 pz = DIM == 3 ? X(2, n) : 0.0;
632 const real_t rx = px - o[0], ry = py - o[1], rz = pz - o[2];
633 const real_t z = rx * n0 + ry * n1 + rz * n2;
634 const bool inRange = (rx * rx + ry * ry + rz * rz - z * z) < radius * radius;
635 const bool inLayerP = inRange && (z > 0.0 && z <= (+swirling / 2.0));
636 const bool inLayerM = inRange && (z <= 0.0 && z >= (-swirling / 2.0));
637 real_t alpha = 0.0;
638 if (inLayerP) { alpha = -M_PI * (2.0 * z / swirling - 1.0); }
639 if (inLayerM) { alpha = -M_PI * (2.0 * z / swirling + 1.0); }
640 complex_t psi1(psi1_r(n), psi1_i(n));
641 const complex_t alpha_i(0, alpha);
642 psi1 *= exp(alpha_i);
643 psi1_r(n) = psi1.real(), psi1_i(n) = psi1.imag();
644 });
645 }
646};
647
648/// @brief Crank-Nicolson time solver for the Schrodinger equation.
649template<typename TFiniteElementSpace,
650 typename TSesquilinearForm,
651 typename TComplexGridFunction>
653{
655 TSesquilinearForm C_form, R_form;
658 TComplexGridFunction z;
660
661 // ∂ₜ𝝭 = ½ℏ𝑖∆𝝭
662 // 𝝭ⁿ⁺¹ - 𝝭ⁿ = ¼𝛅t𝑖ℏ(∆𝝭ⁿ⁺¹ + ∆𝝭ⁿ)
663 // 𝝭ⁿ⁺¹ - ¼𝑖ℏ𝛅t∆𝝭ⁿ⁺¹ = 𝝭ⁿ + ¼𝑖ℏ𝛅t∆𝝭ⁿ
664 // [M + ¼𝑖ℏ𝛅tA]𝝭ⁿ⁺¹ = [M - ¼𝑖ℏ𝛅tA]𝝭ⁿ
665 // C = M + ¼𝑖ℏ𝛅tA, R = M - ¼𝑖ℏ𝛅tA
666 CrankNicolsonTimeBaseSolver(TFiniteElementSpace &fes,
667 real_t hbar, real_t dt,
668 std::function<GMRESSolver()> CreateGMRESSolver,
669 real_t rtol, real_t atol, int maxiter,
670 int print_level):
671 one(1.0), dthq(dt * hbar / 4.0), mdthq(-dt * hbar / 4.0),
672 C_form(&fes), R_form(&fes),
673 z(&fes),
674 gmres_solver(CreateGMRESSolver())
675 {
676 // C = M + ¼𝑖ℏ𝛅tA
677 C_form.AddDomainIntegrator(new MassIntegrator(one),
679 C_form.SetAssemblyLevel(AssemblyLevel::PARTIAL);
680 C_form.Assemble();
681 C_form.FormSystemMatrix(no_bc, C_op);
682
683 // R = M - ¼𝑖ℏ𝛅tA
684 R_form.AddDomainIntegrator(new MassIntegrator(one),
686 R_form.SetAssemblyLevel(AssemblyLevel::PARTIAL);
687 R_form.Assemble();
688 R_form.FormSystemMatrix(no_bc, R_op);
689
690 gmres_solver.SetPrintLevel(print_level);
692 gmres_solver.SetMaxIter(maxiter);
696 }
697
698 // 𝝭ⁿ⁺¹ = C⁻¹ R 𝝭ⁿ
699 virtual void Mult(TComplexGridFunction &psi) = 0;
700};
701
702/// @brief Class for simulating incompressible Schrodinger flow.
703template<typename TSchrodingerSolver, typename TGridFunction>
705{
706public:
707 TSchrodingerSolver &solver;
708 const int ndofs;
710 TGridFunction isJet, phase, vx, vy, vz;
711
712 IncompressibleBaseFlow(Options &config, TSchrodingerSolver &solver):
713 Options(config),
714 solver(solver),
716 omega(0.0),
717 isJet(&solver.h1_fes),
718 phase(&solver.h1_fes),
719 vx(&solver.h1_fes), vy(&solver.h1_fes), vz(&solver.h1_fes)
720 {
721 isJet = 0.0, phase = 0.0;
722 vx = 0.0, vy = 0.0, vz = 0.0;
723 Setup();
724 }
725
726 /**
727 * @brief Setup the solver.
728 *
729 * This function initializes the solver by setting up the phase,
730 * the Jet vectors and normalizing the wave functions.
731 * It also adds circular vortex rings if leapfrog is enabled.
732 */
733 void Setup()
734 {
735 real_t velocity[3];
736 velocity[0] = leapfrog ? leapfrog_vx : jet ? jet_vx : 0.0;
737 velocity[1] = velocity[2] = 0.0;
738
739 const real_t kvec0 = velocity[0] / hbar;
740 const real_t kvec1 = velocity[1] / hbar;
741 const real_t kvec2 = velocity[2] / hbar;
742
743 omega = 0.0;
744 for (int i = 0; i < 3; i++) { omega += velocity[i] * velocity[i]; }
745 omega /= 2.0 * hbar;
746
747 auto phase_w = phase.Write();
748 auto isJet_w = isJet.Write();
749 const int DIM = dim, JET_GEOM = jet_geom;
750 const real_t SX = sx, SY = sy, SZ = sz;
751 const auto X = Reshape(solver.nodes.Read(), dim, ndofs);
752 mfem::forall(ndofs, [=] MFEM_HOST_DEVICE(int n)
753 {
754 const real_t r = SX / 16.0;
755 const real_t px = X(0, n), py = X(1, n),
756 pz = DIM == 3 ? X(2, n) : 0.0;
757 const real_t dx = px - (SX/8.0), dy = py - (SY/2.0),
758 dz = DIM == 3 ? pz - (SZ/2.0) : 0.0;
759 const auto geom = static_cast<Options::JetGeom>(JET_GEOM);
760 if (geom == JetGeom::Band) { isJet_w[n] = (fabs(dy*dy) < (r*r)) ? 1 : 0; }
761 if (geom == JetGeom::Disc) { isJet_w[n] = (fabs(dx*dx + dy*dy + dz*dz) < (r*r)) ? 1 : 0; }
762 if (geom == JetGeom::Rect) { isJet_w[n] = (px > 0.2 && dx < 1.0 && fabs(dy*dy + dz*dz) < (r*r)) ? 1 : 0; }
763 phase_w[n] = kvec0 * px + kvec1 * py + kvec2 * pz;
764 });
765
766 solver.Initialize(phase);
767
768 if (jet) { ConstrainJetVelocity(); }
769
770 if (leapfrog) // Add vortex rings
771 {
772 const real_t zh = sx / 2.0, yh = sy / 2.0;
773 const real_t z2 = zh * zh, y2 = yh * yh, r = sqrt(z2 + y2);
774 const real3_t n = { -1.0, 0.0, 0.0 },
775 o = { sx / 2.0_r, sy / 2.0_r, DIM == 3 ? sz / 2.0_r : 0.0_r };
776 solver.AddCircularVortex(o, n, r * leapfrog_r1, leapfrog_sw);
777 solver.AddCircularVortex(o, n, r * leapfrog_r2, leapfrog_sw);
778 solver.Normalize(), solver.PressureProject();
779 }
780 }
781
782 /// @brief This function performs a single time step of the solver.
783 void Step(const real_t &t)
784 {
785 // Solve linear Schrödinger equation
786 solver.Step();
787
788 // Normalization of the wave function
789 solver.Normalize();
790
791 // Pressure projection
792 solver.PressureProject();
793
794 // Enforce geometry constraints
795 if (jet)
796 {
797 solver.Restrict(t, isJet, omega, phase);
798 solver.Normalize();
799 solver.PressureProject();
800 }
801
802 // Compute velocity field for visualization
803 if (visualization &&
804 static_cast<VisData>(vis_data) == VisData::Velocity)
805 {
806 solver.VelocityOneForm(vx, vy, vz);
807 }
808 }
809
810 /// @brief Restricts the velocity of the wavefunctions.
812 {
813 solver.Normalize();
814 MFEM_VERIFY(jet, "ConstrainJetVelocity() only for jet geometry");
815 for (int i = 0; i < 10; i++)
816 {
817 solver.Restrict(0.0, isJet, omega, phase);
818 solver.PressureProject();
819 }
820 }
821};
822
823/// @brief Base class for visualization.
824template <typename TMesh,
825 typename TGridFunction,
826 typename TFiniteElementSpace,
827 typename TSchrodingerSolver,
828 typename TIncompressibleFlow>
829struct VisualizerBase : private Options
830{
833 TGridFunction vis_gf;
834 std::function<void()> vis_fn;
835 std::function<std::string()> vis_prefix;
836 const int ndofs;
838 const TIncompressibleFlow &isf;
839 const TSchrodingerSolver &solver;
840
841#ifndef MFEM_USE_HDF5
843#else
845#endif
846
848 TSchrodingerSolver &solver,
849 const TIncompressibleFlow &isf,
850 std::function<std::string()> prefix):
851 Options(config),
853 vis_gf(&solver.h1_fes),
854 vis_prefix(std::move(prefix)),
855 ndofs(solver.h1_fes.GetNDofs()),
856 vis_data(static_cast<Options::VisData>(config.vis_data)),
857 isf(isf),
858 solver(solver),
859 dc("ISF", &mesh) { Visualize(); }
860
862 {
863 vis_gf = 0.0;
864
865 // Coordinates visualization (debug)
867 {
868 vis_fn = [&]()
869 {
870 const auto X = Reshape(solver.nodes.Read(), dim, ndofs);
871 auto viz_h1_w = vis_gf.Write();
872 mfem::forall(ndofs, [=] MFEM_HOST_DEVICE(int i)
873 {
874 viz_h1_w[i] = vis_data == Options::VisData::X ? X(0,i):
876 vis_data == Options::VisData::Z ? dim == 3 ? X(2,i) : 0.0:
877 0.0;
878 });
879 vis_gf.HostRead();
880 };
881 }
882
883 // Jet geometry visualization (debug)
885 {
886 vis_fn = [&]()
887 {
888 auto isJet_r = isf.isJet.Read();
889 auto viz_h1_w = vis_gf.Write();
890 mfem::forall(ndofs, [=] MFEM_HOST_DEVICE(int i) { viz_h1_w[i] = isJet_r[i]; });
891 vis_gf.HostRead();
892 };
893 }
894
895 // Velocity norm visualization
897 {
898 vis_fn = [&]()
899 {
900 const auto vx_r = isf.vx.Read(), vy_r = isf.vy.Read(), vz_r = isf.vz.Read();
901 auto viz_h1_w = vis_gf.Write();
902 mfem::forall(ndofs, [=] MFEM_HOST_DEVICE(int i)
903 {
904 const real_t vx = vx_r[i], vy = vy_r[i], vz = vz_r[i];
905 viz_h1_w[i] = std::sqrt(vx*vx + vy*vy + vz*vz);
906 });
907 vis_gf.HostRead();
908 };
909 }
910
911 // Vorticity norm visualization
913 {
914 vis_fn = [&]()
915 {
916 const auto psi1_r = solver.psi1.real().Read();
917 const auto psi1_i = solver.psi1.imag().Read();
918 const auto psi2_r = solver.psi2.real().Read();
919 const auto psi2_i = solver.psi2.imag().Read();
920 auto viz_h1_w = vis_gf.Write();
921 mfem::forall(ndofs, [=] MFEM_HOST_DEVICE(int i)
922 {
923 const auto psi1 = psi1_r[i] * psi1_r[i] + psi1_i[i] * psi1_i[i];
924 const auto psi2 = psi2_r[i] * psi2_r[i] + psi2_i[i] * psi2_i[i];
925 viz_h1_w[i] = psi1 * psi1 + psi2 * psi2;
926 });
927 vis_gf.HostRead();
928 };
929 }
930
931 if (visualization)
932 {
933 glvis.open("localhost", 19916);
934 if (glvis.is_open())
935 {
936 glvis.precision(8);
937 glvis << vis_prefix().c_str();
938 glvis << "solution\n" << mesh << *(this->operator()())
939 << "window_geometry 0 0 " << vis_width << " " << vis_height << "\n"
940 << "keys " << vis_keys << "\n" << std::flush;
941 }
942 }
943
944 if (paraview)
945 {
947 dc.SetPrefixPath("ParaView");
950 dc.RegisterField("vis_gf", this->operator()());
951 dc.SetCycle(0);
952 dc.SetTime(0.0);
953 dc.Save();
954 }
955 }
956
958
959 /// @brief Get the visualization data.
960 TGridFunction* operator()()
961 {
962 vis_fn();
963 return &vis_gf;
964 }
965
966 /// @brief Save the visualization data to a file.
967 void Save(int cycle, real_t time)
968 {
969 if (!paraview) { return; }
970 this->operator()();
971 dc.SetCycle(cycle);
972 dc.SetTime(time);
973 dc.Save();
974 }
975
976 /// @brief Send the visualization data to GLVis.
977 void GLVis()
978 {
979 if (!glvis.is_open()) { return; }
980 glvis << vis_prefix().c_str();
981 glvis << "solution\n" << mesh << *(this->operator()()) << std::flush;
982 }
983};
984
985} // namespace mfem
Conjugate gradient method.
Definition solvers.hpp:627
void SetOperator(const Operator &op) override
Set/update the solver for the given operator.
Definition solvers.hpp:640
A coefficient that is constant across space and time.
virtual void RegisterField(const std::string &field_name, GridFunction *gf)
Add a grid function to the collection.
void SetCycle(int c)
Set time cycle (for time-dependent simulations)
void SetTime(real_t t)
Set physical time (for time-dependent simulations)
void SetPrefixPath(const std::string &prefix)
Set the path where the DataCollection will be saved.
Class for domain integration .
Definition lininteg.hpp:108
GMRES method.
Definition solvers.hpp:661
Coefficient defined by a GridFunction. This coefficient is mesh dependent.
Arbitrary order H1-conforming (continuous) finite elements.
Definition fe_coll.hpp:279
Class for simulating incompressible Schrodinger flow.
IncompressibleBaseFlow(Options &config, TSchrodingerSolver &solver)
void Setup()
Setup the solver.
void Step(const real_t &t)
This function performs a single time step of the solver.
void ConstrainJetVelocity()
Restricts the velocity of the wavefunctions.
void SetOperator(const Operator &op) override
Also calls SetOperator for the preconditioner if there is one.
Definition solvers.cpp:184
void SetRelTol(real_t rtol)
Definition solvers.hpp:238
virtual void SetPreconditioner(Solver &pr)
This should be called before SetOperator.
Definition solvers.cpp:178
virtual void SetPrintLevel(int print_lvl)
Legacy method to set the level of verbosity of the solver output.
Definition solvers.cpp:76
void SetMaxIter(int max_it)
Definition solvers.hpp:240
void SetAbsTol(real_t atol)
Definition solvers.hpp:239
Mesh data type.
Definition mesh.hpp:65
static Mesh MakeCartesian3D(int nx, int ny, int nz, Element::Type type, real_t sx=1.0, real_t sy=1.0, real_t sz=1.0, bool sfc_ordering=true)
Creates a mesh for the parallelepiped [0,sx]x[0,sy]x[0,sz], divided into nx*ny*nz hexahedra if type =...
Definition mesh.cpp:4627
static Mesh MakePeriodic(const Mesh &orig_mesh, const std::vector< int > &v2v)
Create a periodic mesh by identifying vertices of orig_mesh.
Definition mesh.cpp:6042
std::vector< int > CreatePeriodicVertexMapping(const std::vector< Vector > &translations, real_t tol=1e-8) const
Creates a mapping v2v from the vertex indices of the mesh such that coincident vertices under the giv...
Definition mesh.cpp:6076
static Mesh MakeCartesian2D(int nx, int ny, Element::Type type, bool generate_edges=false, real_t sx=1.0, real_t sy=1.0, bool sfc_ordering=true)
Creates mesh for the rectangle [0,sx]x[0,sy], divided into nx*ny quadrilaterals if type = QUADRILATER...
Definition mesh.cpp:4617
virtual void SetCurvature(int order, bool discont=false, int space_dim=-1, int ordering=1)
Set the curvature of the mesh nodes using the given polynomial degree.
Definition mesh.cpp:6799
Arbitrary order H(curl)-conforming Nedelec finite elements.
Definition fe_coll.hpp:500
Pointer to an Operator of a specified type.
Definition handle.hpp:34
Jacobi smoothing for a given bilinear form (no matrix necessary).
Definition solvers.hpp:422
void ParseCheck(std::ostream &out=mfem::out)
void AddOption(bool *var, const char *enable_short_name, const char *enable_long_name, const char *disable_short_name, const char *disable_long_name, const char *description, bool required=false)
Add a boolean option and set 'var' to receive the value. Enable/disable tags are used to set the bool...
Definition optparser.hpp:82
The ordering method used when the number of unknowns per mesh node (vector dimension) is bigger than ...
Definition ordering.hpp:13
Solver wrapper which orthogonalizes the input and output vector.
Definition solvers.hpp:1332
void SetSolver(Solver &s)
Set the solver used by the OrthoSolver.
Definition solvers.cpp:3702
void SetLevelsOfDetail(int levels_of_detail_)
Set the refinement level.
void SetHighOrderOutput(bool high_order_output_)
Sets whether or not to output the data as high-order elements (false by default).
void SetDataFormat(VTKFormat fmt)
Set the data format for the ParaView output files.
Writer for ParaView visualization (PVD and VTU format)
Writer for ParaView visualization (VTKHDF format)
bool iterative_mode
If true, use the second argument of Mult() as an initial guess.
Definition operator.hpp:795
Templated bilinear form class, cf. bilinearform.?pp.
void Assemble()
Partial assembly of quadrature point data.
A general vector function coefficient.
Vector data type.
Definition vector.hpp:82
virtual const real_t * Read(bool on_dev=true) const
Shortcut for mfem::Read(vec.GetMemory(), vec.Size(), on_dev).
Definition vector.hpp:520
int Size() const
Returns the size of the vector.
Definition vector.hpp:234
void SetSize(int s)
Resize the vector to size s.
Definition vector.hpp:584
int open(const char hostname[], int port)
Open the socket stream on 'port' at 'hostname'.
int close()
Close the socketstream.
bool is_open()
True if the socketstream is open, false otherwise.
const real_t radius
Definition distance.cpp:109
const real_t alpha
Definition ex15.cpp:369
real_t omega
Definition ex25.cpp:142
real_t b
Definition lissajous.cpp:42
real_t a
Definition lissajous.cpp:41
constexpr int DIM
MemoryClass operator*(MemoryClass mc1, MemoryClass mc2)
Return a suitable MemoryClass from a pair of MemoryClasses.
MFEM_ALWAYS_INLINE AutoSIMD< scalar_t, S, A > operator+(const scalar_t &e, const AutoSIMD< scalar_t, S, A > &v)
Definition auto.hpp:238
std::array< real_t, 3 > real3_t
MFEM_HOST_DEVICE DeviceTensor< sizeof...(Dims), T > Reshape(T *ptr, Dims... dims)
Wrap a pointer as a DeviceTensor with automatically deduced template parameters.
Definition dtensor.hpp:138
float real_t
Definition config.hpp:46
void forall(int N, lambda &&body)
Definition forall.hpp:839
MFEM_ALWAYS_INLINE AutoSIMD< scalar_t, S, A > operator/(const scalar_t &e, const AutoSIMD< scalar_t, S, A > &v)
Definition auto.hpp:271
MFEM_HOST_DEVICE real_t norm(const Complex &z)
MFEM_HOST_DEVICE Complex exp(const Complex &q)
MFEM_HOST_DEVICE real_t abs(const Complex &z)
std::complex< real_t > complex_t
Complex number type for device.
cuFloatComplex RealComplex_t
MFEM_HOST_DEVICE void imag(real_t i)
MFEM_HOST_DEVICE Complex(real_t r, real_t i)
MFEM_HOST_DEVICE Complex()=default
MFEM_HOST_DEVICE Complex(real_t r)
MFEM_HOST_DEVICE real_t real() const
MFEM_HOST_DEVICE Complex & operator/=(const U &z)
MFEM_HOST_DEVICE real_t imag() const
MFEM_HOST_DEVICE Complex & operator*=(const U &z)
MFEM_HOST_DEVICE void real(real_t r)
Crank-Nicolson time solver for the Schrodinger equation.
CrankNicolsonTimeBaseSolver(TFiniteElementSpace &fes, real_t hbar, real_t dt, std::function< GMRESSolver()> CreateGMRESSolver, real_t rtol, real_t atol, int maxiter, int print_level)
virtual void Mult(TComplexGridFunction &psi)=0
Options for the Incompressible Schrödinger Flow solver.
const char * vis_keys
const char * device
Options(int argc, char *argv[])
Base class for Schrodinger solver kernels.
void Restrict(const real_t t, const TGridFunction &isJet_in, const real_t omega, const TGridFunction &phase_in)
Restrict the wavefunctions psi1 and psi2.
std::function< void()> SetEssentialTrueDofs
void GradPsiVelocity(const real_t hbar, TGridFunction &ux, TGridFunction &uy, TGridFunction &uz)
VectorFunctionCoefficient Vy
OperatorJacobiSmoother diff_h1_smoother
VectorFunctionCoefficient Vx
void Normalize()
Normalize the wavefunctions psi1 and psi2.
VectorFunctionCoefficient Vz
void GradPsi(Gfn &Grad_nd, Xfn &x_dot_Mm1, Yfn &y_dot_Mm1, Zfn &z_dot_Mm1)
std::function< Mesh()> CreateMesh2D
void Initialize(Vector &phase_r)
Initialize the wavefunctions psi1 and psi2.
GridFunctionCoefficient div_u_coeff
void AddCircularVortex(const real3_t center, const real3_t normal, const real_t radius, const real_t swirling)
Add a circular vortex to the wavefunctions psi1 and psi2.
std::function< Mesh()> CreateMesh3D
SchrodingerBaseKernels(Options &config, std::function< TMesh(Mesh &)> CreateMesh, std::function< OrthoSolver()> CreateOrthoSolver, std::function< CGSolver()> CreateCGSolver)
Base class for visualization.
ParaViewHDFDataCollection dc
const Options::VisData vis_data
const TIncompressibleFlow & isf
TGridFunction * operator()()
Get the visualization data.
void GLVis()
Send the visualization data to GLVis.
VisualizerBase(Options &config, TSchrodingerSolver &solver, const TIncompressibleFlow &isf, std::function< std::string()> prefix)
std::function< void()> vis_fn
std::function< std::string()> vis_prefix
void Save(int cycle, real_t time)
Save the visualization data to a file.
ParaViewDataCollection dc
const TSchrodingerSolver & solver