Block-Structured AMR Software Framework
 
Loading...
Searching...
No Matches
AMReX_TagParallelFor.H
Go to the documentation of this file.
1#ifndef AMREX_TAG_PARALLELFOR_H_
2#define AMREX_TAG_PARALLELFOR_H_
3#include <AMReX_Config.H>
4
5#include <AMReX_Arena.H>
6#include <AMReX_Array4.H>
7#include <AMReX_Box.H>
8#include <AMReX_GpuLaunch.H>
9#include <AMReX_Vector.H>
10#include <limits>
11#include <utility>
12
13namespace amrex {
14
15template <class T>
20
22 Box const& box () const noexcept { return dbox; }
23};
24
25template <class T0, class T1=T0>
28 int dindex;
31 Dim3 offset; // sbox.smallEnd() - dbox.smallEnd()
32
34 Box const& box () const noexcept { return dbox; }
35};
36
37template <class T0, class T1=T0>
43 Dim3 offset; // sbox.smallEnd() - dbox.smallEnd()
44
46 Box const& box () const noexcept { return dbox; }
47};
48
49template <class T>
50struct Array4Tag {
52
54 Box box () const noexcept { return Box(dfab); }
55};
56
57template <class T>
61
63 Box const& box () const noexcept { return dbox; }
64};
65
66template <class T>
70 T val;
71
73 Box const& box () const noexcept { return dbox; }
74};
75
76template <class T>
81
83 Box const& box() const noexcept { return bx; }
84};
85
86template <class T>
91
93 Box const& box() const noexcept { return bx; }
94};
95
96template <class T>
97struct VectorTag {
98 T* p;
100
102 Long size () const noexcept { return m_size; }
103};
104
105template <class T>
106struct CommRecvBufTag { // for unpacking recv buffer
108 std::ptrdiff_t poff;
110
112 Box const& box () const noexcept { return bx; }
113};
114
115template <class T>
116struct CommSendBufTag { // for packing send buffer
118 std::ptrdiff_t poff;
120
122 Box const& box () const noexcept { return bx; }
123};
124
126namespace detail {
127
128 template <typename T>
129 std::enable_if_t<std::is_same_v<std::decay_t<decltype(std::declval<T>().box())>, Box>, Long>
130 get_tag_size (T const& tag) noexcept
131 {
132 AMREX_ASSERT(tag.box().numPts() < Long(std::numeric_limits<int>::max()));
133 return static_cast<int>(tag.box().numPts());
134 }
135
136 template <typename T>
137 std::enable_if_t<std::is_integral_v<std::decay_t<decltype(std::declval<T>().size())> >, Long>
138 get_tag_size (T const& tag) noexcept
139 {
140 AMREX_ASSERT(tag.size() < Long(std::numeric_limits<int>::max()));
141 return tag.size();
142 }
143
144 template <typename T>
145 constexpr
146 std::enable_if_t<std::is_same_v<std::decay_t<decltype(std::declval<T>().box())>, Box>, bool>
147 is_box_tag (T const&) { return true; }
148
149 template <typename T>
150 constexpr
151 std::enable_if_t<std::is_integral_v<std::decay_t<decltype(std::declval<T>().size())> >, bool>
152 is_box_tag (T const&) { return false; }
153
154}
156
157template <class TagType>
158struct TagVector {
159
160 char* h_buffer = nullptr;
161 char* d_buffer = nullptr;
162 TagType* d_tags = nullptr;
163 int* d_nwarps = nullptr;
164 int ntags = 0;
165 int ntotwarps = 0;
166 int nblocks = 0;
167 bool defined = false;
168 static constexpr int nthreads = 256;
169
170 TagVector () = default;
171
173 define(tags);
174 }
175
177 if (defined) {
178 undefine();
179 }
180 }
181
182 TagVector (const TagVector& other) = delete;
183 TagVector& operator= (const TagVector& other) = delete;
184 TagVector (TagVector&& other) noexcept
185 : h_buffer{other.h_buffer},
186 d_buffer{other.d_buffer},
187 d_tags{other.d_tags},
188 d_nwarps{other.d_nwarps},
189 ntags{other.ntags},
190 ntotwarps{other.ntotwarps},
191 nblocks{other.nblocks},
192 defined{other.defined}
193 {
194 other.h_buffer = nullptr;
195 other.d_buffer = nullptr;
196 other.d_tags = nullptr;
197 other.d_nwarps = nullptr;
198 other.ntags = 0;
199 other.ntotwarps = 0;
200 other.nblocks = 0;
201 other.defined = false;
202 }
203 TagVector& operator= (TagVector&& other) noexcept {
204 if (this == &other) {
205 return *this;
206 }
207 undefine();
208 h_buffer = other.h_buffer;
209 other.h_buffer = nullptr;
210 d_buffer = other.d_buffer;
211 other.d_buffer = nullptr;
212 d_tags = other.d_tags;
213 other.d_tags = nullptr;
214 d_nwarps = other.d_nwarps;
215 other.d_nwarps = nullptr;
216 ntags = other.ntags;
217 other.ntags = 0;
218 ntotwarps = other.ntotwarps;
219 other.ntotwarps = 0;
220 nblocks = other.nblocks;
221 other.nblocks = 0;
222 defined = other.defined;
223 other.defined = false;
224 return *this;
225 }
226
227 [[nodiscard]] bool is_defined () const { return defined; }
228
229 void define (Vector<TagType> const& tags) {
230 if (defined) {
231 undefine();
232 }
233
234 ntags = tags.size();
235 if (ntags == 0) {
236 defined = true;
237 return;
238 }
239
240#ifdef AMREX_USE_GPU
241 Long l_ntotwarps = 0;
242 ntotwarps = 0;
243 Vector<int> nwarps;
244 nwarps.reserve(ntags+1);
245 for (int i = 0; i < ntags; ++i)
246 {
247 auto& tag = tags[i];
248 nwarps.push_back(ntotwarps);
249 auto nw = (detail::get_tag_size(tag) + Gpu::Device::warp_size-1) /
251 l_ntotwarps += nw;
252 ntotwarps += static_cast<int>(nw);
253 }
254 nwarps.push_back(ntotwarps);
255
256 std::size_t sizeof_tags = ntags*sizeof(TagType);
257 std::size_t offset_nwarps = Arena::align(sizeof_tags);
258 std::size_t sizeof_nwarps = (ntags+1)*sizeof(int);
259 std::size_t total_buf_size = offset_nwarps + sizeof_nwarps;
260
261 h_buffer = (char*)The_Pinned_Arena()->alloc(total_buf_size);
262 d_buffer = (char*)The_Arena()->alloc(total_buf_size);
263
264 std::memcpy(h_buffer, tags.data(), sizeof_tags);
265 std::memcpy(h_buffer+offset_nwarps, nwarps.data(), sizeof_nwarps);
266 Gpu::htod_memcpy_async(d_buffer, h_buffer, total_buf_size);
267
268 d_tags = reinterpret_cast<TagType*>(d_buffer);
269 d_nwarps = reinterpret_cast<int*>(d_buffer+offset_nwarps);
270
271 constexpr int nwarps_per_block = nthreads/Gpu::Device::warp_size;
272 nblocks = (ntotwarps + nwarps_per_block-1) / nwarps_per_block;
273
274 defined = true;
275
276 amrex::ignore_unused(l_ntotwarps);
277 AMREX_ALWAYS_ASSERT(l_ntotwarps+nwarps_per_block-1 < Long(std::numeric_limits<int>::max()));
278#else
279 std::size_t sizeof_tags = ntags*sizeof(TagType);
280 h_buffer = (char*)The_Pinned_Arena()->alloc(sizeof_tags);
281
282 std::memcpy(h_buffer, tags.data(), sizeof_tags);
283
284 d_tags = reinterpret_cast<TagType*>(h_buffer);
285
286 defined = true;
287#endif
288 }
289
290 void undefine () {
291 if (defined) {
295 h_buffer = nullptr;
296 d_buffer = nullptr;
297 d_tags = nullptr;
298 d_nwarps = nullptr;
299 ntags = 0;
300 ntotwarps = 0;
301 nblocks = 0;
302 defined = false;
303 }
304 }
305};
306
308namespace detail {
309
310#ifdef AMREX_USE_GPU
311
312template <typename T, typename F>
314std::enable_if_t<std::is_same_v<std::decay_t<decltype(std::declval<T>().box())>, Box>, void>
315tagparfor_call_f (int icell, T const& tag, F&& f) noexcept
316{
317 int ncells = tag.box().numPts();
318 const auto len = amrex::length(tag.box());
319 const auto lo = amrex::lbound(tag.box());
320 int k = icell / (len.x*len.y);
321 int j = (icell - k*(len.x*len.y)) / len.x;
322 int i = (icell - k*(len.x*len.y)) - j*len.x;
323 i += lo.x;
324 j += lo.y;
325 k += lo.z;
326 f(icell, ncells, i, j, k, tag);
327}
328
329template <typename T, typename F>
331std::enable_if_t<std::is_integral_v<std::decay_t<decltype(std::declval<T>().size())> >, void>
332tagparfor_call_f (int i, T const& tag, F&& f) noexcept
333{
334 int N = tag.size();
335 f(i, N, tag);
336}
337
338template <class TagType, class F>
339void
340ParallelFor_doit (TagVector<TagType> const& tv, F const& f)
341{
342 AMREX_ALWAYS_ASSERT(tv.is_defined());
343
344 if (tv.ntags == 0) { return; }
345
346 const auto d_tags = tv.d_tags;
347 const auto d_nwarps = tv.d_nwarps;
348 const auto ntags = tv.ntags;
349 const auto ntotwarps = tv.ntotwarps;
350 constexpr auto nthreads = TagVector<TagType>::nthreads;
351
352 amrex::launch<nthreads>(tv.nblocks, Gpu::gpuStream(),
353#ifdef AMREX_USE_SYCL
354 [=] AMREX_GPU_DEVICE (sycl::nd_item<1> const& item) noexcept
355 [[sycl::reqd_work_group_size(nthreads)]]
356 [[sycl::reqd_sub_group_size(Gpu::Device::warp_size)]]
357#else
358 [=] AMREX_GPU_DEVICE () noexcept
359#endif
360 {
361#ifdef AMREX_USE_SYCL
362 std::size_t g_tid = item.get_global_id(0);
363#else
364 auto g_tid = std::size_t(blockDim.x)*blockIdx.x + threadIdx.x;
365#endif
366 auto g_wid = int(g_tid / Gpu::Device::warp_size);
367 if (g_wid >= ntotwarps) { return; }
368
369 int tag_id = amrex::bisect(d_nwarps, 0, ntags, g_wid);
370
371 int b_wid = g_wid - d_nwarps[tag_id]; // b_wid'th warp on this box
372#ifdef AMREX_USE_SYCL
373 int lane = item.get_local_id(0) % Gpu::Device::warp_size;
374#else
375 int lane = threadIdx.x % Gpu::Device::warp_size;
376#endif
377 int icell = b_wid*Gpu::Device::warp_size + lane;
378
379 tagparfor_call_f(icell, d_tags[tag_id], f);
380 });
381}
382
383#else // ifdef AMREX_USE_GPU
384
385template <class TagType, class F>
386void
387ParallelFor_doit (TagVector<TagType> const& tv, F const& f)
388{
389 // Note: this CPU version may not have optimal performance:
390 // The loop over ncomp is the innermost instead of the outermost
391 // There is no load-balancing or splitting of tags
392 AMREX_ALWAYS_ASSERT(tv.is_defined());
393
394 constexpr bool tag_type = is_box_tag(TagType{});
395
396 if (tv.ntags == 0) { return; }
397
398 const auto d_tags = tv.d_tags;
399 const auto ntags = tv.ntags;
400
401#ifdef AMREX_USE_OMP
402#pragma omp parallel for
403#endif
404 for (int itag = 0; itag < ntags; ++itag) {
405
406 const auto& t = d_tags[itag];
407
408 if constexpr (tag_type) {
409 const auto lo = amrex::lbound(t.box());
410 const auto hi = amrex::ubound(t.box());
411
412 for (int k = lo.z; k <= hi.z; ++k) {
413 for (int j = lo.y; j <= hi.y; ++j) {
415 for (int i = lo.x; i <= hi.x; ++i) {
416 f(0, 1, i, j, k, t);
417 }
418 }
419 }
420 } else {
421 const auto size = t.size();
422
424 for (int i = 0; i < size; ++i) {
425 f(i, size, t);
426 }
427 }
428 }
429}
430
431#endif
432
433}
435
436template <class TagType, class F>
437std::enable_if_t<std::is_same_v<std::decay_t<decltype(std::declval<TagType>().box())>, Box>>
438ParallelFor (TagVector<TagType> const& tv, int ncomp, F const& f)
439{
440 detail::ParallelFor_doit(tv,
441 [=] AMREX_GPU_DEVICE (
442 int icell, int ncells, int i, int j, int k, TagType const& tag) noexcept
443 {
444 if (icell < ncells) {
445 for (int n = 0; n < ncomp; ++n) {
446 f(i,j,k,n,tag);
447 }
448 }
449 });
450}
451
452template <class TagType, class F>
453std::enable_if_t<std::is_same_v<std::decay_t<decltype(std::declval<TagType>().box())>, Box>, void>
454ParallelFor (TagVector<TagType> const& tv, F const& f)
455{
456 detail::ParallelFor_doit(tv,
457 [=] AMREX_GPU_DEVICE (
458 int icell, int ncells, int i, int j, int k, TagType const& tag) noexcept
459 {
460 if (icell < ncells) {
461 f(i,j,k,tag);
462 }
463 });
464}
465
466template <class TagType, class F>
467std::enable_if_t<std::is_integral_v<std::decay_t<decltype(std::declval<TagType>().size())> >, void>
468ParallelFor (TagVector<TagType> const& tv, F const& f)
469{
470 detail::ParallelFor_doit(tv,
471 [=] AMREX_GPU_DEVICE (
472 int icell, int ncells, TagType const& tag) noexcept
473 {
474 if (icell < ncells) {
475 f(icell,tag);
476 }
477 });
478}
479
480template <class TagType, class F>
481std::enable_if_t<std::is_same_v<std::decay_t<decltype(std::declval<TagType>().box())>, Box>>
482ParallelFor (Vector<TagType> const& tags, int ncomp, F && f)
483{
484 TagVector<TagType> tv{tags};
485 ParallelFor(tv, ncomp, std::forward<F>(f));
486}
487
488template <class TagType, class F>
489std::enable_if_t<std::is_same_v<std::decay_t<decltype(std::declval<TagType>().box())>, Box>, void>
490ParallelFor (Vector<TagType> const& tags, F && f)
491{
492 TagVector<TagType> tv{tags};
493 ParallelFor(tv, std::forward<F>(f));
494}
495
496template <class TagType, class F>
497std::enable_if_t<std::is_integral_v<std::decay_t<decltype(std::declval<TagType>().size())> >, void>
498ParallelFor (Vector<TagType> const& tags, F && f)
499{
500 TagVector<TagType> tv{tags};
501 ParallelFor(tv, std::forward<F>(f));
502}
503
504}
505
506#endif
#define AMREX_ASSERT(EX)
Definition AMReX_BLassert.H:38
#define AMREX_ALWAYS_ASSERT(EX)
Definition AMReX_BLassert.H:50
#define AMREX_PRAGMA_SIMD
Definition AMReX_Extension.H:80
#define AMREX_FORCE_INLINE
Definition AMReX_Extension.H:119
#define AMREX_GPU_DEVICE
Definition AMReX_GpuQualifiers.H:18
#define AMREX_GPU_HOST_DEVICE
Definition AMReX_GpuQualifiers.H:20
virtual void free(void *pt)=0
A pure virtual function for deleting the arena pointed to by pt.
virtual void * alloc(std::size_t sz)=0
static std::size_t align(std::size_t sz)
Given a minimum required arena size of sz bytes, this returns the next largest arena size that will a...
Definition AMReX_Arena.cpp:152
__host__ __device__ IntVectND< dim > size() const noexcept
Return the length of the BoxND.
Definition AMReX_Box.H:147
static constexpr int warp_size
Definition AMReX_GpuDevice.H:197
Encapsulation of the Orientation of the Faces of a Box.
Definition AMReX_Orientation.H:29
This class is a thin wrapper around std::vector. Unlike vector, Vector::operator[] provides bound che...
Definition AMReX_Vector.H:28
Long size() const noexcept
Definition AMReX_Vector.H:53
amrex_long Long
Definition AMReX_INT.H:30
void streamSynchronize() noexcept
Definition AMReX_GpuDevice.H:263
void htod_memcpy_async(void *p_d, const void *p_h, const std::size_t sz) noexcept
Definition AMReX_GpuDevice.H:301
Definition AMReX_Amr.cpp:49
__host__ __device__ T bisect(T lo, T hi, F f, T tol=1e-12, int max_iter=100)
Definition AMReX_Algorithm.H:105
__host__ __device__ Dim3 ubound(Array4< T > const &a) noexcept
Definition AMReX_Array4.H:319
__host__ __device__ void ignore_unused(const Ts &...)
This shuts up the compiler about unused variables.
Definition AMReX.H:138
std::enable_if_t< std::is_integral_v< T > > ParallelFor(TypeList< CTOs... > ctos, std::array< int, sizeof...(CTOs)> const &runtime_options, T N, F &&f)
Definition AMReX_CTOParallelForImpl.H:193
__host__ __device__ Dim3 length(Array4< T > const &a) noexcept
Definition AMReX_Array4.H:326
BoxND< 3 > Box
Box is an alias for amrex::BoxND instantiated with AMREX_SPACEDIM.
Definition AMReX_BaseFwd.H:27
Arena * The_Pinned_Arena()
Definition AMReX_Arena.cpp:823
Arena * The_Arena()
Definition AMReX_Arena.cpp:783
__host__ __device__ Dim3 lbound(Array4< T > const &a) noexcept
Definition AMReX_Array4.H:312
Definition AMReX_TagParallelFor.H:87
Box bx
Definition AMReX_TagParallelFor.H:89
Array4< T > fab
Definition AMReX_TagParallelFor.H:88
__host__ __device__ Box const & box() const noexcept
Definition AMReX_TagParallelFor.H:93
Dim3 offset
Definition AMReX_TagParallelFor.H:90
Definition AMReX_TagParallelFor.H:77
Array4< T > fab
Definition AMReX_TagParallelFor.H:78
__host__ __device__ Box const & box() const noexcept
Definition AMReX_TagParallelFor.H:83
Box bx
Definition AMReX_TagParallelFor.H:79
Orientation face
Definition AMReX_TagParallelFor.H:80
Definition AMReX_TagParallelFor.H:58
Array4< T > dfab
Definition AMReX_TagParallelFor.H:59
__host__ __device__ Box const & box() const noexcept
Definition AMReX_TagParallelFor.H:63
Box dbox
Definition AMReX_TagParallelFor.H:60
Definition AMReX_TagParallelFor.H:67
T val
Definition AMReX_TagParallelFor.H:70
Array4< T > dfab
Definition AMReX_TagParallelFor.H:68
__host__ __device__ Box const & box() const noexcept
Definition AMReX_TagParallelFor.H:73
Box dbox
Definition AMReX_TagParallelFor.H:69
Definition AMReX_TagParallelFor.H:26
Array4< T1 const > sfab
Definition AMReX_TagParallelFor.H:29
Dim3 offset
Definition AMReX_TagParallelFor.H:31
int dindex
Definition AMReX_TagParallelFor.H:28
__host__ __device__ Box const & box() const noexcept
Definition AMReX_TagParallelFor.H:34
Array4< T0 > dfab
Definition AMReX_TagParallelFor.H:27
Box dbox
Definition AMReX_TagParallelFor.H:30
Definition AMReX_TagParallelFor.H:38
__host__ __device__ Box const & box() const noexcept
Definition AMReX_TagParallelFor.H:46
Dim3 offset
Definition AMReX_TagParallelFor.H:43
Box dbox
Definition AMReX_TagParallelFor.H:42
Array4< T0 > dfab
Definition AMReX_TagParallelFor.H:39
Array4< int > mask
Definition AMReX_TagParallelFor.H:41
Array4< T1 const > sfab
Definition AMReX_TagParallelFor.H:40
Definition AMReX_TagParallelFor.H:16
Array4< T > dfab
Definition AMReX_TagParallelFor.H:17
__host__ __device__ Box const & box() const noexcept
Definition AMReX_TagParallelFor.H:22
Box dbox
Definition AMReX_TagParallelFor.H:19
Array4< T const > sfab
Definition AMReX_TagParallelFor.H:18
Definition AMReX_TagParallelFor.H:50
Array4< T > dfab
Definition AMReX_TagParallelFor.H:51
__host__ __device__ Box box() const noexcept
Definition AMReX_TagParallelFor.H:54
Definition AMReX_Array4.H:61
Definition AMReX_TagParallelFor.H:106
Box bx
Definition AMReX_TagParallelFor.H:109
__host__ __device__ Box const & box() const noexcept
Definition AMReX_TagParallelFor.H:112
std::ptrdiff_t poff
Definition AMReX_TagParallelFor.H:108
Array4< T > dfab
Definition AMReX_TagParallelFor.H:107
Definition AMReX_TagParallelFor.H:116
std::ptrdiff_t poff
Definition AMReX_TagParallelFor.H:118
__host__ __device__ Box const & box() const noexcept
Definition AMReX_TagParallelFor.H:122
Array4< T const > sfab
Definition AMReX_TagParallelFor.H:117
Box bx
Definition AMReX_TagParallelFor.H:119
Definition AMReX_Dim3.H:12
Definition AMReX_TagParallelFor.H:158
TagVector(TagVector &&other) noexcept
Definition AMReX_TagParallelFor.H:184
~TagVector()
Definition AMReX_TagParallelFor.H:176
char * h_buffer
Definition AMReX_TagParallelFor.H:160
TagVector(const TagVector &other)=delete
TagType * d_tags
Definition AMReX_TagParallelFor.H:162
bool defined
Definition AMReX_TagParallelFor.H:167
int * d_nwarps
Definition AMReX_TagParallelFor.H:163
TagVector(Vector< TagType > const &tags)
Definition AMReX_TagParallelFor.H:172
bool is_defined() const
Definition AMReX_TagParallelFor.H:227
int ntotwarps
Definition AMReX_TagParallelFor.H:165
char * d_buffer
Definition AMReX_TagParallelFor.H:161
int ntags
Definition AMReX_TagParallelFor.H:164
TagVector & operator=(const TagVector &other)=delete
int nblocks
Definition AMReX_TagParallelFor.H:166
static constexpr int nthreads
Definition AMReX_TagParallelFor.H:168
void undefine()
Definition AMReX_TagParallelFor.H:290
TagVector()=default
void define(Vector< TagType > const &tags)
Definition AMReX_TagParallelFor.H:229
Definition AMReX_TagParallelFor.H:97
T * p
Definition AMReX_TagParallelFor.H:98
Long m_size
Definition AMReX_TagParallelFor.H:99
__host__ __device__ Long size() const noexcept
Definition AMReX_TagParallelFor.H:102