Block-Structured AMR Software Framework
 
Loading...
Searching...
No Matches
AMReX_SpMV.H
Go to the documentation of this file.
1#ifndef AMREX_SPMV_H_
2#define AMREX_SPMV_H_
3#include <AMReX_Config.H>
4
5#include <AMReX_AlgVector.H>
6#include <AMReX_GpuComplex.H>
7#include <AMReX_SpMatrix.H>
8
9#if defined(AMREX_USE_CUDA)
10# include <cusparse.h>
11#elif defined(AMREX_USE_HIP)
12# include <rocsparse/rocsparse.h>
13#elif defined(AMREX_USE_SYCL)
14# include <mkl_version.h>
15# include <oneapi/mkl/spblas.hpp>
16#endif
17
18namespace amrex {
19
20template <typename T>
21void SpMV (AlgVector<T>& y, SpMatrix<T> const& A, AlgVector<T> const& x)
22{
23 // xxxxx TODOL We might want to cache the cusparse and rocsparse handles
24
25 // xxxxx TODO: let's assume it's square matrix for now.
26 AMREX_ALWAYS_ASSERT(x.partition() == y.partition() &&
27 x.partition() == A.partition());
28
29 const_cast<SpMatrix<T>&>(A).startComm(x);
30
31 T * AMREX_RESTRICT py = y.data();
32 T const* AMREX_RESTRICT px = x.data();
33 T const* AMREX_RESTRICT mat = A.data();
34 auto const* AMREX_RESTRICT col = A.columnIndex();
35 auto const* AMREX_RESTRICT row = A.rowOffset();
36
37#if defined(AMREX_USE_GPU)
38
39 Long const nrows = A.numLocalRows();
40 Long const ncols = x.numLocalRows();
41 Long const nnz = A.numLocalNonZero();
42
43#if defined(AMREX_USE_CUDA)
44
45 cusparseHandle_t handle;
46 cusparseCreate(&handle);
47 cusparseSetStream(handle, Gpu::gpuStream());
48
49 cudaDataType data_type;
50 if constexpr (std::is_same_v<T,float>) {
51 data_type = CUDA_R_32F;
52 } else if constexpr (std::is_same_v<T,double>) {
53 data_type = CUDA_R_64F;
54 } else if constexpr (std::is_same_v<T,GpuComplex<float>>) {
55 data_type = CUDA_C_32F;
56 } else if constexpr (std::is_same_v<T,GpuComplex<double>>) {
57 data_type = CUDA_C_64F;
58 } else {
59 amrex::Abort("SpMV: unsupported data type");
60 }
61
62 cusparseIndexType_t index_type = CUSPARSE_INDEX_64I;
63
64 cusparseSpMatDescr_t mat_descr;
65 cusparseCreateCsr(&mat_descr, nrows, ncols, nnz, (void*)row, (void*)col, (void*)mat,
66 index_type, index_type, CUSPARSE_INDEX_BASE_ZERO, data_type);
67
68 cusparseDnVecDescr_t x_descr;
69 cusparseCreateDnVec(&x_descr, ncols, (void*)px, data_type);
70
71 cusparseDnVecDescr_t y_descr;
72 cusparseCreateDnVec(&y_descr, nrows, (void*)py, data_type);
73
74 T alpha = T(1);
75 T beta = T(0);
76
77 std::size_t buffer_size;
78 cusparseSpMV_bufferSize(handle, CUSPARSE_OPERATION_NON_TRANSPOSE, &alpha, mat_descr, x_descr,
79 &beta, y_descr, data_type, CUSPARSE_SPMV_ALG_DEFAULT, &buffer_size);
80
81 auto* pbuffer = (void*)The_Arena()->alloc(buffer_size);
82
83 cusparseSpMV(handle, CUSPARSE_OPERATION_NON_TRANSPOSE, &alpha, mat_descr, x_descr,
84 &beta, y_descr, data_type, CUSPARSE_SPMV_ALG_DEFAULT, pbuffer);
85
87
88 cusparseDestroySpMat(mat_descr);
89 cusparseDestroyDnVec(x_descr);
90 cusparseDestroyDnVec(y_descr);
91 cusparseDestroy(handle);
92 The_Arena()->free(pbuffer);
93
94#elif defined(AMREX_USE_HIP)
95
96 rocsparse_handle handle;
97 rocsparse_create_handle(&handle);
98 rocsparse_set_stream(handle, Gpu::gpuStream());
99
100 rocsparse_datatype data_type;
101 if constexpr (std::is_same_v<T,float>) {
102 data_type = rocsparse_datatype_f32_r;
103 } else if constexpr (std::is_same_v<T,double>) {
104 data_type = rocsparse_datatype_f64_r;
105 } else if constexpr (std::is_same_v<T,GpuComplex<float>>) {
106 data_type = rocsparse_datatype_f32_c;
107 } else if constexpr (std::is_same_v<T,GpuComplex<double>>) {
108 data_type = rocsparse_datatype_f64_c;
109 } else {
110 amrex::Abort("SpMV: unsupported data type");
111 }
112
113 rocsparse_indextype index_type = rocsparse_indextype_i64;
114
115 rocsparse_spmat_descr mat_descr;
116 rocsparse_create_csr_descr(&mat_descr, nrows, ncols, nnz, (void*)row, (void*)col,
117 (void*)mat, index_type, index_type,
118 rocsparse_index_base_zero, data_type);
119
120 rocsparse_dnvec_descr x_descr;
121 rocsparse_create_dnvec_descr(&x_descr, ncols, (void*)px, data_type);
122
123 rocsparse_dnvec_descr y_descr;
124 rocsparse_create_dnvec_descr(&y_descr, nrows, (void*)py, data_type);
125
126 T alpha = T(1.0);
127 T beta = T(0.0);
128
129#if (HIP_VERSION_MAJOR >= 7)
130#pragma clang diagnostic push
131#pragma clang diagnostic ignored "-Wdeprecated-declarations"
132#endif
133
134 std::size_t buffer_size;
135 rocsparse_spmv(handle, rocsparse_operation_none, &alpha, mat_descr, x_descr,
136 &beta, y_descr, data_type, rocsparse_spmv_alg_default,
137#if (HIP_VERSION_MAJOR >= 6)
138 rocsparse_spmv_stage_buffer_size,
139#endif
140 &buffer_size, nullptr);
141
142 void* pbuffer = (void*)The_Arena()->alloc(buffer_size);
143
144#if (HIP_VERSION_MAJOR >= 6)
145 rocsparse_spmv(handle, rocsparse_operation_none, &alpha, mat_descr, x_descr,
146 &beta, y_descr, data_type, rocsparse_spmv_alg_default,
147 rocsparse_spmv_stage_preprocess, &buffer_size, pbuffer);
148#endif
149
150 rocsparse_spmv(handle, rocsparse_operation_none, &alpha, mat_descr, x_descr,
151 &beta, y_descr, data_type, rocsparse_spmv_alg_default,
152#if (HIP_VERSION_MAJOR >= 6)
153 rocsparse_spmv_stage_compute,
154#endif
155 &buffer_size, pbuffer);
156
157#if (HIP_VERSION_MAJOR >= 7)
158// xxxxx HIP TODO: rocsparse_spmv has been deprecated.
159#pragma clang diagnostic pop
160#endif
161
163
164 rocsparse_destroy_spmat_descr(mat_descr);
165 rocsparse_destroy_dnvec_descr(x_descr);
166 rocsparse_destroy_dnvec_descr(y_descr);
167 rocsparse_destroy_handle(handle);
168 The_Arena()->free(pbuffer);
169
170#elif defined(AMREX_USE_SYCL)
171
173 mkl::sparse::matrix_handle_t handle{};
174 mkl::sparse::init_matrix_handle(&handle);
175
176#if defined(INTEL_MKL_VERSION) && (INTEL_MKL_VERSION < 20250300)
177 mkl::sparse::set_csr_data(Gpu::Device::streamQueue(), handle, nrows, ncols,
178 mkl::index_base::zero, (Long*)row, (Long*)col, (T*)mat);
179#else
180 mkl::sparse::set_csr_data(Gpu::Device::streamQueue(), handle, nrows, ncols, nnz,
181 mkl::index_base::zero, (Long*)row, (Long*)col, (T*)mat);
182#endif
183 mkl::sparse::gemv(Gpu::Device::streamQueue(), mkl::transpose::nontrans,
184 T(1), handle, px, T(0), py);
185
186 auto ev = mkl::sparse::release_matrix_handle(Gpu::Device::streamQueue(), &handle);
187 ev.wait();
188
189#endif
190
192
193#else
194
195 Long const ny = y.numLocalRows();
196 for (Long i = 0; i < ny; ++i) {
197 T r = 0;
198#ifdef AMREX_USE_OMP
199#pragma omp parallel for reduction(+:r)
200#endif
201 for (Long j = row[i]; j < row[i+1]; ++j) {
202 r += mat[j] * px[col[j]];
203 }
204 py[i] = r;
205 }
206
207#endif
208
209 const_cast<SpMatrix<T>&>(A).finishComm(y);
210}
211
212}
213
214#endif
#define AMREX_ALWAYS_ASSERT(EX)
Definition AMReX_BLassert.H:50
#define AMREX_RESTRICT
Definition AMReX_Extension.H:37
#define AMREX_GPU_ERROR_CHECK()
Definition AMReX_GpuError.H:133
Definition AMReX_AlgVector.H:19
virtual void free(void *pt)=0
A pure virtual function for deleting the arena pointed to by pt.
virtual void * alloc(std::size_t sz)=0
Definition AMReX_SpMatrix.H:19
Long const * rowOffset() const
Definition AMReX_SpMatrix.H:60
T const * data() const
Definition AMReX_SpMatrix.H:56
Long const * columnIndex() const
Definition AMReX_SpMatrix.H:58
Long numLocalRows() const
Definition AMReX_SpMatrix.H:47
AlgPartition const & partition() const
Definition AMReX_SpMatrix.H:45
Long numLocalNonZero() const
Definition AMReX_SpMatrix.H:49
void streamSynchronize() noexcept
Definition AMReX_GpuDevice.H:260
gpuStream_t gpuStream() noexcept
Definition AMReX_GpuDevice.H:241
Definition AMReX_Amr.cpp:49
__host__ __device__ void ignore_unused(const Ts &...)
This shuts up the compiler about unused variables.
Definition AMReX.H:138
void SpMV(AlgVector< T > &y, SpMatrix< T > const &A, AlgVector< T > const &x)
Definition AMReX_SpMV.H:21
void Abort(const std::string &msg)
Print out message to cerr and exit via abort().
Definition AMReX.cpp:230
Arena * The_Arena()
Definition AMReX_Arena.cpp:705