Block-Structured AMR Software Framework
 
Loading...
Searching...
No Matches
AMReX_MLCGSolver.H
Go to the documentation of this file.
1
2#ifndef AMREX_MLCGSOLVER_H_
3#define AMREX_MLCGSOLVER_H_
4#include <AMReX_Config.H>
5
6#include <AMReX_MLLinOp.H>
7
8namespace amrex {
9
10template <typename MF>
12{
13public:
14
15 using FAB = typename MLLinOpT<MF>::FAB;
16 using RT = typename MLLinOpT<MF>::RT;
17
18 enum struct Type { BiCGStab, CG };
19
22
23 MLCGSolverT (const MLCGSolverT<MF>& rhs) = delete;
24 MLCGSolverT (MLCGSolverT<MF>&& rhs) = delete;
27
28 void setSolver (Type _typ) noexcept { solver_type = _typ; }
29
37 int solve (MF& solnL, const MF& rhsL, RT eps_rel, RT eps_abs);
38
39 void setVerbose (int _verbose) { verbose = _verbose; }
40 [[nodiscard]] int getVerbose () const { return verbose; }
41
42 void setMaxIter (int _maxiter) { maxiter = _maxiter; }
43 [[nodiscard]] int getMaxIter () const { return maxiter; }
44
45 void setPrintIdentation (std::string s) { print_ident = std::move(s); }
46
53 void setInitSolnZeroed (bool _sol_zeroed) { initial_vec_zeroed = _sol_zeroed; }
54 [[nodiscard]] bool getInitSolnZeroed () const { return initial_vec_zeroed; }
55
56 void setNGhost(int _nghost) {nghost = IntVect(_nghost);}
57 [[nodiscard]] int getNGhost() {return nghost[0];}
58
59 [[nodiscard]] RT dotxy (const MF& r, const MF& z, bool local = false);
60 [[nodiscard]] RT norm_inf (const MF& res, bool local = false);
61 int solve_bicgstab (MF& solnL, const MF& rhsL, RT eps_rel, RT eps_abs);
62 int solve_cg (MF& solnL, const MF& rhsL, RT eps_rel, RT eps_abs);
63
64 [[nodiscard]] int getNumIters () const noexcept { return iter; }
65
66private:
67
70 const int amrlev = 0;
71 const int mglev;
72 int verbose = 0;
73 int maxiter = 100;
75 int iter = -1;
76 bool initial_vec_zeroed = false;
77 std::string print_ident;
78};
79
80template <typename MF>
82 : Lp(_lp), solver_type(_typ), mglev(_lp.NMGLevels(0)-1)
83{}
84
85template <typename MF> MLCGSolverT<MF>::~MLCGSolverT () = default;
86
87template <typename MF>
88int
89MLCGSolverT<MF>::solve (MF& sol, const MF& rhs, RT eps_rel, RT eps_abs)
90{
91 if (solver_type == Type::BiCGStab) {
92 return solve_bicgstab(sol,rhs,eps_rel,eps_abs);
93 } else {
94 return solve_cg(sol,rhs,eps_rel,eps_abs);
95 }
96}
97
98template <typename MF>
99int
100MLCGSolverT<MF>::solve_bicgstab (MF& sol, const MF& rhs, RT eps_rel, RT eps_abs)
101{
102 BL_PROFILE("MLCGSolver::bicgstab");
103
104 const int ncomp = nComp(sol);
105
106 MF p = Lp.make(amrlev, mglev, nGrowVect(sol));
107 MF r = Lp.make(amrlev, mglev, nGrowVect(sol));
108 setVal(p, RT(0.0)); // Make sure all entries are initialized to avoid errors
109 setVal(r, RT(0.0));
110
111 MF rh = Lp.make(amrlev, mglev, nghost);
112 MF v = Lp.make(amrlev, mglev, nghost);
113 MF t = Lp.make(amrlev, mglev, nghost);
114
115
116 MF sorig;
117
118 if ( initial_vec_zeroed ) {
119 LocalCopy(r,rhs,0,0,ncomp,nghost);
120 } else {
121 sorig = Lp.make(amrlev, mglev, nghost);
122
123 Lp.correctionResidual(amrlev, mglev, r, sol, rhs, MLLinOpT<MF>::BCMode::Homogeneous);
124
125 LocalCopy(sorig,sol,0,0,ncomp,nghost);
126 setVal(sol, RT(0.0));
127 }
128
129 // Then normalize
130 Lp.normalize(amrlev, mglev, r);
131 LocalCopy(rh, r, 0,0,ncomp,nghost);
132
133 RT rnorm = norm_inf(r);
134 const RT rnorm0 = rnorm;
135
136 if ( verbose > 0 )
137 {
138 amrex::Print() << print_ident << "MLCGSolver_BiCGStab: Initial error (error0) = " << rnorm0 << '\n';
139 }
140 int ret = 0;
141 iter = 1;
142 RT rho_1 = 0, alpha = 0, omega = 0;
143
144 if ( rnorm0 == 0 || rnorm0 < eps_abs )
145 {
146 if ( verbose > 0 )
147 {
148 amrex::Print() << print_ident << "MLCGSolver_BiCGStab: niter = 0,"
149 << ", rnorm = " << rnorm
150 << ", eps_abs = " << eps_abs << '\n';
151 }
152 return ret;
153 }
154
155 for (; iter <= maxiter; ++iter)
156 {
157 const RT rho = dotxy(rh,r);
158 if ( rho == 0 )
159 {
160 ret = 1; break;
161 }
162 if ( iter == 1 )
163 {
164 LocalCopy(p,r,0,0,ncomp,nghost);
165 }
166 else
167 {
168 const RT beta = (rho/rho_1)*(alpha/omega);
169 Saxpy(p, -omega, v, 0, 0, ncomp, nghost); // p += -omega*v
170 Xpay(p, beta, r, 0, 0, ncomp, nghost); // p = r + beta*p
171 }
173 Lp.normalize(amrlev, mglev, v);
174
175 RT rhTv = dotxy(rh,v);
176 if ( rhTv != RT(0.0) )
177 {
178 alpha = rho/rhTv;
179 }
180 else
181 {
182 ret = 2; break;
183 }
184 Saxpy(sol, alpha, p, 0, 0, ncomp, nghost); // sol += alpha * p
185 Saxpy(r, -alpha, v, 0, 0, ncomp, nghost); // r += -alpha * v
186
187 rnorm = norm_inf(r);
188
190 {
191 amrex::Print() << print_ident << "MLCGSolver_BiCGStab: Half Iter "
192 << std::setw(11) << iter
193 << " rel. err. "
194 << rnorm/(rnorm0) << '\n';
195 }
196
197 if ( rnorm < eps_rel*rnorm0 || rnorm < eps_abs ) { break; }
198
200 Lp.normalize(amrlev, mglev, t);
201 //
202 // This is a little funky. I want to elide one of the reductions
203 // in the following two dotxy()s. We do that by calculating the "local"
204 // values and then reducing the two local values at the same time.
205 //
206 RT tvals[2] = { dotxy(t,t,true), dotxy(t,r,true) };
207
208 BL_PROFILE_VAR("MLCGSolver::ParallelAllReduce", blp_par);
209 ParallelAllReduce::Sum(tvals,2,Lp.BottomCommunicator());
210 BL_PROFILE_VAR_STOP(blp_par);
211
212 if ( tvals[0] != RT(0.0) )
213 {
214 omega = tvals[1]/tvals[0];
215 }
216 else
217 {
218 ret = 3; break;
219 }
220 Saxpy(sol, omega, r, 0, 0, ncomp, nghost); // sol += omega * r
221 Saxpy(r, -omega, t, 0, 0, ncomp, nghost); // r += -omega * t
222
223 rnorm = norm_inf(r);
224
225 if ( verbose > 2 )
226 {
227 amrex::Print() << print_ident << "MLCGSolver_BiCGStab: Iteration "
228 << std::setw(11) << iter
229 << " rel. err. "
230 << rnorm/(rnorm0) << '\n';
231 }
232
233 if ( rnorm < eps_rel*rnorm0 || rnorm < eps_abs ) { break; }
234
235 if ( omega == 0 )
236 {
237 ret = 4; break;
238 }
239 rho_1 = rho;
240 }
241
242 if ( verbose > 0 )
243 {
244 amrex::Print() << print_ident << "MLCGSolver_BiCGStab: Final: Iteration "
245 << std::setw(4) << iter
246 << " rel. err. "
247 << rnorm/(rnorm0) << '\n';
248 }
249
250 if ( ret == 0 && rnorm > eps_rel*rnorm0 && rnorm > eps_abs)
251 {
253 amrex::Warning("MLCGSolver_BiCGStab:: failed to converge!");
254 }
255 ret = 8;
256 }
257
258 if ( ( ret == 0 || ret == 8 ) && (rnorm < rnorm0) )
259 {
260 if ( !initial_vec_zeroed ) {
261 LocalAdd(sol, sorig, 0, 0, ncomp, nghost);
262 }
263 if (ret == 8) { ret = 9; }
264 }
265 else
266 {
267 setVal(sol, RT(0.0));
268 if ( !initial_vec_zeroed ) {
269 LocalAdd(sol, sorig, 0, 0, ncomp, nghost);
270 }
271 }
272
273 return ret;
274}
275
276template <typename MF>
277int
278MLCGSolverT<MF>::solve_cg (MF& sol, const MF& rhs, RT eps_rel, RT eps_abs)
279{
280 BL_PROFILE("MLCGSolver::cg");
281
282 const int ncomp = nComp(sol);
283
284 MF p = Lp.make(amrlev, mglev, nGrowVect(sol));
285 setVal(p, RT(0.0));
286
287 MF r = Lp.make(amrlev, mglev, nghost);
288 MF q = Lp.make(amrlev, mglev, nghost);
289
290 MF sorig;
291
292 if ( initial_vec_zeroed ) {
293 LocalCopy(r,rhs,0,0,ncomp,nghost);
294 } else {
295 sorig = Lp.make(amrlev, mglev, nghost);
296
297 Lp.correctionResidual(amrlev, mglev, r, sol, rhs, MLLinOpT<MF>::BCMode::Homogeneous);
298
299 LocalCopy(sorig,sol,0,0,ncomp,nghost);
300 setVal(sol, RT(0.0));
301 }
302
303 RT rnorm = norm_inf(r);
304 const RT rnorm0 = rnorm;
305
306 if ( verbose > 0 )
307 {
308 amrex::Print() << print_ident << "MLCGSolver_CG: Initial error (error0) : " << rnorm0 << '\n';
309 }
310
311 RT rho_1 = 0;
312 int ret = 0;
313 iter = 1;
314
315 if ( rnorm0 == 0 || rnorm0 < eps_abs )
316 {
317 if ( verbose > 0 ) {
318 amrex::Print() << print_ident << "MLCGSolver_CG: niter = 0,"
319 << ", rnorm = " << rnorm
320 << ", eps_abs = " << eps_abs << '\n';
321 }
322 return ret;
323 }
324
325 for (; iter <= maxiter; ++iter)
326 {
327 RT rho = dotxy(r,r);
328
329 if ( rho == 0 )
330 {
331 ret = 1; break;
332 }
333 if (iter == 1)
334 {
335 LocalCopy(p,r,0,0,ncomp,nghost);
336 }
337 else
338 {
339 RT beta = rho/rho_1;
340 Xpay(p, beta, r, 0, 0, ncomp, nghost); // p = r + beta * p
341 }
343
344 RT alpha;
345 RT pw = dotxy(p,q);
346 if ( pw != RT(0.0))
347 {
348 alpha = rho/pw;
349 }
350 else
351 {
352 ret = 1; break;
353 }
354
355 if ( verbose > 2 )
356 {
357 amrex::Print() << print_ident << "MLCGSolver_cg:"
358 << " iter " << iter
359 << " rho " << rho
360 << " alpha " << alpha << '\n';
361 }
362 Saxpy(sol, alpha, p, 0, 0, ncomp, nghost); // sol += alpha * p
363 Saxpy(r, -alpha, q, 0, 0, ncomp, nghost); // r += -alpha * q
364 rnorm = norm_inf(r);
365
366 if ( verbose > 2 )
367 {
368 amrex::Print() << print_ident << "MLCGSolver_cg: Iteration"
369 << std::setw(4) << iter
370 << " rel. err. "
371 << rnorm/(rnorm0) << '\n';
372 }
373
374 if ( rnorm < eps_rel*rnorm0 || rnorm < eps_abs ) { break; }
375
376 rho_1 = rho;
377 }
378
379 if ( verbose > 0 )
380 {
381 amrex::Print() << print_ident << "MLCGSolver_cg: Final Iteration"
382 << std::setw(4) << iter
383 << " rel. err. "
384 << rnorm/(rnorm0) << '\n';
385 }
386
387 if ( ret == 0 && rnorm > eps_rel*rnorm0 && rnorm > eps_abs )
388 {
390 amrex::Warning("MLCGSolver_cg: failed to converge!");
391 }
392 ret = 8;
393 }
394
395 if ( ( ret == 0 || ret == 8 ) && (rnorm < rnorm0) )
396 {
397 if ( !initial_vec_zeroed ) {
398 LocalAdd(sol, sorig, 0, 0, ncomp, nghost);
399 }
400 if (ret == 8) { ret = 9; }
401 }
402 else
403 {
404 setVal(sol, RT(0.0));
405 if ( !initial_vec_zeroed ) {
406 LocalAdd(sol, sorig, 0, 0, ncomp, nghost);
407 }
408 }
409
410 return ret;
411}
412
413template <typename MF>
414auto
415MLCGSolverT<MF>::dotxy (const MF& r, const MF& z, bool local) -> RT
416{
417 BL_PROFILE_VAR_NS("MLCGSolver::ParallelAllReduce", blp_par);
418 if (!local) { BL_PROFILE_VAR_START(blp_par); }
419 RT result = Lp.xdoty(amrlev, mglev, r, z, local);
420 if (!local) { BL_PROFILE_VAR_STOP(blp_par); }
421 return result;
422}
423
424template <typename MF>
425auto
426MLCGSolverT<MF>::norm_inf (const MF& res, bool local) -> RT
427{
428 int ncomp = nComp(res);
429 RT result = norminf(res,0,ncomp,IntVect(0),true);
430 if (!local) {
431 BL_PROFILE("MLCGSolver::ParallelAllReduce");
432 ParallelAllReduce::Max(result, Lp.BottomCommunicator());
433 }
434 return result;
435}
436
438
439}
440
441#endif /*_CGSOLVER_H_*/
#define BL_PROFILE_VAR_START(vname)
Definition AMReX_BLProfiler.H:562
#define BL_PROFILE(a)
Definition AMReX_BLProfiler.H:551
#define BL_PROFILE_VAR_STOP(vname)
Definition AMReX_BLProfiler.H:563
#define BL_PROFILE_VAR(fname, vname)
Definition AMReX_BLProfiler.H:560
#define BL_PROFILE_VAR_NS(fname, vname)
Definition AMReX_BLProfiler.H:561
Definition AMReX_MLCGSolver.H:12
RT dotxy(const MF &r, const MF &z, bool local=false)
Definition AMReX_MLCGSolver.H:415
int verbose
Definition AMReX_MLCGSolver.H:72
int iter
Definition AMReX_MLCGSolver.H:75
MLLinOpT< MF > & Lp
Definition AMReX_MLCGSolver.H:68
void setSolver(Type _typ) noexcept
Definition AMReX_MLCGSolver.H:28
bool getInitSolnZeroed() const
Definition AMReX_MLCGSolver.H:54
void setVerbose(int _verbose)
Definition AMReX_MLCGSolver.H:39
int getNGhost()
Definition AMReX_MLCGSolver.H:57
const int mglev
Definition AMReX_MLCGSolver.H:71
IntVect nghost
Definition AMReX_MLCGSolver.H:74
MLCGSolverT< MF > & operator=(const MLCGSolverT< MF > &rhs)=delete
int getNumIters() const noexcept
Definition AMReX_MLCGSolver.H:64
int solve_cg(MF &solnL, const MF &rhsL, RT eps_rel, RT eps_abs)
Definition AMReX_MLCGSolver.H:278
RT norm_inf(const MF &res, bool local=false)
Definition AMReX_MLCGSolver.H:426
const int amrlev
Definition AMReX_MLCGSolver.H:70
bool initial_vec_zeroed
Definition AMReX_MLCGSolver.H:76
void setInitSolnZeroed(bool _sol_zeroed)
Definition AMReX_MLCGSolver.H:53
int solve_bicgstab(MF &solnL, const MF &rhsL, RT eps_rel, RT eps_abs)
Definition AMReX_MLCGSolver.H:100
MLCGSolverT(MLCGSolverT< MF > &&rhs)=delete
typename MLLinOpT< MF >::FAB FAB
Definition AMReX_MLCGSolver.H:15
int getMaxIter() const
Definition AMReX_MLCGSolver.H:43
int maxiter
Definition AMReX_MLCGSolver.H:73
void setPrintIdentation(std::string s)
Definition AMReX_MLCGSolver.H:45
int solve(MF &solnL, const MF &rhsL, RT eps_rel, RT eps_abs)
Definition AMReX_MLCGSolver.H:89
std::string print_ident
Definition AMReX_MLCGSolver.H:77
Type
Definition AMReX_MLCGSolver.H:18
typename MLLinOpT< MF >::RT RT
Definition AMReX_MLCGSolver.H:16
void setNGhost(int _nghost)
Definition AMReX_MLCGSolver.H:56
MLCGSolverT(MLLinOpT< MF > &_lp, Type _typ=Type::BiCGStab)
Definition AMReX_MLCGSolver.H:81
int getVerbose() const
Definition AMReX_MLCGSolver.H:40
void setMaxIter(int _maxiter)
Definition AMReX_MLCGSolver.H:42
Type solver_type
Definition AMReX_MLCGSolver.H:69
MLCGSolverT(const MLCGSolverT< MF > &rhs)=delete
Definition AMReX_MLLinOp.H:98
typename FabDataType< MF >::fab_type FAB
Definition AMReX_MLLinOp.H:108
typename FabDataType< MF >::value_type RT
Definition AMReX_MLLinOp.H:109
This class provides the user with a few print options.
Definition AMReX_Print.H:35
void Sum(T &v, MPI_Comm comm)
Definition AMReX_ParallelReduce.H:204
void Max(KeyValuePair< K, V > &vi, MPI_Comm comm)
Definition AMReX_ParallelReduce.H:126
bool IOProcessor() noexcept
Is this CPU the I/O Processor? To get the rank number, call IOProcessorNumber()
Definition AMReX_ParallelDescriptor.H:275
Definition AMReX_Amr.cpp:49
MF::value_type norminf(MF const &mf, int scomp, int ncomp, IntVect const &nghost, bool local=false)
Definition AMReX_FabArrayUtility.H:1883
int nComp(FabArrayBase const &fa)
void Saxpy(MF &dst, typename MF::value_type a, MF const &src, int scomp, int dcomp, int ncomp, IntVect const &nghost)
dst += a * src
Definition AMReX_FabArrayUtility.H:1847
IntVect nGrowVect(FabArrayBase const &fa)
void LocalCopy(DMF &dst, SMF const &src, int scomp, int dcomp, int ncomp, IntVect const &nghost)
dst = src
Definition AMReX_FabArrayUtility.H:1831
void LocalAdd(MF &dst, MF const &src, int scomp, int dcomp, int ncomp, IntVect const &nghost)
dst += src
Definition AMReX_FabArrayUtility.H:1839
IntVectND< AMREX_SPACEDIM > IntVect
Definition AMReX_BaseFwd.H:30
void Xpay(MF &dst, typename MF::value_type a, MF const &src, int scomp, int dcomp, int ncomp, IntVect const &nghost)
dst = src + a * dst
Definition AMReX_FabArrayUtility.H:1855
void Warning(const std::string &msg)
Print out warning message to cerr.
Definition AMReX.cpp:231
int verbose
Definition AMReX_DistributionMapping.cpp:36
void setVal(MF &dst, typename MF::value_type val)
dst = val
Definition AMReX_FabArrayUtility.H:1808