1#ifndef AMREX_SP_MATRIX_H_
2#define AMREX_SP_MATRIX_H_
3#include <AMReX_Config.H>
43 Long nnz, Long
const* row_index);
56 [[nodiscard]] T
const*
data ()
const {
return m_data.mat.data(); }
89 template <
template <
typename>
class V>
134template <
typename T,
template<
typename>
class Allocator>
136 : m_partition(std::move(partition)),
137 m_row_begin(m_partition[ParallelDescriptor::MyProc()]),
138 m_row_end(m_partition[ParallelDescriptor::MyProc()+1])
140 static_assert(std::is_floating_point<T>::value,
"SpMatrix is for floating point type only");
144template <
typename T,
template<
typename>
class Allocator>
147 m_partition = std::move(partition);
153template <
typename T,
template<
typename>
class Allocator>
157 Long nlocalrows = this->numLocalRows();
158 Long total_nnz = nlocalrows*nnz;
159 m_data.mat.resize(total_nnz);
160 m_data.col_index.resize(total_nnz);
161 m_data.row_offset.resize(nlocalrows+1);
162 m_data.nnz = total_nnz;
164 auto* poffset = m_data.row_offset.data();
167 poffset[lrow] = lrow*nnz;
171template <
typename T,
template<
typename>
class Allocator>
174 Long
const* col_index, Long nnz,
175 Long
const* row_index)
177 m_partition = std::move(partition);
180 Long nlocalrows = this->numLocalRows();
181 m_data.mat.resize(nnz);
182 m_data.col_index.resize(nnz);
183 m_data.row_offset.resize(nlocalrows+1);
187 m_data.col_index.begin());
189 m_data.row_index.begin());
193template <
typename T,
template<
typename>
class Allocator>
202 csr.
mat.resize(m_data.mat.size());
203 csr.
col_index.resize(m_data.col_index.size());
204 csr.
row_offset.resize(m_data.row_offset.size());
208 csr.
nnz = m_data.nnz;
211 auto const& csr = m_data;
216 ofs << m_row_begin <<
" " << m_row_end <<
" " << csr.
nnz <<
"\n";
217 for (Long i = 0, nrows = numLocalRows(); i < nrows; ++i) {
221 for (Long j = 0; j < nnz_row; ++j) {
222 ofs << i+lrow_begin <<
" " << col[j] <<
" " << mat[j] <<
"\n";
227template <
typename T,
template<
typename>
class Allocator>
233 Long nlocalrows = this->numLocalRows();
234 Long rowbegin = this->globalRowBegin();
235 auto* pmat = m_data.mat.data();
236 auto* pcolindex = m_data.col_index.data();
237 auto* prowoffset = m_data.row_offset.data();
240 f(rowbegin+lrow, pcolindex+prowoffset[lrow], pmat+prowoffset[lrow]);
244template <
typename T,
template<
typename>
class Allocator>
247 if (m_diagonal.empty()) {
248 m_diagonal.
define(this->partition());
253 auto offset = m_shifted ? Long(0) : m_row_begin;
254 Long nrows = this->numLocalRows();
258 for (Long j = row[i]; j < row[i+1]; ++j) {
259 if (i == col[j] -
offset) {
270template <
typename T,
template<
typename>
class Allocator>
276 if (this->partition().numActiveProcs() <= 1) {
return; }
278 this->prepare_comm();
284 auto const nrecvs =
int(m_recv_from.size());
288 auto* p_recv = m_recv_buffer;
289 for (
int irecv = 0; irecv < nrecvs; ++irecv) {
290 BL_MPI_REQUIRE(MPI_Irecv(p_recv,
291 m_recv_counts[irecv], mpi_t_type,
292 m_recv_from[irecv], mpi_tag, mpi_comm,
293 &(m_recv_reqs[irecv])));
294 p_recv += m_recv_counts[irecv];
296 AMREX_ASSERT(p_recv == m_recv_buffer + m_total_counts_recv);
299 auto const nsends =
int(m_send_to.size());
307 auto* p_send = m_send_buffer;
308 for (
int isend = 0; isend < nsends; ++isend) {
309 auto count = m_send_counts[isend];
310 BL_MPI_REQUIRE(MPI_Isend(p_send, count, mpi_t_type, m_send_to[isend],
311 mpi_tag, mpi_comm, &(m_send_reqs[isend])));
314 AMREX_ASSERT(p_send == m_send_buffer + m_total_counts_send);
319template <
typename T,
template<
typename>
class Allocator>
322 if (this->numLocalRows() == 0) {
return; }
327 if (this->partition().numActiveProcs() <= 1) {
return; }
329 if ( ! m_recv_reqs.empty()) {
331 BL_MPI_REQUIRE(MPI_Waitall(
int(m_recv_reqs.size()),
333 mpi_statuses.data()));
338 if ( ! m_send_reqs.empty()) {
340 BL_MPI_REQUIRE(MPI_Waitall(
int(m_send_reqs.size()),
342 mpi_statuses.data()));
355template <
typename T,
template<
typename>
class Allocator>
358 if (m_comm_prepared) {
return; }
369 Long all_nnz = m_data.nnz;
372 auto* p_pfsum = pfsum.
data();
373 auto row_begin = m_row_begin;
374 auto row_end = m_row_end;
375 if (m_data.nnz < Long(std::numeric_limits<int>::max())) {
376 auto const* pcol = m_data.col_index.data();
377 local_nnz = Scan::PrefixSum<int>(
int(all_nnz),
379 return (pcol[i] >= row_begin &&
380 pcol[i] < row_end); },
385 auto const* pcol = m_data.col_index.data();
386 local_nnz = Scan::PrefixSum<Long>(all_nnz,
388 return (pcol[i] >= row_begin &&
389 pcol[i] < row_end); },
395 m_data.nnz = local_nnz;
396 Long remote_nnz = all_nnz - local_nnz;
397 m_data_remote.nnz = remote_nnz;
402 if (local_nnz != all_nnz) {
403 m_data_remote.mat.resize(remote_nnz);
404 m_data_remote.col_index.resize(remote_nnz);
407 auto const* pmat = m_data.mat.data();
408 auto const* pcol = m_data.col_index.data();
409 auto* pmat_l = new_mat.
data();
410 auto* pcol_l = new_col.
data();
411 auto* pmat_r = m_data_remote.mat.data();
412 auto* pcol_r = m_data_remote.col_index.data();
415 auto ps = p_pfsum[i];
416 auto local = (pcol[i] >= row_begin &&
419 pmat_l[ps] = pmat[i];
420 pcol_l[ps] = pcol[i] - row_begin;
422 pmat_r[i-ps] = pmat[i];
423 pcol_r[i-ps] = pcol[i];
427 auto noffset = Long(m_data.row_offset.size());
428 auto* pro = m_data.row_offset.data();
429 m_data_remote.row_offset.resize(noffset);
430 auto* pro_r = m_data_remote.row_offset.data();
434 auto ro_l = p_pfsum[pro[i]];
435 pro_r[i] = pro[i] - ro_l;
439 pro_r[i] = remote_nnz;
443 m_data.mat.swap(new_mat);
444 m_data.col_index.swap(new_col);
449 Long old_size = m_data_remote.row_offset.size();
450 m_rtol.resize(old_size-1);
451 auto* p_rtol = m_rtol.data();
453 auto const* p_ro = m_data_remote.row_offset.data();
454 auto* p_tro = trimmed_row_offset.
data();
456 if (old_size < Long(std::numeric_limits<int>::max())) {
458 new_size = Scan::PrefixSum<int>(
int(old_size),
460 if (i+1 < old_size) {
461 return (p_ro[i+1] > p_ro[i]);
469 }
else if (p_ro[i] > p_ro[i-1]) {
472 if ((i+1 < old_size) &&
481 new_size = Scan::PrefixSum<Long>(old_size,
483 if (i+1 < old_size) {
484 return (p_ro[i+1] > p_ro[i]);
492 }
else if (p_ro[i] > p_ro[i-1]) {
495 if ((i+1 < old_size) &&
504 m_rtol.resize(new_size-1);
505 trimmed_row_offset.
resize(new_size);
507 m_rtol.shrink_to_fit();
510 m_data_remote.row_offset.swap(trimmed_row_offset);
514 m_remote_cols.resize(m_data_remote.col_index.size());
516 m_data_remote.col_index.end(),
517 m_remote_cols.begin());
520 auto const& m_remote_cols = m_data_remote.col_index;
523 unique_remote_cols_v.resize(m_remote_cols.size());
524 std::partial_sort_copy(m_remote_cols.begin(),
526 unique_remote_cols_v.begin(),
527 unique_remote_cols_v.end());
530 m_total_counts_recv = Long(unique_remote_cols_v.
size());
533 auto const& rows = this->m_partition.dataVector();
534 auto it = rows.cbegin();
535 for (
auto c : unique_remote_cols_v) {
536 it = std::find_if(it, rows.cend(), [&] (
auto x) { return x > c; });
537 if (it != rows.cend()) {
538 int iproc =
int(std::distance(rows.cbegin(),it)) - 1;
539 unique_remote_cols_vv[iproc].push_back(c);
541 amrex::Abort(
"SpMatrix::prepare_comm: how did this happen?");
553 for (
int iproc = 0; iproc < nprocs; ++iproc) {
554 need_from[iproc] = unique_remote_cols_vv[iproc].empty() ? 0 : 1;
558 BL_MPI_REQUIRE(MPI_Reduce_scatter
559 (need_from.data(), &nsends, reduce_scatter_counts.data(),
560 mpi_int, MPI_SUM, mpi_comm));
565 for (
int iproc = 0; iproc < nprocs; ++iproc) {
566 if ( ! unique_remote_cols_vv[iproc].empty()) {
569 BL_MPI_REQUIRE(MPI_Isend(unique_remote_cols_vv[iproc].data(),
570 int(unique_remote_cols_vv[iproc].size()),
571 mpi_long, iproc, mpi_tag, mpi_comm,
572 &(mpi_requests.back())));
573 m_recv_from.push_back(iproc);
574 m_recv_counts.push_back(
int(unique_remote_cols_vv[iproc].size()));
579 m_total_counts_send = 0;
580 for (
int isend = 0; isend < nsends; ++isend) {
582 BL_MPI_REQUIRE(MPI_Probe(MPI_ANY_SOURCE, mpi_tag, mpi_comm, &mpi_status));
583 int receiver = mpi_status.MPI_SOURCE;
585 BL_MPI_REQUIRE(MPI_Get_count(&mpi_status, mpi_long, &count));
586 m_send_to.push_back(receiver);
587 m_send_counts.push_back(count);
588 send_indices[isend].resize(count);
589 BL_MPI_REQUIRE(MPI_Recv(send_indices[isend].data(), count, mpi_long,
590 receiver, mpi_tag, mpi_comm, &mpi_status));
591 m_total_counts_send += count;
594 m_send_indices.resize(m_total_counts_send);
596 send_indices_all.
reserve(m_total_counts_send);
597 for (
auto const& vl : send_indices) {
603 m_send_indices.begin());
607 BL_MPI_REQUIRE(MPI_Waitall(
int(mpi_requests.
size()), mpi_requests.data(),
608 mpi_statuses.data()));
611 std::map<Long,Long> gtol;
612 for (Long i = 0, N = Long(unique_remote_cols_v.
size()); i < N; ++i) {
613 gtol[unique_remote_cols_v[i]] = i;
616 auto& cols = m_remote_cols;
618 auto& cols = m_data_remote.col_index;
620 for (
auto& c : cols) {
626 m_data_remote.col_index.data());
629 m_comm_prepared =
true;
632template <
typename T,
template<
typename>
class Allocator>
635 auto*
pdst = m_send_buffer;
636 auto* pidx = m_send_indices.data();
637 auto const& vv = v.
view();
638 auto const nsends = Long(m_send_indices.size());
641 pdst[i] = vv(pidx[i]);
645template <
typename T,
template<
typename>
class Allocator>
648 auto const& csr = m_data_remote;
654 auto const* rtol = m_rtol.data();
659 auto const nrr = Long(csr.row_offset.size())-1;
663 for (Long j = row[i]; j < row[i+1]; ++j) {
664 r += mat[j] * px[col[j]];
#define AMREX_ASSERT(EX)
Definition AMReX_BLassert.H:38
#define AMREX_RESTRICT
Definition AMReX_Extension.H:37
#define AMREX_GPU_DEVICE
Definition AMReX_GpuQualifiers.H:18
Array4< int const > offset
Definition AMReX_HypreMLABecLap.cpp:1089
Real * pdst
Definition AMReX_HypreMLABecLap.cpp:1090
static constexpr int MPI_REQUEST_NULL
Definition AMReX_ccse-mpi.H:53
Definition AMReX_AlgPartition.H:14
Long numGlobalRows() const
Definition AMReX_AlgPartition.H:28
Definition AMReX_AlgVector.H:19
T const * data() const
Definition AMReX_AlgVector.H:53
void define(Long global_size)
Definition AMReX_AlgVector.H:128
AMREX_FORCE_INLINE Table1D< T const, Long > view() const
Definition AMReX_AlgVector.H:57
virtual void free(void *pt)=0
A pure virtual function for deleting the arena pointed to by pt.
virtual void * alloc(std::size_t sz)=0
Definition AMReX_PODVector.H:262
void reserve(size_type a_capacity)
Definition AMReX_PODVector.H:663
void shrink_to_fit()
Definition AMReX_PODVector.H:672
iterator begin() noexcept
Definition AMReX_PODVector.H:617
iterator end() noexcept
Definition AMReX_PODVector.H:621
void resize(size_type a_new_size)
Definition AMReX_PODVector.H:641
T * data() noexcept
Definition AMReX_PODVector.H:609
void push_back(const T &a_value)
Definition AMReX_PODVector.H:572
Definition AMReX_SpMatrix.H:19
AlgVector< T > m_diagonal
Definition AMReX_SpMatrix.H:102
bool m_comm_prepared
Definition AMReX_SpMatrix.H:128
CSR< DVec > m_data_remote
Definition AMReX_SpMatrix.H:105
void prepare_comm()
Private function, but public for cuda.
Definition AMReX_SpMatrix.H:356
Long globalRowBegin() const
Inclusive global index begin.
Definition AMReX_SpMatrix.H:52
void setVal(F const &f)
Definition AMReX_SpMatrix.H:229
Long * rowOffset()
Definition AMReX_SpMatrix.H:61
Long m_row_end
Definition AMReX_SpMatrix.H:99
Long const * rowOffset() const
Definition AMReX_SpMatrix.H:60
void define(AlgPartition partition, int nnz)
Definition AMReX_SpMatrix.H:145
Gpu::DeviceVector< Long > m_send_indices
Definition AMReX_SpMatrix.H:115
void startComm(AlgVector< T > const &x)
Definition AMReX_SpMatrix.H:271
Long m_total_counts_send
Definition AMReX_SpMatrix.H:122
T * data()
Definition AMReX_SpMatrix.H:57
Vector< int > m_recv_counts
Definition AMReX_SpMatrix.H:118
Long numGlobalRows() const
Definition AMReX_SpMatrix.H:48
Gpu::PinnedVector< Long > m_remote_cols
Definition AMReX_SpMatrix.H:108
SpMatrix & operator=(SpMatrix const &)=delete
T value_type
Definition AMReX_SpMatrix.H:21
T const * data() const
Definition AMReX_SpMatrix.H:56
Long globalRowEnd() const
Exclusive global index end.
Definition AMReX_SpMatrix.H:54
bool m_shifted
Definition AMReX_SpMatrix.H:131
Long const * columnIndex() const
Definition AMReX_SpMatrix.H:58
Vector< int > m_recv_from
Definition AMReX_SpMatrix.H:117
Long m_row_begin
Definition AMReX_SpMatrix.H:98
void finishComm(AlgVector< T > &y)
Definition AMReX_SpMatrix.H:320
void printToFile(std::string const &file) const
Definition AMReX_SpMatrix.H:195
Long m_total_counts_recv
Definition AMReX_SpMatrix.H:126
void unpack_buffer(AlgVector< T > &v)
Definition AMReX_SpMatrix.H:646
Vector< MPI_Request > m_send_reqs
Definition AMReX_SpMatrix.H:120
SpMatrix(SpMatrix const &)=delete
CSR< DVec > m_data
Definition AMReX_SpMatrix.H:100
T * m_recv_buffer
Definition AMReX_SpMatrix.H:125
Vector< MPI_Request > m_recv_reqs
Definition AMReX_SpMatrix.H:124
AlgPartition m_partition
Definition AMReX_SpMatrix.H:97
SpMatrix(SpMatrix &&)=default
friend void SpMV(AlgVector< U > &y, SpMatrix< U > const &A, AlgVector< U > const &x)
Long numLocalRows() const
Definition AMReX_SpMatrix.H:47
AlgPartition const & partition() const
Definition AMReX_SpMatrix.H:45
Long numLocalNonZero() const
Definition AMReX_SpMatrix.H:49
Long * columnIndex()
Definition AMReX_SpMatrix.H:59
T * m_send_buffer
Definition AMReX_SpMatrix.H:121
Vector< int > m_send_to
Definition AMReX_SpMatrix.H:113
AlgVector< T > const & diagonalVector() const
Definition AMReX_SpMatrix.H:245
Vector< int > m_send_counts
Definition AMReX_SpMatrix.H:114
void define_doit(int nnz)
Private function, but public for cuda.
Definition AMReX_SpMatrix.H:155
DVec< Long > m_rtol
Definition AMReX_SpMatrix.H:111
void pack_buffer(AlgVector< T > const &v)
Definition AMReX_SpMatrix.H:633
This class is a thin wrapper around std::vector. Unlike vector, Vector::operator[] provides bound che...
Definition AMReX_Vector.H:27
Long size() const noexcept
Definition AMReX_Vector.H:50
void copyAsync(HostToDevice, InIter begin, InIter end, OutIter result) noexcept
A host-to-device copy routine. Note this is just a wrapper around memcpy, so it assumes contiguous st...
Definition AMReX_GpuContainers.H:233
static constexpr DeviceToDevice deviceToDevice
Definition AMReX_GpuContainers.H:100
static constexpr DeviceToHost deviceToHost
Definition AMReX_GpuContainers.H:99
static constexpr HostToDevice hostToDevice
Definition AMReX_GpuContainers.H:98
void streamSynchronize() noexcept
Definition AMReX_GpuDevice.H:237
MPI_Comm CommunicatorSub() noexcept
sub-communicator for current frame
Definition AMReX_ParallelContext.H:70
int NProcsSub() noexcept
number of ranks in current frame
Definition AMReX_ParallelContext.H:74
int MyProc() noexcept
return the rank number local to the current Parallel Context
Definition AMReX_ParallelDescriptor.H:125
int SeqNum() noexcept
Returns sequential message sequence numbers, usually used as tags for send/recv.
Definition AMReX_ParallelDescriptor.H:613
static constexpr struct amrex::Scan::Type::Exclusive exclusive
static constexpr RetSum retSum
Definition AMReX_Scan.H:29
Definition AMReX_Amr.cpp:49
amrex::ArenaAllocator< T > DefaultAllocator
Definition AMReX_GpuAllocators.H:194
std::enable_if_t< std::is_integral_v< T > > ParallelFor(TypeList< CTOs... > ctos, std::array< int, sizeof...(CTOs)> const &runtime_options, T N, F &&f)
Definition AMReX_CTOParallelForImpl.H:191
Arena * The_Comms_Arena()
Definition AMReX_Arena.cpp:676
AMREX_GPU_HOST_DEVICE AMREX_FORCE_INLINE void ignore_unused(const Ts &...)
This shuts up the compiler about unused variables.
Definition AMReX.H:127
void Abort(const std::string &msg)
Print out message to cerr and exit via abort().
Definition AMReX.cpp:230
const int[]
Definition AMReX_BLProfiler.cpp:1664
void RemoveDuplicates(Vector< T > &vec)
Definition AMReX_Vector.H:208
Definition AMReX_ccse-mpi.H:51
static MPI_Datatype type()
Definition AMReX_SpMatrix.H:90
Long nnz
Definition AMReX_SpMatrix.H:94
V< Long > row_offset
Definition AMReX_SpMatrix.H:93
V< Long > col_index
Definition AMReX_SpMatrix.H:92
V< T > mat
Definition AMReX_SpMatrix.H:91