8 #ifndef AMREX_HYPRE_SOLVER_H_
9 #define AMREX_HYPRE_SOLVER_H_
17 #include "_hypre_utilities.h"
63 template <
class Marker,
class Filler>
72 std::string a_options_namespace =
"hypre");
74 template <class MF, std::enable_if_t<IsFabArray<MF>::value &&
75 std::is_same_v<
typename MF::value_type,
76 HYPRE_Real>,
int> = 0>
89 HYPRE_Real rel_tol, HYPRE_Real abs_tol,
int max_iter);
103 template <
class Marker>
104 #ifdef AMREX_USE_CUDA
105 std::enable_if_t<IsCallable<Marker,int,int,int,int,int>::value>
107 std::enable_if_t<IsCallableR<bool,Marker,int,int,int,int,int>::value>
113 template <
class Filler,
116 HYPRE_Int&, HYPRE_Int*,
117 HYPRE_Real*>::value,
int> FOO = 0>
120 template <class MF, std::enable_if_t<IsFabArray<MF>::value &&
121 std::is_same_v<
typename MF::value_type,
122 HYPRE_Real>,
int> = 0>
126 template <class MF, std::enable_if_t<IsFabArray<MF>::value &&
127 std::is_same_v<
typename MF::value_type,
128 HYPRE_Real>,
int> = 0>
162 HYPRE_IJMatrix
m_A =
nullptr;
163 HYPRE_IJVector
m_b =
nullptr;
164 HYPRE_IJVector
m_x =
nullptr;
168 template <
class Marker,
class Filler>
177 std::string a_options_namespace)
178 : m_nvars (
int(a_index_type.
size())),
179 m_index_type (a_index_type),
183 m_verbose (a_verbose),
184 m_options_namespace(std::move(a_options_namespace))
198 for (
int ivar = 0; ivar <
m_nvars; ++ivar) {
204 nrows_max +=
m_grids[ivar].numPts();
209 "Need to configure Hypre with --enable-bigint");
212 for (
int ivar = 0; ivar <
m_nvars; ++ivar) {
234 if (nrows_allprocs.
size() > 1) {
235 MPI_Allgather(&
m_nrows_proc,
sizeof(HYPRE_Int), MPI_CHAR,
236 nrows_allprocs.data(),
sizeof(HYPRE_Int), MPI_CHAR,
m_comm);
243 HYPRE_Int proc_begin = 0;
244 for (
int i = 0; i < myproc; ++i) {
245 proc_begin += nrows_allprocs[i];
248 HYPRE_Int proc_end = proc_begin;
250 for (
int ivar = 0; ivar <
m_nvars; ++ivar) {
260 HYPRE_Int ilower = proc_begin;
261 HYPRE_Int iupper = proc_end-1;
274 template <
class Marker>
275 #ifdef AMREX_USE_CUDA
276 std::enable_if_t<IsCallable<Marker,int,int,int,int,int>::value>
278 std::enable_if_t<IsCallableR<bool,Marker,int,int,int,int,int>::value>
287 int boxno = mfi.LocalIndex();
289 for (
int ivar = 0; ivar < m_nvars; ++ivar) {
293 m_cell_offset[mfi].resize(npts_tot);
295 int* p_cell_offset = m_cell_offset[mfi].data();
296 for (
int ivar = 0; ivar < m_nvars; ++ivar) {
298 auto const& lid = m_local_id[ivar].array(mfi);
299 auto const& owner = m_owner_mask[ivar]->const_array(mfi);
301 const auto npts =
static_cast<int>(bx.
numPts());
302 int npts_box = amrex::Scan::PrefixSum<int>(npts,
306 int id = (owner ( cell.
x,cell.
y,cell.
z ) &&
307 marker(boxno,cell.
x,cell.
y,cell.
z,ivar)) ? 1 : 0;
308 lid(cell.
x,cell.
y,cell.
z) = id;
314 if (lid(cell.
x,cell.
y,cell.
z)) {
315 lid(cell.
x,cell.
y,cell.
z) = ps;
316 p_cell_offset[ps] =
offset;
318 lid(cell.
x,cell.
y,cell.
z) = std::numeric_limits<int>::lowest();
322 m_nrows_grid[ivar][mfi] = npts_box;
323 npts_tot += npts_box;
324 p_cell_offset += npts_box;
326 m_cell_offset[mfi].resize(npts_tot);
335 int boxno = mfi.LocalIndex();
336 for (
int ivar = 0; ivar < m_nvars; ++ivar) {
338 auto const& lid = m_local_id[ivar].array(mfi);
339 auto const& owner = m_owner_mask[ivar]->const_array(mfi);
343 for (
int k = lo.z; k <= hi.z; ++k) {
344 for (
int j = lo.y; j <= hi.y; ++j) {
345 for (
int i = lo.x; i <= hi.x; ++i) {
346 if (owner(i,j,k) && marker(boxno,i,j,k,ivar)) {
349 lid(i,j,k) = std::numeric_limits<int>::lowest();
352 m_nrows_grid[ivar][mfi] = id;
360 for (
int ivar = 0; ivar < m_nvars; ++ivar) {
361 nrows += m_nrows_grid[ivar][mfi];
363 m_nrows[mfi] = nrows;
364 m_nrows_proc += nrows;
378 using AtomicInt = std::conditional_t<
sizeof(HYPRE_Int) == 4,
379 HYPRE_Int,
unsigned long long>;
383 if constexpr (std::is_same<HYPRE_Int, AtomicInt>()) {
384 for (
int ivar = 0; ivar < m_nvars; ++ivar) {
385 p_global_id.push_back(&(m_global_id[ivar]));
388 for (
int ivar = 0; ivar < m_nvars; ++ivar) {
389 global_id_raii[ivar].define(m_global_id[ivar].
boxArray(),
392 p_global_id.push_back(&(global_id_raii[ivar]));
397 #pragma omp parallel if (Gpu::notInLaunchRegion())
400 auto& rows_vec = m_global_id_vec[mfi];
401 rows_vec.resize(m_nrows[mfi]);
404 for (
int ivar = 0; ivar < m_nvars; ++ivar) {
405 HYPRE_Int
const os = m_id_offset[ivar][mfi];
406 Box bx = mfi.validbox();
409 auto const& lid = m_local_id[ivar].const_array(mfi);
410 HYPRE_Int* rows = rows_vec.data() + nrows;
411 nrows += m_nrows_grid[ivar][mfi];
414 if (lid.contains(i,j,k) && lid(i,j,k) >= 0) {
415 const auto id = lid(i,j,k) + os;
416 rows[lid(i,j,k)] = id;
417 gid(i,j,k) =
static_cast<AtomicInt
>(id);
419 gid(i,j,k) =
static_cast<AtomicInt
>
426 for (
int ivar = 0; ivar < m_nvars; ++ivar) {
428 m_geom.periodicity());
429 p_global_id[ivar]->FillBoundary(m_geom.periodicity());
431 if constexpr (!std::is_same<HYPRE_Int, AtomicInt>()) {
432 auto const& dst = m_global_id[ivar].arrays();
433 auto const& src = p_global_id[ivar]->const_arrays();
437 dst[
b](i,j,k) =
static_cast<HYPRE_Int
>(src[
b](i,j,k));
445 template <
typename T>
451 auto* p_cols_tmp = cols_tmp.
data();
452 auto* p_mat_tmp = mat_tmp.
data();
453 auto const* p_cols = cols.
data();
454 auto const* p_mat = mat.
data();
455 const auto N = Long(cols.
size());
456 Scan::PrefixSum<T>(N,
459 return static_cast<T
>(p_cols[i] >= 0);
463 if (p_cols[i] >= 0) {
464 p_cols_tmp[s] = p_cols[i];
465 p_mat_tmp[s] = p_mat[i];
476 template <
class Filler,
478 Array4<HYPRE_Int const>
const*,
479 HYPRE_Int&, HYPRE_Int*,
480 HYPRE_Real*>::value,
int> FOO>
492 for (
MFIter mfi(m_local_id[0],mfitinfo); mfi.
isValid(); ++mfi)
494 int boxno = mfi.LocalIndex();
495 const HYPRE_Int nrows = m_nrows[mfi];
500 HYPRE_Int* ncols = ncols_vec.
data();
503 cols_vec.
resize(Long(nrows)*MSS, -1);
504 HYPRE_Int* cols = cols_vec.
data();
507 mat_vec.
resize(Long(nrows)*MSS);
508 HYPRE_Real* mat = mat_vec.
data();
511 for (
int ivar = 0; ivar < m_nvars; ++ivar) {
512 gid_v[ivar] = m_global_id[ivar].const_array(mfi);
517 (gid_v.data(), gid_v.
size());
518 auto const* pgid = gid_buf.
data();
519 auto const* p_cell_offset = m_cell_offset[mfi].data();
521 for (
int ivar = 0; ivar < m_nvars; ++ivar) {
522 const HYPRE_Int nrows_var = m_nrows_grid[ivar][mfi];
525 ntot += Reduce::Sum<Long>(nrows_var,
529 filler(boxno, cell.
x, cell.
y, cell.
z, ivar, pgid,
534 p_cell_offset += nrows_var;
536 cols += Long(nrows_var)*MSS;
537 mat += Long(nrows_var)*MSS;
543 detail::pack_matrix_gpu<Long>(cols_tmp, mat_tmp, cols_vec, mat_vec);
545 detail::pack_matrix_gpu<int>(cols_tmp, mat_tmp, cols_vec, mat_vec);
548 auto* pgid = gid_v.data();
549 for (
int ivar = 0; ivar < m_nvars; ++ivar) {
550 if (m_nrows_grid[ivar][mfi] > 0) {
551 auto const& lid = m_local_id[ivar].const_array(mfi);
553 [=,&ncols,&cols,&mat] (
int i,
int j,
int k)
555 if (lid(i,j,k) >= 0) {
556 filler(boxno, i, j, k, ivar, pgid, *ncols, cols, mat);
566 const auto& rows_vec = m_global_id_vec[mfi];
567 HYPRE_Int
const* rows = rows_vec.data();
570 HYPRE_IJMatrixSetValues(m_A, nrows, ncols_vec.
data(), rows,
572 Gpu::hypreSynchronize();
575 HYPRE_IJMatrixAssemble(m_A);
579 template <class MF, std::enable_if_t<IsFabArray<MF>::value &&
580 std::is_same_v<
typename MF::value_type,
581 HYPRE_Real>,
int> FOO>
585 HYPRE_Real rel_tol, HYPRE_Real abs_tol,
int max_iter)
591 HYPRE_IJVectorInitialize(m_b);
592 HYPRE_IJVectorInitialize(m_x);
594 load_vectors(a_soln, a_rhs);
596 HYPRE_IJVectorAssemble(m_x);
597 HYPRE_IJVectorAssemble(m_b);
599 m_hypre_ij->solve(rel_tol, abs_tol, max_iter);
601 get_solution(a_soln);
605 template <class MF, std::enable_if_t<IsFabArray<MF>::value &&
606 std::is_same_v<
typename MF::value_type,
607 HYPRE_Real>,
int> FOO>
621 const HYPRE_Int nrows = m_nrows[mfi];
628 auto* xp = xvec.
data();
629 auto* bp = bvec.
data();
631 HYPRE_Int
const* rows = m_global_id_vec[mfi].data();
634 for (
int ivar = 0; ivar < m_nvars; ++ivar) {
635 if (m_nrows_grid[ivar][mfi] > 0) {
636 auto const& xfab = a_soln[ivar]->const_array(mfi);
637 auto const& bfab = a_rhs [ivar]->const_array(mfi);
638 auto const& lid = m_local_id[ivar].const_array(mfi);
644 if (lid(i,j,k) >= 0) {
645 x[lid(i,j,k)] = xfab(i,j,k);
646 b[lid(i,j,k)] = bfab(i,j,k);
649 offset += m_nrows_grid[ivar][mfi];
654 HYPRE_IJVectorSetValues(m_x, nrows, rows, xp);
655 HYPRE_IJVectorSetValues(m_b, nrows, rows, bp);
656 Gpu::hypreSynchronize();
662 template <class MF, std::enable_if_t<IsFabArray<MF>::value &&
663 std::is_same_v<
typename MF::value_type,
664 HYPRE_Real>,
int> FOO>
676 const HYPRE_Int nrows = m_nrows[mfi];
681 auto* xp = xvec.
data();
683 HYPRE_Int
const* rows = m_global_id_vec[mfi].data();
685 HYPRE_IJVectorGetValues(m_x, nrows, rows, xp);
686 Gpu::hypreSynchronize();
689 for (
int ivar = 0; ivar < m_nvars; ++ivar) {
690 if (m_nrows_grid[ivar][mfi] > 0) {
691 auto const& xfab = a_soln[ivar]->array(mfi);
692 auto const& lid = m_local_id[ivar].const_array(mfi);
697 if (lid(i,j,k) >= 0) {
698 xfab(i,j,k) =
x[lid(i,j,k)];
701 offset += m_nrows_grid[ivar][mfi];
708 for (
int ivar = 0; ivar < m_nvars; ++ivar) {
710 m_geom.periodicity());
#define BL_PROFILE(a)
Definition: AMReX_BLProfiler.H:551
#define AMREX_ALWAYS_ASSERT_WITH_MESSAGE(EX, MSG)
Definition: AMReX_BLassert.H:49
#define AMREX_ASSERT(EX)
Definition: AMReX_BLassert.H:38
#define AMREX_GPU_DEVICE
Definition: AMReX_GpuQualifiers.H:18
Array4< int const > offset
Definition: AMReX_HypreMLABecLap.cpp:1089
int MPI_Comm
Definition: AMReX_ccse-mpi.H:47
static constexpr int MPI_COMM_NULL
Definition: AMReX_ccse-mpi.H:55
A collection of Boxes stored in an Array.
Definition: AMReX_BoxArray.H:549
AMREX_GPU_HOST_DEVICE IntVectND< dim > atOffset(Long offset) const noexcept
Given the offset, compute IntVectND<dim>
Definition: AMReX_Box.H:1009
AMREX_GPU_HOST_DEVICE BoxND & grow(int i) noexcept
Definition: AMReX_Box.H:627
AMREX_GPU_HOST_DEVICE BoxND & convert(IndexTypeND< dim > typ) noexcept
Convert the BoxND from the current type into the argument type. This may change the BoxND coordinates...
Definition: AMReX_Box.H:912
AMREX_GPU_HOST_DEVICE Long numPts() const noexcept
Returns the number of points contained in the BoxND.
Definition: AMReX_Box.H:346
Calculates the distribution of FABs to MPI processes.
Definition: AMReX_DistributionMapping.H:41
Rectangular problem domain geometry.
Definition: AMReX_Geometry.H:73
Periodicity periodicity() const noexcept
Definition: AMReX_Geometry.H:355
Definition: AMReX_GpuBuffer.H:17
T const * data() const noexcept
Definition: AMReX_GpuBuffer.H:65
Solve Ax = b using HYPRE's generic IJ matrix format where A is a sparse matrix specified using the co...
Definition: AMReX_HypreSolver.H:34
HypreSolver(Vector< IndexType > const &a_index_type, IntVect const &a_nghost, Geometry const &a_geom, BoxArray const &a_grids, DistributionMapping const &a_dmap, Marker &&a_marker, Filler &&a_filler, int a_verbose=0, std::string a_options_namespace="hypre")
Definition: AMReX_HypreSolver.H:169
LayoutData< HYPRE_Int > m_nrows
Definition: AMReX_HypreSolver.H:156
IntVect m_nghost
Definition: AMReX_HypreSolver.H:135
Vector< LayoutData< HYPRE_Int > > m_nrows_grid
Definition: AMReX_HypreSolver.H:154
int getNumIters() const
Definition: AMReX_HypreSolver.H:91
Geometry m_geom
Definition: AMReX_HypreSolver.H:136
LayoutData< Gpu::DeviceVector< HYPRE_Int > > m_global_id_vec
Definition: AMReX_HypreSolver.H:148
std::enable_if_t< IsCallable< Marker, int, int, int, int, int >::value > fill_local_id(Marker const &marker)
Definition: AMReX_HypreSolver.H:280
std::unique_ptr< HypreIJIface > m_hypre_ij
Definition: AMReX_HypreSolver.H:159
int m_nvars
Definition: AMReX_HypreSolver.H:133
Vector< LayoutData< HYPRE_Int > > m_id_offset
Definition: AMReX_HypreSolver.H:155
void fill_global_id()
Definition: AMReX_HypreSolver.H:370
MPI_Comm m_comm
Definition: AMReX_HypreSolver.H:143
void get_solution(Vector< MF * > const &a_soln)
Definition: AMReX_HypreSolver.H:666
HYPRE_IJVector getx() const
Definition: AMReX_HypreSolver.H:99
void load_vectors(Vector< MF * > const &a_soln, Vector< MF const * > const &a_rhs)
Definition: AMReX_HypreSolver.H:609
HYPRE_IJMatrix getA() const
Definition: AMReX_HypreSolver.H:97
std::string m_options_namespace
Definition: AMReX_HypreSolver.H:141
void fill_matrix(Filler const &filler)
Definition: AMReX_HypreSolver.H:482
HYPRE_IJMatrix m_A
Definition: AMReX_HypreSolver.H:162
void solve(Vector< MF * > const &a_soln, Vector< MF const * > const &a_rhs, HYPRE_Real rel_tol, HYPRE_Real abs_tol, int max_iter)
Definition: AMReX_HypreSolver.H:583
HYPRE_Int m_nrows_proc
Definition: AMReX_HypreSolver.H:157
LayoutData< Gpu::DeviceVector< int > > m_cell_offset
Definition: AMReX_HypreSolver.H:151
Vector< FabArray< BaseFab< HYPRE_Int > > > m_global_id
Definition: AMReX_HypreSolver.H:147
HYPRE_Real getFinalResidualNorm() const
Definition: AMReX_HypreSolver.H:93
Vector< std::unique_ptr< iMultiFab > > m_owner_mask
Definition: AMReX_HypreSolver.H:145
HYPRE_IJVector m_b
Definition: AMReX_HypreSolver.H:163
Vector< IndexType > m_index_type
Definition: AMReX_HypreSolver.H:134
Vector< BoxArray > m_grids
Definition: AMReX_HypreSolver.H:137
Vector< iMultiFab > m_local_id
Definition: AMReX_HypreSolver.H:146
int m_verbose
Definition: AMReX_HypreSolver.H:140
DistributionMapping m_dmap
Definition: AMReX_HypreSolver.H:138
HYPRE_IJVector m_x
Definition: AMReX_HypreSolver.H:164
HYPRE_IJVector getb() const
Definition: AMReX_HypreSolver.H:98
a one-thingy-per-box distributed object
Definition: AMReX_LayoutData.H:13
void define(const BoxArray &a_grids, const DistributionMapping &a_dm)
Definition: AMReX_LayoutData.H:25
Definition: AMReX_MFIter.H:57
bool isValid() const noexcept
Is the iterator valid i.e. is it associated with a FAB?
Definition: AMReX_MFIter.H:141
Definition: AMReX_PODVector.H:246
size_type size() const noexcept
Definition: AMReX_PODVector.H:575
T * data() noexcept
Definition: AMReX_PODVector.H:593
void clear() noexcept
Definition: AMReX_PODVector.H:573
void resize(size_type a_new_size)
Definition: AMReX_PODVector.H:625
This class is a thin wrapper around std::vector. Unlike vector, Vector::operator[] provides bound che...
Definition: AMReX_Vector.H:27
Long size() const noexcept
Definition: AMReX_Vector.H:50
AMREX_GPU_HOST_DEVICE Long size(T const &b) noexcept
integer version
Definition: AMReX_GpuRange.H:26
void streamSynchronize() noexcept
Definition: AMReX_GpuDevice.H:237
MPI_Comm CommunicatorSub() noexcept
sub-communicator for current frame
Definition: AMReX_ParallelContext.H:70
int MyProcSub() noexcept
my sub-rank in current frame
Definition: AMReX_ParallelContext.H:76
int NProcsSub() noexcept
number of ranks in current frame
Definition: AMReX_ParallelContext.H:74
static constexpr struct amrex::Scan::Type::Exclusive exclusive
static constexpr RetSum noRetSum
Definition: AMReX_Scan.H:30
AMREX_GPU_HOST_DEVICE AMREX_FORCE_INLINE void swap(T &a, T &b) noexcept
Definition: AMReX_algoim_K.H:113
@ max
Definition: AMReX_ParallelReduce.H:17
void pack_matrix_gpu(Gpu::DeviceVector< HYPRE_Int > &cols_tmp, Gpu::DeviceVector< HYPRE_Real > mat_tmp, Gpu::DeviceVector< HYPRE_Int > &cols, Gpu::DeviceVector< HYPRE_Real > &mat)
Definition: AMReX_HypreSolver.H:446
Definition: AMReX_Amr.cpp:49
std::enable_if_t< std::is_integral_v< T > > ParallelFor(TypeList< CTOs... > ctos, std::array< int, sizeof...(CTOs)> const &runtime_options, T N, F &&f)
Definition: AMReX_CTOParallelForImpl.H:200
DistributionMapping const & DistributionMap(FabArrayBase const &fa)
IntVect nGrowVect(FabArrayBase const &fa)
AMREX_GPU_HOST_DEVICE AMREX_FORCE_INLINE BoxND< dim > convert(const BoxND< dim > &b, const IntVectND< dim > &typ) noexcept
Returns a BoxND with different type.
Definition: AMReX_Box.H:1435
void OverrideSync(FabArray< FAB > &fa, FabArray< IFAB > const &msk, const Periodicity &period)
Definition: AMReX_FabArrayUtility.H:1323
std::unique_ptr< iMultiFab > OwnerMask(FabArrayBase const &mf, const Periodicity &period, const IntVect &ngrow)
Definition: AMReX_iMultiFab.cpp:637
AMREX_GPU_HOST_DEVICE AMREX_FORCE_INLINE Dim3 ubound(Array4< T > const &a) noexcept
Definition: AMReX_Array4.H:315
BoxArray const & boxArray(FabArrayBase const &fa)
AMREX_GPU_HOST_DEVICE AMREX_FORCE_INLINE Dim3 lbound(Array4< T > const &a) noexcept
Definition: AMReX_Array4.H:308
const int[]
Definition: AMReX_BLProfiler.cpp:1664
AMREX_GPU_HOST_DEVICE AMREX_ATTRIBUTE_FLATTEN_FOR void Loop(Dim3 lo, Dim3 hi, F const &f) noexcept
Definition: AMReX_Loop.H:125
Definition: AMReX_FabArrayCommI.H:841
Definition: AMReX_Array4.H:61
Definition: AMReX_Dim3.H:12
int x
Definition: AMReX_Dim3.H:12
int z
Definition: AMReX_Dim3.H:12
int y
Definition: AMReX_Dim3.H:12
Test if a given type T is callable with arguments of type Args...
Definition: AMReX_TypeTraits.H:201
Definition: AMReX_MFIter.H:20
MFItInfo & UseDefaultStream() noexcept
Definition: AMReX_MFIter.H:50
MFItInfo & DisableDeviceSync() noexcept
Definition: AMReX_MFIter.H:38