Block-Structured AMR Software Framework
 
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages
Loading...
Searching...
No Matches
AMReX_SpMV.H
Go to the documentation of this file.
1#ifndef AMREX_SPMV_H_
2#define AMREX_SPMV_H_
3#include <AMReX_Config.H>
4
5#include <AMReX_AlgVector.H>
6#include <AMReX_GpuComplex.H>
7#include <AMReX_SpMatrix.H>
8
9#if defined(AMREX_USE_CUDA)
10# include <cusparse.h>
11#elif defined(AMREX_USE_HIP)
12# include <rocsparse/rocsparse.h>
13#elif defined(AMREX_USE_DPCPP)
14# include <oneapi/mkl/spblas.hpp>
15#endif
16
17namespace amrex {
18
19template <typename T>
20void SpMV (AlgVector<T>& y, SpMatrix<T> const& A, AlgVector<T> const& x)
21{
22 // xxxxx TODOL We might want to cache the cusparse and rocsparse handles
23
24 // xxxxx TODO: let's assume it's square matrix for now.
25 AMREX_ALWAYS_ASSERT(x.partition() == y.partition() &&
26 x.partition() == A.partition());
27
28 const_cast<SpMatrix<T>&>(A).startComm(x);
29
30 T * AMREX_RESTRICT py = y.data();
31 T const* AMREX_RESTRICT px = x.data();
32 T const* AMREX_RESTRICT mat = A.data();
33 auto const* AMREX_RESTRICT col = A.columnIndex();
34 auto const* AMREX_RESTRICT row = A.rowOffset();
35
36#if defined(AMREX_USE_GPU)
37
38 Long const nrows = A.numLocalRows();
39 Long const ncols = x.numLocalRows();
40 Long const nnz = A.numLocalNonZero();
41
42#if defined(AMREX_USE_CUDA)
43
44 cusparseHandle_t handle;
45 cusparseCreate(&handle);
46 cusparseSetStream(handle, Gpu::gpuStream());
47
48 cudaDataType data_type;
49 if constexpr (std::is_same_v<T,float>) {
50 data_type = CUDA_R_32F;
51 } else if constexpr (std::is_same_v<T,double>) {
52 data_type = CUDA_R_64F;
53 } else if constexpr (std::is_same_v<T,GpuComplex<float>>) {
54 data_type = CUDA_C_32F;
55 } else if constexpr (std::is_same_v<T,GpuComplex<double>>) {
56 data_type = CUDA_C_64F;
57 } else {
58 amrex::Abort("SpMV: unsupported data type");
59 }
60
61 cusparseIndexType_t index_type = CUSPARSE_INDEX_64I;
62
63 cusparseSpMatDescr_t mat_descr;
64 cusparseCreateCsr(&mat_descr, nrows, ncols, nnz, (void*)row, (void*)col, (void*)mat,
65 index_type, index_type, CUSPARSE_INDEX_BASE_ZERO, data_type);
66
67 cusparseDnVecDescr_t x_descr;
68 cusparseCreateDnVec(&x_descr, ncols, (void*)px, data_type);
69
70 cusparseDnVecDescr_t y_descr;
71 cusparseCreateDnVec(&y_descr, nrows, (void*)py, data_type);
72
73 T alpha = T(1);
74 T beta = T(0);
75
76 std::size_t buffer_size;
77 cusparseSpMV_bufferSize(handle, CUSPARSE_OPERATION_NON_TRANSPOSE, &alpha, mat_descr, x_descr,
78 &beta, y_descr, data_type, CUSPARSE_SPMV_ALG_DEFAULT, &buffer_size);
79
80 auto* pbuffer = (void*)The_Arena()->alloc(buffer_size);
81
82 cusparseSpMV(handle, CUSPARSE_OPERATION_NON_TRANSPOSE, &alpha, mat_descr, x_descr,
83 &beta, y_descr, data_type, CUSPARSE_SPMV_ALG_DEFAULT, pbuffer);
84
86
87 cusparseDestroySpMat(mat_descr);
88 cusparseDestroyDnVec(x_descr);
89 cusparseDestroyDnVec(y_descr);
90 cusparseDestroy(handle);
91 The_Arena()->free(pbuffer);
92
93#elif defined(AMREX_USE_HIP)
94
95 rocsparse_handle handle;
96 rocsparse_create_handle(&handle);
97 rocsparse_set_stream(handle, Gpu::gpuStream());
98
99 rocsparse_datatype data_type;
100 if constexpr (std::is_same_v<T,float>) {
101 data_type = rocsparse_datatype_f32_r;
102 } else if constexpr (std::is_same_v<T,double>) {
103 data_type = rocsparse_datatype_f64_r;
104 } else if constexpr (std::is_same_v<T,GpuComplex<float>>) {
105 data_type = rocsparse_datatype_f32_c;
106 } else if constexpr (std::is_same_v<T,GpuComplex<double>>) {
107 data_type = rocsparse_datatype_f64_c;
108 } else {
109 amrex::Abort("SpMV: unsupported data type");
110 }
111
112 rocsparse_indextype index_type = rocsparse_indextype_i64;
113
114 rocsparse_spmat_descr mat_descr;
115 rocsparse_create_csr_descr(&mat_descr, nrows, ncols, nnz, (void*)row, (void*)col,
116 (void*)mat, index_type, index_type,
117 rocsparse_index_base_zero, data_type);
118
119 rocsparse_dnvec_descr x_descr;
120 rocsparse_create_dnvec_descr(&x_descr, ncols, (void*)px, data_type);
121
122 rocsparse_dnvec_descr y_descr;
123 rocsparse_create_dnvec_descr(&y_descr, nrows, (void*)py, data_type);
124
125 T alpha = T(1.0);
126 T beta = T(0.0);
127
128 std::size_t buffer_size;
129 rocsparse_spmv(handle, rocsparse_operation_none, &alpha, mat_descr, x_descr,
130 &beta, y_descr, data_type, rocsparse_spmv_alg_default,
131#if (HIP_VERSION_MAJOR >= 6)
132 rocsparse_spmv_stage_buffer_size,
133#endif
134 &buffer_size, nullptr);
135
136 void* pbuffer = (void*)The_Arena()->alloc(buffer_size);
137
138#if (HIP_VERSION_MAJOR >= 6)
139 rocsparse_spmv(handle, rocsparse_operation_none, &alpha, mat_descr, x_descr,
140 &beta, y_descr, data_type, rocsparse_spmv_alg_default,
141 rocsparse_spmv_stage_preprocess, &buffer_size, pbuffer);
142#endif
143
144 rocsparse_spmv(handle, rocsparse_operation_none, &alpha, mat_descr, x_descr,
145 &beta, y_descr, data_type, rocsparse_spmv_alg_default,
146#if (HIP_VERSION_MAJOR >= 6)
147 rocsparse_spmv_stage_compute,
148#endif
149 &buffer_size, pbuffer);
150
152
153 rocsparse_destroy_spmat_descr(mat_descr);
154 rocsparse_destroy_dnvec_descr(x_descr);
155 rocsparse_destroy_dnvec_descr(y_descr);
156 rocsparse_destroy_handle(handle);
157 The_Arena()->free(pbuffer);
158
159#elif defined(AMREX_USE_DPCPP)
160
161 mkl::sparse::matrix_handle_t handle{};
162 mkl::sparse::set_csr_data(Gpu::Device::streamQueue(), handle, nrows, ncols, mkl::index_base::zero,
163 (Long*)row, (Long*)col, (T*)mat);
164 mkl::sparse::gemv(Gpu::Device::streamQueue(), mkl::transpose::nontrans,
165 T(1), handle, px, T(0), py);
166
167#endif
168
170
171#else
172
173 Long const ny = y.numLocalRows();
174 for (Long i = 0; i < ny; ++i) {
175 T r = 0;
176#ifdef AMREX_USE_OMP
177#pragma omp parallel for reduction(+:r)
178#endif
179 for (Long j = row[i]; j < row[i+1]; ++j) {
180 r += mat[j] * px[col[j]];
181 }
182 py[i] = r;
183 }
184
185#endif
186
187 const_cast<SpMatrix<T>&>(A).finishComm(y);
188}
189
190}
191
192#endif
#define AMREX_ALWAYS_ASSERT(EX)
Definition AMReX_BLassert.H:50
#define AMREX_RESTRICT
Definition AMReX_Extension.H:37
#define AMREX_GPU_ERROR_CHECK()
Definition AMReX_GpuError.H:133
Definition AMReX_AlgVector.H:19
Long numLocalRows() const
Definition AMReX_AlgVector.H:45
AlgPartition const & partition() const
Definition AMReX_AlgVector.H:43
T const * data() const
Definition AMReX_AlgVector.H:53
virtual void free(void *pt)=0
A pure virtual function for deleting the arena pointed to by pt.
virtual void * alloc(std::size_t sz)=0
Definition AMReX_SpMatrix.H:19
Long const * rowOffset() const
Definition AMReX_SpMatrix.H:60
T const * data() const
Definition AMReX_SpMatrix.H:56
Long const * columnIndex() const
Definition AMReX_SpMatrix.H:58
Long numLocalRows() const
Definition AMReX_SpMatrix.H:47
AlgPartition const & partition() const
Definition AMReX_SpMatrix.H:45
Long numLocalNonZero() const
Definition AMReX_SpMatrix.H:49
void streamSynchronize() noexcept
Definition AMReX_GpuDevice.H:237
gpuStream_t gpuStream() noexcept
Definition AMReX_GpuDevice.H:218
Definition AMReX_Amr.cpp:49
void SpMV(AlgVector< T > &y, SpMatrix< T > const &A, AlgVector< T > const &x)
Definition AMReX_SpMV.H:20
void Abort(const std::string &msg)
Print out message to cerr and exit via abort().
Definition AMReX.cpp:230
Arena * The_Arena()
Definition AMReX_Arena.cpp:616