Block-Structured AMR Software Framework
Loading...
Searching...
No Matches
AMReX_SpMV.H
Go to the documentation of this file.
1#ifndef AMREX_SPMV_H_
2#define AMREX_SPMV_H_
3#include <AMReX_Config.H>
4
5#include <AMReX_AlgVector.H>
6#include <AMReX_AlgVecUtil.H>
7#include <AMReX_GpuComplex.H>
8#include <AMReX_SpMatrix.H>
9
10namespace amrex {
11
12template <typename T>
13void SpMV (Long nrows, Long ncols, T* AMREX_RESTRICT py, CsrView<T const> const& A,
14 T const* AMREX_RESTRICT px)
15{
16 T const* AMREX_RESTRICT mat = A.mat;
17 auto const* AMREX_RESTRICT col = A.col_index;
18 auto const* AMREX_RESTRICT row = A.row_offset;
19
20#if defined(AMREX_USE_GPU)
21
22 Long const nnz = A.nnz;
23
24#if defined(AMREX_USE_CUDA)
25
26 cusparseHandle_t handle;
27 AMREX_CUSPARSE_SAFE_CALL(cusparseCreate(&handle));
28 AMREX_CUSPARSE_SAFE_CALL(cusparseSetStream(handle, Gpu::gpuStream()));
29
30 cudaDataType data_type;
31 if constexpr (std::is_same_v<T,float>) {
32 data_type = CUDA_R_32F;
33 } else if constexpr (std::is_same_v<T,double>) {
34 data_type = CUDA_R_64F;
35 } else if constexpr (std::is_same_v<T,GpuComplex<float>>) {
36 data_type = CUDA_C_32F;
37 } else if constexpr (std::is_same_v<T,GpuComplex<double>>) {
38 data_type = CUDA_C_64F;
39 } else {
40 amrex::Abort("SpMV: unsupported data type");
41 }
42
43 cusparseIndexType_t index_type = CUSPARSE_INDEX_64I;
44
45 cusparseSpMatDescr_t mat_descr;
47 (cusparseCreateCsr(&mat_descr, nrows, ncols, nnz,
48 (void*)row, (void*)col, (void*)mat,
49 index_type, index_type, CUSPARSE_INDEX_BASE_ZERO,
50 data_type));
51
52 cusparseDnVecDescr_t x_descr;
53 AMREX_CUSPARSE_SAFE_CALL(cusparseCreateDnVec(&x_descr, ncols, (void*)px, data_type));
54
55 cusparseDnVecDescr_t y_descr;
56 AMREX_CUSPARSE_SAFE_CALL(cusparseCreateDnVec(&y_descr, nrows, (void*)py, data_type));
57
58 T alpha = T(1);
59 T beta = T(0);
60
61 std::size_t buffer_size;
63 (cusparseSpMV_bufferSize(handle, CUSPARSE_OPERATION_NON_TRANSPOSE,
64 &alpha, mat_descr, x_descr, &beta, y_descr,
65 data_type, CUSPARSE_SPMV_ALG_DEFAULT,
66 &buffer_size));
67
68 auto* pbuffer = (void*)The_Arena()->alloc(buffer_size);
69
71 (cusparseSpMV(handle, CUSPARSE_OPERATION_NON_TRANSPOSE,
72 &alpha, mat_descr, x_descr, &beta, y_descr,
73 data_type, CUSPARSE_SPMV_ALG_DEFAULT, pbuffer));
74
76
77 AMREX_CUSPARSE_SAFE_CALL(cusparseDestroySpMat(mat_descr));
78 AMREX_CUSPARSE_SAFE_CALL(cusparseDestroyDnVec(x_descr));
79 AMREX_CUSPARSE_SAFE_CALL(cusparseDestroyDnVec(y_descr));
80 AMREX_CUSPARSE_SAFE_CALL(cusparseDestroy(handle));
81 The_Arena()->free(pbuffer);
82
83#elif defined(AMREX_USE_HIP)
84
85 rocsparse_handle handle;
86 AMREX_ROCSPARSE_SAFE_CALL(rocsparse_create_handle(&handle));
87 AMREX_ROCSPARSE_SAFE_CALL(rocsparse_set_stream(handle, Gpu::gpuStream()));
88
89 rocsparse_datatype data_type;
90 if constexpr (std::is_same_v<T,float>) {
91 data_type = rocsparse_datatype_f32_r;
92 } else if constexpr (std::is_same_v<T,double>) {
93 data_type = rocsparse_datatype_f64_r;
94 } else if constexpr (std::is_same_v<T,GpuComplex<float>>) {
95 data_type = rocsparse_datatype_f32_c;
96 } else if constexpr (std::is_same_v<T,GpuComplex<double>>) {
97 data_type = rocsparse_datatype_f64_c;
98 } else {
99 amrex::Abort("SpMV: unsupported data type");
100 }
101
102 rocsparse_indextype index_type = rocsparse_indextype_i64;
103
104 rocsparse_spmat_descr mat_descr;
105 AMREX_ROCSPARSE_SAFE_CALL(
106 rocsparse_create_csr_descr(&mat_descr, nrows, ncols, nnz,
107 (void*)row, (void*)col, (void*)mat,
108 index_type, index_type,
109 rocsparse_index_base_zero, data_type));
110
111 rocsparse_dnvec_descr x_descr;
112 AMREX_ROCSPARSE_SAFE_CALL(
113 rocsparse_create_dnvec_descr(&x_descr, ncols, (void*)px, data_type));
114
115 rocsparse_dnvec_descr y_descr;
116 AMREX_ROCSPARSE_SAFE_CALL(
117 rocsparse_create_dnvec_descr(&y_descr, nrows, (void*)py, data_type));
118
119 T alpha = T(1.0);
120 T beta = T(0.0);
121
122#if (HIP_VERSION_MAJOR >= 7)
123#pragma clang diagnostic push
124#pragma clang diagnostic ignored "-Wdeprecated-declarations"
125#endif
126
127 std::size_t buffer_size;
128 auto err0 =
129 rocsparse_spmv(handle, rocsparse_operation_none, &alpha, mat_descr, x_descr,
130 &beta, y_descr, data_type, rocsparse_spmv_alg_default,
131#if (HIP_VERSION_MAJOR >= 6)
132 rocsparse_spmv_stage_buffer_size,
133#endif
134 &buffer_size, nullptr);
135 AMREX_ROCSPARSE_SAFE_CALL(err0);
136
137 void* pbuffer = (void*)The_Arena()->alloc(buffer_size);
138
139#if (HIP_VERSION_MAJOR >= 6)
140 AMREX_ROCSPARSE_SAFE_CALL(
141 rocsparse_spmv(handle, rocsparse_operation_none, &alpha, mat_descr, x_descr,
142 &beta, y_descr, data_type, rocsparse_spmv_alg_default,
143 rocsparse_spmv_stage_preprocess, &buffer_size, pbuffer));
144#endif
145
146 auto err1 =
147 rocsparse_spmv(handle, rocsparse_operation_none, &alpha, mat_descr, x_descr,
148 &beta, y_descr, data_type, rocsparse_spmv_alg_default,
149#if (HIP_VERSION_MAJOR >= 6)
150 rocsparse_spmv_stage_compute,
151#endif
152 &buffer_size, pbuffer);
153 AMREX_ROCSPARSE_SAFE_CALL(err1);
154
155#if (HIP_VERSION_MAJOR >= 7)
156// xxxxx HIP TODO: rocsparse_spmv has been deprecated.
157#pragma clang diagnostic pop
158#endif
159
161
162 AMREX_ROCSPARSE_SAFE_CALL(rocsparse_destroy_spmat_descr(mat_descr));
163 AMREX_ROCSPARSE_SAFE_CALL(rocsparse_destroy_dnvec_descr(x_descr));
164 AMREX_ROCSPARSE_SAFE_CALL(rocsparse_destroy_dnvec_descr(y_descr));
165 AMREX_ROCSPARSE_SAFE_CALL(rocsparse_destroy_handle(handle));
166 The_Arena()->free(pbuffer);
167
168#elif defined(AMREX_USE_SYCL)
169
170 mkl::sparse::matrix_handle_t handle{};
171 mkl::sparse::init_matrix_handle(&handle);
172
173#if defined(INTEL_MKL_VERSION) && (INTEL_MKL_VERSION < 20250300)
175 mkl::sparse::set_csr_data(Gpu::Device::streamQueue(), handle, nrows, ncols,
176 mkl::index_base::zero, (Long*)row, (Long*)col, (T*)mat);
177#else
178 mkl::sparse::set_csr_data(Gpu::Device::streamQueue(), handle, nrows, ncols, nnz,
179 mkl::index_base::zero, (Long*)row, (Long*)col, (T*)mat);
180#endif
181 mkl::sparse::gemv(Gpu::Device::streamQueue(), mkl::transpose::nontrans,
182 T(1), handle, px, T(0), py);
183
184 auto ev = mkl::sparse::release_matrix_handle(Gpu::Device::streamQueue(), &handle);
185 ev.wait();
186
187#endif
188
190
191#else
192
194
195#ifdef AMREX_USE_OMP
196#pragma omp parallel for
197#endif
198 for (Long i = 0; i < nrows; ++i) {
199 T r = 0;
200 for (Long j = row[i]; j < row[i+1]; ++j) {
201 r += mat[j] * px[col[j]];
202 }
203 py[i] = r;
204 }
205
206#endif
207}
208
209template <typename T, template<typename> class AllocM, typename AllocV>
211 AlgVector<T,AllocV> const& x)
212{
213 // xxxxx TODO: Is it worth to cache cusparse and rocsparse handles?
214
215 const_cast<SpMatrix<T,AllocM>&>(A).startComm_mv(x);
216
217 // Diagonal part
218 SpMV<T>(y.numLocalRows(), x.numLocalRows(), y.data(), A.m_csr.const_view(),
219 x.data());
220
221 const_cast<SpMatrix<T,AllocM>&>(A).finishComm_mv(y);
222}
223
225template <typename T, template<typename> class AllocM, typename AllocV>
228{
229 SpMV(res, A, x);
230 Xpay(res, T(-1), b);
231}
232
233}
234
235#endif
#define AMREX_RESTRICT
Definition AMReX_Extension.H:32
#define AMREX_CUSPARSE_SAFE_CALL(call)
Definition AMReX_GpuError.H:101
#define AMREX_GPU_ERROR_CHECK()
Definition AMReX_GpuError.H:151
Definition AMReX_AlgVector.H:20
virtual void free(void *pt)=0
A pure virtual function for deleting the arena pointed to by pt.
virtual void * alloc(std::size_t sz)=0
Definition AMReX_SpMatrix.H:52
amrex_long Long
Definition AMReX_INT.H:30
Arena * The_Arena()
Definition AMReX_Arena.cpp:783
void streamSynchronize() noexcept
Definition AMReX_GpuDevice.H:263
gpuStream_t gpuStream() noexcept
Definition AMReX_GpuDevice.H:244
Definition AMReX_Amr.cpp:49
__host__ __device__ void ignore_unused(const Ts &...)
This shuts up the compiler about unused variables.
Definition AMReX.H:138
void computeResidual(AlgVector< T, AllocV > &res, SpMatrix< T, AllocM > const &A, AlgVector< T, AllocV > const &x, AlgVector< T, AllocV > const &b)
res = b - A*x
Definition AMReX_SpMV.H:226
void Xpay(MF &dst, typename MF::value_type a, MF const &src, int scomp, int dcomp, int ncomp, IntVect const &nghost)
dst = src + a * dst
Definition AMReX_FabArrayUtility.H:1974
void SpMV(Long nrows, Long ncols, T *__restrict__ py, CsrView< T const > const &A, T const *__restrict__ px)
Definition AMReX_SpMV.H:13
void Abort(const std::string &msg)
Print out message to cerr and exit via abort().
Definition AMReX.cpp:230
CsrView< T const > const_view() const
Definition AMReX_CSR.H:57
Definition AMReX_CSR.H:20
T *__restrict__ mat
Definition AMReX_CSR.H:22
Long nnz
Definition AMReX_CSR.H:25
U *__restrict__ row_offset
Definition AMReX_CSR.H:24
U *__restrict__ col_index
Definition AMReX_CSR.H:23