Block-Structured AMR Software Framework
AMReX_SpMV.H
Go to the documentation of this file.
1 #ifndef AMREX_SPMV_H_
2 #define AMREX_SPMV_H_
3 #include <AMReX_Config.H>
4 
5 #include <AMReX_AlgVector.H>
6 #include <AMReX_GpuComplex.H>
7 #include <AMReX_SpMatrix.H>
8 
9 #if defined(AMREX_USE_CUDA)
10 # include <cusparse.h>
11 #elif defined(AMREX_USE_HIP)
12 # include <rocsparse/rocsparse.h>
13 #elif defined(AMREX_USE_DPCPP)
14 # include <oneapi/mkl/spblas.hpp>
15 #endif
16 
17 namespace amrex {
18 
19 template <typename T>
20 void SpMV (AlgVector<T>& y, SpMatrix<T> const& A, AlgVector<T> const& x)
21 {
22  // xxxxx TODOL We might want to cache the cusparse and rocsparse handles
23 
24  // xxxxx TODO: let's assume it's square matrix for now.
25  AMREX_ALWAYS_ASSERT(x.partition() == y.partition() &&
26  x.partition() == A.partition());
27 
28  const_cast<SpMatrix<T>&>(A).startComm(x);
29 
30  T * AMREX_RESTRICT py = y.data();
31  T const* AMREX_RESTRICT px = x.data();
32  T const* AMREX_RESTRICT mat = A.data();
33  auto const* AMREX_RESTRICT col = A.columnIndex();
34  auto const* AMREX_RESTRICT row = A.rowOffset();
35 
36 #if defined(AMREX_USE_GPU)
37 
38  Long const nrows = A.numLocalRows();
39  Long const ncols = x.numLocalRows();
40  Long const nnz = A.numLocalNonZero();
41 
42 #if defined(AMREX_USE_CUDA)
43 
44  cusparseHandle_t handle;
45  cusparseCreate(&handle);
46  cusparseSetStream(handle, Gpu::gpuStream());
47 
48  cudaDataType data_type;
49  if constexpr (std::is_same_v<T,float>) {
50  data_type = CUDA_R_32F;
51  } else if constexpr (std::is_same_v<T,double>) {
52  data_type = CUDA_R_64F;
53  } else if constexpr (std::is_same_v<T,GpuComplex<float>>) {
54  data_type = CUDA_C_32F;
55  } else if constexpr (std::is_same_v<T,GpuComplex<double>>) {
56  data_type = CUDA_C_64F;
57  } else {
58  amrex::Abort("SpMV: unsupported data type");
59  }
60 
61  cusparseIndexType_t index_type = CUSPARSE_INDEX_64I;
62 
63  cusparseSpMatDescr_t mat_descr;
64  cusparseCreateCsr(&mat_descr, nrows, ncols, nnz, (void*)row, (void*)col, (void*)mat,
65  index_type, index_type, CUSPARSE_INDEX_BASE_ZERO, data_type);
66 
67  cusparseDnVecDescr_t x_descr;
68  cusparseCreateDnVec(&x_descr, ncols, (void*)px, data_type);
69 
70  cusparseDnVecDescr_t y_descr;
71  cusparseCreateDnVec(&y_descr, nrows, (void*)py, data_type);
72 
73  T alpha = T(1);
74  T beta = T(0);
75 
76  std::size_t buffer_size;
77  cusparseSpMV_bufferSize(handle, CUSPARSE_OPERATION_NON_TRANSPOSE, &alpha, mat_descr, x_descr,
78  &beta, y_descr, data_type, CUSPARSE_SPMV_ALG_DEFAULT, &buffer_size);
79 
80  auto* pbuffer = (void*)The_Arena()->alloc(buffer_size);
81 
82  cusparseSpMV(handle, CUSPARSE_OPERATION_NON_TRANSPOSE, &alpha, mat_descr, x_descr,
83  &beta, y_descr, data_type, CUSPARSE_SPMV_ALG_DEFAULT, pbuffer);
84 
86 
87  cusparseDestroySpMat(mat_descr);
88  cusparseDestroyDnVec(x_descr);
89  cusparseDestroyDnVec(y_descr);
90  cusparseDestroy(handle);
91  The_Arena()->free(pbuffer);
92 
93 #elif defined(AMREX_USE_HIP)
94 
95  rocsparse_handle handle;
96  rocsparse_create_handle(&handle);
97  rocsparse_set_stream(handle, Gpu::gpuStream());
98 
99  rocsparse_datatype data_type;
100  if constexpr (std::is_same_v<T,float>) {
101  data_type = rocsparse_datatype_f32_r;
102  } else if constexpr (std::is_same_v<T,double>) {
103  data_type = rocsparse_datatype_f64_r;
104  } else if constexpr (std::is_same_v<T,GpuComplex<float>>) {
105  data_type = rocsparse_datatype_f32_c;
106  } else if constexpr (std::is_same_v<T,GpuComplex<double>>) {
107  data_type = rocsparse_datatype_f64_c;
108  } else {
109  amrex::Abort("SpMV: unsupported data type");
110  }
111 
112  rocsparse_indextype index_type = rocsparse_indextype_i64;
113 
114  rocsparse_spmat_descr mat_descr;
115  rocsparse_create_csr_descr(&mat_descr, nrows, ncols, nnz, (void*)row, (void*)col,
116  (void*)mat, index_type, index_type,
117  rocsparse_index_base_zero, data_type);
118 
119  rocsparse_dnvec_descr x_descr;
120  rocsparse_create_dnvec_descr(&x_descr, ncols, (void*)px, data_type);
121 
122  rocsparse_dnvec_descr y_descr;
123  rocsparse_create_dnvec_descr(&y_descr, nrows, (void*)py, data_type);
124 
125  T alpha = T(1.0);
126  T beta = T(0.0);
127 
128  std::size_t buffer_size;
129  rocsparse_spmv(handle, rocsparse_operation_none, &alpha, mat_descr, x_descr,
130  &beta, y_descr, data_type, rocsparse_spmv_alg_default,
131 #if (HIP_VERSION_MAJOR >= 6)
132  rocsparse_spmv_stage_buffer_size,
133 #endif
134  &buffer_size, nullptr);
135 
136  void* pbuffer = (void*)The_Arena()->alloc(buffer_size);
137 
138 #if (HIP_VERSION_MAJOR >= 6)
139  rocsparse_spmv(handle, rocsparse_operation_none, &alpha, mat_descr, x_descr,
140  &beta, y_descr, data_type, rocsparse_spmv_alg_default,
141  rocsparse_spmv_stage_preprocess, &buffer_size, pbuffer);
142 #endif
143 
144  rocsparse_spmv(handle, rocsparse_operation_none, &alpha, mat_descr, x_descr,
145  &beta, y_descr, data_type, rocsparse_spmv_alg_default,
146 #if (HIP_VERSION_MAJOR >= 6)
147  rocsparse_spmv_stage_compute,
148 #endif
149  &buffer_size, pbuffer);
150 
152 
153  rocsparse_destroy_spmat_descr(mat_descr);
154  rocsparse_destroy_dnvec_descr(x_descr);
155  rocsparse_destroy_dnvec_descr(y_descr);
156  rocsparse_destroy_handle(handle);
157  The_Arena()->free(pbuffer);
158 
159 #elif defined(AMREX_USE_DPCPP)
160 
161  mkl::sparse::matrix_handle_t handle{};
162  mkl::sparse::set_csr_data(Gpu::Device::streamQueue(), handle, nrows, ncols, mkl::index_base::zero,
163  (Long*)row, (Long*)col, (T*)mat);
164  mkl::sparse::gemv(Gpu::Device::streamQueue(), mkl::transpose::nontrans,
165  T(1), handle, px, T(0), py);
166 
167 #endif
168 
170 
171 #else
172 
173  Long const ny = y.numLocalRows();
174  for (Long i = 0; i < ny; ++i) {
175  T r = 0;
176 #ifdef AMREX_USE_OMP
177 #pragma omp parallel for reduction(+:r)
178 #endif
179  for (Long j = row[i]; j < row[i+1]; ++j) {
180  r += mat[j] * px[col[j]];
181  }
182  py[i] = r;
183  }
184 
185 #endif
186 
187  const_cast<SpMatrix<T>&>(A).finishComm(y);
188 }
189 
190 }
191 
192 #endif
#define zero
#define AMREX_ALWAYS_ASSERT(EX)
Definition: AMReX_BLassert.H:50
#define AMREX_RESTRICT
Definition: AMReX_Extension.H:37
#define AMREX_GPU_ERROR_CHECK()
Definition: AMReX_GpuError.H:125
Definition: AMReX_AlgVector.H:19
Long numLocalRows() const
Definition: AMReX_AlgVector.H:45
AlgPartition const & partition() const
Definition: AMReX_AlgVector.H:43
T const * data() const
Definition: AMReX_AlgVector.H:53
virtual void free(void *pt)=0
A pure virtual function for deleting the arena pointed to by pt.
virtual void * alloc(std::size_t sz)=0
Definition: AMReX_SpMatrix.H:19
Long const * columnIndex() const
Definition: AMReX_SpMatrix.H:50
AlgPartition const & partition() const
Definition: AMReX_SpMatrix.H:37
T const * data() const
Definition: AMReX_SpMatrix.H:48
Long numLocalRows() const
Definition: AMReX_SpMatrix.H:39
Long const * rowOffset() const
Definition: AMReX_SpMatrix.H:52
Long numLocalNonZero() const
Definition: AMReX_SpMatrix.H:41
void streamSynchronize() noexcept
Definition: AMReX_GpuDevice.H:237
gpuStream_t gpuStream() noexcept
Definition: AMReX_GpuDevice.H:218
Definition: AMReX_Amr.cpp:49
void SpMV(AlgVector< T > &y, SpMatrix< T > const &A, AlgVector< T > const &x)
Definition: AMReX_SpMV.H:20
void Abort(const std::string &msg)
Print out message to cerr and exit via abort().
Definition: AMReX.cpp:225
Arena * The_Arena()
Definition: AMReX_Arena.cpp:609
A host / device complex number type, because std::complex doesn't work in device code with Cuda yet.
Definition: AMReX_GpuComplex.H:29