3 #include <AMReX_Config.H>
9 #if defined(AMREX_USE_CUDA)
10 # include <cusparse.h>
11 #elif defined(AMREX_USE_HIP)
12 # include <rocsparse/rocsparse.h>
13 #elif defined(AMREX_USE_DPCPP)
14 # include <oneapi/mkl/spblas.hpp>
36 #if defined(AMREX_USE_GPU)
39 Long
const ncols =
x.numLocalRows();
42 #if defined(AMREX_USE_CUDA)
44 cusparseHandle_t handle;
45 cusparseCreate(&handle);
48 cudaDataType data_type;
49 if constexpr (std::is_same_v<T,float>) {
50 data_type = CUDA_R_32F;
51 }
else if constexpr (std::is_same_v<T,double>) {
52 data_type = CUDA_R_64F;
54 data_type = CUDA_C_32F;
56 data_type = CUDA_C_64F;
61 cusparseIndexType_t index_type = CUSPARSE_INDEX_64I;
63 cusparseSpMatDescr_t mat_descr;
64 cusparseCreateCsr(&mat_descr, nrows, ncols, nnz, (
void*)row, (
void*)col, (
void*)mat,
65 index_type, index_type, CUSPARSE_INDEX_BASE_ZERO, data_type);
67 cusparseDnVecDescr_t x_descr;
68 cusparseCreateDnVec(&x_descr, ncols, (
void*)px, data_type);
70 cusparseDnVecDescr_t y_descr;
71 cusparseCreateDnVec(&y_descr, nrows, (
void*)py, data_type);
76 std::size_t buffer_size;
77 cusparseSpMV_bufferSize(handle, CUSPARSE_OPERATION_NON_TRANSPOSE, &alpha, mat_descr, x_descr,
78 &beta, y_descr, data_type, CUSPARSE_SPMV_ALG_DEFAULT, &buffer_size);
82 cusparseSpMV(handle, CUSPARSE_OPERATION_NON_TRANSPOSE, &alpha, mat_descr, x_descr,
83 &beta, y_descr, data_type, CUSPARSE_SPMV_ALG_DEFAULT, pbuffer);
87 cusparseDestroySpMat(mat_descr);
88 cusparseDestroyDnVec(x_descr);
89 cusparseDestroyDnVec(y_descr);
90 cusparseDestroy(handle);
93 #elif defined(AMREX_USE_HIP)
95 rocsparse_handle handle;
96 rocsparse_create_handle(&handle);
99 rocsparse_datatype data_type;
100 if constexpr (std::is_same_v<T,float>) {
101 data_type = rocsparse_datatype_f32_r;
102 }
else if constexpr (std::is_same_v<T,double>) {
103 data_type = rocsparse_datatype_f64_r;
105 data_type = rocsparse_datatype_f32_c;
107 data_type = rocsparse_datatype_f64_c;
112 rocsparse_indextype index_type = rocsparse_indextype_i64;
114 rocsparse_spmat_descr mat_descr;
115 rocsparse_create_csr_descr(&mat_descr, nrows, ncols, nnz, (
void*)row, (
void*)col,
116 (
void*)mat, index_type, index_type,
117 rocsparse_index_base_zero, data_type);
119 rocsparse_dnvec_descr x_descr;
120 rocsparse_create_dnvec_descr(&x_descr, ncols, (
void*)px, data_type);
122 rocsparse_dnvec_descr y_descr;
123 rocsparse_create_dnvec_descr(&y_descr, nrows, (
void*)py, data_type);
128 std::size_t buffer_size;
129 rocsparse_spmv(handle, rocsparse_operation_none, &alpha, mat_descr, x_descr,
130 &beta, y_descr, data_type, rocsparse_spmv_alg_default,
131 #
if (HIP_VERSION_MAJOR >= 6)
132 rocsparse_spmv_stage_buffer_size,
134 &buffer_size,
nullptr);
138 #if (HIP_VERSION_MAJOR >= 6)
139 rocsparse_spmv(handle, rocsparse_operation_none, &alpha, mat_descr, x_descr,
140 &beta, y_descr, data_type, rocsparse_spmv_alg_default,
141 rocsparse_spmv_stage_preprocess, &buffer_size, pbuffer);
144 rocsparse_spmv(handle, rocsparse_operation_none, &alpha, mat_descr, x_descr,
145 &beta, y_descr, data_type, rocsparse_spmv_alg_default,
146 #
if (HIP_VERSION_MAJOR >= 6)
147 rocsparse_spmv_stage_compute,
149 &buffer_size, pbuffer);
153 rocsparse_destroy_spmat_descr(mat_descr);
154 rocsparse_destroy_dnvec_descr(x_descr);
155 rocsparse_destroy_dnvec_descr(y_descr);
156 rocsparse_destroy_handle(handle);
159 #elif defined(AMREX_USE_DPCPP)
161 mkl::sparse::matrix_handle_t handle{};
163 (Long*)row, (Long*)col, (T*)mat);
164 mkl::sparse::gemv(Gpu::Device::streamQueue(), mkl::transpose::nontrans,
165 T(1), handle, px, T(0), py);
174 for (Long i = 0; i < ny; ++i) {
177 #pragma omp parallel for reduction(+:r)
179 for (Long j = row[i]; j < row[i+1]; ++j) {
180 r += mat[j] * px[col[j]];
#define AMREX_ALWAYS_ASSERT(EX)
Definition: AMReX_BLassert.H:50
#define AMREX_RESTRICT
Definition: AMReX_Extension.H:37
#define AMREX_GPU_ERROR_CHECK()
Definition: AMReX_GpuError.H:125
Definition: AMReX_AlgVector.H:19
Long numLocalRows() const
Definition: AMReX_AlgVector.H:45
AlgPartition const & partition() const
Definition: AMReX_AlgVector.H:43
T const * data() const
Definition: AMReX_AlgVector.H:53
virtual void free(void *pt)=0
A pure virtual function for deleting the arena pointed to by pt.
virtual void * alloc(std::size_t sz)=0
Definition: AMReX_SpMatrix.H:19
Long const * columnIndex() const
Definition: AMReX_SpMatrix.H:50
AlgPartition const & partition() const
Definition: AMReX_SpMatrix.H:37
T const * data() const
Definition: AMReX_SpMatrix.H:48
Long numLocalRows() const
Definition: AMReX_SpMatrix.H:39
Long const * rowOffset() const
Definition: AMReX_SpMatrix.H:52
Long numLocalNonZero() const
Definition: AMReX_SpMatrix.H:41
void streamSynchronize() noexcept
Definition: AMReX_GpuDevice.H:237
gpuStream_t gpuStream() noexcept
Definition: AMReX_GpuDevice.H:218
Definition: AMReX_Amr.cpp:49
void SpMV(AlgVector< T > &y, SpMatrix< T > const &A, AlgVector< T > const &x)
Definition: AMReX_SpMV.H:20
void Abort(const std::string &msg)
Print out message to cerr and exit via abort().
Definition: AMReX.cpp:225
Arena * The_Arena()
Definition: AMReX_Arena.cpp:609
A host / device complex number type, because std::complex doesn't work in device code with Cuda yet.
Definition: AMReX_GpuComplex.H:29