3 #include <AMReX_Config.H>
59 template <
typename V,
typename M>
64 using RT =
typename M::RT;
82 void solve (V& a_sol, V
const& a_rhs,
RT a_tol_rel,
RT a_tol_abs,
int a_its=-1);
105 void cycle (V& a_xx,
int& a_status,
int& a_itcount,
RT& a_rnorm0);
135 template <
typename V,
typename M>
141 template <
typename V,
typename M>
144 int rs = m_restrtlen;
146 m_hh_1d.resize(std::size_t(rs + 2) * (rs + 1));
147 m_hh =
Table2D<RT>(m_hh_1d.data(), {0,0}, {rs+1,rs});
149 m_hes_1d.resize(std::size_t(rs + 2) * (rs + 1));
150 m_hes =
Table2D<RT>(m_hes_1d.data(), {0,0}, {rs+1,rs});
152 m_grs.resize(rs + 2);
157 template <
typename V,
typename M>
160 if (m_restrtlen != rl) {
167 template <
typename V,
typename M>
174 template <
typename V,
typename M>
186 template <
typename V,
typename M>
189 return (
r < r0*m_rtol) || (
r < m_atol);
192 template <
typename V,
typename M>
201 if (m_v_tmp_rhs ==
nullptr) {
202 m_v_tmp_rhs = std::make_unique<V>(m_linop->makeVecRHS());
204 if (m_v_tmp_lhs ==
nullptr) {
205 m_v_tmp_lhs = std::make_unique<V>(m_linop->makeVecLHS());
208 m_vv.reserve(m_restrtlen+1);
209 for (
int i = 0; i < 2; ++i) {
210 m_vv.emplace_back(m_linop->makeVecRHS());
217 if (a_its < 0) { a_its = m_maxiter; }
221 m_linop->assign(m_vv[0], a_rhs);
222 m_linop->setToZero(a_sol);
226 cycle(a_sol, m_status, m_its, rnorm0);
228 while (m_status == -1 && m_its < a_its) {
229 compute_residual(m_vv[0], a_sol, a_rhs);
230 cycle(a_sol, m_status, m_its, rnorm0);
233 if (m_status == -1 && m_its >= a_its) { m_status = 1; }
241 amrex::Print() <<
"GMRES: Solve Time = " << t1-t0 <<
'\n';
245 template <
typename V,
typename M>
250 m_res = m_linop->norm2(m_vv[0]);
253 if (m_res ==
RT(0.0)) {
258 m_linop->scale(m_vv[0],
RT(1.0)/m_res);
260 if (a_itcount == 0) { a_rnorm0 = m_res; }
262 a_status = converged(a_rnorm0,m_res) ? 0 : -1;
265 while (it < m_restrtlen && a_itcount < m_maxiter)
269 <<
", residual = " << m_res <<
", " << m_res/a_rnorm0
273 if (a_status == 0) {
break; }
275 while (m_vv.size() < it+2) {
276 m_vv.emplace_back(m_linop->makeVecRHS());
279 auto const& vv_it = m_vv[it ];
280 auto & vv_it1 = m_vv[it+1];
282 m_linop->precond(*m_v_tmp_lhs, vv_it);
283 m_linop->apply(vv_it1, *m_v_tmp_lhs);
285 gram_schmidt_orthogonalization(it);
287 auto tt = m_linop->norm2(vv_it1);
289 auto const small =
RT((
sizeof(
RT) == 8) ? 1.e-99 : 1.e-30);
290 bool happyend = (tt < small);
292 m_linop->scale(vv_it1,
RT(1.0)/tt);
298 update_hessenberg(it, happyend, m_res);
302 a_status = converged(a_rnorm0, m_res) ? 0 : -1;
303 if (happyend) {
break; }
306 if ((m_verbose > 1) && (a_status != 0 || a_itcount >= m_maxiter)) {
308 <<
", residual = " << m_res <<
", " << m_res/a_rnorm0
312 build_solution(a_xx, it-1);
315 template <
typename V,
typename M>
322 auto& vv_1 = m_vv[it+1];
326 for (
int j = 0; j <= it; ++j) {
327 m_hh (j,it) =
RT(0.0);
328 m_hes(j,it) =
RT(0.0);
331 for (
int ncnt = 0; ncnt < 2 ; ++ncnt)
333 for (
int j = 0; j <= it; ++j) {
334 lhh[j] = m_linop->dotProduct(vv_1, m_vv[j]);
337 for (
int j = 0; j <= it; ++j) {
338 m_linop->increment(vv_1, m_vv[j], -lhh[j]);
339 m_hh (j,it) += lhh[j];
340 m_hes(j,it) -= lhh[j];
345 template <
typename V,
typename M>
350 for (
int j = 1; j <= it; ++j) {
351 auto tt = m_hh(j-1,it);
352 m_hh(j-1,it) = m_cc[j-1] * tt + m_ss[j-1] * m_hh(j,it);
353 m_hh(j ,it) = m_cc[j-1] * m_hh(j,it) - m_ss[j-1] * tt;
358 auto tt =
std::sqrt(m_hh(it,it)*m_hh(it,it) + m_hh(it+1,it)*m_hh(it+1,it));
359 m_cc[it] = m_hh(it ,it) / tt;
360 m_ss[it] = m_hh(it+1,it) / tt;
361 m_grs[it+1] = - (m_ss[it] * m_grs[it]);
362 m_grs[it ] = m_cc[it] * m_grs[it];
363 m_hh(it,it) = m_cc[it] * m_hh(it,it) + m_ss[it] * m_hh(it+1,it);
372 template <
typename V,
typename M>
377 if (it < 0) {
return; }
379 if (m_hh(it,it) !=
RT(0.0)) {
380 m_grs[it] /= m_hh(it,it);
385 for (
int ii = 1; ii <= it; ++ii) {
388 for (
int j = k+1; j <= it; ++j) {
389 tt -= m_hh(k,j) * m_grs[j];
391 m_grs[k] = tt / m_hh(k,k);
394 m_linop->setToZero(*m_v_tmp_rhs);
395 for (
int ii = 0; ii < it+1; ++ii) {
396 m_linop->increment(*m_v_tmp_rhs, m_vv[ii], m_grs[ii]);
399 m_linop->precond(*m_v_tmp_lhs, *m_v_tmp_rhs);
400 m_linop->increment(a_xx, *m_v_tmp_lhs,
RT(1.0));
403 template <
typename V,
typename M>
407 m_linop->assign(*m_v_tmp_lhs, a_xx);
408 m_linop->apply(*m_v_tmp_rhs, *m_v_tmp_lhs);
409 m_linop->linComb(a_rr,
RT(1.0), a_bb,
RT(-1.0), *m_v_tmp_rhs);
#define BL_PROFILE(a)
Definition: AMReX_BLProfiler.H:551
#define AMREX_ALWAYS_ASSERT(EX)
Definition: AMReX_BLassert.H:50
GMRES.
Definition: AMReX_GMRES.H:61
int getNumIters() const
Gets the number of iterations.
Definition: AMReX_GMRES.H:94
void solve(V &a_sol, V const &a_rhs, RT a_tol_rel, RT a_tol_abs, int a_its=-1)
Solve the linear system.
Definition: AMReX_GMRES.H:193
int m_status
Definition: AMReX_GMRES.H:117
void build_solution(V &a_xx, int it)
Definition: AMReX_GMRES.H:373
void allocate_scratch()
Definition: AMReX_GMRES.H:142
int m_verbose
Definition: AMReX_GMRES.H:114
void define(M &linop)
Definition: AMReX_GMRES.H:168
RT m_rtol
Definition: AMReX_GMRES.H:120
bool converged(RT r0, RT r) const
Definition: AMReX_GMRES.H:187
Vector< RT > m_hh_1d
Definition: AMReX_GMRES.H:122
RT m_atol
Definition: AMReX_GMRES.H:121
int m_maxiter
Definition: AMReX_GMRES.H:115
void cycle(V &a_xx, int &a_status, int &a_itcount, RT &a_rnorm0)
Definition: AMReX_GMRES.H:246
Table2D< RT > m_hes
Definition: AMReX_GMRES.H:125
Table2D< RT > m_hh
Definition: AMReX_GMRES.H:124
void setVerbose(int v)
Sets verbosity.
Definition: AMReX_GMRES.H:85
void compute_residual(V &a_rr, V const &a_xx, V const &a_bb)
Definition: AMReX_GMRES.H:404
Vector< RT > m_cc
Definition: AMReX_GMRES.H:127
M * m_linop
Definition: AMReX_GMRES.H:132
int m_its
Definition: AMReX_GMRES.H:116
int m_restrtlen
Definition: AMReX_GMRES.H:118
void setMaxIters(int niters)
Sets the max number of iterations.
Definition: AMReX_GMRES.H:91
std::unique_ptr< V > m_v_tmp_lhs
Definition: AMReX_GMRES.H:130
Vector< RT > m_hes_1d
Definition: AMReX_GMRES.H:123
void setRestartLength(int rl)
Sets restart length. The default is 30.
Definition: AMReX_GMRES.H:158
void update_hessenberg(int it, bool happyend, RT &res)
Definition: AMReX_GMRES.H:346
void gram_schmidt_orthogonalization(int it)
Definition: AMReX_GMRES.H:316
RT getResidualNorm() const
Gets the 2-norm of the residual.
Definition: AMReX_GMRES.H:100
GMRES()
Definition: AMReX_GMRES.H:136
Vector< V > m_vv
Definition: AMReX_GMRES.H:131
Vector< RT > m_grs
Definition: AMReX_GMRES.H:126
Vector< RT > m_ss
Definition: AMReX_GMRES.H:128
void clear()
Definition: AMReX_GMRES.H:175
int getStatus() const
Gets the solver status.
Definition: AMReX_GMRES.H:97
std::unique_ptr< V > m_v_tmp_rhs
Definition: AMReX_GMRES.H:129
RT m_res
Definition: AMReX_GMRES.H:119
typename M::RT RT
Definition: AMReX_GMRES.H:64
This class provides the user with a few print options.
Definition: AMReX_Print.H:35
@ max
Definition: AMReX_ParallelReduce.H:17
static constexpr int M
Definition: AMReX_OpenBC.H:13
Definition: AMReX_Amr.cpp:49
AMREX_GPU_HOST_DEVICE AMREX_FORCE_INLINE T abs(const GpuComplex< T > &a_z) noexcept
Return the absolute value of a complex number.
Definition: AMReX_GpuComplex.H:356
double second() noexcept
Definition: AMReX_Utility.cpp:922
AMREX_GPU_HOST_DEVICE AMREX_FORCE_INLINE GpuComplex< T > sqrt(const GpuComplex< T > &a_z) noexcept
Return the square root of a complex number.
Definition: AMReX_GpuComplex.H:373