3 #include <AMReX_Config.H>
9 #if defined(AMREX_USE_CUDA) && defined(__CUDACC__) && (__CUDACC_VER_MAJOR__ >= 11)
10 # include <cub/cub.cuh>
11 #elif defined(AMREX_USE_HIP)
12 # include <rocprim/rocprim.hpp>
13 #elif defined(AMREX_USE_SYCL) && defined(AMREX_USE_ONEDPL)
14 # include <oneapi/dpl/execution>
15 # include <oneapi/dpl/numeric>
20 #include <type_traits>
27 explicit operator bool() const noexcept {
return flag; }
37 #if defined(AMREX_USE_GPU)
48 template <
typename T,
bool SINGLE_WORD>
struct BlockStatus {};
63 void write (
char a_status, T a_value) {
64 #if defined(AMREX_USE_CUDA)
65 volatile uint64_t tmp;
66 reinterpret_cast<STVA<T> volatile&
>(tmp).status = a_status;
67 reinterpret_cast<STVA<T> volatile&
>(tmp).value = a_value;
68 reinterpret_cast<uint64_t&
>(d.s) = tmp;
71 tmp.s = {a_status, a_value};
72 static_assert(
sizeof(
unsigned long long) ==
sizeof(uint64_t),
73 "HIP/SYCL: unsigned long long must be 64 bits");
75 reinterpret_cast<unsigned long long&
>(tmp));
84 #if defined(AMREX_USE_CUDA)
85 volatile uint64_t tmp =
reinterpret_cast<uint64_t volatile&
>(d);
86 return {
reinterpret_cast<STVA<T> volatile&
>(tmp).status,
89 static_assert(
sizeof(
unsigned long long) ==
sizeof(uint64_t),
90 "HIP/SYCL: unsigned long long must be 64 bits");
92 (
reinterpret_cast<unsigned long long*
>(
const_cast<Data<T>*
>(&d)), 0ull);
93 return (*
reinterpret_cast<Data<T>*
>(&tmp)).s;
98 void set_status (
char a_status) { d.s.status = a_status; }
104 #if defined(AMREX_USE_SYCL)
105 sycl::atomic_fence(sycl::memory_order::acq_rel, sycl::memory_scope::work_group);
107 __threadfence_block();
110 }
while (
r.status ==
'x');
115 template <
typename T>
123 void write (
char a_status, T a_value) {
124 if (a_status ==
'a') {
129 #if defined(AMREX_USE_SYCL)
130 sycl::atomic_fence(sycl::memory_order::acq_rel, sycl::memory_scope::device);
142 #if defined(AMREX_USE_SYCL)
143 constexpr
auto mo = sycl::memory_order::relaxed;
144 constexpr
auto ms = sycl::memory_scope::device;
145 constexpr
auto as = sycl::access::address_space::global_space;
149 }
else if (status ==
'a') {
150 #if defined(AMREX_USE_SYCL)
151 sycl::atomic_ref<T,mo,ms,as> ar{
const_cast<T&
>(aggregate)};
152 return {
'a', ar.load()};
154 return {
'a', aggregate};
157 #if defined(AMREX_USE_SYCL)
158 sycl::atomic_ref<T,mo,ms,as> ar{
const_cast<T&
>(
inclusive)};
159 return {
'p', ar.load()};
174 #if defined(AMREX_USE_SYCL)
175 sycl::atomic_fence(sycl::memory_order::acq_rel, sycl::memory_scope::device);
179 }
while (
r.status ==
'x');
186 #if defined(AMREX_USE_SYCL)
188 #ifndef AMREX_SYCL_NO_MULTIPASS_SCAN
189 template <
typename T,
typename N,
typename FIN,
typename FOUT,
typename TYPE>
190 T PrefixSum_mp (N n, FIN
const& fin, FOUT
const& fout, TYPE, RetSum a_ret_sum)
192 if (n <= 0) {
return 0; }
193 constexpr
int nwarps_per_block = 8;
195 constexpr
int nchunks = 12;
196 constexpr
int nelms_per_block = nthreads * nchunks;
198 int nblocks = (
static_cast<Long
>(n) + nelms_per_block - 1) / nelms_per_block;
202 std::size_t nbytes_blockresult =
Arena::align(
sizeof(T)*n);
203 std::size_t nbytes_blocksum =
Arena::align(
sizeof(T)*nblocks);
208 T* blockresult_p = (T*)
dp;
209 T* blocksum_p = (T*)(
dp + nbytes_blockresult);
210 T* totalsum_p = (T*)(
dp + nbytes_blockresult + nbytes_blocksum);
215 sycl::sub_group
const& sg = gh.item->get_sub_group();
216 int lane = sg.get_local_id()[0];
217 int warp = sg.get_group_id()[0];
218 int nwarps = sg.get_group_range()[0];
220 int threadIdxx = gh.item->get_local_id(0);
221 int blockIdxx = gh.item->get_group_linear_id();
222 int blockDimx = gh.item->get_local_range(0);
224 T* shared = (T*)(gh.local);
228 N ibegin = nelms_per_block * blockIdxx;
229 N iend =
amrex::min(
static_cast<N
>(ibegin+nelms_per_block), n);
234 T sum_prev_chunk = 0;
235 for (
int ichunk = 0; ichunk < nchunks; ++ichunk) {
236 N
offset = ibegin + ichunk*blockDimx;
237 if (
offset >= iend) {
break; }
244 T s = sycl::shift_group_right(sg, x, i);
245 if (lane >= i) {
x += s; }
255 gh.item->barrier(sycl::access::fence_space::local_space);
260 T y = (lane < nwarps) ? shared[lane] : 0;
262 T s = sycl::shift_group_right(sg, y, i);
263 if (lane >= i) { y += s; }
266 if (lane < nwarps) { shared2[lane] = y; }
269 gh.item->barrier(sycl::access::fence_space::local_space);
276 T sum_prev_warp = (warp == 0) ? 0 : shared2[warp-1];
277 T tmp_out = sum_prev_warp + sum_prev_chunk +
278 (std::is_same<std::decay_t<TYPE>,Type::Inclusive>::value ?
x :
x-x0);
279 sum_prev_chunk += shared2[nwarps-1];
282 blockresult_p[
offset] = tmp_out;
287 if (threadIdxx == 0) {
288 blocksum_p[blockIdxx] = sum_prev_chunk;
295 sycl::sub_group
const& sg = gh.item->get_sub_group();
296 int lane = sg.get_local_id()[0];
297 int warp = sg.get_group_id()[0];
298 int nwarps = sg.get_group_range()[0];
300 int threadIdxx = gh.item->get_local_id(0);
301 int blockDimx = gh.item->get_local_range(0);
303 T* shared = (T*)(gh.local);
306 T sum_prev_chunk = 0;
311 T s = sycl::shift_group_right(sg, x, i);
312 if (lane >= i) {
x += s; }
322 gh.item->barrier(sycl::access::fence_space::local_space);
327 T y = (lane < nwarps) ? shared[lane] : 0;
329 T s = sycl::shift_group_right(sg, y, i);
330 if (lane >= i) { y += s; }
333 if (lane < nwarps) { shared2[lane] = y; }
336 gh.item->barrier(sycl::access::fence_space::local_space);
343 T sum_prev_warp = (warp == 0) ? 0 : shared2[warp-1];
344 T tmp_out = sum_prev_warp + sum_prev_chunk +
x;
345 sum_prev_chunk += shared2[nwarps-1];
348 blocksum_p[
offset] = tmp_out;
353 if (threadIdxx == 0) {
354 *totalsum_p = sum_prev_chunk;
361 int threadIdxx = gh.item->get_local_id(0);
362 int blockIdxx = gh.item->get_group_linear_id();
363 int blockDimx = gh.item->get_local_range(0);
366 N ibegin = nelms_per_block * blockIdxx;
367 N iend =
amrex::min(
static_cast<N
>(ibegin+nelms_per_block), n);
368 T prev_sum = (blockIdxx == 0) ? 0 : blocksum_p[blockIdxx-1];
387 template <
typename T,
typename N,
typename FIN,
typename FOUT,
typename TYPE,
388 typename M=std::enable_if_t<std::is_integral<N>::value &&
389 (std::is_same<std::decay_t<TYPE>,Type::Inclusive>::value ||
390 std::is_same<std::decay_t<TYPE>,Type::Exclusive>::value)> >
391 T
PrefixSum (N n, FIN && fin, FOUT && fout, TYPE type, RetSum a_ret_sum =
retSum)
393 if (n <= 0) {
return 0; }
394 constexpr
int nwarps_per_block = 8;
396 constexpr
int nchunks = 12;
397 constexpr
int nelms_per_block = nthreads * nchunks;
399 int nblocks = (
static_cast<Long
>(n) + nelms_per_block - 1) / nelms_per_block;
401 #ifndef AMREX_SYCL_NO_MULTIPASS_SCAN
403 return PrefixSum_mp<T>(n, std::forward<FIN>(fin), std::forward<FOUT>(fout), type, a_ret_sum);
410 using BlockStatusT = std::conditional_t<
sizeof(detail::STVA<T>) <= 8,
411 detail::BlockStatus<T,true>, detail::BlockStatus<T,false> >;
413 std::size_t nbytes_blockstatus =
Arena::align(
sizeof(BlockStatusT)*nblocks);
414 std::size_t nbytes_blockid =
Arena::align(
sizeof(
unsigned int));
420 unsigned int*
AMREX_RESTRICT virtual_block_id_p = (
unsigned int*)(
dp + nbytes_blockstatus);
421 T*
AMREX_RESTRICT totalsum_p = (T*)(
dp + nbytes_blockstatus + nbytes_blockid);
424 BlockStatusT& block_status = block_status_p[i];
425 block_status.set_status(
'x');
427 *virtual_block_id_p = 0;
435 sycl::sub_group
const& sg = gh.item->get_sub_group();
436 int lane = sg.get_local_id()[0];
437 int warp = sg.get_group_id()[0];
438 int nwarps = sg.get_group_range()[0];
440 int threadIdxx = gh.item->get_local_id(0);
441 int blockDimx = gh.item->get_local_range(0);
442 int gridDimx = gh.item->get_group_range(0);
444 T* shared = (T*)(gh.local);
450 int virtual_block_id = 0;
452 int& virtual_block_id_shared = *((
int*)(shared2+nwarps));
453 if (threadIdxx == 0) {
455 virtual_block_id_shared = bid;
457 gh.item->barrier(sycl::access::fence_space::local_space);
458 virtual_block_id = virtual_block_id_shared;
462 N ibegin = nelms_per_block * virtual_block_id;
463 N iend =
amrex::min(
static_cast<N
>(ibegin+nelms_per_block), n);
464 BlockStatusT& block_status = block_status_p[virtual_block_id];
475 T sum_prev_chunk = 0;
477 for (
int ichunk = 0; ichunk < nchunks; ++ichunk) {
478 N
offset = ibegin + ichunk*blockDimx;
479 if (
offset >= iend) {
break; }
483 if (std::is_same<std::decay_t<TYPE>,Type::Exclusive>::value &&
offset == n-1) {
489 T s = sycl::shift_group_right(sg, x, i);
490 if (lane >= i) {
x += s; }
500 gh.item->barrier(sycl::access::fence_space::local_space);
505 T y = (lane < nwarps) ? shared[lane] : 0;
507 T s = sycl::shift_group_right(sg, y, i);
508 if (lane >= i) { y += s; }
511 if (lane < nwarps) { shared2[lane] = y; }
514 gh.item->barrier(sycl::access::fence_space::local_space);
521 T sum_prev_warp = (warp == 0) ? 0 : shared2[warp-1];
522 tmp_out[ichunk] = sum_prev_warp + sum_prev_chunk +
523 (std::is_same<std::decay_t<TYPE>,Type::Inclusive>::value ?
x :
x-x0);
524 sum_prev_chunk += shared2[nwarps-1];
528 if (threadIdxx == 0 && gridDimx > 1) {
529 block_status.write((virtual_block_id == 0) ?
'p' :
'a',
533 if (virtual_block_id == 0) {
534 for (
int ichunk = 0; ichunk < nchunks; ++ichunk) {
535 N
offset = ibegin + ichunk*blockDimx + threadIdxx;
536 if (
offset >= iend) {
break; }
537 fout(
offset, tmp_out[ichunk]);
539 *totalsum_p += tmp_out[ichunk];
542 }
else if (virtual_block_id > 0) {
545 T exclusive_prefix = 0;
546 BlockStatusT
volatile* pbs = block_status_p;
549 int iblock = iblock0-lane;
550 detail::STVA<T> stva{
'p', 0};
552 stva = pbs[iblock].wait();
558 unsigned status_bf = (stva.status ==
'p') ? (0x1u << lane) : 0;
560 status_bf |= sycl::permute_group_by_xor(sg, status_bf, i);
563 bool stop_lookback = status_bf & 0x1u;
564 if (stop_lookback ==
false) {
565 if (status_bf != 0) {
567 if (lane > 0) {
x = 0; }
568 unsigned int bit_mask = 0x1u;
571 if (i == lane) {
x = y; }
572 if (status_bf & bit_mask) {
573 stop_lookback =
true;
580 x += sycl::shift_group_left(sg, x,i);
584 if (lane == 0) { exclusive_prefix +=
x; }
585 if (stop_lookback) {
break; }
589 block_status.write(
'p', block_status.get_aggregate() + exclusive_prefix);
590 shared[0] = exclusive_prefix;
594 gh.item->barrier(sycl::access::fence_space::local_space);
596 T exclusive_prefix = shared[0];
598 for (
int ichunk = 0; ichunk < nchunks; ++ichunk) {
599 N
offset = ibegin + ichunk*blockDimx + threadIdxx;
600 if (
offset >= iend) {
break; }
601 T t = tmp_out[ichunk] + exclusive_prefix;
624 #elif defined(AMREX_USE_HIP)
626 template <
typename T,
typename N,
typename FIN,
typename FOUT,
typename TYPE,
627 typename M=std::enable_if_t<std::is_integral<N>::value &&
628 (std::is_same<std::decay_t<TYPE>,Type::Inclusive>::value ||
629 std::is_same<std::decay_t<TYPE>,Type::Exclusive>::value)> >
630 T
PrefixSum (N n, FIN
const& fin, FOUT
const& fout, TYPE, RetSum a_ret_sum =
retSum)
632 if (n <= 0) {
return 0; }
633 constexpr
int nwarps_per_block = 4;
635 constexpr
int nelms_per_thread =
sizeof(T) >= 8 ? 8 : 16;
636 constexpr
int nelms_per_block = nthreads * nelms_per_thread;
637 int nblocks = (n + nelms_per_block - 1) / nelms_per_block;
641 using ScanTileState = rocprim::detail::lookback_scan_state<T>;
642 using OrderedBlockId = rocprim::detail::ordered_block_id<unsigned int>;
644 #if (defined(HIP_VERSION_MAJOR) && (HIP_VERSION_MAJOR < 6)) || \
645 (defined(HIP_VERSION_MAJOR) && (HIP_VERSION_MAJOR == 6) && \
646 defined(HIP_VERSION_MINOR) && (HIP_VERSION_MINOR == 0))
648 std::size_t nbytes_tile_state = rocprim::detail::align_size
649 (ScanTileState::get_storage_size(nblocks));
650 std::size_t nbytes_block_id = OrderedBlockId::get_storage_size();
654 ScanTileState tile_state = ScanTileState::create(
dp, nblocks);
658 std::size_t nbytes_tile_state;
659 AMREX_HIP_SAFE_CALL(ScanTileState::get_storage_size(nblocks, stream, nbytes_tile_state));
660 nbytes_tile_state = rocprim::detail::align_size(nbytes_tile_state);
662 std::size_t nbytes_block_id = OrderedBlockId::get_storage_size();
666 ScanTileState tile_state;
667 AMREX_HIP_SAFE_CALL(ScanTileState::create(tile_state,
dp, nblocks, stream));
671 auto ordered_block_id = OrderedBlockId::create
672 (
reinterpret_cast<OrderedBlockId::id_type*
>(
dp + nbytes_tile_state));
677 auto& scan_tile_state =
const_cast<ScanTileState&
>(tile_state);
678 auto& scan_bid =
const_cast<OrderedBlockId&
>(ordered_block_id);
679 const unsigned int gid = blockIdx.x*nthreads + threadIdx.x;
680 if (gid == 0) { scan_bid.reset(); }
681 scan_tile_state.initialize_prefix(gid, nblocks);
686 amrex::launch_global<nthreads> <<<nblocks, nthreads, sm, stream>>> (
689 using BlockLoad = rocprim::block_load<T, nthreads, nelms_per_thread,
690 rocprim::block_load_method::block_load_transpose>;
691 using BlockScan = rocprim::block_scan<T, nthreads,
692 rocprim::block_scan_algorithm::using_warp_scan>;
693 using BlockExchange = rocprim::block_exchange<T, nthreads, nelms_per_thread>;
694 using LookbackScanPrefixOp = rocprim::detail::lookback_scan_prefix_op
695 <T, rocprim::plus<T>, ScanTileState>;
697 __shared__
struct TempStorage {
698 typename OrderedBlockId::storage_type ordered_bid;
700 typename BlockLoad::storage_type load;
701 typename BlockExchange::storage_type exchange;
702 typename BlockScan::storage_type scan;
707 auto& scan_tile_state =
const_cast<ScanTileState&
>(tile_state);
708 auto& scan_bid =
const_cast<OrderedBlockId&
>(ordered_block_id);
710 auto const virtual_block_id = scan_bid.get(threadIdx.x, temp_storage.ordered_bid);
713 N ibegin = nelms_per_block * virtual_block_id;
714 N iend =
amrex::min(
static_cast<N
>(ibegin+nelms_per_block), n);
716 auto input_begin = rocprim::make_transform_iterator(
717 rocprim::make_counting_iterator(N(0)),
718 [&] (N i) -> T {
return fin(i+ibegin); });
720 T data[nelms_per_thread];
721 if (
static_cast<int>(iend-ibegin) == nelms_per_block) {
722 BlockLoad().load(input_begin, data, temp_storage.load);
725 BlockLoad().load(input_begin, data, iend-ibegin, 0, temp_storage.load);
730 constexpr
bool is_exclusive = std::is_same<std::decay_t<TYPE>,Type::Exclusive>::value;
732 if (virtual_block_id == 0) {
735 BlockScan().exclusive_scan(data, data, T{0}, block_agg, temp_storage.scan);
737 BlockScan().inclusive_scan(data, data, block_agg, temp_storage.scan);
739 if (threadIdx.x == 0) {
741 scan_tile_state.set_complete(0, block_agg);
742 }
else if (nblocks == 1 && totalsum_p) {
743 *totalsum_p = block_agg;
747 T last = data[nelms_per_thread-1];
749 LookbackScanPrefixOp prefix_op(virtual_block_id, rocprim::plus<T>(), scan_tile_state);
751 BlockScan().exclusive_scan(data, data, temp_storage.scan, prefix_op,
754 BlockScan().inclusive_scan(data, data, temp_storage.scan, prefix_op,
758 if (iend == n && threadIdx.x == nthreads-1) {
759 T tsum = data[nelms_per_thread-1];
768 BlockExchange().blocked_to_striped(data, data, temp_storage.exchange);
770 for (
int i = 0; i < nelms_per_thread; ++i) {
771 N
offset = ibegin + i*nthreads + threadIdx.x;
781 T ret = (a_ret_sum) ? *totalsum_p : T(0);
787 #elif defined(AMREX_USE_CUDA) && defined(__CUDACC__) && (__CUDACC_VER_MAJOR__ >= 11)
789 template <
typename T,
typename N,
typename FIN,
typename FOUT,
typename TYPE,
790 typename M=std::enable_if_t<std::is_integral<N>::value &&
791 (std::is_same<std::decay_t<TYPE>,Type::Inclusive>::value ||
792 std::is_same<std::decay_t<TYPE>,Type::Exclusive>::value)> >
793 T
PrefixSum (N n, FIN
const& fin, FOUT
const& fout, TYPE, RetSum a_ret_sum =
retSum)
795 if (n <= 0) {
return 0; }
796 constexpr
int nwarps_per_block = 8;
798 constexpr
int nelms_per_thread =
sizeof(T) >= 8 ? 4 : 8;
799 constexpr
int nelms_per_block = nthreads * nelms_per_thread;
800 int nblocks = (n + nelms_per_block - 1) / nelms_per_block;
804 using ScanTileState = cub::ScanTileState<T>;
805 std::size_t tile_state_size = 0;
806 ScanTileState::AllocationSize(nblocks, tile_state_size);
808 std::size_t nbytes_tile_state =
Arena::align(tile_state_size);
809 auto tile_state_p = (
char*)(
The_Arena()->
alloc(nbytes_tile_state));
811 ScanTileState tile_state;
812 tile_state.Init(nblocks, tile_state_p, tile_state_size);
818 const_cast<ScanTileState&
>(tile_state).InitializeStatus(nblocks);
824 amrex::launch_global<nthreads> <<<nblocks, nthreads, sm, stream>>> (
827 using BlockLoad = cub::BlockLoad<T, nthreads, nelms_per_thread, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
828 using BlockScan = cub::BlockScan<T, nthreads, cub::BLOCK_SCAN_WARP_SCANS>;
829 using BlockExchange = cub::BlockExchange<T, nthreads, nelms_per_thread>;
830 using TilePrefixCallbackOp = cub::TilePrefixCallbackOp<T, cub::Sum, ScanTileState>;
832 __shared__
union TempStorage
834 typename BlockLoad::TempStorage load;
835 typename BlockExchange::TempStorage exchange;
837 typename BlockScan::TempStorage scan;
838 typename TilePrefixCallbackOp::TempStorage prefix;
843 auto& scan_tile_state =
const_cast<ScanTileState&
>(tile_state);
845 int virtual_block_id = blockIdx.x;
848 N ibegin = nelms_per_block * virtual_block_id;
849 N iend =
amrex::min(
static_cast<N
>(ibegin+nelms_per_block), n);
851 auto input_lambda = [&] (N i) -> T {
return fin(i+ibegin); };
852 cub::TransformInputIterator<T,decltype(input_lambda),cub::CountingInputIterator<N> >
853 input_begin(cub::CountingInputIterator<N>(0), input_lambda);
855 T data[nelms_per_thread];
856 if (
static_cast<int>(iend-ibegin) == nelms_per_block) {
857 BlockLoad(temp_storage.load).Load(input_begin, data);
859 BlockLoad(temp_storage.load).Load(input_begin, data, iend-ibegin, 0);
864 constexpr
bool is_exclusive = std::is_same<std::decay_t<TYPE>,Type::Exclusive>::value;
866 if (virtual_block_id == 0) {
869 BlockScan(temp_storage.scan_storeage.scan).ExclusiveSum(data, data, block_agg);
871 BlockScan(temp_storage.scan_storeage.scan).InclusiveSum(data, data, block_agg);
873 if (threadIdx.x == 0) {
875 scan_tile_state.SetInclusive(0, block_agg);
876 }
else if (nblocks == 1 && totalsum_p) {
877 *totalsum_p = block_agg;
881 T last = data[nelms_per_thread-1];
883 TilePrefixCallbackOp prefix_op(scan_tile_state, temp_storage.scan_storeage.prefix,
886 BlockScan(temp_storage.scan_storeage.scan).ExclusiveSum(data, data, prefix_op);
888 BlockScan(temp_storage.scan_storeage.scan).InclusiveSum(data, data, prefix_op);
891 if (iend == n && threadIdx.x == nthreads-1) {
892 T tsum = data[nelms_per_thread-1];
901 BlockExchange(temp_storage.exchange).BlockedToStriped(data);
903 for (
int i = 0; i < nelms_per_thread; ++i) {
904 N
offset = ibegin + i*nthreads + threadIdx.x;
914 T ret = (a_ret_sum) ? *totalsum_p : T(0);
922 template <
typename T,
typename N,
typename FIN,
typename FOUT,
typename TYPE,
923 typename M=std::enable_if_t<std::is_integral<N>::value &&
924 (std::is_same<std::decay_t<TYPE>,Type::Inclusive>::value ||
925 std::is_same<std::decay_t<TYPE>,Type::Exclusive>::value)> >
928 if (n <= 0) {
return 0; }
929 constexpr
int nwarps_per_block = 4;
931 constexpr
int nchunks = 12;
932 constexpr
int nelms_per_block = nthreads * nchunks;
934 int nblocks = (
static_cast<Long
>(n) + nelms_per_block - 1) / nelms_per_block;
941 std::size_t nbytes_blockstatus =
Arena::align(
sizeof(BlockStatusT)*nblocks);
942 std::size_t nbytes_blockid =
Arena::align(
sizeof(
unsigned int));
948 unsigned int*
AMREX_RESTRICT virtual_block_id_p = (
unsigned int*)(
dp + nbytes_blockstatus);
949 T*
AMREX_RESTRICT totalsum_p = (T*)(
dp + nbytes_blockstatus + nbytes_blockid);
952 BlockStatusT& block_status = block_status_p[i];
953 block_status.set_status(
'x');
955 *virtual_block_id_p = 0;
974 int virtual_block_id = 0;
976 int& virtual_block_id_shared = *((int*)(shared2+nwarps));
977 if (threadIdx.x == 0) {
978 unsigned int bid = Gpu::Atomic::Add(virtual_block_id_p, 1u);
979 virtual_block_id_shared = bid;
982 virtual_block_id = virtual_block_id_shared;
986 N ibegin = nelms_per_block * virtual_block_id;
987 N iend =
amrex::min(
static_cast<N
>(ibegin+nelms_per_block), n);
988 BlockStatusT& block_status = block_status_p[virtual_block_id];
999 T sum_prev_chunk = 0;
1001 for (
int ichunk = 0; ichunk < nchunks; ++ichunk) {
1002 N
offset = ibegin + ichunk*nthreads;
1003 if (
offset >= iend) {
break; }
1014 T s = __shfl_up_sync(0xffffffff,
x, i); )
1015 if (lane >= i) {
x += s; }
1030 #ifdef AMREX_USE_CUDA
1031 if (warp == 0 && lane < nwarps) {
1033 int mask = (1 << nwarps) - 1;
1034 for (
int i = 1; i <= nwarps; i *= 2) {
1035 T s = __shfl_up_sync(
mask, y, i, nwarps);
1036 if (lane >= i) { y += s; }
1043 if (lane < nwarps) {
1046 for (
int i = 1; i <= nwarps; i *= 2) {
1047 T s = __shfl_up(y, i, nwarps);
1048 if (lane >= i) { y += s; }
1050 if (lane < nwarps) {
1063 T sum_prev_warp = (warp == 0) ? 0 : shared2[warp-1];
1064 tmp_out[ichunk] = sum_prev_warp + sum_prev_chunk +
1066 sum_prev_chunk += shared2[nwarps-1];
1070 if (threadIdx.x == 0 && gridDim.x > 1) {
1071 block_status.write((virtual_block_id == 0) ?
'p' :
'a',
1075 if (virtual_block_id == 0) {
1076 for (
int ichunk = 0; ichunk < nchunks; ++ichunk) {
1077 N
offset = ibegin + ichunk*nthreads + threadIdx.x;
1078 if (
offset >= iend) {
break; }
1079 fout(
offset, tmp_out[ichunk]);
1081 *totalsum_p += tmp_out[ichunk];
1084 }
else if (virtual_block_id > 0) {
1087 T exclusive_prefix = 0;
1088 BlockStatusT
volatile* pbs = block_status_p;
1091 int iblock = iblock0-lane;
1094 stva = pbs[iblock].wait();
1100 unsigned const status_bf = __ballot_sync(0xffffffff, stva.status ==
'p'));
1101 bool stop_lookback = status_bf & 0x1u;
1102 if (stop_lookback ==
false) {
1103 if (status_bf != 0) {
1105 if (lane > 0) {
x = 0; }
1107 unsigned bit_mask = 0x1u);
1110 if (i == lane) {
x = y; }
1111 if (status_bf & bit_mask) {
1112 stop_lookback =
true;
1120 x += __shfl_down_sync(0xffffffff,
x, i); )
1124 if (lane == 0) { exclusive_prefix +=
x; }
1125 if (stop_lookback) {
break; }
1129 block_status.write(
'p', block_status.get_aggregate() + exclusive_prefix);
1130 shared[0] = exclusive_prefix;
1136 T exclusive_prefix = shared[0];
1138 for (
int ichunk = 0; ichunk < nchunks; ++ichunk) {
1139 N
offset = ibegin + ichunk*nthreads + threadIdx.x;
1140 if (
offset >= iend) {
break; }
1141 T t = tmp_out[ichunk] + exclusive_prefix;
1167 template <typename N, typename T, typename M=std::enable_if_t<std::is_integral<N>::value> >
1170 #if defined(AMREX_USE_CUDA) && defined(__CUDACC__) && (__CUDACC_VER_MAJOR__ >= 11)
1171 void* d_temp =
nullptr;
1172 std::size_t temp_bytes = 0;
1186 #elif defined(AMREX_USE_HIP)
1187 void* d_temp =
nullptr;
1188 std::size_t temp_bytes = 0;
1202 #elif defined(AMREX_USE_SYCL) && defined(AMREX_USE_ONEDPL)
1203 auto policy = oneapi::dpl::execution::make_device_policy(Gpu::Device::streamQueue());
1214 return PrefixSum<T>(
static_cast<int>(n),
1219 return PrefixSum<T>(n,
1228 template <typename N, typename T, typename M=std::enable_if_t<std::is_integral<N>::value> >
1231 if (n <= 0) {
return 0; }
1232 #if defined(AMREX_USE_CUDA) && defined(__CUDACC__) && (__CUDACC_VER_MAJOR__ >= 11)
1237 void* d_temp =
nullptr;
1238 std::size_t temp_bytes = 0;
1251 return in_last+out_last;
1252 #elif defined(AMREX_USE_HIP)
1257 void* d_temp =
nullptr;
1258 std::size_t temp_bytes = 0;
1271 return in_last+out_last;
1272 #elif defined(AMREX_USE_SYCL) && defined(AMREX_USE_ONEDPL)
1277 auto policy = oneapi::dpl::execution::make_device_policy(Gpu::Device::streamQueue());
1285 return in_last+out_last;
1288 return PrefixSum<T>(
static_cast<int>(n),
1293 return PrefixSum<T>(n,
1303 template <
typename T,
typename N,
typename FIN,
typename FOUT,
typename TYPE,
1304 typename M=std::enable_if_t<std::is_integral_v<N> &&
1305 (std::is_same_v<std::decay_t<TYPE>,Type::Inclusive> ||
1306 std::is_same_v<std::decay_t<TYPE>,Type::Exclusive>)> >
1307 T
PrefixSum (N n, FIN
const& fin, FOUT
const& fout, TYPE, RetSum =
retSum)
1309 if (n <= 0) {
return 0; }
1311 for (N i = 0; i < n; ++i) {
1324 template <
typename N,
typename T,
typename M=std::enable_if_t<std::is_
integral_v<N>> >
1327 #if (__cplusplus >= 201703L) && (!defined(_GLIBCXX_RELEASE) || _GLIBCXX_RELEASE >= 10)
1331 std::partial_sum(in, in+n, out);
1333 return (n > 0) ? out[n-1] : T(0);
1337 template <
typename N,
typename T,
typename M=std::enable_if_t<std::is_
integral_v<N>> >
1340 if (n <= 0) {
return 0; }
1342 auto in_last = in[n-1];
1343 #if (__cplusplus >= 201703L) && (!defined(_GLIBCXX_RELEASE) || _GLIBCXX_RELEASE >= 10)
1348 std::partial_sum(in, in+n-1, out+1);
1350 return in_last + out[n-1];
1359 template<
class InIter,
class OutIter>
1362 #if defined(AMREX_USE_GPU)
1363 auto N = std::distance(
begin,
end);
1365 OutIter result_end = result;
1366 std::advance(result_end, N);
1368 #elif (__cplusplus >= 201703L) && (!defined(_GLIBCXX_RELEASE) || _GLIBCXX_RELEASE >= 10)
1372 return std::partial_sum(
begin,
end, result);
1376 template<
class InIter,
class OutIter>
1379 #if defined(AMREX_USE_GPU)
1380 auto N = std::distance(
begin,
end);
1382 OutIter result_end = result;
1383 std::advance(result_end, N);
1385 #elif (__cplusplus >= 201703L) && (!defined(_GLIBCXX_RELEASE) || _GLIBCXX_RELEASE >= 10)
1389 if (
begin ==
end) {
return result; }
1391 typename std::iterator_traits<InIter>::value_type
sum = *
begin;
#define AMREX_ALWAYS_ASSERT(EX)
Definition: AMReX_BLassert.H:50
#define AMREX_FORCE_INLINE
Definition: AMReX_Extension.H:119
#define AMREX_RESTRICT
Definition: AMReX_Extension.H:37
#define AMREX_IF_CONSTEXPR
Definition: AMReX_Extension.H:269
#define AMREX_HIP_OR_CUDA(a, b)
Definition: AMReX_GpuControl.H:21
#define AMREX_GPU_SAFE_CALL(call)
Definition: AMReX_GpuError.H:63
#define AMREX_GPU_ERROR_CHECK()
Definition: AMReX_GpuError.H:125
#define AMREX_GPU_DEVICE
Definition: AMReX_GpuQualifiers.H:18
Array4< int const > offset
Definition: AMReX_HypreMLABecLap.cpp:1089
Array4< int const > mask
Definition: AMReX_InterpFaceRegister.cpp:93
virtual void free(void *pt)=0
A pure virtual function for deleting the arena pointed to by pt.
static std::size_t align(std::size_t sz)
Given a minimum required arena size of sz bytes, this returns the next largest arena size that will a...
Definition: AMReX_Arena.cpp:143
virtual void * alloc(std::size_t sz)=0
static constexpr AMREX_EXPORT int warp_size
Definition: AMReX_GpuDevice.H:173
AMREX_GPU_HOST_DEVICE AMREX_FORCE_INLINE T Exch(T *address, T val) noexcept
Definition: AMReX_GpuAtomic.H:485
AMREX_GPU_HOST_DEVICE AMREX_FORCE_INLINE T Add(T *sum, T value) noexcept
Definition: AMReX_GpuAtomic.H:198
OutIter exclusive_scan(InIter begin, InIter end, OutIter result)
Definition: AMReX_Scan.H:1377
void streamSynchronize() noexcept
Definition: AMReX_GpuDevice.H:237
void dtoh_memcpy_async(void *p_h, const void *p_d, const std::size_t sz) noexcept
Definition: AMReX_GpuDevice.H:265
OutIter inclusive_scan(InIter begin, InIter end, OutIter result)
Definition: AMReX_Scan.H:1360
gpuStream_t gpuStream() noexcept
Definition: AMReX_GpuDevice.H:218
void Sum(T &v, MPI_Comm comm)
Definition: AMReX_ParallelReduce.H:204
static constexpr struct amrex::Scan::Type::Exclusive exclusive
static constexpr struct amrex::Scan::Type::Inclusive inclusive
T InclusiveSum(N n, T const *in, T *out, RetSum a_ret_sum=retSum)
Definition: AMReX_Scan.H:1168
T ExclusiveSum(N n, T const *in, T *out, RetSum a_ret_sum=retSum)
Definition: AMReX_Scan.H:1229
static constexpr RetSum noRetSum
Definition: AMReX_Scan.H:30
static constexpr RetSum retSum
Definition: AMReX_Scan.H:29
T PrefixSum(N n, FIN const &fin, FOUT const &fout, TYPE, RetSum a_ret_sum=retSum)
Definition: AMReX_Scan.H:926
@ max
Definition: AMReX_ParallelReduce.H:17
@ sum
Definition: AMReX_ParallelReduce.H:19
static constexpr int M
Definition: AMReX_OpenBC.H:13
Definition: AMReX_Amr.cpp:49
std::enable_if_t< std::is_integral_v< T > > ParallelFor(TypeList< CTOs... > ctos, std::array< int, sizeof...(CTOs)> const &runtime_options, T N, F &&f)
Definition: AMReX_CTOParallelForImpl.H:200
AMREX_GPU_HOST_DEVICE constexpr AMREX_FORCE_INLINE const T & min(const T &a, const T &b) noexcept
Definition: AMReX_Algorithm.H:21
void launch(T const &n, L &&f) noexcept
Definition: AMReX_GpuLaunchFunctsC.H:120
AMREX_GPU_HOST_DEVICE AMREX_FORCE_INLINE Dim3 end(BoxND< dim > const &box) noexcept
Definition: AMReX_Box.H:1890
Arena * The_Pinned_Arena()
Definition: AMReX_Arena.cpp:649
AMREX_GPU_HOST_DEVICE AMREX_FORCE_INLINE Dim3 begin(BoxND< dim > const &box) noexcept
Definition: AMReX_Box.H:1881
const int[]
Definition: AMReX_BLProfiler.cpp:1664
Arena * The_Arena()
Definition: AMReX_Arena.cpp:609
Definition: AMReX_FabArrayCommI.H:841
integer, parameter dp
Definition: AMReX_SDCquadrature.F90:8
Definition: AMReX_GpuMemory.H:125
AMREX_GPU_DEVICE T * dataPtr() noexcept
Definition: AMReX_GpuMemory.H:126
Definition: AMReX_Scan.H:25
bool flag
Definition: AMReX_Scan.H:26
Definition: AMReX_Scan.H:34
Definition: AMReX_Scan.H:33
Definition: AMReX_Scan.H:117
T inclusive
Definition: AMReX_Scan.H:119
T aggregate
Definition: AMReX_Scan.H:118
AMREX_GPU_DEVICE AMREX_FORCE_INLINE void write(char a_status, T a_value)
Definition: AMReX_Scan.H:123
AMREX_GPU_DEVICE AMREX_FORCE_INLINE T get_aggregate() const
Definition: AMReX_Scan.H:138
AMREX_GPU_DEVICE AMREX_FORCE_INLINE STVA< T > read() volatile
Definition: AMReX_Scan.H:141
char status
Definition: AMReX_Scan.H:120
AMREX_GPU_DEVICE AMREX_FORCE_INLINE void set_status(char a_status)
Definition: AMReX_Scan.H:167
AMREX_GPU_DEVICE AMREX_FORCE_INLINE STVA< T > wait() volatile
Definition: AMReX_Scan.H:170
Definition: AMReX_Scan.H:52
AMREX_GPU_DEVICE AMREX_FORCE_INLINE void set_status(char a_status)
Definition: AMReX_Scan.H:98
AMREX_GPU_DEVICE AMREX_FORCE_INLINE T get_aggregate() const
Definition: AMReX_Scan.H:80
AMREX_GPU_DEVICE AMREX_FORCE_INLINE STVA< T > wait() volatile
Definition: AMReX_Scan.H:101
AMREX_GPU_DEVICE AMREX_FORCE_INLINE STVA< T > read() volatile
Definition: AMReX_Scan.H:83
Data< T > d
Definition: AMReX_Scan.H:60
AMREX_GPU_DEVICE AMREX_FORCE_INLINE void write(char a_status, T a_value)
Definition: AMReX_Scan.H:63
Definition: AMReX_Scan.H:48
Definition: AMReX_Scan.H:43
char status
Definition: AMReX_Scan.H:44
T value
Definition: AMReX_Scan.H:45
void operator=(Data< U > &&)=delete
STVA< U > s
Definition: AMReX_Scan.H:55
uint64_t i
Definition: AMReX_Scan.H:56
void operator=(Data< U > const &)=delete