Block-Structured AMR Software Framework
AMReX_MLCGSolver.H
Go to the documentation of this file.
1 
2 #ifndef AMREX_MLCGSOLVER_H_
3 #define AMREX_MLCGSOLVER_H_
4 #include <AMReX_Config.H>
5 
6 #include <AMReX_MLLinOp.H>
7 
8 namespace amrex {
9 
10 template <typename MF>
12 {
13 public:
14 
15  using FAB = typename MLLinOpT<MF>::FAB;
16  using RT = typename MLLinOpT<MF>::RT;
17 
18  enum struct Type { BiCGStab, CG };
19 
22 
23  MLCGSolverT (const MLCGSolverT<MF>& rhs) = delete;
24  MLCGSolverT (MLCGSolverT<MF>&& rhs) = delete;
27 
28  void setSolver (Type _typ) noexcept { solver_type = _typ; }
29 
37  int solve (MF& solnL, const MF& rhsL, RT eps_rel, RT eps_abs);
38 
39  void setVerbose (int _verbose) { verbose = _verbose; }
40  [[nodiscard]] int getVerbose () const { return verbose; }
41 
42  void setMaxIter (int _maxiter) { maxiter = _maxiter; }
43  [[nodiscard]] int getMaxIter () const { return maxiter; }
44 
45 
52  void setInitSolnZeroed (bool _sol_zeroed) { initial_vec_zeroed = _sol_zeroed; }
53  [[nodiscard]] bool getInitSolnZeroed () const { return initial_vec_zeroed; }
54 
55  void setNGhost(int _nghost) {nghost = IntVect(_nghost);}
56  [[nodiscard]] int getNGhost() {return nghost[0];}
57 
58  [[nodiscard]] RT dotxy (const MF& r, const MF& z, bool local = false);
59  [[nodiscard]] RT norm_inf (const MF& res, bool local = false);
60  int solve_bicgstab (MF& solnL, const MF& rhsL, RT eps_rel, RT eps_abs);
61  int solve_cg (MF& solnL, const MF& rhsL, RT eps_rel, RT eps_abs);
62 
63  [[nodiscard]] int getNumIters () const noexcept { return iter; }
64 
65 private:
66 
69  const int amrlev = 0;
70  const int mglev;
71  int verbose = 0;
72  int maxiter = 100;
74  int iter = -1;
75  bool initial_vec_zeroed = false;
76 };
77 
78 template <typename MF>
80  : Lp(_lp), solver_type(_typ), mglev(_lp.NMGLevels(0)-1)
81 {}
82 
83 template <typename MF> MLCGSolverT<MF>::~MLCGSolverT () = default;
84 
85 template <typename MF>
86 int
87 MLCGSolverT<MF>::solve (MF& sol, const MF& rhs, RT eps_rel, RT eps_abs)
88 {
89  if (solver_type == Type::BiCGStab) {
90  return solve_bicgstab(sol,rhs,eps_rel,eps_abs);
91  } else {
92  return solve_cg(sol,rhs,eps_rel,eps_abs);
93  }
94 }
95 
96 template <typename MF>
97 int
98 MLCGSolverT<MF>::solve_bicgstab (MF& sol, const MF& rhs, RT eps_rel, RT eps_abs)
99 {
100  BL_PROFILE("MLCGSolver::bicgstab");
101 
102  const int ncomp = nComp(sol);
103 
104  MF p = Lp.make(amrlev, mglev, nGrowVect(sol));
105  MF r = Lp.make(amrlev, mglev, nGrowVect(sol));
106  setVal(p, RT(0.0)); // Make sure all entries are initialized to avoid errors
107  setVal(r, RT(0.0));
108 
109  MF rh = Lp.make(amrlev, mglev, nghost);
110  MF v = Lp.make(amrlev, mglev, nghost);
111  MF t = Lp.make(amrlev, mglev, nghost);
112 
113 
114  MF sorig;
115 
116  if ( initial_vec_zeroed ) {
117  LocalCopy(r,rhs,0,0,ncomp,nghost);
118  } else {
119  sorig = Lp.make(amrlev, mglev, nghost);
120 
121  Lp.correctionResidual(amrlev, mglev, r, sol, rhs, MLLinOpT<MF>::BCMode::Homogeneous);
122 
123  LocalCopy(sorig,sol,0,0,ncomp,nghost);
124  setVal(sol, RT(0.0));
125  }
126 
127  // Then normalize
128  Lp.normalize(amrlev, mglev, r);
129  LocalCopy(rh, r, 0,0,ncomp,nghost);
130 
131  RT rnorm = norm_inf(r);
132  const RT rnorm0 = rnorm;
133 
134  if ( verbose > 0 )
135  {
136  amrex::Print() << "MLCGSolver_BiCGStab: Initial error (error0) = " << rnorm0 << '\n';
137  }
138  int ret = 0;
139  iter = 1;
140  RT rho_1 = 0, alpha = 0, omega = 0;
141 
142  if ( rnorm0 == 0 || rnorm0 < eps_abs )
143  {
144  if ( verbose > 0 )
145  {
146  amrex::Print() << "MLCGSolver_BiCGStab: niter = 0,"
147  << ", rnorm = " << rnorm
148  << ", eps_abs = " << eps_abs << '\n';
149  }
150  return ret;
151  }
152 
153  for (; iter <= maxiter; ++iter)
154  {
155  const RT rho = dotxy(rh,r);
156  if ( rho == 0 )
157  {
158  ret = 1; break;
159  }
160  if ( iter == 1 )
161  {
162  LocalCopy(p,r,0,0,ncomp,nghost);
163  }
164  else
165  {
166  const RT beta = (rho/rho_1)*(alpha/omega);
167  Saxpy(p, -omega, v, 0, 0, ncomp, nghost); // p += -omega*v
168  Xpay(p, beta, r, 0, 0, ncomp, nghost); // p = r + beta*p
169  }
171  Lp.normalize(amrlev, mglev, v);
172 
173  RT rhTv = dotxy(rh,v);
174  if ( rhTv != RT(0.0) )
175  {
176  alpha = rho/rhTv;
177  }
178  else
179  {
180  ret = 2; break;
181  }
182  Saxpy(sol, alpha, p, 0, 0, ncomp, nghost); // sol += alpha * p
183  Saxpy(r, -alpha, v, 0, 0, ncomp, nghost); // r += -alpha * v
184 
185  rnorm = norm_inf(r);
186 
188  {
189  amrex::Print() << "MLCGSolver_BiCGStab: Half Iter "
190  << std::setw(11) << iter
191  << " rel. err. "
192  << rnorm/(rnorm0) << '\n';
193  }
194 
195  if ( rnorm < eps_rel*rnorm0 || rnorm < eps_abs ) { break; }
196 
198  Lp.normalize(amrlev, mglev, t);
199  //
200  // This is a little funky. I want to elide one of the reductions
201  // in the following two dotxy()s. We do that by calculating the "local"
202  // values and then reducing the two local values at the same time.
203  //
204  RT tvals[2] = { dotxy(t,t,true), dotxy(t,r,true) };
205 
206  BL_PROFILE_VAR("MLCGSolver::ParallelAllReduce", blp_par);
207  ParallelAllReduce::Sum(tvals,2,Lp.BottomCommunicator());
208  BL_PROFILE_VAR_STOP(blp_par);
209 
210  if ( tvals[0] != RT(0.0) )
211  {
212  omega = tvals[1]/tvals[0];
213  }
214  else
215  {
216  ret = 3; break;
217  }
218  Saxpy(sol, omega, r, 0, 0, ncomp, nghost); // sol += omega * r
219  Saxpy(r, -omega, t, 0, 0, ncomp, nghost); // r += -omega * t
220 
221  rnorm = norm_inf(r);
222 
223  if ( verbose > 2 )
224  {
225  amrex::Print() << "MLCGSolver_BiCGStab: Iteration "
226  << std::setw(11) << iter
227  << " rel. err. "
228  << rnorm/(rnorm0) << '\n';
229  }
230 
231  if ( rnorm < eps_rel*rnorm0 || rnorm < eps_abs ) { break; }
232 
233  if ( omega == 0 )
234  {
235  ret = 4; break;
236  }
237  rho_1 = rho;
238  }
239 
240  if ( verbose > 0 )
241  {
242  amrex::Print() << "MLCGSolver_BiCGStab: Final: Iteration "
243  << std::setw(4) << iter
244  << " rel. err. "
245  << rnorm/(rnorm0) << '\n';
246  }
247 
248  if ( ret == 0 && rnorm > eps_rel*rnorm0 && rnorm > eps_abs)
249  {
250  if ( verbose > 0 && ParallelDescriptor::IOProcessor() ) {
251  amrex::Warning("MLCGSolver_BiCGStab:: failed to converge!");
252  }
253  ret = 8;
254  }
255 
256  if ( ( ret == 0 || ret == 8 ) && (rnorm < rnorm0) )
257  {
258  if ( !initial_vec_zeroed ) {
259  LocalAdd(sol, sorig, 0, 0, ncomp, nghost);
260  }
261  if (ret == 8) { ret = 9; }
262  }
263  else
264  {
265  setVal(sol, RT(0.0));
266  if ( !initial_vec_zeroed ) {
267  LocalAdd(sol, sorig, 0, 0, ncomp, nghost);
268  }
269  }
270 
271  return ret;
272 }
273 
274 template <typename MF>
275 int
276 MLCGSolverT<MF>::solve_cg (MF& sol, const MF& rhs, RT eps_rel, RT eps_abs)
277 {
278  BL_PROFILE("MLCGSolver::cg");
279 
280  const int ncomp = nComp(sol);
281 
282  MF p = Lp.make(amrlev, mglev, nGrowVect(sol));
283  setVal(p, RT(0.0));
284 
285  MF r = Lp.make(amrlev, mglev, nghost);
286  MF q = Lp.make(amrlev, mglev, nghost);
287 
288  MF sorig;
289 
290  if ( initial_vec_zeroed ) {
291  LocalCopy(r,rhs,0,0,ncomp,nghost);
292  } else {
293  sorig = Lp.make(amrlev, mglev, nghost);
294 
295  Lp.correctionResidual(amrlev, mglev, r, sol, rhs, MLLinOpT<MF>::BCMode::Homogeneous);
296 
297  LocalCopy(sorig,sol,0,0,ncomp,nghost);
298  setVal(sol, RT(0.0));
299  }
300 
301  RT rnorm = norm_inf(r);
302  const RT rnorm0 = rnorm;
303 
304  if ( verbose > 0 )
305  {
306  amrex::Print() << "MLCGSolver_CG: Initial error (error0) : " << rnorm0 << '\n';
307  }
308 
309  RT rho_1 = 0;
310  int ret = 0;
311  iter = 1;
312 
313  if ( rnorm0 == 0 || rnorm0 < eps_abs )
314  {
315  if ( verbose > 0 ) {
316  amrex::Print() << "MLCGSolver_CG: niter = 0,"
317  << ", rnorm = " << rnorm
318  << ", eps_abs = " << eps_abs << '\n';
319  }
320  return ret;
321  }
322 
323  for (; iter <= maxiter; ++iter)
324  {
325  RT rho = dotxy(r,r);
326 
327  if ( rho == 0 )
328  {
329  ret = 1; break;
330  }
331  if (iter == 1)
332  {
333  LocalCopy(p,r,0,0,ncomp,nghost);
334  }
335  else
336  {
337  RT beta = rho/rho_1;
338  Xpay(p, beta, r, 0, 0, ncomp, nghost); // p = r + beta * p
339  }
341 
342  RT alpha;
343  RT pw = dotxy(p,q);
344  if ( pw != RT(0.0))
345  {
346  alpha = rho/pw;
347  }
348  else
349  {
350  ret = 1; break;
351  }
352 
353  if ( verbose > 2 )
354  {
355  amrex::Print() << "MLCGSolver_cg:"
356  << " iter " << iter
357  << " rho " << rho
358  << " alpha " << alpha << '\n';
359  }
360  Saxpy(sol, alpha, p, 0, 0, ncomp, nghost); // sol += alpha * p
361  Saxpy(r, -alpha, q, 0, 0, ncomp, nghost); // r += -alpha * q
362  rnorm = norm_inf(r);
363 
364  if ( verbose > 2 )
365  {
366  amrex::Print() << "MLCGSolver_cg: Iteration"
367  << std::setw(4) << iter
368  << " rel. err. "
369  << rnorm/(rnorm0) << '\n';
370  }
371 
372  if ( rnorm < eps_rel*rnorm0 || rnorm < eps_abs ) { break; }
373 
374  rho_1 = rho;
375  }
376 
377  if ( verbose > 0 )
378  {
379  amrex::Print() << "MLCGSolver_cg: Final Iteration"
380  << std::setw(4) << iter
381  << " rel. err. "
382  << rnorm/(rnorm0) << '\n';
383  }
384 
385  if ( ret == 0 && rnorm > eps_rel*rnorm0 && rnorm > eps_abs )
386  {
387  if ( verbose > 0 && ParallelDescriptor::IOProcessor() ) {
388  amrex::Warning("MLCGSolver_cg: failed to converge!");
389  }
390  ret = 8;
391  }
392 
393  if ( ( ret == 0 || ret == 8 ) && (rnorm < rnorm0) )
394  {
395  if ( !initial_vec_zeroed ) {
396  LocalAdd(sol, sorig, 0, 0, ncomp, nghost);
397  }
398  if (ret == 8) { ret = 9; }
399  }
400  else
401  {
402  setVal(sol, RT(0.0));
403  if ( !initial_vec_zeroed ) {
404  LocalAdd(sol, sorig, 0, 0, ncomp, nghost);
405  }
406  }
407 
408  return ret;
409 }
410 
411 template <typename MF>
412 auto
413 MLCGSolverT<MF>::dotxy (const MF& r, const MF& z, bool local) -> RT
414 {
415  BL_PROFILE_VAR_NS("MLCGSolver::ParallelAllReduce", blp_par);
416  if (!local) { BL_PROFILE_VAR_START(blp_par); }
417  RT result = Lp.xdoty(amrlev, mglev, r, z, local);
418  if (!local) { BL_PROFILE_VAR_STOP(blp_par); }
419  return result;
420 }
421 
422 template <typename MF>
423 auto
424 MLCGSolverT<MF>::norm_inf (const MF& res, bool local) -> RT
425 {
426  int ncomp = nComp(res);
427  RT result = norminf(res,0,ncomp,IntVect(0),true);
428  if (!local) {
429  BL_PROFILE("MLCGSolver::ParallelAllReduce");
430  ParallelAllReduce::Max(result, Lp.BottomCommunicator());
431  }
432  return result;
433 }
434 
436 
437 }
438 
439 #endif /*_CGSOLVER_H_*/
#define BL_PROFILE_VAR_START(vname)
Definition: AMReX_BLProfiler.H:562
#define BL_PROFILE(a)
Definition: AMReX_BLProfiler.H:551
#define BL_PROFILE_VAR_STOP(vname)
Definition: AMReX_BLProfiler.H:563
#define BL_PROFILE_VAR(fname, vname)
Definition: AMReX_BLProfiler.H:560
#define BL_PROFILE_VAR_NS(fname, vname)
Definition: AMReX_BLProfiler.H:561
Definition: AMReX_MLCGSolver.H:12
RT dotxy(const MF &r, const MF &z, bool local=false)
Definition: AMReX_MLCGSolver.H:413
int verbose
Definition: AMReX_MLCGSolver.H:71
int iter
Definition: AMReX_MLCGSolver.H:74
MLLinOpT< MF > & Lp
Definition: AMReX_MLCGSolver.H:67
MLCGSolverT< MF > & operator=(const MLCGSolverT< MF > &rhs)=delete
void setSolver(Type _typ) noexcept
Definition: AMReX_MLCGSolver.H:28
bool getInitSolnZeroed() const
Definition: AMReX_MLCGSolver.H:53
void setVerbose(int _verbose)
Definition: AMReX_MLCGSolver.H:39
int getNGhost()
Definition: AMReX_MLCGSolver.H:56
const int mglev
Definition: AMReX_MLCGSolver.H:70
IntVect nghost
Definition: AMReX_MLCGSolver.H:73
int getNumIters() const noexcept
Definition: AMReX_MLCGSolver.H:63
int solve_cg(MF &solnL, const MF &rhsL, RT eps_rel, RT eps_abs)
Definition: AMReX_MLCGSolver.H:276
RT norm_inf(const MF &res, bool local=false)
Definition: AMReX_MLCGSolver.H:424
const int amrlev
Definition: AMReX_MLCGSolver.H:69
bool initial_vec_zeroed
Definition: AMReX_MLCGSolver.H:75
void setInitSolnZeroed(bool _sol_zeroed)
Definition: AMReX_MLCGSolver.H:52
int solve_bicgstab(MF &solnL, const MF &rhsL, RT eps_rel, RT eps_abs)
Definition: AMReX_MLCGSolver.H:98
MLCGSolverT(MLCGSolverT< MF > &&rhs)=delete
typename MLLinOpT< MF >::FAB FAB
Definition: AMReX_MLCGSolver.H:15
int getMaxIter() const
Definition: AMReX_MLCGSolver.H:43
int maxiter
Definition: AMReX_MLCGSolver.H:72
int solve(MF &solnL, const MF &rhsL, RT eps_rel, RT eps_abs)
Definition: AMReX_MLCGSolver.H:87
Type
Definition: AMReX_MLCGSolver.H:18
typename MLLinOpT< MF >::RT RT
Definition: AMReX_MLCGSolver.H:16
void setNGhost(int _nghost)
Definition: AMReX_MLCGSolver.H:55
MLCGSolverT(MLLinOpT< MF > &_lp, Type _typ=Type::BiCGStab)
Definition: AMReX_MLCGSolver.H:79
int getVerbose() const
Definition: AMReX_MLCGSolver.H:40
void setMaxIter(int _maxiter)
Definition: AMReX_MLCGSolver.H:42
Type solver_type
Definition: AMReX_MLCGSolver.H:68
MLCGSolverT(const MLCGSolverT< MF > &rhs)=delete
Definition: AMReX_MLLinOp.H:98
typename FabDataType< MF >::fab_type FAB
Definition: AMReX_MLLinOp.H:108
typename FabDataType< MF >::value_type RT
Definition: AMReX_MLLinOp.H:109
This class provides the user with a few print options.
Definition: AMReX_Print.H:35
void Sum(T &v, MPI_Comm comm)
Definition: AMReX_ParallelReduce.H:204
void Max(KeyValuePair< K, V > &vi, MPI_Comm comm)
Definition: AMReX_ParallelReduce.H:126
bool IOProcessor() noexcept
Is this CPU the I/O Processor? To get the rank number, call IOProcessorNumber()
Definition: AMReX_ParallelDescriptor.H:275
Definition: AMReX_Amr.cpp:49
MF::value_type norminf(MF const &mf, int scomp, int ncomp, IntVect const &nghost, bool local=false)
Definition: AMReX_FabArrayUtility.H:1682
int nComp(FabArrayBase const &fa)
void Saxpy(MF &dst, typename MF::value_type a, MF const &src, int scomp, int dcomp, int ncomp, IntVect const &nghost)
dst += a * src
Definition: AMReX_FabArrayUtility.H:1646
IntVect nGrowVect(FabArrayBase const &fa)
void LocalCopy(DMF &dst, SMF const &src, int scomp, int dcomp, int ncomp, IntVect const &nghost)
dst = src
Definition: AMReX_FabArrayUtility.H:1630
void LocalAdd(MF &dst, MF const &src, int scomp, int dcomp, int ncomp, IntVect const &nghost)
dst += src
Definition: AMReX_FabArrayUtility.H:1638
IntVectND< AMREX_SPACEDIM > IntVect
Definition: AMReX_BaseFwd.H:30
void Xpay(MF &dst, typename MF::value_type a, MF const &src, int scomp, int dcomp, int ncomp, IntVect const &nghost)
dst = src + a * dst
Definition: AMReX_FabArrayUtility.H:1654
void Warning(const std::string &msg)
Print out warning message to cerr.
Definition: AMReX.cpp:231
int verbose
Definition: AMReX_DistributionMapping.cpp:36
void setVal(MF &dst, typename MF::value_type val)
dst = val
Definition: AMReX_FabArrayUtility.H:1607