Block-Structured AMR Software Framework
 
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages
Loading...
Searching...
No Matches
AMReX_MLCGSolver.H
Go to the documentation of this file.
1
2#ifndef AMREX_MLCGSOLVER_H_
3#define AMREX_MLCGSOLVER_H_
4#include <AMReX_Config.H>
5
6#include <AMReX_MLLinOp.H>
7
8namespace amrex {
9
10template <typename MF>
12{
13public:
14
15 using FAB = typename MLLinOpT<MF>::FAB;
16 using RT = typename MLLinOpT<MF>::RT;
17
18 enum struct Type { BiCGStab, CG };
19
22
23 MLCGSolverT (const MLCGSolverT<MF>& rhs) = delete;
24 MLCGSolverT (MLCGSolverT<MF>&& rhs) = delete;
27
28 void setSolver (Type _typ) noexcept { solver_type = _typ; }
29
37 int solve (MF& solnL, const MF& rhsL, RT eps_rel, RT eps_abs);
38
39 void setVerbose (int _verbose) { verbose = _verbose; }
40 [[nodiscard]] int getVerbose () const { return verbose; }
41
42 void setMaxIter (int _maxiter) { maxiter = _maxiter; }
43 [[nodiscard]] int getMaxIter () const { return maxiter; }
44
45 void setPrintIdentation (std::string s) { print_ident = std::move(s); }
46
53 void setInitSolnZeroed (bool _sol_zeroed) { initial_vec_zeroed = _sol_zeroed; }
54 [[nodiscard]] bool getInitSolnZeroed () const { return initial_vec_zeroed; }
55
56 void setNGhost(int _nghost) {nghost = IntVect(_nghost);}
57 [[nodiscard]] int getNGhost() {return nghost[0];}
58
59 [[nodiscard]] RT dotxy (const MF& r, const MF& z, bool local = false);
60 [[nodiscard]] RT norm_inf (const MF& res, bool local = false);
61 int solve_bicgstab (MF& solnL, const MF& rhsL, RT eps_rel, RT eps_abs);
62 int solve_cg (MF& solnL, const MF& rhsL, RT eps_rel, RT eps_abs);
63
64 [[nodiscard]] int getNumIters () const noexcept { return iter; }
65
66private:
67
70 const int amrlev = 0;
71 const int mglev;
72 int verbose = 0;
73 int maxiter = 100;
75 int iter = -1;
76 bool initial_vec_zeroed = false;
77 std::string print_ident;
78};
79
80template <typename MF>
82 : Lp(_lp), solver_type(_typ), mglev(_lp.NMGLevels(0)-1)
83{}
84
85template <typename MF> MLCGSolverT<MF>::~MLCGSolverT () = default;
86
87template <typename MF>
88int
89MLCGSolverT<MF>::solve (MF& sol, const MF& rhs, RT eps_rel, RT eps_abs)
90{
91 if (solver_type == Type::BiCGStab) {
92 return solve_bicgstab(sol,rhs,eps_rel,eps_abs);
93 } else {
94 return solve_cg(sol,rhs,eps_rel,eps_abs);
95 }
96}
97
98template <typename MF>
99int
100MLCGSolverT<MF>::solve_bicgstab (MF& sol, const MF& rhs, RT eps_rel, RT eps_abs)
101{
102 BL_PROFILE("MLCGSolver::bicgstab");
103
104 const int ncomp = nComp(sol);
105
106 MF p = Lp.make(amrlev, mglev, nGrowVect(sol));
107 MF r = Lp.make(amrlev, mglev, nGrowVect(sol));
108 setVal(p, RT(0.0)); // Make sure all entries are initialized to avoid errors
109 setVal(r, RT(0.0));
110
111 MF rh = Lp.make(amrlev, mglev, nghost);
112 MF v = Lp.make(amrlev, mglev, nghost);
113 MF t = Lp.make(amrlev, mglev, nghost);
114
115
116 MF sorig;
117
118 if ( initial_vec_zeroed ) {
119 LocalCopy(r,rhs,0,0,ncomp,nghost);
120 } else {
121 sorig = Lp.make(amrlev, mglev, nghost);
122
123 Lp.correctionResidual(amrlev, mglev, r, sol, rhs, MLLinOpT<MF>::BCMode::Homogeneous);
124
125 LocalCopy(sorig,sol,0,0,ncomp,nghost);
126 setVal(sol, RT(0.0));
127 }
128
129 // Then normalize
130 Lp.normalize(amrlev, mglev, r);
131 LocalCopy(rh, r, 0,0,ncomp,nghost);
132
133 RT rnorm = norm_inf(r);
134 const RT rnorm0 = rnorm;
135
136 if ( verbose > 0 )
137 {
138 amrex::Print() << print_ident << "MLCGSolver_BiCGStab: Initial error (error0) = " << rnorm0 << '\n';
139 }
140 int ret = 0;
141 iter = 1;
142 RT rho_1 = 0, alpha = 0, omega = 0;
143
144 if ( rnorm0 == 0 || rnorm0 < eps_abs )
145 {
146 if ( verbose > 0 )
147 {
148 amrex::Print() << print_ident << "MLCGSolver_BiCGStab: niter = 0,"
149 << ", rnorm = " << rnorm
150 << ", eps_abs = " << eps_abs << '\n';
151 }
152 return ret;
153 }
154
155 for (; iter <= maxiter; ++iter)
156 {
157 const RT rho = dotxy(rh,r);
158 if ( rho == 0 )
159 {
160 ret = 1; break;
161 }
162 if ( iter == 1 )
163 {
164 LocalCopy(p,r,0,0,ncomp,nghost);
165 }
166 else
167 {
168 const RT beta = (rho/rho_1)*(alpha/omega);
169 if constexpr (IsMultiFabLike_v<MF>) {
170 // two operations: p += -omega*v; p = r + beta*p
171 // same as: p = r + beta*(p - omega*v)
172 Saxpy_Xpay(p, -omega, v, beta, r, 0, 0, ncomp, nghost);
173 } else {
174 Saxpy(p, -omega, v, 0, 0, ncomp, nghost); // p += -omega*v
175 Xpay(p, beta, r, 0, 0, ncomp, nghost); // p = r + beta*p
176 }
177 }
179 Lp.normalize(amrlev, mglev, v);
180
181 RT rhTv = dotxy(rh,v);
182 if ( rhTv != RT(0.0) )
183 {
184 alpha = rho/rhTv;
185 }
186 else
187 {
188 ret = 2; break;
189 }
190 if constexpr (IsMultiFabLike_v<MF>) {
191 // sol += alpha * p; r += -alpha * v
192 Saxpy_Saxpy(sol, alpha, p, r, -alpha, v, 0, 0, ncomp, nghost);
193 } else {
194 Saxpy(sol, alpha, p, 0, 0, ncomp, nghost); // sol += alpha * p
195 Saxpy(r, -alpha, v, 0, 0, ncomp, nghost); // r += -alpha * v
196 }
197
198 rnorm = norm_inf(r);
199
201 {
202 amrex::Print() << print_ident << "MLCGSolver_BiCGStab: Half Iter "
203 << std::setw(11) << iter
204 << " rel. err. "
205 << rnorm/(rnorm0) << '\n';
206 }
207
208 if ( rnorm < eps_rel*rnorm0 || rnorm < eps_abs ) { break; }
209
211 Lp.normalize(amrlev, mglev, t);
212 //
213 // This is a little funky. I want to elide one of the reductions
214 // in the following two dotxy()s. We do that by calculating the "local"
215 // values and then reducing the two local values at the same time.
216 //
217 RT tvals[2] = { dotxy(t,t,true), dotxy(t,r,true) };
218
219 BL_PROFILE_VAR("MLCGSolver::ParallelAllReduce", blp_par);
220 ParallelAllReduce::Sum(tvals,2,Lp.BottomCommunicator());
221 BL_PROFILE_VAR_STOP(blp_par);
222
223 if ( tvals[0] != RT(0.0) )
224 {
225 omega = tvals[1]/tvals[0];
226 }
227 else
228 {
229 ret = 3; break;
230 }
231 if constexpr (IsMultiFabLike_v<MF>) {
232 // sol += omega * r; r += -omega * t
233 Saypy_Saxpy(sol, omega, r, -omega, t, 0, 0, ncomp, nghost);
234 } else {
235 Saxpy(sol, omega, r, 0, 0, ncomp, nghost); // sol += omega * r
236 Saxpy(r, -omega, t, 0, 0, ncomp, nghost); // r += -omega * t
237 }
238
239 rnorm = norm_inf(r);
240
241 if ( verbose > 2 )
242 {
243 amrex::Print() << print_ident << "MLCGSolver_BiCGStab: Iteration "
244 << std::setw(11) << iter
245 << " rel. err. "
246 << rnorm/(rnorm0) << '\n';
247 }
248
249 if ( rnorm < eps_rel*rnorm0 || rnorm < eps_abs ) { break; }
250
251 if ( omega == 0 )
252 {
253 ret = 4; break;
254 }
255 rho_1 = rho;
256 }
257
258 if ( verbose > 0 )
259 {
260 amrex::Print() << print_ident << "MLCGSolver_BiCGStab: Final: Iteration "
261 << std::setw(4) << iter
262 << " rel. err. "
263 << rnorm/(rnorm0) << '\n';
264 }
265
266 if ( ret == 0 && rnorm > eps_rel*rnorm0 && rnorm > eps_abs)
267 {
269 amrex::Warning("MLCGSolver_BiCGStab:: failed to converge!");
270 }
271 ret = 8;
272 }
273
274 if ( ( ret == 0 || ret == 8 ) && (rnorm < rnorm0) )
275 {
276 if ( !initial_vec_zeroed ) {
277 LocalAdd(sol, sorig, 0, 0, ncomp, nghost);
278 }
279 if (ret == 8) { ret = 9; }
280 }
281 else
282 {
283 setVal(sol, RT(0.0));
284 if ( !initial_vec_zeroed ) {
285 LocalAdd(sol, sorig, 0, 0, ncomp, nghost);
286 }
287 }
288
289 return ret;
290}
291
292template <typename MF>
293int
294MLCGSolverT<MF>::solve_cg (MF& sol, const MF& rhs, RT eps_rel, RT eps_abs)
295{
296 BL_PROFILE("MLCGSolver::cg");
297
298 const int ncomp = nComp(sol);
299
300 MF p = Lp.make(amrlev, mglev, nGrowVect(sol));
301 setVal(p, RT(0.0));
302
303 MF r = Lp.make(amrlev, mglev, nghost);
304 MF q = Lp.make(amrlev, mglev, nghost);
305
306 MF sorig;
307
308 if ( initial_vec_zeroed ) {
309 LocalCopy(r,rhs,0,0,ncomp,nghost);
310 } else {
311 sorig = Lp.make(amrlev, mglev, nghost);
312
313 Lp.correctionResidual(amrlev, mglev, r, sol, rhs, MLLinOpT<MF>::BCMode::Homogeneous);
314
315 LocalCopy(sorig,sol,0,0,ncomp,nghost);
316 setVal(sol, RT(0.0));
317 }
318
319 RT rnorm = norm_inf(r);
320 const RT rnorm0 = rnorm;
321
322 if ( verbose > 0 )
323 {
324 amrex::Print() << print_ident << "MLCGSolver_CG: Initial error (error0) : " << rnorm0 << '\n';
325 }
326
327 RT rho_1 = 0;
328 int ret = 0;
329 iter = 1;
330
331 if ( rnorm0 == 0 || rnorm0 < eps_abs )
332 {
333 if ( verbose > 0 ) {
334 amrex::Print() << print_ident << "MLCGSolver_CG: niter = 0,"
335 << ", rnorm = " << rnorm
336 << ", eps_abs = " << eps_abs << '\n';
337 }
338 return ret;
339 }
340
341 for (; iter <= maxiter; ++iter)
342 {
343 RT rho = dotxy(r,r);
344
345 if ( rho == 0 )
346 {
347 ret = 1; break;
348 }
349 if (iter == 1)
350 {
351 LocalCopy(p,r,0,0,ncomp,nghost);
352 }
353 else
354 {
355 RT beta = rho/rho_1;
356 Xpay(p, beta, r, 0, 0, ncomp, nghost); // p = r + beta * p
357 }
359
360 RT alpha;
361 RT pw = dotxy(p,q);
362 if ( pw != RT(0.0))
363 {
364 alpha = rho/pw;
365 }
366 else
367 {
368 ret = 1; break;
369 }
370
371 if ( verbose > 2 )
372 {
373 amrex::Print() << print_ident << "MLCGSolver_cg:"
374 << " iter " << iter
375 << " rho " << rho
376 << " alpha " << alpha << '\n';
377 }
378 if constexpr (IsMultiFabLike_v<MF>) {
379 // sol += alpha * p; r += -alpha * q
380 Saxpy_Saxpy(sol, alpha, p, r, -alpha, q, 0, 0, ncomp, nghost);
381 } else {
382 Saxpy(sol, alpha, p, 0, 0, ncomp, nghost); // sol += alpha * p
383 Saxpy(r, -alpha, q, 0, 0, ncomp, nghost); // r += -alpha * q
384 }
385 rnorm = norm_inf(r);
386
387 if ( verbose > 2 )
388 {
389 amrex::Print() << print_ident << "MLCGSolver_cg: Iteration"
390 << std::setw(4) << iter
391 << " rel. err. "
392 << rnorm/(rnorm0) << '\n';
393 }
394
395 if ( rnorm < eps_rel*rnorm0 || rnorm < eps_abs ) { break; }
396
397 rho_1 = rho;
398 }
399
400 if ( verbose > 0 )
401 {
402 amrex::Print() << print_ident << "MLCGSolver_cg: Final Iteration"
403 << std::setw(4) << iter
404 << " rel. err. "
405 << rnorm/(rnorm0) << '\n';
406 }
407
408 if ( ret == 0 && rnorm > eps_rel*rnorm0 && rnorm > eps_abs )
409 {
411 amrex::Warning("MLCGSolver_cg: failed to converge!");
412 }
413 ret = 8;
414 }
415
416 if ( ( ret == 0 || ret == 8 ) && (rnorm < rnorm0) )
417 {
418 if ( !initial_vec_zeroed ) {
419 LocalAdd(sol, sorig, 0, 0, ncomp, nghost);
420 }
421 if (ret == 8) { ret = 9; }
422 }
423 else
424 {
425 setVal(sol, RT(0.0));
426 if ( !initial_vec_zeroed ) {
427 LocalAdd(sol, sorig, 0, 0, ncomp, nghost);
428 }
429 }
430
431 return ret;
432}
433
434template <typename MF>
435auto
436MLCGSolverT<MF>::dotxy (const MF& r, const MF& z, bool local) -> RT
437{
438 BL_PROFILE_VAR_NS("MLCGSolver::ParallelAllReduce", blp_par);
439 if (!local) { BL_PROFILE_VAR_START(blp_par); }
440 RT result = Lp.xdoty(amrlev, mglev, r, z, local);
441 if (!local) { BL_PROFILE_VAR_STOP(blp_par); }
442 return result;
443}
444
445template <typename MF>
446auto
447MLCGSolverT<MF>::norm_inf (const MF& res, bool local) -> RT
448{
449 int ncomp = nComp(res);
450 RT result = norminf(res,0,ncomp,IntVect(0),true);
451 if (!local) {
452 BL_PROFILE("MLCGSolver::ParallelAllReduce");
453 ParallelAllReduce::Max(result, Lp.BottomCommunicator());
454 }
455 return result;
456}
457
459
460}
461
462#endif /*_CGSOLVER_H_*/
#define BL_PROFILE_VAR_START(vname)
Definition AMReX_BLProfiler.H:562
#define BL_PROFILE(a)
Definition AMReX_BLProfiler.H:551
#define BL_PROFILE_VAR_STOP(vname)
Definition AMReX_BLProfiler.H:563
#define BL_PROFILE_VAR(fname, vname)
Definition AMReX_BLProfiler.H:560
#define BL_PROFILE_VAR_NS(fname, vname)
Definition AMReX_BLProfiler.H:561
Definition AMReX_MLCGSolver.H:12
RT dotxy(const MF &r, const MF &z, bool local=false)
Definition AMReX_MLCGSolver.H:436
int verbose
Definition AMReX_MLCGSolver.H:72
int iter
Definition AMReX_MLCGSolver.H:75
MLLinOpT< MF > & Lp
Definition AMReX_MLCGSolver.H:68
void setSolver(Type _typ) noexcept
Definition AMReX_MLCGSolver.H:28
bool getInitSolnZeroed() const
Definition AMReX_MLCGSolver.H:54
void setVerbose(int _verbose)
Definition AMReX_MLCGSolver.H:39
int getNGhost()
Definition AMReX_MLCGSolver.H:57
const int mglev
Definition AMReX_MLCGSolver.H:71
IntVect nghost
Definition AMReX_MLCGSolver.H:74
MLCGSolverT< MF > & operator=(const MLCGSolverT< MF > &rhs)=delete
int getNumIters() const noexcept
Definition AMReX_MLCGSolver.H:64
int solve_cg(MF &solnL, const MF &rhsL, RT eps_rel, RT eps_abs)
Definition AMReX_MLCGSolver.H:294
RT norm_inf(const MF &res, bool local=false)
Definition AMReX_MLCGSolver.H:447
const int amrlev
Definition AMReX_MLCGSolver.H:70
bool initial_vec_zeroed
Definition AMReX_MLCGSolver.H:76
void setInitSolnZeroed(bool _sol_zeroed)
Definition AMReX_MLCGSolver.H:53
int solve_bicgstab(MF &solnL, const MF &rhsL, RT eps_rel, RT eps_abs)
Definition AMReX_MLCGSolver.H:100
MLCGSolverT(MLCGSolverT< MF > &&rhs)=delete
typename MLLinOpT< MF >::FAB FAB
Definition AMReX_MLCGSolver.H:15
int getMaxIter() const
Definition AMReX_MLCGSolver.H:43
int maxiter
Definition AMReX_MLCGSolver.H:73
void setPrintIdentation(std::string s)
Definition AMReX_MLCGSolver.H:45
int solve(MF &solnL, const MF &rhsL, RT eps_rel, RT eps_abs)
Definition AMReX_MLCGSolver.H:89
std::string print_ident
Definition AMReX_MLCGSolver.H:77
Type
Definition AMReX_MLCGSolver.H:18
typename MLLinOpT< MF >::RT RT
Definition AMReX_MLCGSolver.H:16
void setNGhost(int _nghost)
Definition AMReX_MLCGSolver.H:56
MLCGSolverT(MLLinOpT< MF > &_lp, Type _typ=Type::BiCGStab)
Definition AMReX_MLCGSolver.H:81
int getVerbose() const
Definition AMReX_MLCGSolver.H:40
void setMaxIter(int _maxiter)
Definition AMReX_MLCGSolver.H:42
Type solver_type
Definition AMReX_MLCGSolver.H:69
MLCGSolverT(const MLCGSolverT< MF > &rhs)=delete
Definition AMReX_MLLinOp.H:98
typename FabDataType< MF >::fab_type FAB
Definition AMReX_MLLinOp.H:108
typename FabDataType< MF >::value_type RT
Definition AMReX_MLLinOp.H:109
This class provides the user with a few print options.
Definition AMReX_Print.H:35
void Sum(T &v, MPI_Comm comm)
Definition AMReX_ParallelReduce.H:204
void Max(KeyValuePair< K, V > &vi, MPI_Comm comm)
Definition AMReX_ParallelReduce.H:126
bool IOProcessor() noexcept
Is this CPU the I/O Processor? To get the rank number, call IOProcessorNumber()
Definition AMReX_ParallelDescriptor.H:275
Definition AMReX_Amr.cpp:49
MF::value_type norminf(MF const &mf, int scomp, int ncomp, IntVect const &nghost, bool local=false)
Definition AMReX_FabArrayUtility.H:1977
int nComp(FabArrayBase const &fa)
void Saxpy(MF &dst, typename MF::value_type a, MF const &src, int scomp, int dcomp, int ncomp, IntVect const &nghost)
dst += a * src
Definition AMReX_FabArrayUtility.H:1914
IntVect nGrowVect(FabArrayBase const &fa)
void LocalCopy(DMF &dst, SMF const &src, int scomp, int dcomp, int ncomp, IntVect const &nghost)
dst = src
Definition AMReX_FabArrayUtility.H:1898
void LocalAdd(MF &dst, MF const &src, int scomp, int dcomp, int ncomp, IntVect const &nghost)
dst += src
Definition AMReX_FabArrayUtility.H:1906
void Saxpy_Saxpy(MF &dst1, typename MF::value_type a1, MF const &src1, MF &dst2, typename MF::value_type a2, MF const &src2, int scomp, int dcomp, int ncomp, IntVect const &nghost)
dst1 += a1 * src1 followed by dst2 += a2 * src2
Definition AMReX_FabArrayUtility.H:1939
void Saypy_Saxpy(MF &dst1, typename MF::value_type a1, MF &dst2, typename MF::value_type a2, MF const &src, int scomp, int dcomp, int ncomp, IntVect const &nghost)
dst1 += a1 * dst2 followed by dst2 += a2 * src
Definition AMReX_FabArrayUtility.H:1948
IntVectND< AMREX_SPACEDIM > IntVect
Definition AMReX_BaseFwd.H:30
void Saxpy_Xpay(MF &dst, typename MF::value_type a_saxpy, MF const &src_saxpy, typename MF::value_type a_xpay, MF const &src_xpay, int scomp, int dcomp, int ncomp, IntVect const &nghost)
dst += a_saxpy * src_saxpy followed by dst = src_xpay + a_xpay * dst
Definition AMReX_FabArrayUtility.H:1930
void Xpay(MF &dst, typename MF::value_type a, MF const &src, int scomp, int dcomp, int ncomp, IntVect const &nghost)
dst = src + a * dst
Definition AMReX_FabArrayUtility.H:1922
void Warning(const std::string &msg)
Print out warning message to cerr.
Definition AMReX.cpp:236
int verbose
Definition AMReX_DistributionMapping.cpp:36
void setVal(MF &dst, typename MF::value_type val)
dst = val
Definition AMReX_FabArrayUtility.H:1875