Block-Structured AMR Software Framework
AMReX_GMRES_MLMG.H
Go to the documentation of this file.
1 #ifndef AMREX_GMRES_MLMG_H_
2 #define AMREX_GMRES_MLMG_H_
3 #include <AMReX_Config.H>
4 
5 #include <AMReX_GMRES.H>
6 #include <AMReX_MLMG.H>
7 #include <utility>
8 
9 namespace amrex {
10 
18 template <typename MF>
20 {
21 public:
22  using MG = MLMGT<MF>;
23  using RT = typename MG::RT; // double or float
25 
26  explicit GMRESMLMGT (MG& mlmg);
27 
36  void solve (MF& a_sol, MF const& a_rhs, RT a_tol_rel, RT a_tol_abs);
37 
39  void setVerbose (int v) { m_gmres.setVerbose(v); }
40 
42  void setMaxIters (int niters) { m_gmres.setMaxIters(niters); }
43 
45  [[nodiscard]] int getNumIters () const { return m_gmres.getNumIters(); }
46 
48  [[nodiscard]] RT getResidualNorm () const { return m_gmres.getResidualNorm(); }
49 
51  GM& getGMRES () { return m_gmres; }
52 
63  void setPropertyOfZero (bool b) { m_prop_zero = b; }
64 
66  MF makeVecRHS () const;
67 
69  MF makeVecLHS () const;
70 
71  RT norm2 (MF const& mf) const;
72 
73  static void scale (MF& mf, RT scale_factor);
74 
75  RT dotProduct (MF const& mf1, MF const& mf2) const;
76 
78  static void setToZero (MF& lhs);
79 
81  static void assign (MF& lhs, MF const& rhs);
82 
84  static void increment (MF& lhs, MF const& rhs, RT a);
85 
87  static void linComb (MF& lhs, RT a, MF const& rhs_a, RT b, MF const& rhs_b);
88 
90  void apply (MF& lhs, MF const& rhs) const;
91 
92  void precond (MF& lhs, MF const& rhs) const;
93 
95  bool usePrecond (bool new_flag) { return std::exchange(m_use_precond, new_flag); }
96 
98  void setPrecondNumIters (int precond_niters) { m_precond_niters = precond_niters; }
99 
100 private:
104  bool m_use_precond = true;
105  bool m_prop_zero = false;
107 };
108 
109 template <typename MF>
111  : m_mlmg(&mlmg), m_linop(&mlmg.getLinOp())
112 {
113  AMREX_ALWAYS_ASSERT_WITH_MESSAGE(m_linop->NAMRLevels() == 1,
114  "Only support single level solve");
115  m_mlmg->setVerbose(0);
118  m_gmres.define(*this);
119 }
120 
121 template <typename MF>
122 auto GMRESMLMGT<MF>::makeVecRHS () const -> MF
123 {
124  return m_linop->make(0, 0, IntVect(0));
125 }
126 
127 template <typename MF>
128 auto GMRESMLMGT<MF>::makeVecLHS () const -> MF
129 {
130  auto mf = m_linop->make(0, 0, IntVect(1));
131  setBndry(mf, RT(0), 0, nComp(mf));
132  return mf;
133 }
134 
135 template <typename MF>
136 auto GMRESMLMGT<MF>::norm2 (MF const& mf) const -> RT
137 {
138  auto r = m_linop->xdoty(0, 0, mf, mf, false);
139  return std::sqrt(r);
140 }
141 
142 template <typename MF>
143 void GMRESMLMGT<MF>::scale (MF& mf, RT scale_factor)
144 {
145  Scale(mf, scale_factor, 0, nComp(mf), 0);
146 }
147 
148 template <typename MF>
149 auto GMRESMLMGT<MF>::dotProduct (MF const& mf1, MF const& mf2) const -> RT
150 {
151  return m_linop->xdoty(0, 0, mf1, mf2, false);
152 }
153 
154 template <typename MF>
156 {
157  setVal(lhs, RT(0.0));
158 }
159 
160 template <typename MF>
161 void GMRESMLMGT<MF>::assign (MF& lhs, MF const& rhs)
162 {
163  LocalCopy(lhs, rhs, 0, 0, nComp(lhs), IntVect(0));
164 }
165 
166 template <typename MF>
167 void GMRESMLMGT<MF>::increment (MF& lhs, MF const& rhs, RT a)
168 {
169  Saxpy(lhs, a, rhs, 0, 0, nComp(lhs), IntVect(0));
170 }
171 
172 template <typename MF>
173 void GMRESMLMGT<MF>::linComb (MF& lhs, RT a, MF const& rhs_a, RT b, MF const& rhs_b)
174 {
175  LinComb(lhs, a, rhs_a, 0, b, rhs_b, 0, 0, nComp(lhs), IntVect(0));
176 }
177 
178 template <typename MF>
179 void GMRESMLMGT<MF>::apply (MF& lhs, MF const& rhs) const
180 {
181  m_linop->apply(0, 0, lhs, const_cast<MF&>(rhs),
184 }
185 
186 template <typename MF>
187 void GMRESMLMGT<MF>::precond (MF& lhs, MF const& rhs) const
188 {
189  if (m_use_precond) {
190  m_mlmg->prepareMGcycle();
191 
192  for (int icycle = 0; icycle < m_precond_niters; ++icycle) {
193  if (icycle == 0) {
194  LocalCopy(m_mlmg->res[0][0], rhs, 0, 0, nComp(rhs), IntVect(0));
195  } else {
196  m_mlmg->computeResOfCorrection(0,0);
197  LocalCopy(m_mlmg->res[0][0], m_mlmg->rescor[0][0], 0, 0, nComp(rhs), IntVect(0));
198  }
199 
200  m_mlmg->mgVcycle(0,0);
201 
202  if (icycle == 0) {
203  LocalCopy(lhs, m_mlmg->cor[0][0], 0, 0, nComp(rhs), IntVect(0));
204  } else {
205  increment(lhs, m_mlmg->cor[0][0], RT(1));
206  }
207  }
208  } else {
209  LocalCopy(lhs, rhs, 0, 0, nComp(lhs), IntVect(0));
210  }
211 }
212 
213 template <typename MF>
214 void GMRESMLMGT<MF>::solve (MF& a_sol, MF const& a_rhs, RT a_tol_rel, RT a_tol_abs)
215 {
216  if (m_prop_zero) {
217  auto rhs = makeVecRHS();
218  assign(rhs, a_rhs);
219  m_linop->setDirichletNodesToZero(0,0,rhs);
220  m_gmres.solve(a_sol, rhs, a_tol_rel, a_tol_abs);
221  } else {
222  auto res = makeVecRHS();
223  m_mlmg->apply({&res}, {&a_sol}); // res = L(sol)
224  increment(res, a_rhs, RT(-1)); // res = L(sol) - rhs
225  auto cor = makeVecLHS();
226  m_linop->setDirichletNodesToZero(0,0,res);
227  m_gmres.solve(cor, res, a_tol_rel, a_tol_abs); // L(cor) = res
228  increment(a_sol, cor, RT(-1)); // sol = sol - cor
229  }
230 }
231 
233 
234 }
235 
236 #endif
#define AMREX_ALWAYS_ASSERT_WITH_MESSAGE(EX, MSG)
Definition: AMReX_BLassert.H:49
Solve using GMRES with multigrid as preconditioner.
Definition: AMReX_GMRES_MLMG.H:20
void solve(MF &a_sol, MF const &a_rhs, RT a_tol_rel, RT a_tol_abs)
Solve the linear system.
Definition: AMReX_GMRES_MLMG.H:214
MF makeVecRHS() const
Make MultiFab without ghost cells.
Definition: AMReX_GMRES_MLMG.H:122
int m_precond_niters
Definition: AMReX_GMRES_MLMG.H:106
GMRESMLMGT(MG &mlmg)
Definition: AMReX_GMRES_MLMG.H:110
GM m_gmres
Definition: AMReX_GMRES_MLMG.H:101
static void increment(MF &lhs, MF const &rhs, RT a)
lhs += a*rhs
Definition: AMReX_GMRES_MLMG.H:167
int getNumIters() const
Gets the number of iterations.
Definition: AMReX_GMRES_MLMG.H:45
RT dotProduct(MF const &mf1, MF const &mf2) const
Definition: AMReX_GMRES_MLMG.H:149
static void assign(MF &lhs, MF const &rhs)
lhs = rhs
Definition: AMReX_GMRES_MLMG.H:161
void setPrecondNumIters(int precond_niters)
Set the number of MLMG preconditioner iterations per GMRES iteration.
Definition: AMReX_GMRES_MLMG.H:98
GM & getGMRES()
Get the GMRES object.
Definition: AMReX_GMRES_MLMG.H:51
void setVerbose(int v)
Sets verbosity.
Definition: AMReX_GMRES_MLMG.H:39
void setPropertyOfZero(bool b)
Set MLMG's multiplicative property of zero.
Definition: AMReX_GMRES_MLMG.H:63
typename MG::RT RT
Definition: AMReX_GMRES_MLMG.H:23
void apply(MF &lhs, MF const &rhs) const
lhs = L(rhs)
Definition: AMReX_GMRES_MLMG.H:179
static void linComb(MF &lhs, RT a, MF const &rhs_a, RT b, MF const &rhs_b)
lhs = a*rhs_a + b*rhs_b
Definition: AMReX_GMRES_MLMG.H:173
MF makeVecLHS() const
Make MultiFab with ghost cells and set ghost cells to zero.
Definition: AMReX_GMRES_MLMG.H:128
static void setToZero(MF &lhs)
lhs = 0
Definition: AMReX_GMRES_MLMG.H:155
MLLinOpT< MF > * m_linop
Definition: AMReX_GMRES_MLMG.H:103
bool m_use_precond
Definition: AMReX_GMRES_MLMG.H:104
static void scale(MF &mf, RT scale_factor)
Definition: AMReX_GMRES_MLMG.H:143
MG * m_mlmg
Definition: AMReX_GMRES_MLMG.H:102
RT norm2(MF const &mf) const
Definition: AMReX_GMRES_MLMG.H:136
RT getResidualNorm() const
Gets the 2-norm of the residual.
Definition: AMReX_GMRES_MLMG.H:48
bool usePrecond(bool new_flag)
Control whether or not to use MLMG as preconditioner.
Definition: AMReX_GMRES_MLMG.H:95
bool m_prop_zero
Definition: AMReX_GMRES_MLMG.H:105
void setMaxIters(int niters)
Sets the max number of iterations.
Definition: AMReX_GMRES_MLMG.H:42
void precond(MF &lhs, MF const &rhs) const
Definition: AMReX_GMRES_MLMG.H:187
int getNumIters() const
Gets the number of iterations.
Definition: AMReX_GMRES.H:113
void define(M &linop)
Definition: AMReX_GMRES.H:187
void setVerbose(int v)
Sets verbosity.
Definition: AMReX_GMRES.H:104
void setMaxIters(int niters)
Sets the max number of iterations.
Definition: AMReX_GMRES.H:110
RT getResidualNorm() const
Gets the 2-norm of the residual.
Definition: AMReX_GMRES.H:119
Definition: AMReX_MLLinOp.H:98
Definition: AMReX_MLMG.H:12
void setBottomVerbose(int v) noexcept
Definition: AMReX_MLMG.H:129
void setVerbose(int v) noexcept
Definition: AMReX_MLMG.H:117
typename MLLinOpT< MF >::RT RT
Definition: AMReX_MLMG.H:27
void prepareForGMRES()
Definition: AMReX_MLMG.H:1132
Definition: AMReX_Amr.cpp:49
int nComp(FabArrayBase const &fa)
void Saxpy(MF &dst, typename MF::value_type a, MF const &src, int scomp, int dcomp, int ncomp, IntVect const &nghost)
dst += a * src
Definition: AMReX_FabArrayUtility.H:1646
void LocalCopy(DMF &dst, SMF const &src, int scomp, int dcomp, int ncomp, IntVect const &nghost)
dst = src
Definition: AMReX_FabArrayUtility.H:1630
IntVectND< AMREX_SPACEDIM > IntVect
Definition: AMReX_BaseFwd.H:30
void LinComb(MF &dst, typename MF::value_type a, MF const &src_a, int acomp, typename MF::value_type b, MF const &src_b, int bcomp, int dcomp, int ncomp, IntVect const &nghost)
dst = a*src_a + b*src_b
Definition: AMReX_FabArrayUtility.H:1662
void setBndry(MF &dst, typename MF::value_type val, int scomp, int ncomp)
dst = val in ghost cells.
Definition: AMReX_FabArrayUtility.H:1614
void Scale(MF &dst, typename MF::value_type val, int scomp, int ncomp, int nghost)
dst *= val
Definition: AMReX_FabArrayUtility.H:1621
void setVal(MF &dst, typename MF::value_type val)
dst = val
Definition: AMReX_FabArrayUtility.H:1607
AMREX_GPU_HOST_DEVICE AMREX_FORCE_INLINE GpuComplex< T > sqrt(const GpuComplex< T > &a_z) noexcept
Return the square root of a complex number.
Definition: AMReX_GpuComplex.H:373