1 #ifndef AMREX_SP_MATRIX_H_
2 #define AMREX_SP_MATRIX_H_
3 #include <AMReX_Config.H>
13 #include <type_traits>
48 [[nodiscard]] T
const*
data ()
const {
return m_data.mat.data(); }
81 template <
template <
typename>
class V>
126 template <
typename T,
template<
typename>
class Allocator>
128 : m_partition(std::move(partition)),
129 m_row_begin(m_partition[ParallelDescriptor::
MyProc()]),
130 m_row_end(m_partition[ParallelDescriptor::
MyProc()+1])
132 static_assert(std::is_floating_point<T>::value,
"SpMatrix is for floating point type only");
136 template <
typename T,
template<
typename>
class Allocator>
139 m_partition = std::move(partition);
145 template <
typename T,
template<
typename>
class Allocator>
149 Long nlocalrows = this->numLocalRows();
150 Long total_nnz = nlocalrows*nnz;
151 m_data.mat.resize(total_nnz);
152 m_data.col_index.resize(total_nnz);
153 m_data.row_offset.resize(nlocalrows+1);
154 m_data.nnz = total_nnz;
156 auto* poffset = m_data.row_offset.data();
159 poffset[lrow] = lrow*nnz;
163 template <
typename T,
template<
typename>
class Allocator>
172 csr.
mat.resize(m_data.mat.size());
173 csr.
col_index.resize(m_data.col_index.size());
174 csr.
row_offset.resize(m_data.row_offset.size());
178 csr.
nnz = m_data.nnz;
181 auto const& csr = m_data;
186 ofs << m_row_begin <<
" " << m_row_end <<
" " << csr.
nnz <<
"\n";
187 for (Long i = 0, nrows = numLocalRows(); i < nrows; ++i) {
191 for (Long j = 0; j < nnz_row; ++j) {
192 ofs << i+lrow_begin <<
" " << col[j] <<
" " << mat[j] <<
"\n";
197 template <
typename T,
template<
typename>
class Allocator>
198 template <
typename F>
203 Long nlocalrows = this->numLocalRows();
204 Long rowbegin = this->globalRowBegin();
205 auto* pmat = m_data.mat.data();
206 auto* pcolindex = m_data.col_index.data();
207 auto* prowoffset = m_data.row_offset.data();
210 f(rowbegin+lrow, pcolindex+prowoffset[lrow], pmat+prowoffset[lrow]);
214 template <
typename T,
template<
typename>
class Allocator>
217 if (m_diagonal.empty()) {
218 m_diagonal.define(this->partition());
223 auto offset = m_shifted ? Long(0) : m_row_begin;
224 Long nrows = this->numLocalRows();
228 for (Long j = row[i]; j < row[i+1]; ++j) {
229 if (i == col[j] -
offset) {
240 template <
typename T,
template<
typename>
class Allocator>
243 #ifndef AMREX_USE_MPI
246 if (this->partition().numActiveProcs() <= 1) {
return; }
248 this->prepare_comm();
254 auto const nrecvs =
int(m_recv_from.size());
258 auto* p_recv = m_recv_buffer;
259 for (
int irecv = 0; irecv < nrecvs; ++irecv) {
260 BL_MPI_REQUIRE(MPI_Irecv(p_recv,
261 m_recv_counts[irecv], mpi_t_type,
262 m_recv_from[irecv], mpi_tag, mpi_comm,
263 &(m_recv_reqs[irecv])));
264 p_recv += m_recv_counts[irecv];
266 AMREX_ASSERT(p_recv == m_recv_buffer + m_total_counts_recv);
269 auto const nsends =
int(m_send_to.size());
277 auto* p_send = m_send_buffer;
278 for (
int isend = 0; isend < nsends; ++isend) {
279 auto count = m_send_counts[isend];
280 BL_MPI_REQUIRE(MPI_Isend(p_send, count, mpi_t_type, m_send_to[isend],
281 mpi_tag, mpi_comm, &(m_send_reqs[isend])));
284 AMREX_ASSERT(p_send == m_send_buffer + m_total_counts_send);
289 template <
typename T,
template<
typename>
class Allocator>
292 if (this->numLocalRows() == 0) {
return; }
294 #ifndef AMREX_USE_MPI
297 if (this->partition().numActiveProcs() <= 1) {
return; }
299 if ( ! m_recv_reqs.empty()) {
301 BL_MPI_REQUIRE(MPI_Waitall(
int(m_recv_reqs.size()),
303 mpi_statuses.data()));
308 if ( ! m_send_reqs.empty()) {
310 BL_MPI_REQUIRE(MPI_Waitall(
int(m_send_reqs.size()),
312 mpi_statuses.data()));
325 template <
typename T,
template<
typename>
class Allocator>
328 if (m_comm_prepared) {
return; }
339 Long all_nnz = m_data.nnz;
342 auto* p_pfsum = pfsum.
data();
343 auto row_begin = m_row_begin;
344 auto row_end = m_row_end;
346 auto const* pcol = m_data.col_index.data();
347 local_nnz = Scan::PrefixSum<int>(
int(all_nnz),
349 return (pcol[i] >= row_begin &&
350 pcol[i] < row_end); },
355 auto const* pcol = m_data.col_index.data();
356 local_nnz = Scan::PrefixSum<Long>(all_nnz,
358 return (pcol[i] >= row_begin &&
359 pcol[i] < row_end); },
365 m_data.nnz = local_nnz;
366 Long remote_nnz = all_nnz - local_nnz;
367 m_data_remote.nnz = remote_nnz;
372 if (local_nnz != all_nnz) {
373 m_data_remote.mat.resize(remote_nnz);
374 m_data_remote.col_index.resize(remote_nnz);
377 auto const* pmat = m_data.mat.data();
378 auto const* pcol = m_data.col_index.data();
379 auto* pmat_l = new_mat.
data();
380 auto* pcol_l = new_col.
data();
381 auto* pmat_r = m_data_remote.mat.data();
382 auto* pcol_r = m_data_remote.col_index.data();
385 auto ps = p_pfsum[i];
386 auto local = (pcol[i] >= row_begin &&
389 pmat_l[ps] = pmat[i];
390 pcol_l[ps] = pcol[i] - row_begin;
392 pmat_r[i-ps] = pmat[i];
393 pcol_r[i-ps] = pcol[i];
397 auto noffset = Long(m_data.row_offset.size());
398 auto* pro = m_data.row_offset.data();
399 m_data_remote.row_offset.resize(noffset);
400 auto* pro_r = m_data_remote.row_offset.data();
404 auto ro_l = p_pfsum[pro[i]];
405 pro_r[i] = pro[i] - ro_l;
409 pro_r[i] = remote_nnz;
413 m_data.mat.swap(new_mat);
414 m_data.col_index.swap(new_col);
419 Long old_size = m_data_remote.row_offset.size();
420 m_rtol.resize(old_size-1);
421 auto* p_rtol = m_rtol.data();
423 auto const* p_ro = m_data_remote.row_offset.data();
424 auto* p_tro = trimmed_row_offset.
data();
428 new_size = Scan::PrefixSum<int>(
int(old_size),
430 if (i+1 < old_size) {
431 return (p_ro[i+1] > p_ro[i]);
439 }
else if (p_ro[i] > p_ro[i-1]) {
442 if ((i+1 < old_size) &&
451 new_size = Scan::PrefixSum<Long>(old_size,
453 if (i+1 < old_size) {
454 return (p_ro[i+1] > p_ro[i]);
462 }
else if (p_ro[i] > p_ro[i-1]) {
465 if ((i+1 < old_size) &&
474 m_rtol.resize(new_size-1);
475 trimmed_row_offset.
resize(new_size);
477 m_rtol.shrink_to_fit();
480 m_data_remote.row_offset.swap(trimmed_row_offset);
484 m_remote_cols.resize(m_data_remote.col_index.size());
486 m_data_remote.col_index.end(),
487 m_remote_cols.begin());
490 auto const& m_remote_cols = m_data_remote.col_index;
493 unique_remote_cols_v.resize(m_remote_cols.size());
494 std::partial_sort_copy(m_remote_cols.begin(),
496 unique_remote_cols_v.begin(),
497 unique_remote_cols_v.end());
500 m_total_counts_recv = Long(unique_remote_cols_v.
size());
503 auto const& rows = this->m_partition.dataVector();
504 auto it = rows.cbegin();
505 for (
auto c : unique_remote_cols_v) {
506 it = std::find_if(it, rows.cend(), [&] (
auto x) { return x > c; });
507 if (it != rows.cend()) {
508 int iproc =
int(std::distance(rows.cbegin(),it)) - 1;
509 unique_remote_cols_vv[iproc].push_back(c);
511 amrex::Abort(
"SpMatrix::prepare_comm: how did this happen?");
523 for (
int iproc = 0; iproc < nprocs; ++iproc) {
524 need_from[iproc] = unique_remote_cols_vv[iproc].empty() ? 0 : 1;
528 BL_MPI_REQUIRE(MPI_Reduce_scatter
529 (need_from.data(), &nsends, reduce_scatter_counts.data(),
530 mpi_int, MPI_SUM, mpi_comm));
535 for (
int iproc = 0; iproc < nprocs; ++iproc) {
536 if ( ! unique_remote_cols_vv[iproc].empty()) {
539 BL_MPI_REQUIRE(MPI_Isend(unique_remote_cols_vv[iproc].data(),
540 int(unique_remote_cols_vv[iproc].
size()),
541 mpi_long, iproc, mpi_tag, mpi_comm,
542 &(mpi_requests.back())));
543 m_recv_from.push_back(iproc);
544 m_recv_counts.push_back(
int(unique_remote_cols_vv[iproc].
size()));
549 m_total_counts_send = 0;
550 for (
int isend = 0; isend < nsends; ++isend) {
552 BL_MPI_REQUIRE(MPI_Probe(MPI_ANY_SOURCE, mpi_tag, mpi_comm, &mpi_status));
553 int receiver = mpi_status.MPI_SOURCE;
555 BL_MPI_REQUIRE(MPI_Get_count(&mpi_status, mpi_long, &count));
556 m_send_to.push_back(receiver);
557 m_send_counts.push_back(count);
558 send_indices[isend].resize(count);
559 BL_MPI_REQUIRE(MPI_Recv(send_indices[isend].data(), count, mpi_long,
560 receiver, mpi_tag, mpi_comm, &mpi_status));
561 m_total_counts_send += count;
564 m_send_indices.resize(m_total_counts_send);
566 send_indices_all.
reserve(m_total_counts_send);
567 for (
auto const& vl : send_indices) {
573 m_send_indices.begin());
577 BL_MPI_REQUIRE(MPI_Waitall(
int(mpi_requests.
size()), mpi_requests.data(),
578 mpi_statuses.data()));
581 std::map<Long,Long> gtol;
582 for (Long i = 0, N = Long(unique_remote_cols_v.
size()); i < N; ++i) {
583 gtol[unique_remote_cols_v[i]] = i;
586 auto& cols = m_remote_cols;
588 auto& cols = m_data_remote.col_index;
590 for (
auto& c : cols) {
596 m_data_remote.col_index.data());
599 m_comm_prepared =
true;
602 template <
typename T,
template<
typename>
class Allocator>
605 auto*
pdst = m_send_buffer;
606 auto* pidx = m_send_indices.data();
607 auto const& vv = v.
view();
608 auto const nsends = Long(m_send_indices.size());
611 pdst[i] = vv(pidx[i]);
615 template <
typename T,
template<
typename>
class Allocator>
618 auto const& csr = m_data_remote;
624 auto const* rtol = m_rtol.data();
629 auto const nrr = Long(csr.row_offset.size())-1;
633 for (Long j = row[i]; j < row[i+1]; ++j) {
634 r += mat[j] * px[col[j]];
#define AMREX_ASSERT(EX)
Definition: AMReX_BLassert.H:38
#define AMREX_RESTRICT
Definition: AMReX_Extension.H:37
#define AMREX_GPU_DEVICE
Definition: AMReX_GpuQualifiers.H:18
Array4< int const > offset
Definition: AMReX_HypreMLABecLap.cpp:1089
Real * pdst
Definition: AMReX_HypreMLABecLap.cpp:1090
static constexpr int MPI_REQUEST_NULL
Definition: AMReX_ccse-mpi.H:53
Definition: AMReX_AlgPartition.H:14
Long numGlobalRows() const
Definition: AMReX_AlgPartition.H:28
Definition: AMReX_AlgVector.H:19
T const * data() const
Definition: AMReX_AlgVector.H:53
AMREX_FORCE_INLINE Table1D< T const, Long > view() const
Definition: AMReX_AlgVector.H:57
virtual void free(void *pt)=0
A pure virtual function for deleting the arena pointed to by pt.
virtual void * alloc(std::size_t sz)=0
Definition: AMReX_PODVector.H:246
void reserve(size_type a_capacity)
Definition: AMReX_PODVector.H:647
void shrink_to_fit()
Definition: AMReX_PODVector.H:656
T * data() noexcept
Definition: AMReX_PODVector.H:593
iterator begin() noexcept
Definition: AMReX_PODVector.H:601
iterator end() noexcept
Definition: AMReX_PODVector.H:605
void resize(size_type a_new_size)
Definition: AMReX_PODVector.H:625
void push_back(const T &a_value)
Definition: AMReX_PODVector.H:556
Definition: AMReX_SpMatrix.H:19
AlgVector< T > m_diagonal
Definition: AMReX_SpMatrix.H:94
bool m_comm_prepared
Definition: AMReX_SpMatrix.H:120
CSR< DVec > m_data_remote
Definition: AMReX_SpMatrix.H:97
void prepare_comm()
Private function, but public for cuda.
Definition: AMReX_SpMatrix.H:326
Long const * columnIndex() const
Definition: AMReX_SpMatrix.H:50
Long globalRowBegin() const
Inclusive global index begin.
Definition: AMReX_SpMatrix.H:44
AlgPartition const & partition() const
Definition: AMReX_SpMatrix.H:37
void setVal(F const &f)
Definition: AMReX_SpMatrix.H:199
T const * data() const
Definition: AMReX_SpMatrix.H:48
Long m_row_end
Definition: AMReX_SpMatrix.H:91
void define(AlgPartition partition, int nnz)
Definition: AMReX_SpMatrix.H:137
Gpu::DeviceVector< Long > m_send_indices
Definition: AMReX_SpMatrix.H:107
void startComm(AlgVector< T > const &x)
Definition: AMReX_SpMatrix.H:241
Long m_total_counts_send
Definition: AMReX_SpMatrix.H:114
Vector< int > m_recv_counts
Definition: AMReX_SpMatrix.H:110
Long numGlobalRows() const
Definition: AMReX_SpMatrix.H:40
Gpu::PinnedVector< Long > m_remote_cols
Definition: AMReX_SpMatrix.H:100
T value_type
Definition: AMReX_SpMatrix.H:21
T * data()
Definition: AMReX_SpMatrix.H:49
Long * columnIndex()
Definition: AMReX_SpMatrix.H:51
Long globalRowEnd() const
Exclusive global index end.
Definition: AMReX_SpMatrix.H:46
bool m_shifted
Definition: AMReX_SpMatrix.H:123
Vector< int > m_recv_from
Definition: AMReX_SpMatrix.H:109
Long m_row_begin
Definition: AMReX_SpMatrix.H:90
void finishComm(AlgVector< T > &y)
Definition: AMReX_SpMatrix.H:290
void printToFile(std::string const &file) const
Definition: AMReX_SpMatrix.H:165
Long m_total_counts_recv
Definition: AMReX_SpMatrix.H:118
void unpack_buffer(AlgVector< T > &v)
Definition: AMReX_SpMatrix.H:616
Vector< MPI_Request > m_send_reqs
Definition: AMReX_SpMatrix.H:112
SpMatrix(SpMatrix const &)=delete
CSR< DVec > m_data
Definition: AMReX_SpMatrix.H:92
T * m_recv_buffer
Definition: AMReX_SpMatrix.H:117
Vector< MPI_Request > m_recv_reqs
Definition: AMReX_SpMatrix.H:116
AlgPartition m_partition
Definition: AMReX_SpMatrix.H:89
SpMatrix(SpMatrix &&)=default
friend void SpMV(AlgVector< U > &y, SpMatrix< U > const &A, AlgVector< U > const &x)
Long numLocalRows() const
Definition: AMReX_SpMatrix.H:39
Long const * rowOffset() const
Definition: AMReX_SpMatrix.H:52
Long numLocalNonZero() const
Definition: AMReX_SpMatrix.H:41
T * m_send_buffer
Definition: AMReX_SpMatrix.H:113
Vector< int > m_send_to
Definition: AMReX_SpMatrix.H:105
AlgVector< T > const & diagonalVector() const
Definition: AMReX_SpMatrix.H:215
Vector< int > m_send_counts
Definition: AMReX_SpMatrix.H:106
Long * rowOffset()
Definition: AMReX_SpMatrix.H:53
SpMatrix & operator=(SpMatrix const &)=delete
void define_doit(int nnz)
Private function, but public for cuda.
Definition: AMReX_SpMatrix.H:147
DVec< Long > m_rtol
Definition: AMReX_SpMatrix.H:103
void pack_buffer(AlgVector< T > const &v)
Definition: AMReX_SpMatrix.H:603
Long size() const noexcept
Definition: AMReX_Vector.H:50
AMREX_GPU_HOST_DEVICE Long size(T const &b) noexcept
integer version
Definition: AMReX_GpuRange.H:26
void copyAsync(HostToDevice, InIter begin, InIter end, OutIter result) noexcept
A host-to-device copy routine. Note this is just a wrapper around memcpy, so it assumes contiguous st...
Definition: AMReX_GpuContainers.H:233
static constexpr DeviceToHost deviceToHost
Definition: AMReX_GpuContainers.H:99
static constexpr HostToDevice hostToDevice
Definition: AMReX_GpuContainers.H:98
void streamSynchronize() noexcept
Definition: AMReX_GpuDevice.H:237
int MyProc()
Definition: AMReX_MPMD.cpp:117
MPI_Comm CommunicatorSub() noexcept
sub-communicator for current frame
Definition: AMReX_ParallelContext.H:70
int NProcsSub() noexcept
number of ranks in current frame
Definition: AMReX_ParallelContext.H:74
int MyProc() noexcept
return the rank number local to the current Parallel Context
Definition: AMReX_ParallelDescriptor.H:125
int SeqNum() noexcept
Returns sequential message sequence numbers, usually used as tags for send/recv.
Definition: AMReX_ParallelDescriptor.H:613
static constexpr struct amrex::Scan::Type::Exclusive exclusive
static constexpr RetSum retSum
Definition: AMReX_Scan.H:29
static int f(amrex::Real t, N_Vector y_data, N_Vector y_rhs, void *user_data)
Definition: AMReX_SundialsIntegrator.H:44
@ max
Definition: AMReX_ParallelReduce.H:17
Definition: AMReX_Amr.cpp:49
std::enable_if_t< std::is_integral_v< T > > ParallelFor(TypeList< CTOs... > ctos, std::array< int, sizeof...(CTOs)> const &runtime_options, T N, F &&f)
Definition: AMReX_CTOParallelForImpl.H:200
amrex::ArenaAllocator< T > DefaultAllocator
Definition: AMReX_GpuAllocators.H:194
Arena * The_Comms_Arena()
Definition: AMReX_Arena.cpp:669
AMREX_GPU_HOST_DEVICE AMREX_FORCE_INLINE void ignore_unused(const Ts &...)
This shuts up the compiler about unused variables.
Definition: AMReX.H:111
void Abort(const std::string &msg)
Print out message to cerr and exit via abort().
Definition: AMReX.cpp:225
const int[]
Definition: AMReX_BLProfiler.cpp:1664
void RemoveDuplicates(Vector< T > &vec)
Definition: AMReX_Vector.H:190
Definition: AMReX_ccse-mpi.H:51
static MPI_Datatype type()
Definition: AMReX_SpMatrix.H:82
Long nnz
Definition: AMReX_SpMatrix.H:86
V< Long > row_offset
Definition: AMReX_SpMatrix.H:85
V< Long > col_index
Definition: AMReX_SpMatrix.H:84
V< T > mat
Definition: AMReX_SpMatrix.H:83