Block-Structured AMR Software Framework
Loading...
Searching...
No Matches
AMReX_SpMV.H
Go to the documentation of this file.
1#ifndef AMREX_SPMV_H_
2#define AMREX_SPMV_H_
3#include <AMReX_Config.H>
4
5#include <AMReX_AlgVector.H>
6#include <AMReX_AlgVecUtil.H>
7#include <AMReX_GpuComplex.H>
8#include <AMReX_SpMatrix.H>
9
10namespace amrex {
11
18template <typename T>
28void SpMV (Long nrows, Long ncols, T* AMREX_RESTRICT py, CsrView<T const> const& A,
29 T const* AMREX_RESTRICT px)
30{
31 T const* AMREX_RESTRICT mat = A.mat;
32 auto const* AMREX_RESTRICT col = A.col_index;
33 auto const* AMREX_RESTRICT row = A.row_offset;
34
35#if defined(AMREX_USE_GPU)
36
37 Long const nnz = A.nnz;
38
39#if defined(AMREX_USE_CUDA)
40
41 cusparseHandle_t handle;
42 AMREX_CUSPARSE_SAFE_CALL(cusparseCreate(&handle));
43 AMREX_CUSPARSE_SAFE_CALL(cusparseSetStream(handle, Gpu::gpuStream()));
44
45 cudaDataType data_type;
46 if constexpr (std::is_same_v<T,float>) {
47 data_type = CUDA_R_32F;
48 } else if constexpr (std::is_same_v<T,double>) {
49 data_type = CUDA_R_64F;
50 } else if constexpr (std::is_same_v<T,GpuComplex<float>>) {
51 data_type = CUDA_C_32F;
52 } else if constexpr (std::is_same_v<T,GpuComplex<double>>) {
53 data_type = CUDA_C_64F;
54 } else {
55 amrex::Abort("SpMV: unsupported data type");
56 }
57
58 cusparseIndexType_t index_type = CUSPARSE_INDEX_64I;
59
60 cusparseSpMatDescr_t mat_descr;
62 (cusparseCreateCsr(&mat_descr, nrows, ncols, nnz,
63 (void*)row, (void*)col, (void*)mat,
64 index_type, index_type, CUSPARSE_INDEX_BASE_ZERO,
65 data_type));
66
67 cusparseDnVecDescr_t x_descr;
68 AMREX_CUSPARSE_SAFE_CALL(cusparseCreateDnVec(&x_descr, ncols, (void*)px, data_type));
69
70 cusparseDnVecDescr_t y_descr;
71 AMREX_CUSPARSE_SAFE_CALL(cusparseCreateDnVec(&y_descr, nrows, (void*)py, data_type));
72
73 T alpha = T(1);
74 T beta = T(0);
75
76 std::size_t buffer_size;
78 (cusparseSpMV_bufferSize(handle, CUSPARSE_OPERATION_NON_TRANSPOSE,
79 &alpha, mat_descr, x_descr, &beta, y_descr,
80 data_type, CUSPARSE_SPMV_ALG_DEFAULT,
81 &buffer_size));
82
83 auto* pbuffer = (void*)The_Arena()->alloc(buffer_size);
84
86 (cusparseSpMV(handle, CUSPARSE_OPERATION_NON_TRANSPOSE,
87 &alpha, mat_descr, x_descr, &beta, y_descr,
88 data_type, CUSPARSE_SPMV_ALG_DEFAULT, pbuffer));
89
91
92 AMREX_CUSPARSE_SAFE_CALL(cusparseDestroySpMat(mat_descr));
93 AMREX_CUSPARSE_SAFE_CALL(cusparseDestroyDnVec(x_descr));
94 AMREX_CUSPARSE_SAFE_CALL(cusparseDestroyDnVec(y_descr));
95 AMREX_CUSPARSE_SAFE_CALL(cusparseDestroy(handle));
96 The_Arena()->free(pbuffer);
97
98#elif defined(AMREX_USE_HIP)
99
100 rocsparse_handle handle;
101 AMREX_ROCSPARSE_SAFE_CALL(rocsparse_create_handle(&handle));
102 AMREX_ROCSPARSE_SAFE_CALL(rocsparse_set_stream(handle, Gpu::gpuStream()));
103
104 rocsparse_datatype data_type;
105 if constexpr (std::is_same_v<T,float>) {
106 data_type = rocsparse_datatype_f32_r;
107 } else if constexpr (std::is_same_v<T,double>) {
108 data_type = rocsparse_datatype_f64_r;
109 } else if constexpr (std::is_same_v<T,GpuComplex<float>>) {
110 data_type = rocsparse_datatype_f32_c;
111 } else if constexpr (std::is_same_v<T,GpuComplex<double>>) {
112 data_type = rocsparse_datatype_f64_c;
113 } else {
114 amrex::Abort("SpMV: unsupported data type");
115 }
116
117 rocsparse_indextype index_type = rocsparse_indextype_i64;
118
119 rocsparse_spmat_descr mat_descr;
120 AMREX_ROCSPARSE_SAFE_CALL(
121 rocsparse_create_csr_descr(&mat_descr, nrows, ncols, nnz,
122 (void*)row, (void*)col, (void*)mat,
123 index_type, index_type,
124 rocsparse_index_base_zero, data_type));
125
126 rocsparse_dnvec_descr x_descr;
127 AMREX_ROCSPARSE_SAFE_CALL(
128 rocsparse_create_dnvec_descr(&x_descr, ncols, (void*)px, data_type));
129
130 rocsparse_dnvec_descr y_descr;
131 AMREX_ROCSPARSE_SAFE_CALL(
132 rocsparse_create_dnvec_descr(&y_descr, nrows, (void*)py, data_type));
133
134 T alpha = T(1.0);
135 T beta = T(0.0);
136
137#if (HIP_VERSION_MAJOR >= 7)
138
139 rocsparse_spmv_descr spmv_descr;
140 AMREX_ROCSPARSE_SAFE_CALL(rocsparse_create_spmv_descr(&spmv_descr));
141
142
143 rocsparse_error p_error[1] = {};
144
145 const rocsparse_spmv_alg spmv_alg = rocsparse_spmv_alg_csr_adaptive;
146 AMREX_ROCSPARSE_SAFE_CALL(
147 rocsparse_spmv_set_input(handle, spmv_descr, rocsparse_spmv_input_alg,
148 &spmv_alg, sizeof(spmv_alg), p_error));
149
150 const rocsparse_operation spmv_operation = rocsparse_operation_none;
151 AMREX_ROCSPARSE_SAFE_CALL(
152 rocsparse_spmv_set_input(handle, spmv_descr, rocsparse_spmv_input_operation,
153 &spmv_operation, sizeof(spmv_operation), p_error));
154
155 AMREX_ROCSPARSE_SAFE_CALL(
156 rocsparse_spmv_set_input(handle, spmv_descr, rocsparse_spmv_input_scalar_datatype,
157 &data_type, sizeof(data_type), p_error));
158
159 AMREX_ROCSPARSE_SAFE_CALL(
160 rocsparse_spmv_set_input(handle, spmv_descr, rocsparse_spmv_input_compute_datatype,
161 &data_type, sizeof(data_type), p_error));
162
163 std::size_t buffer_size = 0;
164 AMREX_ROCSPARSE_SAFE_CALL(
165 rocsparse_v2_spmv_buffer_size(handle, spmv_descr, mat_descr, x_descr, y_descr,
166 rocsparse_v2_spmv_stage_analysis, // buffer size for analysis
167 &buffer_size, p_error));
168
169 void* pbuffer = nullptr;
170 if (buffer_size > 0) {
171 pbuffer = (void*)The_Arena()->alloc(buffer_size);
172 }
173
174 AMREX_ROCSPARSE_SAFE_CALL(
175 rocsparse_v2_spmv(handle, spmv_descr, &alpha, mat_descr, x_descr, &beta, y_descr,
176 rocsparse_v2_spmv_stage_analysis, // analysis stage
177 buffer_size, pbuffer, p_error));
178
179 if (pbuffer) {
180 The_Arena()->free(pbuffer);
181 }
182
183 AMREX_ROCSPARSE_SAFE_CALL(
184 rocsparse_v2_spmv_buffer_size(handle, spmv_descr, mat_descr, x_descr, y_descr,
185 rocsparse_v2_spmv_stage_compute, // buffer size for compute
186 &buffer_size, p_error));
187
188 if (buffer_size > 0) {
189 pbuffer = (void*)The_Arena()->alloc(buffer_size);
190 } else {
191 pbuffer = nullptr;
192 }
193
194 AMREX_ROCSPARSE_SAFE_CALL(
195 rocsparse_v2_spmv(handle, spmv_descr, &alpha, mat_descr, x_descr, &beta, y_descr,
196 rocsparse_v2_spmv_stage_compute, // compute stage
197 buffer_size, pbuffer, p_error));
198
199#elif (HIP_VERSION_MAJOR == 6)
200
201 std::size_t buffer_size = 0;
202 AMREX_ROCSPARSE_SAFE_CALL(
203 rocsparse_spmv(handle, rocsparse_operation_none, &alpha, mat_descr, x_descr,
204 &beta, y_descr, data_type, rocsparse_spmv_alg_default,
205 rocsparse_spmv_stage_buffer_size, // buffer size stage
206 &buffer_size, nullptr));
207
208 void* pbuffer = nullptr;
209 if (buffer_size > 0) {
210 pbuffer = (void*)The_Arena()->alloc(buffer_size);
211 }
212
213 AMREX_ROCSPARSE_SAFE_CALL(
214 rocsparse_spmv(handle, rocsparse_operation_none, &alpha, mat_descr, x_descr,
215 &beta, y_descr, data_type, rocsparse_spmv_alg_default,
216 rocsparse_spmv_stage_preprocess, // preprocess stage
217 &buffer_size, pbuffer));
218
219 AMREX_ROCSPARSE_SAFE_CALL(
220 rocsparse_spmv(handle, rocsparse_operation_none, &alpha, mat_descr, x_descr,
221 &beta, y_descr, data_type, rocsparse_spmv_alg_default,
222 rocsparse_spmv_stage_compute, // compute stage
223 &buffer_size, pbuffer));
224
225#else /* HIP_VERSION_MAJOR < 6 */
226
227 std::size_t buffer_size = 0;
228 AMREX_ROCSPARSE_SAFE_CALL(
229 rocsparse_spmv(handle, rocsparse_operation_none, &alpha, mat_descr, x_descr,
230 &beta, y_descr, data_type, rocsparse_spmv_alg_default,
231 &buffer_size, nullptr));
232
233 void* pbuffer = nullptr;
234 if (buffer_size > 0) {
235 pbuffer = (void*)The_Arena()->alloc(buffer_size);
236 }
237
238 AMREX_ROCSPARSE_SAFE_CALL(
239 rocsparse_spmv(handle, rocsparse_operation_none, &alpha, mat_descr, x_descr,
240 &beta, y_descr, data_type, rocsparse_spmv_alg_default,
241 &buffer_size, pbuffer));
242
243#endif /* HIP_VERSION_MAJOR */
244
246
247#if (HIP_VERSION_MAJOR >= 7)
248 AMREX_ROCSPARSE_SAFE_CALL(rocsparse_destroy_error(p_error[0]));
249 AMREX_ROCSPARSE_SAFE_CALL(rocsparse_destroy_spmv_descr(spmv_descr));
250#endif
251 AMREX_ROCSPARSE_SAFE_CALL(rocsparse_destroy_spmat_descr(mat_descr));
252 AMREX_ROCSPARSE_SAFE_CALL(rocsparse_destroy_dnvec_descr(x_descr));
253 AMREX_ROCSPARSE_SAFE_CALL(rocsparse_destroy_dnvec_descr(y_descr));
254 AMREX_ROCSPARSE_SAFE_CALL(rocsparse_destroy_handle(handle));
255 if (pbuffer) {
256 The_Arena()->free(pbuffer);
257 }
258
259#elif defined(AMREX_USE_SYCL)
260
261 mkl::sparse::matrix_handle_t handle{};
262 mkl::sparse::init_matrix_handle(&handle);
263
264#if defined(INTEL_MKL_VERSION) && (INTEL_MKL_VERSION < 20250300)
266 mkl::sparse::set_csr_data(Gpu::Device::streamQueue(), handle, nrows, ncols,
267 mkl::index_base::zero, (Long*)row, (Long*)col, (T*)mat);
268#else
269 mkl::sparse::set_csr_data(Gpu::Device::streamQueue(), handle, nrows, ncols, nnz,
270 mkl::index_base::zero, (Long*)row, (Long*)col, (T*)mat);
271#endif
272 mkl::sparse::gemv(Gpu::Device::streamQueue(), mkl::transpose::nontrans,
273 T(1), handle, px, T(0), py);
274
275 auto ev = mkl::sparse::release_matrix_handle(Gpu::Device::streamQueue(), &handle);
276 ev.wait();
277
278#endif
279
281
282#else
283
285
286#ifdef AMREX_USE_OMP
287#pragma omp parallel for
288#endif
289 for (Long i = 0; i < nrows; ++i) {
290 T r = 0;
291 for (Long j = row[i]; j < row[i+1]; ++j) {
292 r += mat[j] * px[col[j]];
293 }
294 py[i] = r;
295 }
296
297#endif
298}
299
307template <typename T, template<typename> class AllocM, typename AllocV>
309 AlgVector<T,AllocV> const& x)
310{
311 // xxxxx TODO: Is it worth to cache cusparse and rocsparse handles?
312
313 const_cast<SpMatrix<T,AllocM>&>(A).startComm_mv(x);
314
315 // Diagonal part
316 SpMV<T>(y.numLocalRows(), x.numLocalRows(), y.data(), A.m_csr.const_view(),
317 x.data());
318
319 const_cast<SpMatrix<T,AllocM>&>(A).finishComm_mv(y);
320}
321
330template <typename T, template<typename> class AllocM, typename AllocV>
333{
334 SpMV(res, A, x);
335 Xpay(res, T(-1), b);
336}
337
338}
339
340#endif
#define AMREX_RESTRICT
Definition AMReX_Extension.H:32
#define AMREX_CUSPARSE_SAFE_CALL(call)
Definition AMReX_GpuError.H:101
#define AMREX_GPU_ERROR_CHECK()
Definition AMReX_GpuError.H:151
Distributed dense vector that mirrors the layout of an AlgPartition.
Definition AMReX_AlgVector.H:29
virtual void free(void *pt)=0
A pure virtual function for deleting the arena pointed to by pt.
virtual void * alloc(std::size_t sz)=0
Distributed CSR matrix that manages storage and GPU-friendly partitions.
Definition AMReX_SpMatrix.H:61
amrex_long Long
Definition AMReX_INT.H:30
Arena * The_Arena()
Definition AMReX_Arena.cpp:805
void streamSynchronize() noexcept
Definition AMReX_GpuDevice.H:310
gpuStream_t gpuStream() noexcept
Definition AMReX_GpuDevice.H:291
Definition AMReX_Amr.cpp:49
__host__ __device__ void ignore_unused(const Ts &...)
This shuts up the compiler about unused variables.
Definition AMReX.H:139
void computeResidual(AlgVector< T, AllocV > &res, SpMatrix< T, AllocM > const &A, AlgVector< T, AllocV > const &x, AlgVector< T, AllocV > const &b)
Compute the residual res = b - A * x.
Definition AMReX_SpMV.H:331
void Xpay(MF &dst, typename MF::value_type a, MF const &src, int scomp, int dcomp, int ncomp, IntVect const &nghost)
dst = src + a * dst
Definition AMReX_FabArrayUtility.H:1974
void SpMV(Long nrows, Long ncols, T *__restrict__ py, CsrView< T const > const &A, T const *__restrict__ px)
Perform y = A * x using CSR data (GPU/CPU aware).
Definition AMReX_SpMV.H:28
void Abort(const std::string &msg)
Print out message to cerr and exit via abort().
Definition AMReX.cpp:240
CsrView< T const > const_view() const
Convenience alias for view() const.
Definition AMReX_CSR.H:89
Lightweight non-owning CSR view that can point to host or device buffers.
Definition AMReX_CSR.H:33
T *__restrict__ mat
Definition AMReX_CSR.H:35
Long nnz
Definition AMReX_CSR.H:38
U *__restrict__ row_offset
Definition AMReX_CSR.H:37
U *__restrict__ col_index
Definition AMReX_CSR.H:36