Block-Structured AMR Software Framework
 
Loading...
Searching...
No Matches
AMReX_GMRES.H
Go to the documentation of this file.
1#ifndef AMREX_GMRES_H_
2#define AMREX_GMRES_H_
3#include <AMReX_Config.H>
4
5#include <AMReX_BLProfiler.H>
6#include <AMReX_Print.H>
7#include <AMReX_TableData.H>
8#include <AMReX_Vector.H>
9#include <cmath>
10#include <limits>
11#include <memory>
12
13namespace amrex {
14
79template <typename V, typename M>
80class GMRES
81{
82public:
83
84 using RT = typename M::RT; // double or float
85
87
91 void define (M& linop);
92
102 void solve (V& a_sol, V const& a_rhs, RT a_tol_rel, RT a_tol_abs, int a_its=-1);
103
105 void setVerbose (int v) { m_verbose = v; }
106
108 void setRestartLength (int rl);
109
111 void setMaxIters (int niters) { m_maxiter = niters; }
112
114 [[nodiscard]] int getNumIters () const { return m_its; }
115
117 [[nodiscard]] int getStatus () const { return m_status; }
118
120 [[nodiscard]] RT getResidualNorm () const { return m_res; }
121
122private:
123 void clear ();
125 void cycle (V& a_xx, int& a_status, int& a_itcount, RT& a_rnorm0);
126 void build_solution (V& a_xx, int it);
127 void compute_residual (V& a_rr, V const& a_xx, V const& a_bb);
128
129 bool converged (RT r0, RT r) const;
130
132 void update_hessenberg (int it, bool happyend, RT& res);
133
134 int m_verbose = 0;
135 int m_maxiter = 2000;
136 int m_its = 0;
137 int m_status = -1;
138 int m_restrtlen = 30;
139 RT m_res = std::numeric_limits<RT>::max();
149 std::unique_ptr<V> m_v_tmp_rhs;
150 std::unique_ptr<V> m_v_tmp_lhs;
152 M* m_linop = nullptr;
153};
154
155template <typename V, typename M>
157{
158 allocate_scratch();
159}
160
161template <typename V, typename M>
163{
164 int rs = m_restrtlen;
165
166 m_hh_1d.resize(std::size_t(rs + 2) * (rs + 1));
167 m_hh = Table2D<RT>(m_hh_1d.data(), {0,0}, {rs+1,rs}); // (0:rs+1,0:rs)
168
169 m_hes_1d.resize(std::size_t(rs + 2) * (rs + 1));
170 m_hes = Table2D<RT>(m_hes_1d.data(), {0,0}, {rs+1,rs}); // (0:rs+1,0:rs)
171
172 m_grs.resize(rs + 2);
173 m_cc.resize(rs + 1);
174 m_ss.resize(rs + 1);
175}
176
177template <typename V, typename M>
179{
180 if (m_restrtlen != rl) {
181 m_restrtlen = rl;
182 allocate_scratch();
183 m_vv.clear();
184 }
185}
186
187template <typename V, typename M>
188void GMRES<V,M>::define (M& linop)
189{
190 clear();
191 m_linop = &linop;
192}
193
194template <typename V, typename M>
196{
197 m_its = 0;
198 m_status = -1;
199 m_res = std::numeric_limits<RT>::max();
200 m_v_tmp_rhs.reset();
201 m_v_tmp_lhs.reset();
202 m_vv.clear();
203 m_linop = nullptr;
204}
205
206template <typename V, typename M>
207bool GMRES<V,M>::converged (RT r0, RT r) const
208{
209 return (r < r0*m_rtol) || (r < m_atol);
210}
211
212template <typename V, typename M>
213void GMRES<V,M>::solve (V& a_sol, V const& a_rhs, RT a_tol_rel, RT a_tol_abs, int a_its)
214{
215 BL_PROFILE("GMRES::solve()");
216
217 AMREX_ALWAYS_ASSERT(m_linop != nullptr);
218
219 auto t0 = amrex::second();
220
221 if (m_v_tmp_rhs == nullptr) {
222 m_v_tmp_rhs = std::make_unique<V>(m_linop->makeVecRHS());
223 }
224 if (m_v_tmp_lhs == nullptr) {
225 m_v_tmp_lhs = std::make_unique<V>(m_linop->makeVecLHS());
226 }
227 if (m_vv.empty()) {
228 m_vv.reserve(m_restrtlen+1);
229 for (int i = 0; i < 2; ++i) { // to save space, start with just 2
230 m_vv.emplace_back(m_linop->makeVecRHS());
231 }
232 }
233
234 m_rtol = a_tol_rel;
235 m_atol = a_tol_abs;
236
237 if (a_its < 0) { a_its = m_maxiter; }
238
239 auto rnorm0 = RT(0);
240
241 m_linop->assign(m_vv[0], a_rhs);
242 m_linop->setToZero(a_sol);
243
244 m_its = 0;
245 m_status = -1;
246 cycle(a_sol, m_status, m_its, rnorm0);
247
248 while (m_status == -1 && m_its < a_its) {
249 compute_residual(m_vv[0], a_sol, a_rhs);
250 cycle(a_sol, m_status, m_its, rnorm0);
251 }
252
253 if (m_status == -1 && m_its >= a_its) { m_status = 1; }
254
255 m_v_tmp_rhs.reset();
256 m_v_tmp_lhs.reset();
257 m_vv.clear();
258
259 auto t1 = amrex::second();
260 if (m_verbose > 0) {
261 amrex::Print() << "GMRES: Solve Time = " << t1-t0 << '\n';
262 }
263}
264
265template <typename V, typename M>
266void GMRES<V,M>::cycle (V& a_xx, int& a_status, int& a_itcount, RT& a_rnorm0)
267{
268 BL_PROFILE("GMRES::cycle()");
269
270 m_res = m_linop->norm2(m_vv[0]);
271 m_grs[0] = m_res;
272
273 if (m_res == RT(0.0)) {
274 a_status = 0;
275 return;
276 }
277
278 m_linop->scale(m_vv[0], RT(1.0)/m_res);
279
280 if (a_itcount == 0) { a_rnorm0 = m_res; }
281
282 a_status = converged(a_rnorm0,m_res) ? 0 : -1;
283
284 int it = 0;
285 while (it < m_restrtlen && a_itcount < m_maxiter)
286 {
287 if (m_verbose > 1) {
288 amrex::Print() << "GMRES: iter = " << a_itcount
289 << ", residual = " << m_res << ", " << m_res/a_rnorm0
290 << " (rel.)\n";
291 }
292
293 if (a_status == 0) { break; }
294
295 while (m_vv.size() < it+2) {
296 m_vv.emplace_back(m_linop->makeVecRHS());
297 }
298
299 auto const& vv_it = m_vv[it ];
300 auto & vv_it1 = m_vv[it+1];
301
302 m_linop->precond(*m_v_tmp_lhs, vv_it);
303 m_linop->apply(vv_it1, *m_v_tmp_lhs);
304
305 gram_schmidt_orthogonalization(it);
306
307 auto tt = m_linop->norm2(vv_it1);
308
309 auto const sml = RT((sizeof(RT) == 8) ? 1.e-99 : 1.e-30);
310 bool happyend = (tt < sml);
311 if (!happyend) {
312 m_linop->scale(vv_it1, RT(1.0)/tt);
313 }
314
315 m_hh (it+1,it) = tt;
316 m_hes(it+1,it) = tt;
317
318 update_hessenberg(it, happyend, m_res);
319
320 ++it;
321 ++a_itcount;
322 a_status = converged(a_rnorm0, m_res) ? 0 : -1;
323 if (happyend) { break; }
324 }
325
326 if ((m_verbose > 1) && (a_status != 0 || a_itcount >= m_maxiter)) {
327 amrex::Print() << "GMRES: iter = " << a_itcount
328 << ", residual = " << m_res << ", " << m_res/a_rnorm0
329 << " (rel.)\n";
330 }
331
332 build_solution(a_xx, it-1);
333}
334
335template <typename V, typename M>
337{
338 // Two unmodified Gram-Schmidt Orthogonalization
339
340 BL_PROFILE("GMRES::GramSchmidt");
341
342 auto& vv_1 = m_vv[it+1];
343
344 Vector<RT> lhh(it+1);
345
346 for (int j = 0; j <= it; ++j) {
347 m_hh (j,it) = RT(0.0);
348 m_hes(j,it) = RT(0.0);
349 }
350
351 for (int ncnt = 0; ncnt < 2 ; ++ncnt)
352 {
353 for (int j = 0; j <= it; ++j) {
354 lhh[j] = m_linop->dotProduct(vv_1, m_vv[j]);
355 }
356
357 for (int j = 0; j <= it; ++j) {
358 m_linop->increment(vv_1, m_vv[j], -lhh[j]);
359 m_hh (j,it) += lhh[j];
360 m_hes(j,it) -= lhh[j];
361 }
362 }
363}
364
365template <typename V, typename M>
366void GMRES<V,M>::update_hessenberg (int it, bool happyend, RT& res)
367{
368 BL_PROFILE("GMRES::update_hessenberg()");
369
370 for (int j = 1; j <= it; ++j) {
371 auto tt = m_hh(j-1,it);
372 m_hh(j-1,it) = m_cc[j-1] * tt + m_ss[j-1] * m_hh(j,it);
373 m_hh(j ,it) = m_cc[j-1] * m_hh(j,it) - m_ss[j-1] * tt;
374 }
375
376 if (!happyend)
377 {
378 auto tt = std::sqrt(m_hh(it,it)*m_hh(it,it) + m_hh(it+1,it)*m_hh(it+1,it));
379 m_cc[it] = m_hh(it ,it) / tt;
380 m_ss[it] = m_hh(it+1,it) / tt;
381 m_grs[it+1] = - (m_ss[it] * m_grs[it]);
382 m_grs[it ] = m_cc[it] * m_grs[it];
383 m_hh(it,it) = m_cc[it] * m_hh(it,it) + m_ss[it] * m_hh(it+1,it);
384 res = std::abs(m_grs[it+1]);
385 }
386 else
387 {
388 res = RT(0.0);
389 }
390}
391
392template <typename V, typename M>
393void GMRES<V,M>::build_solution (V& a_xx, int const it)
394{
395 BL_PROFILE("GMRES:build_solution()");
396
397 if (it < 0) { return; }
398
399 if (m_hh(it,it) != RT(0.0)) {
400 m_grs[it] /= m_hh(it,it);
401 } else {
402 m_grs[it] = RT(0.0);
403 }
404
405 for (int ii = 1; ii <= it; ++ii) {
406 int k = it - ii;
407 auto tt = m_grs[k];
408 for (int j = k+1; j <= it; ++j) {
409 tt -= m_hh(k,j) * m_grs[j];
410 }
411 m_grs[k] = tt / m_hh(k,k);
412 }
413
414 m_linop->setToZero(*m_v_tmp_rhs);
415 for (int ii = 0; ii < it+1; ++ii) {
416 m_linop->increment(*m_v_tmp_rhs, m_vv[ii], m_grs[ii]);
417 }
418
419 m_linop->precond(*m_v_tmp_lhs, *m_v_tmp_rhs);
420 m_linop->increment(a_xx, *m_v_tmp_lhs, RT(1.0));
421}
422
423template <typename V, typename M>
424void GMRES<V,M>::compute_residual (V& a_rr, V const& a_xx, V const& a_bb)
425{
426 BL_PROFILE("GMRES::compute_residual()");
427 m_linop->assign(*m_v_tmp_lhs, a_xx);
428 m_linop->apply(*m_v_tmp_rhs, *m_v_tmp_lhs);
429 m_linop->linComb(a_rr, RT(1.0), a_bb, RT(-1.0), *m_v_tmp_rhs);
430}
431
432}
433#endif
#define BL_PROFILE(a)
Definition AMReX_BLProfiler.H:551
#define AMREX_ALWAYS_ASSERT(EX)
Definition AMReX_BLassert.H:50
GMRES.
Definition AMReX_GMRES.H:81
int getNumIters() const
Gets the number of iterations.
Definition AMReX_GMRES.H:114
void solve(V &a_sol, V const &a_rhs, RT a_tol_rel, RT a_tol_abs, int a_its=-1)
Solve the linear system.
Definition AMReX_GMRES.H:213
int m_status
Definition AMReX_GMRES.H:137
void build_solution(V &a_xx, int it)
Definition AMReX_GMRES.H:393
void allocate_scratch()
Definition AMReX_GMRES.H:162
int m_verbose
Definition AMReX_GMRES.H:134
void define(M &linop)
Definition AMReX_GMRES.H:188
RT m_rtol
Definition AMReX_GMRES.H:140
bool converged(RT r0, RT r) const
Definition AMReX_GMRES.H:207
Vector< RT > m_hh_1d
Definition AMReX_GMRES.H:142
RT m_atol
Definition AMReX_GMRES.H:141
int m_maxiter
Definition AMReX_GMRES.H:135
void cycle(V &a_xx, int &a_status, int &a_itcount, RT &a_rnorm0)
Definition AMReX_GMRES.H:266
Table2D< RT > m_hes
Definition AMReX_GMRES.H:145
Table2D< RT > m_hh
Definition AMReX_GMRES.H:144
void setVerbose(int v)
Sets verbosity.
Definition AMReX_GMRES.H:105
void compute_residual(V &a_rr, V const &a_xx, V const &a_bb)
Definition AMReX_GMRES.H:424
Vector< RT > m_cc
Definition AMReX_GMRES.H:147
M * m_linop
Definition AMReX_GMRES.H:152
int m_its
Definition AMReX_GMRES.H:136
int m_restrtlen
Definition AMReX_GMRES.H:138
void setMaxIters(int niters)
Sets the max number of iterations.
Definition AMReX_GMRES.H:111
std::unique_ptr< V > m_v_tmp_lhs
Definition AMReX_GMRES.H:150
Vector< RT > m_hes_1d
Definition AMReX_GMRES.H:143
void setRestartLength(int rl)
Sets restart length. The default is 30.
Definition AMReX_GMRES.H:178
void update_hessenberg(int it, bool happyend, RT &res)
Definition AMReX_GMRES.H:366
void gram_schmidt_orthogonalization(int it)
Definition AMReX_GMRES.H:336
RT getResidualNorm() const
Gets the 2-norm of the residual.
Definition AMReX_GMRES.H:120
GMRES()
Definition AMReX_GMRES.H:156
Vector< V > m_vv
Definition AMReX_GMRES.H:151
Vector< RT > m_grs
Definition AMReX_GMRES.H:146
Vector< RT > m_ss
Definition AMReX_GMRES.H:148
void clear()
Definition AMReX_GMRES.H:195
int getStatus() const
Gets the solver status.
Definition AMReX_GMRES.H:117
std::unique_ptr< V > m_v_tmp_rhs
Definition AMReX_GMRES.H:149
RT m_res
Definition AMReX_GMRES.H:139
typename M::RT RT
Definition AMReX_GMRES.H:84
This class provides the user with a few print options.
Definition AMReX_Print.H:35
This class is a thin wrapper around std::vector. Unlike vector, Vector::operator[] provides bound che...
Definition AMReX_Vector.H:28
Definition AMReX_Amr.cpp:49
double second() noexcept
Definition AMReX_Utility.cpp:940
Definition AMReX_TableData.H:93