Block-Structured AMR Software Framework
AMReX_LUSolver.H
Go to the documentation of this file.
1 #ifndef AMREX_LU_SOLVER_H_
2 #define AMREX_LU_SOLVER_H_
3 #include <AMReX_Config.H>
4 
5 #include <AMReX_Algorithm.H>
6 #include <AMReX_Array.H>
7 #include <cmath>
8 #include <limits>
9 
10 namespace amrex {
11 
12 // https://en.wikipedia.org/wiki/LU_decomposition
13 
14 template <int N, typename T>
15 class LUSolver
16 {
17 public:
18 
19  LUSolver () = default;
20 
23 
24  void define (Array2D<T, 0, N-1, 0, N-1, Order::C> const& a_mat);
25 
27  void operator() (T* AMREX_RESTRICT x, T const* AMREX_RESTRICT b) const
28  {
29  for (int i = 0; i < N; ++i) {
30  x[i] = b[m_piv(i)];
31  for (int k = 0; k < i; ++k) {
32  x[i] -= m_mat(i,k) * x[k];
33  }
34  }
35 
36  for (int i = N-1; i >= 0; --i) {
37  for (int k = i+1; k < N; ++k) {
38  x[i] -= m_mat(i,k) * x[k];
39  }
40  x[i] *= m_mat(i,i);
41  }
42  }
43 
44  [[nodiscard]] AMREX_GPU_HOST_DEVICE
45  Array2D<T,0,N-1,0,N-1,Order::C> invert () const
46  {
47  Array2D<T,0,N-1,0,N-1,Order::C> IA;
48  for (int j = 0; j < N; ++j) {
49  for (int i = 0; i < N; ++i) {
50  IA(i,j) = (m_piv(i) == j) ? T(1.0) : T(0.0);
51  for (int k = 0; k < i; ++k) {
52  IA(i,j) -= m_mat(i,k) * IA(k,j);
53  }
54  }
55  for (int i = N-1; i >= 0; --i) {
56  for (int k = i+1; k < N; ++k) {
57  IA(i,j) -= m_mat(i,k) * IA(k,j);
58  }
59  IA(i,j) *= m_mat(i,i);
60  }
61  }
62  return IA;
63  }
64 
65  [[nodiscard]] AMREX_GPU_HOST_DEVICE
66  T determinant () const
67  {
68  T det = m_mat(0,0);
69  for (int i = 1; i < N; ++i) {
70  det *= m_mat(i,i);
71  }
72  det = T(1.0) / det;
73  return (m_npivs % 2 == 0) ? det : -det;
74  }
75 
76 private:
77 
79  void define_innard ();
80 
81  Array2D<T, 0, N-1, 0, N-1, Order::C> m_mat;
82  Array1D<int, 0, N-1> m_piv;
83  int m_npivs = 0;
84 };
85 
86 template <int N, typename T>
89  : m_mat(a_mat)
90 {
91  define_innard();
92 }
93 
94 template <int N, typename T>
96 {
97  m_mat = a_mat;
98  define_innard();
99 }
100 
101 template <int N, typename T>
104 {
105  static_assert(N > 1);
106  static_assert(std::is_floating_point_v<T>);
107 
108  for (int i = 0; i < N; ++i) { m_piv(i) = i; }
109  m_npivs = 0;
110 
111  for (int i = 0; i < N; ++i) {
112  T maxA = 0;
113  int imax = i;
114 
115  for (int k = i; k < N; ++k) {
116  auto const absA = std::abs(m_mat(k,i));
117  if (absA > maxA) {
118  maxA = absA;
119  imax = k;
120  }
121  }
122 
123  if (maxA < std::numeric_limits<T>::min()) {
124  amrex::Abort("LUSolver: matrix is degenerate");
125  }
126 
127  if (imax != i) {
128  amrex::Swap(m_piv(i), m_piv(imax));
129  for (int j = 0; j < N; ++j) {
130  amrex::Swap(m_mat(i,j), m_mat(imax,j));
131  }
132  ++m_npivs;
133  }
134 
135  for (int j = i+1; j < N; ++j) {
136  m_mat(j,i) /= m_mat(i,i);
137  for (int k = i+1; k < N; ++k) {
138  m_mat(j,k) -= m_mat(j,i) * m_mat(i,k);
139  }
140  }
141  }
142 
143  for (int i = 0; i < N; ++i) {
144  m_mat(i,i) = T(1) / m_mat(i,i);
145  }
146 }
147 
148 }
149 
150 #endif
#define AMREX_FORCE_INLINE
Definition: AMReX_Extension.H:119
#define AMREX_RESTRICT
Definition: AMReX_Extension.H:37
#define AMREX_GPU_HOST_DEVICE
Definition: AMReX_GpuQualifiers.H:20
Definition: AMReX_LUSolver.H:16
int m_npivs
Definition: AMReX_LUSolver.H:83
AMREX_GPU_HOST_DEVICE T determinant() const
Definition: AMReX_LUSolver.H:66
AMREX_GPU_HOST_DEVICE AMREX_FORCE_INLINE void define_innard()
Definition: AMReX_LUSolver.H:103
Array2D< T, 0, N-1, 0, N-1, Order::C > m_mat
Definition: AMReX_LUSolver.H:81
Array1D< int, 0, N-1 > m_piv
Definition: AMReX_LUSolver.H:82
AMREX_GPU_HOST_DEVICE AMREX_FORCE_INLINE void operator()(T *AMREX_RESTRICT x, T const *AMREX_RESTRICT b) const
Definition: AMReX_LUSolver.H:27
AMREX_GPU_HOST_DEVICE Array2D< T, 0, N-1, 0, N-1, Order::C > invert() const
Definition: AMReX_LUSolver.H:45
void define(Array2D< T, 0, N-1, 0, N-1, Order::C > const &a_mat)
Definition: AMReX_LUSolver.H:95
LUSolver()=default
@ min
Definition: AMReX_ParallelReduce.H:18
Definition: AMReX_Amr.cpp:49
AMREX_GPU_HOST_DEVICE AMREX_FORCE_INLINE T abs(const GpuComplex< T > &a_z) noexcept
Return the absolute value of a complex number.
Definition: AMReX_GpuComplex.H:356
AMREX_GPU_HOST_DEVICE AMREX_FORCE_INLINE void Swap(T &t1, T &t2) noexcept
Definition: AMReX_Algorithm.H:75
void Abort(const std::string &msg)
Print out message to cerr and exit via abort().
Definition: AMReX.cpp:225
const int[]
Definition: AMReX_BLProfiler.cpp:1664
Definition: AMReX_Array.H:161