Block-Structured AMR Software Framework
 
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages
Loading...
Searching...
No Matches
AMReX_LUSolver.H
Go to the documentation of this file.
1#ifndef AMREX_LU_SOLVER_H_
2#define AMREX_LU_SOLVER_H_
3#include <AMReX_Config.H>
4
5#include <AMReX_Algorithm.H>
6#include <AMReX_Array.H>
7#include <cmath>
8#include <limits>
9
10namespace amrex {
11
12// https://en.wikipedia.org/wiki/LU_decomposition
13
14template <int N, typename T>
16{
17public:
18
19 LUSolver () = default;
20
23
25
27 void operator() (T* AMREX_RESTRICT x, T const* AMREX_RESTRICT b) const
28 {
29 for (int i = 0; i < N; ++i) {
30 x[i] = b[m_piv(i)];
31 for (int k = 0; k < i; ++k) {
32 x[i] -= m_mat(i,k) * x[k];
33 }
34 }
35
36 for (int i = N-1; i >= 0; --i) {
37 for (int k = i+1; k < N; ++k) {
38 x[i] -= m_mat(i,k) * x[k];
39 }
40 x[i] *= m_mat(i,i);
41 }
42 }
43
44 [[nodiscard]] AMREX_GPU_HOST_DEVICE
45 Array2D<T,0,N-1,0,N-1,Order::C> invert () const
46 {
47 Array2D<T,0,N-1,0,N-1,Order::C> IA;
48 for (int j = 0; j < N; ++j) {
49 for (int i = 0; i < N; ++i) {
50 IA(i,j) = (m_piv(i) == j) ? T(1.0) : T(0.0);
51 for (int k = 0; k < i; ++k) {
52 IA(i,j) -= m_mat(i,k) * IA(k,j);
53 }
54 }
55 for (int i = N-1; i >= 0; --i) {
56 for (int k = i+1; k < N; ++k) {
57 IA(i,j) -= m_mat(i,k) * IA(k,j);
58 }
59 IA(i,j) *= m_mat(i,i);
60 }
61 }
62 return IA;
63 }
64
65 [[nodiscard]] AMREX_GPU_HOST_DEVICE
66 T determinant () const
67 {
68 T det = m_mat(0,0);
69 for (int i = 1; i < N; ++i) {
70 det *= m_mat(i,i);
71 }
72 det = T(1.0) / det;
73 return (m_npivs % 2 == 0) ? det : -det;
74 }
75
76private:
77
80
81 Array2D<T, 0, N-1, 0, N-1, Order::C> m_mat;
82 Array1D<int, 0, N-1> m_piv;
83 int m_npivs = 0;
84};
85
86template <int N, typename T>
89 : m_mat(a_mat)
90{
92}
93
94template <int N, typename T>
96{
97 m_mat = a_mat;
98 define_innard();
99}
100
101template <int N, typename T>
104{
105 static_assert(N > 1);
106 static_assert(std::is_floating_point_v<T>);
107
108 for (int i = 0; i < N; ++i) { m_piv(i) = i; }
109 m_npivs = 0;
110
111 for (int i = 0; i < N; ++i) {
112 T maxA = 0;
113 int imax = i;
114
115 for (int k = i; k < N; ++k) {
116 auto const absA = std::abs(m_mat(k,i));
117 if (absA > maxA) {
118 maxA = absA;
119 imax = k;
120 }
121 }
122
123 if (maxA < std::numeric_limits<T>::min()) {
124 amrex::Abort("LUSolver: matrix is degenerate");
125 }
126
127 if (imax != i) {
128 amrex::Swap(m_piv(i), m_piv(imax));
129 for (int j = 0; j < N; ++j) {
130 amrex::Swap(m_mat(i,j), m_mat(imax,j));
131 }
132 ++m_npivs;
133 }
134
135 for (int j = i+1; j < N; ++j) {
136 m_mat(j,i) /= m_mat(i,i);
137 for (int k = i+1; k < N; ++k) {
138 m_mat(j,k) -= m_mat(j,i) * m_mat(i,k);
139 }
140 }
141 }
142
143 for (int i = 0; i < N; ++i) {
144 m_mat(i,i) = T(1) / m_mat(i,i);
145 }
146}
147
148}
149
150#endif
#define AMREX_FORCE_INLINE
Definition AMReX_Extension.H:119
#define AMREX_RESTRICT
Definition AMReX_Extension.H:37
#define AMREX_GPU_HOST_DEVICE
Definition AMReX_GpuQualifiers.H:20
Definition AMReX_LUSolver.H:16
int m_npivs
Definition AMReX_LUSolver.H:83
AMREX_GPU_HOST_DEVICE AMREX_FORCE_INLINE LUSolver(Array2D< T, 0, N-1, 0, N-1, Order::C > const &a_mat)
Definition AMReX_LUSolver.H:88
AMREX_GPU_HOST_DEVICE T determinant() const
Definition AMReX_LUSolver.H:66
AMREX_GPU_HOST_DEVICE AMREX_FORCE_INLINE void define_innard()
Definition AMReX_LUSolver.H:103
Array2D< T, 0, N-1, 0, N-1, Order::C > m_mat
Definition AMReX_LUSolver.H:81
Array1D< int, 0, N-1 > m_piv
Definition AMReX_LUSolver.H:82
AMREX_GPU_HOST_DEVICE AMREX_FORCE_INLINE void operator()(T *AMREX_RESTRICT x, T const *AMREX_RESTRICT b) const
Definition AMReX_LUSolver.H:27
void define(Array2D< T, 0, N-1, 0, N-1, Order::C > const &a_mat)
Definition AMReX_LUSolver.H:95
AMREX_GPU_HOST_DEVICE Array2D< T, 0, N-1, 0, N-1, Order::C > invert() const
Definition AMReX_LUSolver.H:45
LUSolver()=default
Definition AMReX_Amr.cpp:49
AMREX_GPU_HOST_DEVICE AMREX_FORCE_INLINE void Swap(T &t1, T &t2) noexcept
Definition AMReX_Algorithm.H:75
void Abort(const std::string &msg)
Print out message to cerr and exit via abort().
Definition AMReX.cpp:230
const int[]
Definition AMReX_BLProfiler.cpp:1664
Definition AMReX_Array.H:161
Definition AMReX_Array.H:282