Block-Structured AMR Software Framework
AMReX_MLCGSolver.H
Go to the documentation of this file.
1 
2 #ifndef AMREX_MLCGSOLVER_H_
3 #define AMREX_MLCGSOLVER_H_
4 #include <AMReX_Config.H>
5 
6 #include <AMReX_MLLinOp.H>
7 
8 namespace amrex {
9 
10 template <typename MF>
12 {
13 public:
14 
15  using FAB = typename MLLinOpT<MF>::FAB;
16  using RT = typename MLLinOpT<MF>::RT;
17 
18  enum struct Type { BiCGStab, CG };
19 
22 
23  MLCGSolverT (const MLCGSolverT<MF>& rhs) = delete;
24  MLCGSolverT (MLCGSolverT<MF>&& rhs) = delete;
27 
28  void setSolver (Type _typ) noexcept { solver_type = _typ; }
29 
37  int solve (MF& solnL, const MF& rhsL, RT eps_rel, RT eps_abs);
38 
39  void setVerbose (int _verbose) { verbose = _verbose; }
40  [[nodiscard]] int getVerbose () const { return verbose; }
41 
42  void setMaxIter (int _maxiter) { maxiter = _maxiter; }
43  [[nodiscard]] int getMaxIter () const { return maxiter; }
44 
45 
52  void setInitSolnZeroed (bool _sol_zeroed) { initial_vec_zeroed = _sol_zeroed; }
53  [[nodiscard]] bool getInitSolnZeroed () const { return initial_vec_zeroed; }
54 
55  void setNGhost(int _nghost) {nghost = IntVect(_nghost);}
56  [[nodiscard]] int getNGhost() {return nghost[0];}
57 
58  [[nodiscard]] RT dotxy (const MF& r, const MF& z, bool local = false);
59  [[nodiscard]] RT norm_inf (const MF& res, bool local = false);
60  int solve_bicgstab (MF& solnL, const MF& rhsL, RT eps_rel, RT eps_abs);
61  int solve_cg (MF& solnL, const MF& rhsL, RT eps_rel, RT eps_abs);
62 
63  [[nodiscard]] int getNumIters () const noexcept { return iter; }
64 
65 private:
66 
69  const int amrlev = 0;
70  const int mglev;
71  int verbose = 0;
72  int maxiter = 100;
74  int iter = -1;
75  bool initial_vec_zeroed = false;
76 };
77 
78 template <typename MF>
80  : Lp(_lp), solver_type(_typ), mglev(_lp.NMGLevels(0)-1)
81 {}
82 
83 template <typename MF> MLCGSolverT<MF>::~MLCGSolverT () = default;
84 
85 template <typename MF>
86 int
87 MLCGSolverT<MF>::solve (MF& sol, const MF& rhs, RT eps_rel, RT eps_abs)
88 {
89  if (solver_type == Type::BiCGStab) {
90  return solve_bicgstab(sol,rhs,eps_rel,eps_abs);
91  } else {
92  return solve_cg(sol,rhs,eps_rel,eps_abs);
93  }
94 }
95 
96 template <typename MF>
97 int
98 MLCGSolverT<MF>::solve_bicgstab (MF& sol, const MF& rhs, RT eps_rel, RT eps_abs)
99 {
100  BL_PROFILE("MLCGSolver::bicgstab");
101 
102  const int ncomp = nComp(sol);
103 
104  MF p = Lp.make(amrlev, mglev, nGrowVect(sol));
105  MF r = Lp.make(amrlev, mglev, nGrowVect(sol));
106  setVal(p, RT(0.0)); // Make sure all entries are initialized to avoid errors
107  setVal(r, RT(0.0));
108 
109  MF rh = Lp.make(amrlev, mglev, nghost);
110  MF v = Lp.make(amrlev, mglev, nghost);
111  MF t = Lp.make(amrlev, mglev, nghost);
112 
113 
114  MF sorig;
115 
116  if ( initial_vec_zeroed ) {
117  LocalCopy(r,rhs,0,0,ncomp,nghost);
118  } else {
119  sorig = Lp.make(amrlev, mglev, nghost);
120 
121  Lp.correctionResidual(amrlev, mglev, r, sol, rhs, MLLinOpT<MF>::BCMode::Homogeneous);
122 
123  LocalCopy(sorig,sol,0,0,ncomp,nghost);
124  setVal(sol, RT(0.0));
125  }
126 
127  // Then normalize
128  Lp.normalize(amrlev, mglev, r);
129  LocalCopy(rh, r, 0,0,ncomp,nghost);
130 
131  RT rnorm = norm_inf(r);
132  const RT rnorm0 = rnorm;
133 
134  if ( verbose > 0 )
135  {
136  amrex::Print() << "MLCGSolver_BiCGStab: Initial error (error0) = " << rnorm0 << '\n';
137  }
138  int ret = 0;
139  iter = 1;
140  RT rho_1 = 0, alpha = 0, omega = 0;
141 
142  if ( rnorm0 == 0 || rnorm0 < eps_abs )
143  {
144  if ( verbose > 0 )
145  {
146  amrex::Print() << "MLCGSolver_BiCGStab: niter = 0,"
147  << ", rnorm = " << rnorm
148  << ", eps_abs = " << eps_abs << '\n';
149  }
150  return ret;
151  }
152 
153  for (; iter <= maxiter; ++iter)
154  {
155  const RT rho = dotxy(rh,r);
156  if ( rho == 0 )
157  {
158  ret = 1; break;
159  }
160  if ( iter == 1 )
161  {
162  LocalCopy(p,r,0,0,ncomp,nghost);
163  }
164  else
165  {
166  const RT beta = (rho/rho_1)*(alpha/omega);
167  Saxpy(p, -omega, v, 0, 0, ncomp, nghost); // p += -omega*v
168  Xpay(p, beta, r, 0, 0, ncomp, nghost); // p = r + beta*p
169  }
171  Lp.normalize(amrlev, mglev, v);
172 
173  RT rhTv = dotxy(rh,v);
174  if ( rhTv != RT(0.0) )
175  {
176  alpha = rho/rhTv;
177  }
178  else
179  {
180  ret = 2; break;
181  }
182  Saxpy(sol, alpha, p, 0, 0, ncomp, nghost); // sol += alpha * p
183  Saxpy(r, -alpha, v, 0, 0, ncomp, nghost); // r += -alpha * v
184 
185  rnorm = norm_inf(r);
186  rnorm = norm_inf(r);
187 
189  {
190  amrex::Print() << "MLCGSolver_BiCGStab: Half Iter "
191  << std::setw(11) << iter
192  << " rel. err. "
193  << rnorm/(rnorm0) << '\n';
194  }
195 
196  if ( rnorm < eps_rel*rnorm0 || rnorm < eps_abs ) { break; }
197 
199  Lp.normalize(amrlev, mglev, t);
200  //
201  // This is a little funky. I want to elide one of the reductions
202  // in the following two dotxy()s. We do that by calculating the "local"
203  // values and then reducing the two local values at the same time.
204  //
205  RT tvals[2] = { dotxy(t,t,true), dotxy(t,r,true) };
206 
207  BL_PROFILE_VAR("MLCGSolver::ParallelAllReduce", blp_par);
208  ParallelAllReduce::Sum(tvals,2,Lp.BottomCommunicator());
209  BL_PROFILE_VAR_STOP(blp_par);
210 
211  if ( tvals[0] != RT(0.0) )
212  {
213  omega = tvals[1]/tvals[0];
214  }
215  else
216  {
217  ret = 3; break;
218  }
219  Saxpy(sol, omega, r, 0, 0, ncomp, nghost); // sol += omega * r
220  Saxpy(r, -omega, t, 0, 0, ncomp, nghost); // r += -omega * t
221 
222  rnorm = norm_inf(r);
223 
224  if ( verbose > 2 )
225  {
226  amrex::Print() << "MLCGSolver_BiCGStab: Iteration "
227  << std::setw(11) << iter
228  << " rel. err. "
229  << rnorm/(rnorm0) << '\n';
230  }
231 
232  if ( rnorm < eps_rel*rnorm0 || rnorm < eps_abs ) { break; }
233 
234  if ( omega == 0 )
235  {
236  ret = 4; break;
237  }
238  rho_1 = rho;
239  }
240 
241  if ( verbose > 0 )
242  {
243  amrex::Print() << "MLCGSolver_BiCGStab: Final: Iteration "
244  << std::setw(4) << iter
245  << " rel. err. "
246  << rnorm/(rnorm0) << '\n';
247  }
248 
249  if ( ret == 0 && rnorm > eps_rel*rnorm0 && rnorm > eps_abs)
250  {
251  if ( verbose > 0 && ParallelDescriptor::IOProcessor() ) {
252  amrex::Warning("MLCGSolver_BiCGStab:: failed to converge!");
253  }
254  ret = 8;
255  }
256 
257  if ( ( ret == 0 || ret == 8 ) && (rnorm < rnorm0) )
258  {
259  if ( !initial_vec_zeroed ) {
260  LocalAdd(sol, sorig, 0, 0, ncomp, nghost);
261  }
262  if (ret == 8) { ret = 9; }
263  }
264  else
265  {
266  setVal(sol, RT(0.0));
267  if ( !initial_vec_zeroed ) {
268  LocalAdd(sol, sorig, 0, 0, ncomp, nghost);
269  }
270  }
271 
272  return ret;
273 }
274 
275 template <typename MF>
276 int
277 MLCGSolverT<MF>::solve_cg (MF& sol, const MF& rhs, RT eps_rel, RT eps_abs)
278 {
279  BL_PROFILE("MLCGSolver::cg");
280 
281  const int ncomp = nComp(sol);
282 
283  MF p = Lp.make(amrlev, mglev, nGrowVect(sol));
284  setVal(p, RT(0.0));
285 
286  MF r = Lp.make(amrlev, mglev, nghost);
287  MF q = Lp.make(amrlev, mglev, nghost);
288 
289  MF sorig;
290 
291  if ( initial_vec_zeroed ) {
292  LocalCopy(r,rhs,0,0,ncomp,nghost);
293  } else {
294  sorig = Lp.make(amrlev, mglev, nghost);
295 
296  Lp.correctionResidual(amrlev, mglev, r, sol, rhs, MLLinOpT<MF>::BCMode::Homogeneous);
297 
298  LocalCopy(sorig,sol,0,0,ncomp,nghost);
299  setVal(sol, RT(0.0));
300  }
301 
302  RT rnorm = norm_inf(r);
303  const RT rnorm0 = rnorm;
304 
305  if ( verbose > 0 )
306  {
307  amrex::Print() << "MLCGSolver_CG: Initial error (error0) : " << rnorm0 << '\n';
308  }
309 
310  RT rho_1 = 0;
311  int ret = 0;
312  iter = 1;
313 
314  if ( rnorm0 == 0 || rnorm0 < eps_abs )
315  {
316  if ( verbose > 0 ) {
317  amrex::Print() << "MLCGSolver_CG: niter = 0,"
318  << ", rnorm = " << rnorm
319  << ", eps_abs = " << eps_abs << '\n';
320  }
321  return ret;
322  }
323 
324  for (; iter <= maxiter; ++iter)
325  {
326  RT rho = dotxy(r,r);
327 
328  if ( rho == 0 )
329  {
330  ret = 1; break;
331  }
332  if (iter == 1)
333  {
334  LocalCopy(p,r,0,0,ncomp,nghost);
335  }
336  else
337  {
338  RT beta = rho/rho_1;
339  Xpay(p, beta, r, 0, 0, ncomp, nghost); // p = r + beta * p
340  }
342 
343  RT alpha;
344  RT pw = dotxy(p,q);
345  if ( pw != RT(0.0))
346  {
347  alpha = rho/pw;
348  }
349  else
350  {
351  ret = 1; break;
352  }
353 
354  if ( verbose > 2 )
355  {
356  amrex::Print() << "MLCGSolver_cg:"
357  << " iter " << iter
358  << " rho " << rho
359  << " alpha " << alpha << '\n';
360  }
361  Saxpy(sol, alpha, p, 0, 0, ncomp, nghost); // sol += alpha * p
362  Saxpy(r, -alpha, q, 0, 0, ncomp, nghost); // r += -alpha * q
363  rnorm = norm_inf(r);
364 
365  if ( verbose > 2 )
366  {
367  amrex::Print() << "MLCGSolver_cg: Iteration"
368  << std::setw(4) << iter
369  << " rel. err. "
370  << rnorm/(rnorm0) << '\n';
371  }
372 
373  if ( rnorm < eps_rel*rnorm0 || rnorm < eps_abs ) { break; }
374 
375  rho_1 = rho;
376  }
377 
378  if ( verbose > 0 )
379  {
380  amrex::Print() << "MLCGSolver_cg: Final Iteration"
381  << std::setw(4) << iter
382  << " rel. err. "
383  << rnorm/(rnorm0) << '\n';
384  }
385 
386  if ( ret == 0 && rnorm > eps_rel*rnorm0 && rnorm > eps_abs )
387  {
388  if ( verbose > 0 && ParallelDescriptor::IOProcessor() ) {
389  amrex::Warning("MLCGSolver_cg: failed to converge!");
390  }
391  ret = 8;
392  }
393 
394  if ( ( ret == 0 || ret == 8 ) && (rnorm < rnorm0) )
395  {
396  if ( !initial_vec_zeroed ) {
397  LocalAdd(sol, sorig, 0, 0, ncomp, nghost);
398  }
399  if (ret == 8) { ret = 9; }
400  }
401  else
402  {
403  setVal(sol, RT(0.0));
404  if ( !initial_vec_zeroed ) {
405  LocalAdd(sol, sorig, 0, 0, ncomp, nghost);
406  }
407  }
408 
409  return ret;
410 }
411 
412 template <typename MF>
413 auto
414 MLCGSolverT<MF>::dotxy (const MF& r, const MF& z, bool local) -> RT
415 {
416  BL_PROFILE_VAR_NS("MLCGSolver::ParallelAllReduce", blp_par);
417  if (!local) { BL_PROFILE_VAR_START(blp_par); }
418  RT result = Lp.xdoty(amrlev, mglev, r, z, local);
419  if (!local) { BL_PROFILE_VAR_STOP(blp_par); }
420  return result;
421 }
422 
423 template <typename MF>
424 auto
425 MLCGSolverT<MF>::norm_inf (const MF& res, bool local) -> RT
426 {
427  int ncomp = nComp(res);
428  RT result = norminf(res,0,ncomp,IntVect(0),true);
429  if (!local) {
430  BL_PROFILE("MLCGSolver::ParallelAllReduce");
431  ParallelAllReduce::Max(result, Lp.BottomCommunicator());
432  }
433  return result;
434 }
435 
437 
438 }
439 
440 #endif /*_CGSOLVER_H_*/
#define BL_PROFILE_VAR_START(vname)
Definition: AMReX_BLProfiler.H:562
#define BL_PROFILE(a)
Definition: AMReX_BLProfiler.H:551
#define BL_PROFILE_VAR_STOP(vname)
Definition: AMReX_BLProfiler.H:563
#define BL_PROFILE_VAR(fname, vname)
Definition: AMReX_BLProfiler.H:560
#define BL_PROFILE_VAR_NS(fname, vname)
Definition: AMReX_BLProfiler.H:561
Definition: AMReX_MLCGSolver.H:12
RT dotxy(const MF &r, const MF &z, bool local=false)
Definition: AMReX_MLCGSolver.H:414
int verbose
Definition: AMReX_MLCGSolver.H:71
int iter
Definition: AMReX_MLCGSolver.H:74
MLLinOpT< MF > & Lp
Definition: AMReX_MLCGSolver.H:67
MLCGSolverT< MF > & operator=(const MLCGSolverT< MF > &rhs)=delete
void setSolver(Type _typ) noexcept
Definition: AMReX_MLCGSolver.H:28
bool getInitSolnZeroed() const
Definition: AMReX_MLCGSolver.H:53
void setVerbose(int _verbose)
Definition: AMReX_MLCGSolver.H:39
int getNGhost()
Definition: AMReX_MLCGSolver.H:56
const int mglev
Definition: AMReX_MLCGSolver.H:70
IntVect nghost
Definition: AMReX_MLCGSolver.H:73
int getNumIters() const noexcept
Definition: AMReX_MLCGSolver.H:63
int solve_cg(MF &solnL, const MF &rhsL, RT eps_rel, RT eps_abs)
Definition: AMReX_MLCGSolver.H:277
RT norm_inf(const MF &res, bool local=false)
Definition: AMReX_MLCGSolver.H:425
const int amrlev
Definition: AMReX_MLCGSolver.H:69
bool initial_vec_zeroed
Definition: AMReX_MLCGSolver.H:75
void setInitSolnZeroed(bool _sol_zeroed)
Definition: AMReX_MLCGSolver.H:52
int solve_bicgstab(MF &solnL, const MF &rhsL, RT eps_rel, RT eps_abs)
Definition: AMReX_MLCGSolver.H:98
MLCGSolverT(MLCGSolverT< MF > &&rhs)=delete
typename MLLinOpT< MF >::FAB FAB
Definition: AMReX_MLCGSolver.H:15
int getMaxIter() const
Definition: AMReX_MLCGSolver.H:43
int maxiter
Definition: AMReX_MLCGSolver.H:72
int solve(MF &solnL, const MF &rhsL, RT eps_rel, RT eps_abs)
Definition: AMReX_MLCGSolver.H:87
Type
Definition: AMReX_MLCGSolver.H:18
typename MLLinOpT< MF >::RT RT
Definition: AMReX_MLCGSolver.H:16
void setNGhost(int _nghost)
Definition: AMReX_MLCGSolver.H:55
MLCGSolverT(MLLinOpT< MF > &_lp, Type _typ=Type::BiCGStab)
Definition: AMReX_MLCGSolver.H:79
int getVerbose() const
Definition: AMReX_MLCGSolver.H:40
void setMaxIter(int _maxiter)
Definition: AMReX_MLCGSolver.H:42
Type solver_type
Definition: AMReX_MLCGSolver.H:68
MLCGSolverT(const MLCGSolverT< MF > &rhs)=delete
Definition: AMReX_MLLinOp.H:98
typename FabDataType< MF >::fab_type FAB
Definition: AMReX_MLLinOp.H:108
typename FabDataType< MF >::value_type RT
Definition: AMReX_MLLinOp.H:109
This class provides the user with a few print options.
Definition: AMReX_Print.H:35
void Sum(T &v, MPI_Comm comm)
Definition: AMReX_ParallelReduce.H:204
void Max(KeyValuePair< K, V > &vi, MPI_Comm comm)
Definition: AMReX_ParallelReduce.H:126
bool IOProcessor() noexcept
Is this CPU the I/O Processor? To get the rank number, call IOProcessorNumber()
Definition: AMReX_ParallelDescriptor.H:275
Definition: AMReX_Amr.cpp:49
MF::value_type norminf(MF const &mf, int scomp, int ncomp, IntVect const &nghost, bool local=false)
Definition: AMReX_FabArrayUtility.H:1682
int nComp(FabArrayBase const &fa)
void Saxpy(MF &dst, typename MF::value_type a, MF const &src, int scomp, int dcomp, int ncomp, IntVect const &nghost)
dst += a * src
Definition: AMReX_FabArrayUtility.H:1646
IntVect nGrowVect(FabArrayBase const &fa)
void LocalCopy(DMF &dst, SMF const &src, int scomp, int dcomp, int ncomp, IntVect const &nghost)
dst = src
Definition: AMReX_FabArrayUtility.H:1630
void LocalAdd(MF &dst, MF const &src, int scomp, int dcomp, int ncomp, IntVect const &nghost)
dst += src
Definition: AMReX_FabArrayUtility.H:1638
IntVectND< AMREX_SPACEDIM > IntVect
Definition: AMReX_BaseFwd.H:30
void Xpay(MF &dst, typename MF::value_type a, MF const &src, int scomp, int dcomp, int ncomp, IntVect const &nghost)
dst = src + a * dst
Definition: AMReX_FabArrayUtility.H:1654
void Warning(const std::string &msg)
Print out warning message to cerr.
Definition: AMReX.cpp:227
int verbose
Definition: AMReX_DistributionMapping.cpp:36
void setVal(MF &dst, typename MF::value_type val)
dst = val
Definition: AMReX_FabArrayUtility.H:1607