Block-Structured AMR Software Framework
Loading...
Searching...
No Matches
AMReX_GMRES.H
Go to the documentation of this file.
1#ifndef AMREX_GMRES_H_
2#define AMREX_GMRES_H_
3#include <AMReX_Config.H>
4
5#include <AMReX_BLProfiler.H>
6#include <AMReX_Print.H>
7#include <AMReX_TableData.H>
8#include <AMReX_Vector.H>
9#include <cmath>
10#include <limits>
11#include <memory>
12
13namespace amrex {
14
85template <typename V, typename M>
86class GMRES
87{
88public:
89
90 using RT = typename M::RT; // double or float
91
94
103 void define (M& linop);
104
114 void solve (V& a_sol, V const& a_rhs, RT a_tol_rel, RT a_tol_abs, int a_its=-1);
115
117 void setVerbose (int v) { m_verbose = v; }
118
124 void setRestartLength (int rl);
125
131 void setMaxIters (int niters) { m_maxiter = niters; }
132
134 [[nodiscard]] int getNumIters () const { return m_its; }
135
137 [[nodiscard]] int getStatus () const { return m_status; }
138
140 [[nodiscard]] RT getResidualNorm () const { return m_res; }
141
142private:
143 void clear ();
144 void allocate_scratch ();
145 void cycle (V& a_xx, int& a_status, int& a_itcount, RT& a_rnorm0);
146 void build_solution (V& a_xx, int it);
147 void compute_residual (V& a_rr, V const& a_xx, V const& a_bb);
148
149 bool converged (RT r0, RT r) const;
150
151 void gram_schmidt_orthogonalization (int it);
152 void update_hessenberg (int it, bool happyend, RT& res);
153
154 int m_verbose = 0;
155 int m_maxiter = 2000;
156 int m_its = 0;
157 int m_status = -1;
158 int m_restrtlen = 30;
159 RT m_res = std::numeric_limits<RT>::max();
160 RT m_rtol = RT(0);
161 RT m_atol = RT(0);
162 Vector<RT> m_hh_1d;
163 Vector<RT> m_hes_1d;
164 Table2D<RT> m_hh;
165 Table2D<RT> m_hes;
166 Vector<RT> m_grs;
167 Vector<RT> m_cc;
168 Vector<RT> m_ss;
169 std::unique_ptr<V> m_v_tmp_rhs;
170 std::unique_ptr<V> m_v_tmp_lhs;
171 Vector<V> m_vv;
172 M* m_linop = nullptr;
173};
174
175template <typename V, typename M>
177{
178 allocate_scratch();
179}
180
181template <typename V, typename M>
183{
184 int rs = m_restrtlen;
185
186 m_hh_1d.resize(std::size_t(rs + 2) * (rs + 1));
187 m_hh = Table2D<RT>(m_hh_1d.data(), {0,0}, {rs+1,rs}); // (0:rs+1,0:rs)
188
189 m_hes_1d.resize(std::size_t(rs + 2) * (rs + 1));
190 m_hes = Table2D<RT>(m_hes_1d.data(), {0,0}, {rs+1,rs}); // (0:rs+1,0:rs)
191
192 m_grs.resize(rs + 2);
193 m_cc.resize(rs + 1);
194 m_ss.resize(rs + 1);
195}
196
197template <typename V, typename M>
199{
200 if (m_restrtlen != rl) {
201 m_restrtlen = rl;
202 allocate_scratch();
203 m_vv.clear();
204 }
205}
206
207template <typename V, typename M>
208void GMRES<V,M>::define (M& linop)
209{
210 clear();
211 m_linop = &linop;
212}
213
214template <typename V, typename M>
215void GMRES<V,M>::clear ()
216{
217 m_its = 0;
218 m_status = -1;
219 m_res = std::numeric_limits<RT>::max();
220 m_v_tmp_rhs.reset();
221 m_v_tmp_lhs.reset();
222 m_vv.clear();
223 m_linop = nullptr;
224}
225
226template <typename V, typename M>
227bool GMRES<V,M>::converged (RT r0, RT r) const
228{
229 return (r < r0*m_rtol) || (r < m_atol);
230}
231
232template <typename V, typename M>
233void GMRES<V,M>::solve (V& a_sol, V const& a_rhs, RT a_tol_rel, RT a_tol_abs, int a_its)
234{
235 BL_PROFILE("GMRES::solve()");
236
237 AMREX_ALWAYS_ASSERT(m_linop != nullptr);
238
239 auto t0 = amrex::second();
240
241 if (m_v_tmp_rhs == nullptr) {
242 m_v_tmp_rhs = std::make_unique<V>(m_linop->makeVecRHS());
243 }
244 if (m_v_tmp_lhs == nullptr) {
245 m_v_tmp_lhs = std::make_unique<V>(m_linop->makeVecLHS());
246 }
247 if (m_vv.empty()) {
248 m_vv.reserve(m_restrtlen+1);
249 for (int i = 0; i < 2; ++i) { // to save space, start with just 2
250 m_vv.emplace_back(m_linop->makeVecRHS());
251 }
252 }
253
254 m_rtol = a_tol_rel;
255 m_atol = a_tol_abs;
256
257 if (a_its < 0) { a_its = m_maxiter; }
258
259 auto rnorm0 = RT(0);
260
261 m_linop->assign(m_vv[0], a_rhs);
262 m_linop->setToZero(a_sol);
263
264 m_its = 0;
265 m_status = -1;
266 cycle(a_sol, m_status, m_its, rnorm0);
267
268 while (m_status == -1 && m_its < a_its) {
269 compute_residual(m_vv[0], a_sol, a_rhs);
270 cycle(a_sol, m_status, m_its, rnorm0);
271 }
272
273 if (m_status == -1 && m_its >= a_its) { m_status = 1; }
274
275 m_v_tmp_rhs.reset();
276 m_v_tmp_lhs.reset();
277 m_vv.clear();
278
279 auto t1 = amrex::second();
280 if (m_verbose > 0) {
281 amrex::Print() << "GMRES: Solve Time = " << t1-t0 << '\n';
282 }
283}
284
285template <typename V, typename M>
286void GMRES<V,M>::cycle (V& a_xx, int& a_status, int& a_itcount, RT& a_rnorm0)
287{
288 BL_PROFILE("GMRES::cycle()");
289
290 m_res = m_linop->norm2(m_vv[0]);
291 m_grs[0] = m_res;
292
293 if (m_res == RT(0.0)) {
294 a_status = 0;
295 return;
296 }
297
298 m_linop->scale(m_vv[0], RT(1.0)/m_res);
299
300 if (a_itcount == 0) { a_rnorm0 = m_res; }
301
302 a_status = converged(a_rnorm0,m_res) ? 0 : -1;
303
304 int it = 0;
305 while (it < m_restrtlen && a_itcount < m_maxiter)
306 {
307 if (m_verbose > 1) {
308 amrex::Print() << "GMRES: iter = " << a_itcount
309 << ", residual = " << m_res << ", " << m_res/a_rnorm0
310 << " (rel.)\n";
311 }
312
313 if (a_status == 0) { break; }
314
315 while (m_vv.size() < it+2) {
316 m_vv.emplace_back(m_linop->makeVecRHS());
317 }
318
319 auto const& vv_it = m_vv[it ];
320 auto & vv_it1 = m_vv[it+1];
321
322 m_linop->precond(*m_v_tmp_lhs, vv_it);
323 m_linop->apply(vv_it1, *m_v_tmp_lhs);
324
325 gram_schmidt_orthogonalization(it);
326
327 auto tt = m_linop->norm2(vv_it1);
328
329 auto const sml = RT((sizeof(RT) == 8) ? 1.e-99 : 1.e-30);
330 bool happyend = (tt < sml);
331 if (!happyend) {
332 m_linop->scale(vv_it1, RT(1.0)/tt);
333 }
334
335 m_hh (it+1,it) = tt;
336 m_hes(it+1,it) = tt;
337
338 update_hessenberg(it, happyend, m_res);
339
340 ++it;
341 ++a_itcount;
342 a_status = converged(a_rnorm0, m_res) ? 0 : -1;
343 if (happyend) { break; }
344 }
345
346 if ((m_verbose > 1) && (a_status != 0 || a_itcount >= m_maxiter)) {
347 amrex::Print() << "GMRES: iter = " << a_itcount
348 << ", residual = " << m_res << ", " << m_res/a_rnorm0
349 << " (rel.)\n";
350 }
351
352 build_solution(a_xx, it-1);
353}
354
355template <typename V, typename M>
356void GMRES<V,M>::gram_schmidt_orthogonalization (int const it)
357{
358 // Two unmodified Gram-Schmidt Orthogonalization
359
360 BL_PROFILE("GMRES::GramSchmidt");
361
362 auto& vv_1 = m_vv[it+1];
363
364 Vector<RT> lhh(it+1);
365
366 for (int j = 0; j <= it; ++j) {
367 m_hh (j,it) = RT(0.0);
368 m_hes(j,it) = RT(0.0);
369 }
370
371 for (int ncnt = 0; ncnt < 2 ; ++ncnt)
372 {
373 for (int j = 0; j <= it; ++j) {
374 lhh[j] = m_linop->dotProduct(vv_1, m_vv[j]);
375 }
376
377 for (int j = 0; j <= it; ++j) {
378 m_linop->increment(vv_1, m_vv[j], -lhh[j]);
379 m_hh (j,it) += lhh[j];
380 m_hes(j,it) -= lhh[j];
381 }
382 }
383}
384
385template <typename V, typename M>
386void GMRES<V,M>::update_hessenberg (int it, bool happyend, RT& res)
387{
388 BL_PROFILE("GMRES::update_hessenberg()");
389
390 for (int j = 1; j <= it; ++j) {
391 auto tt = m_hh(j-1,it);
392 m_hh(j-1,it) = m_cc[j-1] * tt + m_ss[j-1] * m_hh(j,it);
393 m_hh(j ,it) = m_cc[j-1] * m_hh(j,it) - m_ss[j-1] * tt;
394 }
395
396 if (!happyend)
397 {
398 auto tt = std::sqrt(m_hh(it,it)*m_hh(it,it) + m_hh(it+1,it)*m_hh(it+1,it));
399 m_cc[it] = m_hh(it ,it) / tt;
400 m_ss[it] = m_hh(it+1,it) / tt;
401 m_grs[it+1] = - (m_ss[it] * m_grs[it]);
402 m_grs[it ] = m_cc[it] * m_grs[it];
403 m_hh(it,it) = m_cc[it] * m_hh(it,it) + m_ss[it] * m_hh(it+1,it);
404 res = std::abs(m_grs[it+1]);
405 }
406 else
407 {
408 res = RT(0.0);
409 }
410}
411
412template <typename V, typename M>
413void GMRES<V,M>::build_solution (V& a_xx, int const it)
414{
415 BL_PROFILE("GMRES:build_solution()");
416
417 if (it < 0) { return; }
418
419 if (m_hh(it,it) != RT(0.0)) {
420 m_grs[it] /= m_hh(it,it);
421 } else {
422 m_grs[it] = RT(0.0);
423 }
424
425 for (int ii = 1; ii <= it; ++ii) {
426 int k = it - ii;
427 auto tt = m_grs[k];
428 for (int j = k+1; j <= it; ++j) {
429 tt -= m_hh(k,j) * m_grs[j];
430 }
431 m_grs[k] = tt / m_hh(k,k);
432 }
433
434 m_linop->setToZero(*m_v_tmp_rhs);
435 for (int ii = 0; ii < it+1; ++ii) {
436 m_linop->increment(*m_v_tmp_rhs, m_vv[ii], m_grs[ii]);
437 }
438
439 m_linop->precond(*m_v_tmp_lhs, *m_v_tmp_rhs);
440 m_linop->increment(a_xx, *m_v_tmp_lhs, RT(1.0));
441}
442
443template <typename V, typename M>
444void GMRES<V,M>::compute_residual (V& a_rr, V const& a_xx, V const& a_bb)
445{
446 BL_PROFILE("GMRES::compute_residual()");
447 m_linop->assign(*m_v_tmp_lhs, a_xx);
448 m_linop->apply(*m_v_tmp_rhs, *m_v_tmp_lhs);
449 m_linop->linComb(a_rr, RT(1.0), a_bb, RT(-1.0), *m_v_tmp_rhs);
450}
451
452}
453#endif
#define BL_PROFILE(a)
Definition AMReX_BLProfiler.H:551
#define AMREX_ALWAYS_ASSERT(EX)
Definition AMReX_BLassert.H:50
GMRES.
Definition AMReX_GMRES.H:87
int getNumIters() const
Number of iterations executed by the last solve().
Definition AMReX_GMRES.H:134
void solve(V &a_sol, V const &a_rhs, RT a_tol_rel, RT a_tol_abs, int a_its=-1)
Solve the linear system.
Definition AMReX_GMRES.H:233
void define(M &linop)
Bind the solver to a linear operator.
Definition AMReX_GMRES.H:208
void setVerbose(int v)
Set verbosity level v (0 = silent).
Definition AMReX_GMRES.H:117
void setMaxIters(int niters)
Cap the number of iterations performed by solve().
Definition AMReX_GMRES.H:131
void setRestartLength(int rl)
Set the Krylov restart length.
Definition AMReX_GMRES.H:198
RT getResidualNorm() const
Final residual 2-norm from the last solve().
Definition AMReX_GMRES.H:140
GMRES()
Construct a GMRES solver with the default restart length.
Definition AMReX_GMRES.H:176
int getStatus() const
Status flag from the last solve() (0 success, >0 failure).
Definition AMReX_GMRES.H:137
typename M::RT RT
Definition AMReX_GMRES.H:90
This class provides the user with a few print options.
Definition AMReX_Print.H:35
This class is a thin wrapper around std::vector. Unlike vector, Vector::operator[] provides bound che...
Definition AMReX_Vector.H:28
Definition AMReX_Amr.cpp:49
double second() noexcept
Definition AMReX_Utility.cpp:940
Definition AMReX_TableData.H:93