35#if defined(AMREX_USE_GPU)
39#if defined(AMREX_USE_CUDA)
41 cusparseHandle_t handle;
45 cudaDataType data_type;
46 if constexpr (std::is_same_v<T,float>) {
47 data_type = CUDA_R_32F;
48 }
else if constexpr (std::is_same_v<T,double>) {
49 data_type = CUDA_R_64F;
50 }
else if constexpr (std::is_same_v<T,GpuComplex<float>>) {
51 data_type = CUDA_C_32F;
52 }
else if constexpr (std::is_same_v<T,GpuComplex<double>>) {
53 data_type = CUDA_C_64F;
58 cusparseIndexType_t index_type = CUSPARSE_INDEX_64I;
60 cusparseSpMatDescr_t mat_descr;
62 (cusparseCreateCsr(&mat_descr, nrows, ncols, nnz,
63 (
void*)row, (
void*)col, (
void*)mat,
64 index_type, index_type, CUSPARSE_INDEX_BASE_ZERO,
67 cusparseDnVecDescr_t x_descr;
70 cusparseDnVecDescr_t y_descr;
76 std::size_t buffer_size;
78 (cusparseSpMV_bufferSize(handle, CUSPARSE_OPERATION_NON_TRANSPOSE,
79 &alpha, mat_descr, x_descr, &beta, y_descr,
80 data_type, CUSPARSE_SPMV_ALG_DEFAULT,
86 (cusparseSpMV(handle, CUSPARSE_OPERATION_NON_TRANSPOSE,
87 &alpha, mat_descr, x_descr, &beta, y_descr,
88 data_type, CUSPARSE_SPMV_ALG_DEFAULT, pbuffer));
98#elif defined(AMREX_USE_HIP)
100 rocsparse_handle handle;
101 AMREX_ROCSPARSE_SAFE_CALL(rocsparse_create_handle(&handle));
102 AMREX_ROCSPARSE_SAFE_CALL(rocsparse_set_stream(handle,
Gpu::gpuStream()));
104 rocsparse_datatype data_type;
105 if constexpr (std::is_same_v<T,float>) {
106 data_type = rocsparse_datatype_f32_r;
107 }
else if constexpr (std::is_same_v<T,double>) {
108 data_type = rocsparse_datatype_f64_r;
109 }
else if constexpr (std::is_same_v<T,GpuComplex<float>>) {
110 data_type = rocsparse_datatype_f32_c;
111 }
else if constexpr (std::is_same_v<T,GpuComplex<double>>) {
112 data_type = rocsparse_datatype_f64_c;
117 rocsparse_indextype index_type = rocsparse_indextype_i64;
119 rocsparse_spmat_descr mat_descr;
120 AMREX_ROCSPARSE_SAFE_CALL(
121 rocsparse_create_csr_descr(&mat_descr, nrows, ncols, nnz,
122 (
void*)row, (
void*)col, (
void*)mat,
123 index_type, index_type,
124 rocsparse_index_base_zero, data_type));
126 rocsparse_dnvec_descr x_descr;
127 AMREX_ROCSPARSE_SAFE_CALL(
128 rocsparse_create_dnvec_descr(&x_descr, ncols, (
void*)px, data_type));
130 rocsparse_dnvec_descr y_descr;
131 AMREX_ROCSPARSE_SAFE_CALL(
132 rocsparse_create_dnvec_descr(&y_descr, nrows, (
void*)py, data_type));
137#if (HIP_VERSION_MAJOR >= 7)
139 rocsparse_spmv_descr spmv_descr;
140 AMREX_ROCSPARSE_SAFE_CALL(rocsparse_create_spmv_descr(&spmv_descr));
143 rocsparse_error p_error[1] = {};
145 const rocsparse_spmv_alg spmv_alg = rocsparse_spmv_alg_csr_adaptive;
146 AMREX_ROCSPARSE_SAFE_CALL(
147 rocsparse_spmv_set_input(handle, spmv_descr, rocsparse_spmv_input_alg,
148 &spmv_alg,
sizeof(spmv_alg), p_error));
150 const rocsparse_operation spmv_operation = rocsparse_operation_none;
151 AMREX_ROCSPARSE_SAFE_CALL(
152 rocsparse_spmv_set_input(handle, spmv_descr, rocsparse_spmv_input_operation,
153 &spmv_operation,
sizeof(spmv_operation), p_error));
155 AMREX_ROCSPARSE_SAFE_CALL(
156 rocsparse_spmv_set_input(handle, spmv_descr, rocsparse_spmv_input_scalar_datatype,
157 &data_type,
sizeof(data_type), p_error));
159 AMREX_ROCSPARSE_SAFE_CALL(
160 rocsparse_spmv_set_input(handle, spmv_descr, rocsparse_spmv_input_compute_datatype,
161 &data_type,
sizeof(data_type), p_error));
163 std::size_t buffer_size = 0;
164 AMREX_ROCSPARSE_SAFE_CALL(
165 rocsparse_v2_spmv_buffer_size(handle, spmv_descr, mat_descr, x_descr, y_descr,
166 rocsparse_v2_spmv_stage_analysis,
167 &buffer_size, p_error));
169 void* pbuffer =
nullptr;
170 if (buffer_size > 0) {
174 AMREX_ROCSPARSE_SAFE_CALL(
175 rocsparse_v2_spmv(handle, spmv_descr, &alpha, mat_descr, x_descr, &beta, y_descr,
176 rocsparse_v2_spmv_stage_analysis,
177 buffer_size, pbuffer, p_error));
183 AMREX_ROCSPARSE_SAFE_CALL(
184 rocsparse_v2_spmv_buffer_size(handle, spmv_descr, mat_descr, x_descr, y_descr,
185 rocsparse_v2_spmv_stage_compute,
186 &buffer_size, p_error));
188 if (buffer_size > 0) {
194 AMREX_ROCSPARSE_SAFE_CALL(
195 rocsparse_v2_spmv(handle, spmv_descr, &alpha, mat_descr, x_descr, &beta, y_descr,
196 rocsparse_v2_spmv_stage_compute,
197 buffer_size, pbuffer, p_error));
199#elif (HIP_VERSION_MAJOR == 6)
201 std::size_t buffer_size = 0;
202 AMREX_ROCSPARSE_SAFE_CALL(
203 rocsparse_spmv(handle, rocsparse_operation_none, &alpha, mat_descr, x_descr,
204 &beta, y_descr, data_type, rocsparse_spmv_alg_default,
205 rocsparse_spmv_stage_buffer_size,
206 &buffer_size,
nullptr));
208 void* pbuffer =
nullptr;
209 if (buffer_size > 0) {
213 AMREX_ROCSPARSE_SAFE_CALL(
214 rocsparse_spmv(handle, rocsparse_operation_none, &alpha, mat_descr, x_descr,
215 &beta, y_descr, data_type, rocsparse_spmv_alg_default,
216 rocsparse_spmv_stage_preprocess,
217 &buffer_size, pbuffer));
219 AMREX_ROCSPARSE_SAFE_CALL(
220 rocsparse_spmv(handle, rocsparse_operation_none, &alpha, mat_descr, x_descr,
221 &beta, y_descr, data_type, rocsparse_spmv_alg_default,
222 rocsparse_spmv_stage_compute,
223 &buffer_size, pbuffer));
227 std::size_t buffer_size = 0;
228 AMREX_ROCSPARSE_SAFE_CALL(
229 rocsparse_spmv(handle, rocsparse_operation_none, &alpha, mat_descr, x_descr,
230 &beta, y_descr, data_type, rocsparse_spmv_alg_default,
231 &buffer_size,
nullptr));
233 void* pbuffer =
nullptr;
234 if (buffer_size > 0) {
238 AMREX_ROCSPARSE_SAFE_CALL(
239 rocsparse_spmv(handle, rocsparse_operation_none, &alpha, mat_descr, x_descr,
240 &beta, y_descr, data_type, rocsparse_spmv_alg_default,
241 &buffer_size, pbuffer));
247#if (HIP_VERSION_MAJOR >= 7)
248 AMREX_ROCSPARSE_SAFE_CALL(rocsparse_destroy_error(p_error[0]));
249 AMREX_ROCSPARSE_SAFE_CALL(rocsparse_destroy_spmv_descr(spmv_descr));
251 AMREX_ROCSPARSE_SAFE_CALL(rocsparse_destroy_spmat_descr(mat_descr));
252 AMREX_ROCSPARSE_SAFE_CALL(rocsparse_destroy_dnvec_descr(x_descr));
253 AMREX_ROCSPARSE_SAFE_CALL(rocsparse_destroy_dnvec_descr(y_descr));
254 AMREX_ROCSPARSE_SAFE_CALL(rocsparse_destroy_handle(handle));
259#elif defined(AMREX_USE_SYCL)
261 mkl::sparse::matrix_handle_t handle{};
262 mkl::sparse::init_matrix_handle(&handle);
264#if defined(INTEL_MKL_VERSION) && (INTEL_MKL_VERSION < 20250300)
266 mkl::sparse::set_csr_data(Gpu::Device::streamQueue(), handle, nrows, ncols,
267 mkl::index_base::zero, (
Long*)row, (
Long*)col, (T*)mat);
269 mkl::sparse::set_csr_data(Gpu::Device::streamQueue(), handle, nrows, ncols, nnz,
270 mkl::index_base::zero, (
Long*)row, (
Long*)col, (T*)mat);
272 mkl::sparse::gemv(Gpu::Device::streamQueue(), mkl::transpose::nontrans,
273 T(1), handle, px, T(0), py);
275 auto ev = mkl::sparse::release_matrix_handle(Gpu::Device::streamQueue(), &handle);
287#pragma omp parallel for
289 for (
Long i = 0; i < nrows; ++i) {
291 for (
Long j = row[i]; j < row[i+1]; ++j) {
292 r += mat[j] * px[col[j]];