Block-Structured AMR Software Framework
Loading...
Searching...
No Matches
AMReX_CSR.H
Go to the documentation of this file.
1#ifndef AMREX_CSR_H_
2#define AMREX_CSR_H_
3#include <AMReX_Config.H>
4
5#include <AMReX_Gpu.H>
6#include <AMReX_INT.H>
7#include <AMReX_OpenMP.H>
8
9#if defined(AMREX_USE_CUDA)
10#include <cub/cub.cuh> // for Clang
11#endif
12
13#include <algorithm>
14#include <climits>
15#include <type_traits>
16
17namespace amrex {
18
19template <typename T>
20struct CsrView {
21 using U = std::conditional_t<std::is_const_v<T>, Long const, Long>;
22 T* AMREX_RESTRICT mat = nullptr;
25 Long nnz = 0;
27};
28
29template <typename T, template <typename> class V>
30struct CSR {
31 V<T> mat;
32 V<Long> col_index;
33 V<Long> row_offset;
34 Long nnz = 0;
35
36 [[nodiscard]] Long nrows () const {
37 return row_offset.empty() ? Long(0) : Long(row_offset.size())-1;
38 }
39
40 void resize (Long num_rows, Long num_non_zeros) {
41 mat.resize(num_non_zeros);
42 col_index.resize(num_non_zeros);
43 row_offset.resize(num_rows+1);
44 nnz = num_non_zeros;
45 }
46
48 return CsrView<T>{mat.data(), col_index.data(), row_offset.data(),
49 nnz, Long(row_offset.size())-1};
50 }
51
52 [[nodiscard]] CsrView<T const> view () const {
53 return CsrView<T>{mat.data(), col_index.data(), row_offset.data(),
54 nnz, Long(row_offset.size())-1};
55 }
56
57 [[nodiscard]] CsrView<T const> const_view () const {
58 return CsrView<T const>{mat.data(), col_index.data(), row_offset.data(),
59 nnz, Long(row_offset.size())-1};
60 }
61
62 void sort ();
63
64 void sort_on_host ();
65};
66
67template <typename C, typename T, template<typename> class AD, template<typename> class AS,
68 std::enable_if_t<std::is_same_v<C,Gpu::HostToDevice> ||
69 std::is_same_v<C,Gpu::DeviceToHost> ||
70 std::is_same_v<C,Gpu::DeviceToDevice>, int> = 0>
71void duplicateCSR (C c, CSR<T,AD>& dst, CSR<T,AS> const& src)
72{
73 dst.mat.resize(src.mat.size());
74 dst.col_index.resize(src.col_index.size());
75 dst.row_offset.resize(src.row_offset.size());
77 src.mat.begin(),
78 src.mat.end(),
79 dst.mat.begin());
81 src.col_index.begin(),
82 src.col_index.end(),
83 dst.col_index.begin());
85 src.row_offset.begin(),
86 src.row_offset.end(),
87 dst.row_offset.begin());
88 dst.nnz = src.nnz;
89}
90
91template <typename T, template <typename> class V>
93{
94 if (nnz <= 0) { return; }
95
96#ifdef AMREX_USE_GPU
97
98#if defined(AMREX_USE_CUDA) || defined(AMREX_USE_HIP)
99
100 // The function is synchronous. If that is no longer the case, we might
101 // need to update SpMatrix::define.
102
103 constexpr int nthreads = 256;
104 constexpr int nwarps_per_block = nthreads / Gpu::Device::warp_size;
105
106 AMREX_ALWAYS_ASSERT((nrows()+nwarps_per_block-1) < Long(std::numeric_limits<int>::max()));
107
108 auto nr = int(nrows());
109 int nblocks = (nr + nwarps_per_block-1) / nwarps_per_block;
110 auto const& stream = Gpu::gpuStream();
111
112 auto* pmat = mat.data();
113 auto* pcol = col_index.data();
114 auto* prow = row_offset.data();
115
116 Gpu::Buffer<int> needs_fallback({0});
117 auto* d_needs_fallback = needs_fallback.data();
118
119 amrex::launch_global<nthreads><<<nblocks, nthreads, 0, stream>>>
120 ([=] AMREX_GPU_DEVICE () noexcept
121 {
122 int wid = int(threadIdx.x)/Gpu::Device::warp_size;
123 int r = int(blockIdx.x)*nwarps_per_block + wid;
124 if (r >= nr) return;
125
126 Long const b = prow[r];
127 Long const e = prow[r+1];
128 auto const len = int(e - b);
129
130 if (len <= 1) return;
131
132 int lane = int(threadIdx.x) - wid * Gpu::Device::warp_size;
133
134 bool sorted = true;
135 for (Long i = lane + 1; i < len; i += Gpu::Device::warp_size) {
136 sorted = sorted && (pcol[b+i-1] <= pcol[b+i]);
137 }
138#if defined(AMREX_USE_CUDA)
139 if (__all_sync(0xffffffff, sorted)) { return; }
140#else
141 if (__all(sorted)) { return; }
142#endif
143
144 constexpr int ITEMS_PER_THREAD = AMREX_HIP_OR_CUDA(2,4);
145 constexpr int ITEMS_PER_WARP = Gpu::Device::warp_size * ITEMS_PER_THREAD;
146
147 if (len <= ITEMS_PER_WARP)
148 {
149#if defined(AMREX_USE_CUDA)
150 using WarpSort = cub::WarpMergeSort<Long, ITEMS_PER_THREAD, Gpu::Device::warp_size, T>;
151 __shared__ typename WarpSort::TempStorage temp_storage[nwarps_per_block];
152#elif defined(AMREX_USE_HIP)
153 using WarpSort = rocprim::warp_sort<Long, Gpu::Device::warp_size, T>;
154 __shared__ typename WarpSort::storage_type temp_storage[nwarps_per_block];
155#endif
156
157 Long keys[ITEMS_PER_THREAD];
158 T values[ITEMS_PER_THREAD];
159
160 #pragma unroll
161 for (int i = 0; i < ITEMS_PER_THREAD; ++i) {
162 int idx = lane * ITEMS_PER_THREAD + i;
163 if (idx < len) {
164 keys[i] = pcol[b + idx];
165 values[i] = pmat[b + idx];
166 } else {
167 keys[i] = std::numeric_limits<Long>::max();
168 values[i] = T{};
169 }
170 }
171
173 WarpSort{}.sort(keys, values, temp_storage[wid]),
174 WarpSort(temp_storage[wid]).Sort(
175 keys, values, [](Long x, Long y) {return x < y;}));
176
177 #pragma unroll
178 for (int i = 0; i < ITEMS_PER_THREAD; ++i) {
179 int idx = lane * ITEMS_PER_THREAD + i;
180 if (idx < len) {
181 pcol[b + idx] = keys[i];
182 pmat[b + idx] = values[i];
183 }
184 }
185 } else {
186 if (lane == 0) {
187 Gpu::Atomic::AddNoRet(d_needs_fallback, 1);
188 }
189 }
190 });
191
192 auto* h_needs_fallback = needs_fallback.copyToHost();
193
194 if (*h_needs_fallback)
195 {
196 V<Long> col_index_out(col_index.size());
197 V<T> mat_out(mat.size());
198 auto* d_col_out = col_index_out.data();
199 auto* d_val_out = mat_out.data();
200
201 std::size_t temp_bytes = 0;
202
204 rocprim::segmented_radix_sort_pairs,
205 cub::DeviceSegmentedRadixSort::SortPairs)
206 (nullptr, temp_bytes, pcol, d_col_out, pmat, d_val_out,
207 nnz, nr, prow, prow+1, 0, int(sizeof(Long)*CHAR_BIT),
208 stream));
209
210 auto* d_temp = (void*) The_Arena()->alloc(temp_bytes);
211
213 rocprim::segmented_radix_sort_pairs,
214 cub::DeviceSegmentedRadixSort::SortPairs)
215 (d_temp, temp_bytes, pcol, d_col_out, pmat, d_val_out,
216 nnz, nr, prow, prow+1, 0, int(sizeof(Long)*CHAR_BIT),
217 stream));
218
219 std::swap(col_index, col_index_out);
220 std::swap(mat, mat_out);
221
223 The_Arena()->free(d_temp);
224 }
225
226 // let's test both by print matrix out to see if it's sorted.
227
229
230#elif defined(AMREX_USE_SYCL)
231
232 // xxxxx TODO SYCL: Let's not worry about performance for now.
234 duplicateCSR(Gpu::deviceToHost, h_csr, *this);
236 h_csr.sort_on_host();
237 duplicateCSR(Gpu::hostToDevice, *this, h_csr);
239
240#endif
241
242#else
243
244 sort_on_host();
245
246#endif
247}
248
249template <typename T, template <typename> class V>
251{
252 if (nnz <= 0) { return; }
253
254 constexpr int SMALL = 128;
255
256 Long nr = nrows();
257
258#ifdef AMREX_USE_OMP
259#pragma omp parallel
260#endif
261 {
262 V<Long> lcols;
263 V<T > lvals;
264 V<int > perm;
265
266 Long scols[SMALL];
267 T svals[SMALL];
268
269#ifdef AMREX_USE_OMP
270#pragma omp for
271#endif
272 for (Long r = 0; r < nr; ++r) {
273 Long const b = row_offset[r ];
274 Long const e = row_offset[r+1];
275 auto const len = int(e - b);
276
277 if (len <= 1) { continue; }
278
279 bool sorted = true;
280 for (int i = 1; i < len; ++i) {
281 if (col_index[b+i-1] > col_index[b+i]) {
282 sorted = false;
283 break;
284 }
285 }
286 if (sorted) { continue; }
287
288 if (len <= SMALL) {
289 // Insertion sort using arrays on stack
290 for (int i = 0; i < len; ++i) {
291 scols[i] = col_index[b+i];
292 svals[i] = mat [b+i];
293 }
294 for (int i = 1; i < len; ++i) {
295 auto c = scols[i];
296 auto v = svals[i];
297 auto j = i;
298 while (j > 0 && scols[j-1] > c) {
299 scols[j] = scols[j-1];
300 svals[j] = svals[j-1];
301 --j;
302 }
303 scols[j] = c;
304 svals[j] = v;
305 }
306 for (int i = 0; i < len; ++i) {
307 col_index[b+i] = scols[i];
308 mat [b+i] = svals[i];
309 }
310 } else {
311 lcols.resize(len);
312 lvals.resize(len);
313 perm.resize(len);
314
315 for (int i = 0; i < len; ++i) {
316 lcols[i] = col_index[b+i];
317 lvals[i] = mat [b+i];
318 perm [i] = i;
319 }
320
321 std::sort(perm.begin(), perm.end(),
322 [&] (int i0, int i1) {
323 return lcols[i0] < lcols[i1];
324 });
325
326 for (int out = 0; out < len; ++out) {
327 auto const in = perm[out];
328 col_index[b+out] = lcols[in];
329 mat [b+out] = lvals[in];
330 }
331 }
332 }
333 }
334}
335
336}
337
338#endif
#define AMREX_ALWAYS_ASSERT(EX)
Definition AMReX_BLassert.H:50
#define AMREX_RESTRICT
Definition AMReX_Extension.H:32
#define AMREX_HIP_OR_CUDA(a, b)
Definition AMReX_GpuControl.H:21
#define AMREX_GPU_SAFE_CALL(call)
Definition AMReX_GpuError.H:63
#define AMREX_GPU_ERROR_CHECK()
Definition AMReX_GpuError.H:151
#define AMREX_GPU_DEVICE
Definition AMReX_GpuQualifiers.H:18
virtual void free(void *pt)=0
A pure virtual function for deleting the arena pointed to by pt.
virtual void * alloc(std::size_t sz)=0
Definition AMReX_GpuBuffer.H:18
T const * data() const noexcept
Definition AMReX_GpuBuffer.H:45
static constexpr int warp_size
Definition AMReX_GpuDevice.H:197
amrex_long Long
Definition AMReX_INT.H:30
Arena * The_Arena()
Definition AMReX_Arena.cpp:783
__host__ __device__ AMREX_FORCE_INLINE void AddNoRet(T *sum, T value) noexcept
Definition AMReX_GpuAtomic.H:283
void copyAsync(HostToDevice, InIter begin, InIter end, OutIter result) noexcept
A host-to-device copy routine. Note this is just a wrapper around memcpy, so it assumes contiguous st...
Definition AMReX_GpuContainers.H:228
static constexpr DeviceToHost deviceToHost
Definition AMReX_GpuContainers.H:106
static constexpr HostToDevice hostToDevice
Definition AMReX_GpuContainers.H:105
void streamSynchronize() noexcept
Definition AMReX_GpuDevice.H:263
gpuStream_t gpuStream() noexcept
Definition AMReX_GpuDevice.H:244
Definition AMReX_Amr.cpp:49
void duplicateCSR(C c, CSR< T, AD > &dst, CSR< T, AS > const &src)
Definition AMReX_CSR.H:71
const int[]
Definition AMReX_BLProfiler.cpp:1664
Definition AMReX_CSR.H:30
V< Long > row_offset
Definition AMReX_CSR.H:33
Long nrows() const
Definition AMReX_CSR.H:36
Long nnz
Definition AMReX_CSR.H:34
void sort()
Definition AMReX_CSR.H:92
CsrView< T > view()
Definition AMReX_CSR.H:47
void sort_on_host()
Definition AMReX_CSR.H:250
CsrView< T const > view() const
Definition AMReX_CSR.H:52
CsrView< T const > const_view() const
Definition AMReX_CSR.H:57
void resize(Long num_rows, Long num_non_zeros)
Definition AMReX_CSR.H:40
V< Long > col_index
Definition AMReX_CSR.H:32
V< T > mat
Definition AMReX_CSR.H:31
Definition AMReX_CSR.H:20
std::conditional_t< std::is_const_v< T >, Long const, Long > U
Definition AMReX_CSR.H:21
T *__restrict__ mat
Definition AMReX_CSR.H:22
Long nrows
Definition AMReX_CSR.H:26
Long nnz
Definition AMReX_CSR.H:25
U *__restrict__ row_offset
Definition AMReX_CSR.H:24
U *__restrict__ col_index
Definition AMReX_CSR.H:23