30 explicit operator bool() const noexcept {
return flag; }
40#if defined(AMREX_USE_GPU)
66 void write (
char a_status, T a_value) {
67#if defined(AMREX_USE_CUDA)
68 volatile uint64_t tmp;
69 reinterpret_cast<STVA<T> volatile&
>(tmp).status = a_status;
70 reinterpret_cast<STVA<T> volatile&
>(tmp).value = a_value;
71 reinterpret_cast<uint64_t&
>(d.s) = tmp;
74 tmp.s = {a_status, a_value};
75 static_assert(
sizeof(
unsigned long long) ==
sizeof(uint64_t),
76 "HIP/SYCL: unsigned long long must be 64 bits");
78 reinterpret_cast<unsigned long long&
>(tmp));
87#if defined(AMREX_USE_CUDA)
88 volatile uint64_t tmp =
reinterpret_cast<uint64_t volatile&
>(d);
89 return {
reinterpret_cast<STVA<T> volatile&
>(tmp).status,
92 static_assert(
sizeof(
unsigned long long) ==
sizeof(uint64_t),
93 "HIP/SYCL: unsigned long long must be 64 bits");
95 (
reinterpret_cast<unsigned long long*
>(
const_cast<Data<T>*
>(&d)), 0ull);
96 return (*
reinterpret_cast<Data<T>*
>(&tmp)).s;
107#if defined(AMREX_USE_SYCL)
108 sycl::atomic_fence(sycl::memory_order::acq_rel, sycl::memory_scope::work_group);
110 __threadfence_block();
113 }
while (
r.status ==
'x');
126 void write (
char a_status, T a_value) {
127 if (a_status ==
'a') {
132#if defined(AMREX_USE_SYCL)
133 sycl::atomic_fence(sycl::memory_order::acq_rel, sycl::memory_scope::device);
145#if defined(AMREX_USE_SYCL)
146 constexpr auto mo = sycl::memory_order::relaxed;
147 constexpr auto ms = sycl::memory_scope::device;
148 constexpr auto as = sycl::access::address_space::global_space;
152 }
else if (status ==
'a') {
153#if defined(AMREX_USE_SYCL)
154 sycl::atomic_ref<T,mo,ms,as> ar{
const_cast<T&
>(aggregate)};
155 return {
'a', ar.load()};
157 return {
'a', aggregate};
160#if defined(AMREX_USE_SYCL)
161 sycl::atomic_ref<T,mo,ms,as> ar{
const_cast<T&
>(inclusive)};
162 return {
'p', ar.load()};
164 return {
'p', inclusive};
177#if defined(AMREX_USE_SYCL)
178 sycl::atomic_fence(sycl::memory_order::acq_rel, sycl::memory_scope::device);
182 }
while (
r.status ==
'x');
189#if defined(AMREX_USE_SYCL)
191#ifndef AMREX_SYCL_NO_MULTIPASS_SCAN
192template <
typename T,
typename N,
typename FIN,
typename FOUT,
typename TYPE>
193T PrefixSum_mp (N n, FIN
const& fin, FOUT
const& fout, TYPE, RetSum a_ret_sum)
195 if (n <= 0) {
return 0; }
196 constexpr int nwarps_per_block = 8;
198 constexpr int nchunks = 12;
199 constexpr int nelms_per_block = nthreads * nchunks;
201 std::numeric_limits<int>::max())*nelms_per_block);
202 int nblocks = (
static_cast<Long
>(n) + nelms_per_block - 1) / nelms_per_block;
206 std::size_t nbytes_blockresult =
Arena::align(
sizeof(T)*n);
207 std::size_t nbytes_blocksum =
Arena::align(
sizeof(T)*nblocks);
212 T* blockresult_p = (T*)dp;
213 T* blocksum_p = (T*)(dp + nbytes_blockresult);
214 T* totalsum_p = (T*)(dp + nbytes_blockresult + nbytes_blocksum);
216 amrex::launch<nthreads>(nblocks, sm, stream,
219 sycl::sub_group
const& sg = gh.item->get_sub_group();
220 int lane = sg.get_local_id()[0];
221 int warp = sg.get_group_id()[0];
222 int nwarps = sg.get_group_range()[0];
224 int threadIdxx = gh.item->get_local_id(0);
225 int blockIdxx = gh.item->get_group_linear_id();
226 int blockDimx = gh.item->get_local_range(0);
228 T* shared = (T*)(gh.local);
232 N ibegin =
static_cast<N
>(nelms_per_block) * blockIdxx;
233 N iend =
amrex::min(
static_cast<N
>(ibegin+nelms_per_block), n);
238 T sum_prev_chunk = 0;
239 for (
int ichunk = 0; ichunk < nchunks; ++ichunk) {
240 N
offset = ibegin + ichunk*blockDimx;
241 if (
offset >= iend) {
break; }
248 T s = sycl::shift_group_right(sg,
x, i);
249 if (lane >= i) {
x += s; }
259 gh.item->barrier(sycl::access::fence_space::local_space);
264 T
y = (lane < nwarps) ? shared[lane] : 0;
266 T s = sycl::shift_group_right(sg,
y, i);
267 if (lane >= i) {
y += s; }
270 if (lane < nwarps) { shared2[lane] =
y; }
273 gh.item->barrier(sycl::access::fence_space::local_space);
280 T sum_prev_warp = (warp == 0) ? 0 : shared2[warp-1];
281 T tmp_out = sum_prev_warp + sum_prev_chunk +
282 (std::is_same<std::decay_t<TYPE>,Type::Inclusive>::value ?
x :
x-x0);
283 sum_prev_chunk += shared2[nwarps-1];
286 blockresult_p[
offset] = tmp_out;
291 if (threadIdxx == 0) {
292 blocksum_p[blockIdxx] = sum_prev_chunk;
296 amrex::launch<nthreads>(1, sm, stream,
299 sycl::sub_group
const& sg = gh.item->get_sub_group();
300 int lane = sg.get_local_id()[0];
301 int warp = sg.get_group_id()[0];
302 int nwarps = sg.get_group_range()[0];
304 int threadIdxx = gh.item->get_local_id(0);
305 int blockDimx = gh.item->get_local_range(0);
307 T* shared = (T*)(gh.local);
310 T sum_prev_chunk = 0;
315 T s = sycl::shift_group_right(sg,
x, i);
316 if (lane >= i) {
x += s; }
326 gh.item->barrier(sycl::access::fence_space::local_space);
331 T
y = (lane < nwarps) ? shared[lane] : 0;
333 T s = sycl::shift_group_right(sg,
y, i);
334 if (lane >= i) {
y += s; }
337 if (lane < nwarps) { shared2[lane] =
y; }
340 gh.item->barrier(sycl::access::fence_space::local_space);
347 T sum_prev_warp = (warp == 0) ? 0 : shared2[warp-1];
348 T tmp_out = sum_prev_warp + sum_prev_chunk +
x;
349 sum_prev_chunk += shared2[nwarps-1];
352 blocksum_p[
offset] = tmp_out;
357 if (threadIdxx == 0) {
358 *totalsum_p = sum_prev_chunk;
362 amrex::launch<nthreads>(nblocks, 0, stream,
365 int threadIdxx = gh.item->get_local_id(0);
366 int blockIdxx = gh.item->get_group_linear_id();
367 int blockDimx = gh.item->get_local_range(0);
370 N ibegin =
static_cast<N
>(nelms_per_block) * blockIdxx;
371 N iend =
amrex::min(
static_cast<N
>(ibegin+nelms_per_block), n);
372 T prev_sum = (blockIdxx == 0) ? 0 : blocksum_p[blockIdxx-1];
394template <
typename T,
typename N,
typename FIN,
typename FOUT,
typename TYPE,
395 typename M=std::enable_if_t<std::is_integral_v<N> &&
396 (std::is_same<std::decay_t<TYPE>,Type::Inclusive>::value ||
397 std::is_same<std::decay_t<TYPE>,Type::Exclusive>::value)> >
398T
PrefixSum (N n, FIN && fin, FOUT && fout, TYPE type, RetSum a_ret_sum =
retSum)
400 if (n <= 0) {
return 0; }
401 constexpr int nwarps_per_block = 8;
403 constexpr int nchunks = 12;
404 constexpr int nelms_per_block = nthreads * nchunks;
406 std::numeric_limits<int>::max())*nelms_per_block);
407 int nblocks = (
static_cast<Long
>(n) + nelms_per_block - 1) / nelms_per_block;
409#ifndef AMREX_SYCL_NO_MULTIPASS_SCAN
411 return PrefixSum_mp<T>(n, std::forward<FIN>(fin), std::forward<FOUT>(fout), type, a_ret_sum);
418 using BlockStatusT = std::conditional_t<
sizeof(detail::STVA<T>) <= 8,
419 detail::BlockStatus<T,true>, detail::BlockStatus<T,false> >;
421 std::size_t nbytes_blockstatus =
Arena::align(
sizeof(BlockStatusT)*nblocks);
422 std::size_t nbytes_blockid =
Arena::align(
sizeof(
unsigned int));
428 unsigned int*
AMREX_RESTRICT virtual_block_id_p = (
unsigned int*)(dp + nbytes_blockstatus);
429 T*
AMREX_RESTRICT totalsum_p = (T*)(dp + nbytes_blockstatus + nbytes_blockid);
432 BlockStatusT& block_status = block_status_p[i];
433 block_status.set_status(
'x');
435 *virtual_block_id_p = 0;
440 amrex::launch<nthreads>(nblocks, sm, stream,
443 sycl::sub_group
const& sg = gh.item->get_sub_group();
444 int lane = sg.get_local_id()[0];
445 int warp = sg.get_group_id()[0];
446 int nwarps = sg.get_group_range()[0];
448 int threadIdxx = gh.item->get_local_id(0);
449 int blockDimx = gh.item->get_local_range(0);
450 int gridDimx = gh.item->get_group_range(0);
452 T* shared = (T*)(gh.local);
458 int virtual_block_id = 0;
460 int& virtual_block_id_shared = *((
int*)(shared2+nwarps));
461 if (threadIdxx == 0) {
463 virtual_block_id_shared = bid;
465 gh.item->barrier(sycl::access::fence_space::local_space);
466 virtual_block_id = virtual_block_id_shared;
470 N ibegin =
static_cast<N
>(nelms_per_block) * virtual_block_id;
471 N iend =
amrex::min(
static_cast<N
>(ibegin+nelms_per_block), n);
472 BlockStatusT& block_status = block_status_p[virtual_block_id];
483 T sum_prev_chunk = 0;
485 for (
int ichunk = 0; ichunk < nchunks; ++ichunk) {
486 N
offset = ibegin + ichunk*blockDimx;
487 if (
offset >= iend) {
break; }
491 if (std::is_same<std::decay_t<TYPE>,Type::Exclusive>::value &&
offset == n-1) {
497 T s = sycl::shift_group_right(sg,
x, i);
498 if (lane >= i) {
x += s; }
508 gh.item->barrier(sycl::access::fence_space::local_space);
513 T
y = (lane < nwarps) ? shared[lane] : 0;
515 T s = sycl::shift_group_right(sg,
y, i);
516 if (lane >= i) {
y += s; }
519 if (lane < nwarps) { shared2[lane] =
y; }
522 gh.item->barrier(sycl::access::fence_space::local_space);
529 T sum_prev_warp = (warp == 0) ? 0 : shared2[warp-1];
530 tmp_out[ichunk] = sum_prev_warp + sum_prev_chunk +
531 (std::is_same<std::decay_t<TYPE>,Type::Inclusive>::value ?
x :
x-x0);
532 sum_prev_chunk += shared2[nwarps-1];
536 if (threadIdxx == 0 && gridDimx > 1) {
537 block_status.write((virtual_block_id == 0) ?
'p' :
'a',
541 if (virtual_block_id == 0) {
542 for (
int ichunk = 0; ichunk < nchunks; ++ichunk) {
543 N
offset = ibegin + ichunk*blockDimx + threadIdxx;
544 if (
offset >= iend) {
break; }
545 fout(
offset, tmp_out[ichunk]);
547 *totalsum_p += tmp_out[ichunk];
550 }
else if (virtual_block_id > 0) {
553 T exclusive_prefix = 0;
554 BlockStatusT
volatile* pbs = block_status_p;
557 int iblock = iblock0-lane;
558 detail::STVA<T> stva{
'p', 0};
560 stva = pbs[iblock].wait();
566 unsigned status_bf = (stva.status ==
'p') ? (0x1u << lane) : 0;
568 status_bf |= sycl::permute_group_by_xor(sg, status_bf, i);
571 bool stop_lookback = status_bf & 0x1u;
572 if (stop_lookback ==
false) {
573 if (status_bf != 0) {
575 if (lane > 0) {
x = 0; }
576 unsigned int bit_mask = 0x1u;
579 if (i == lane) {
x =
y; }
580 if (status_bf & bit_mask) {
581 stop_lookback =
true;
588 x += sycl::shift_group_left(sg,
x,i);
592 if (lane == 0) { exclusive_prefix +=
x; }
593 if (stop_lookback) {
break; }
597 block_status.write(
'p', block_status.get_aggregate() + exclusive_prefix);
598 shared[0] = exclusive_prefix;
602 gh.item->barrier(sycl::access::fence_space::local_space);
604 T exclusive_prefix = shared[0];
606 for (
int ichunk = 0; ichunk < nchunks; ++ichunk) {
607 N
offset = ibegin + ichunk*blockDimx + threadIdxx;
608 if (
offset >= iend) {
break; }
609 T t = tmp_out[ichunk] + exclusive_prefix;
635#elif defined(AMREX_USE_HIP)
637template <
typename T,
typename N,
typename FIN,
typename FOUT,
typename TYPE,
638 typename M=std::enable_if_t<std::is_integral_v<N> &&
639 (std::is_same<std::decay_t<TYPE>,Type::Inclusive>::value ||
640 std::is_same<std::decay_t<TYPE>,Type::Exclusive>::value)> >
641T
PrefixSum (N n, FIN
const& fin, FOUT
const& fout, TYPE, RetSum a_ret_sum =
retSum)
643 if (n <= 0) {
return 0; }
644 constexpr int nwarps_per_block = 4;
646 constexpr int nelms_per_thread =
sizeof(T) >= 8 ? 8 : 16;
647 constexpr int nelms_per_block = nthreads * nelms_per_thread;
649 std::numeric_limits<int>::max())*nelms_per_block);
650 int nblocks = (n + nelms_per_block - 1) / nelms_per_block;
654 using ScanTileState = rocprim::detail::lookback_scan_state<T>;
655 using OrderedBlockId = rocprim::detail::ordered_block_id<unsigned int>;
657#if (defined(HIP_VERSION_MAJOR) && (HIP_VERSION_MAJOR < 6)) || \
658 (defined(HIP_VERSION_MAJOR) && (HIP_VERSION_MAJOR == 6) && \
659 defined(HIP_VERSION_MINOR) && (HIP_VERSION_MINOR == 0))
661 std::size_t nbytes_tile_state = rocprim::detail::align_size
662 (ScanTileState::get_storage_size(nblocks));
663 std::size_t nbytes_block_id = OrderedBlockId::get_storage_size();
665 auto dp = (
char*)(
The_Arena()->
alloc(nbytes_tile_state+nbytes_block_id));
667 ScanTileState tile_state = ScanTileState::create(dp, nblocks);
671 std::size_t nbytes_tile_state;
672 AMREX_HIP_SAFE_CALL(ScanTileState::get_storage_size(nblocks, stream, nbytes_tile_state));
673 nbytes_tile_state = rocprim::detail::align_size(nbytes_tile_state);
675 std::size_t nbytes_block_id = OrderedBlockId::get_storage_size();
677 auto dp = (
char*)(
The_Arena()->
alloc(nbytes_tile_state+nbytes_block_id));
679 ScanTileState tile_state;
680 AMREX_HIP_SAFE_CALL(ScanTileState::create(tile_state, dp, nblocks, stream));
684 auto ordered_block_id = OrderedBlockId::create
685 (
reinterpret_cast<OrderedBlockId::id_type*
>(dp + nbytes_tile_state));
688 amrex::launch<nthreads>((nblocks+nthreads-1)/nthreads, 0, stream, [=]
AMREX_GPU_DEVICE ()
690 auto& scan_tile_state =
const_cast<ScanTileState&
>(tile_state);
691 auto& scan_bid =
const_cast<OrderedBlockId&
>(ordered_block_id);
692 const unsigned int gid = blockIdx.x*nthreads + threadIdx.x;
693 if (gid == 0) { scan_bid.reset(); }
694 scan_tile_state.initialize_prefix(gid, nblocks);
699 amrex::launch_global<nthreads> <<<nblocks, nthreads, sm, stream>>> (
702 using BlockLoad = rocprim::block_load<T, nthreads, nelms_per_thread,
703 rocprim::block_load_method::block_load_transpose>;
704 using BlockScan = rocprim::block_scan<T, nthreads,
705 rocprim::block_scan_algorithm::using_warp_scan>;
706 using BlockExchange = rocprim::block_exchange<T, nthreads, nelms_per_thread>;
707 using LookbackScanPrefixOp = rocprim::detail::lookback_scan_prefix_op
708 <T, rocprim::plus<T>, ScanTileState>;
710 __shared__
struct TempStorage {
711 typename OrderedBlockId::storage_type ordered_bid;
713 typename BlockLoad::storage_type load;
714 typename BlockExchange::storage_type exchange;
715 typename BlockScan::storage_type scan;
720 auto& scan_tile_state =
const_cast<ScanTileState&
>(tile_state);
721 auto& scan_bid =
const_cast<OrderedBlockId&
>(ordered_block_id);
723 auto const virtual_block_id = scan_bid.get(threadIdx.x, temp_storage.ordered_bid);
726 N ibegin =
static_cast<N
>(nelms_per_block) * virtual_block_id;
727 N iend =
amrex::min(
static_cast<N
>(ibegin+nelms_per_block), n);
729 auto input_begin = rocprim::make_transform_iterator(
730 rocprim::make_counting_iterator(N(0)),
731 [&] (N i) -> T {
return fin(i+ibegin); });
733 T data[nelms_per_thread];
734 if (
static_cast<int>(iend-ibegin) == nelms_per_block) {
735 BlockLoad().load(input_begin, data, temp_storage.load);
738 BlockLoad().load(input_begin, data, iend-ibegin, 0, temp_storage.load);
743 constexpr bool is_exclusive = std::is_same<std::decay_t<TYPE>,Type::Exclusive>::value;
745 if (virtual_block_id == 0) {
748 BlockScan().exclusive_scan(data, data, T{0}, block_agg, temp_storage.scan);
750 BlockScan().inclusive_scan(data, data, block_agg, temp_storage.scan);
752 if (threadIdx.x == 0) {
754 scan_tile_state.set_complete(0, block_agg);
755 }
else if (nblocks == 1 && totalsum_p) {
756 *totalsum_p = block_agg;
760 T last = data[nelms_per_thread-1];
762 LookbackScanPrefixOp prefix_op(virtual_block_id, rocprim::plus<T>(), scan_tile_state);
764 BlockScan().exclusive_scan(data, data, temp_storage.scan, prefix_op,
767 BlockScan().inclusive_scan(data, data, temp_storage.scan, prefix_op,
771 if (iend == n && threadIdx.x == nthreads-1) {
772 T tsum = data[nelms_per_thread-1];
781 BlockExchange().blocked_to_striped(data, data, temp_storage.exchange);
783 for (
int i = 0; i < nelms_per_thread; ++i) {
784 N
offset = ibegin + i*nthreads + threadIdx.x;
798 T ret = (a_ret_sum) ? *totalsum_p : T(0);
804#elif defined(AMREX_USE_CUDA) && defined(__CUDACC__) && (__CUDACC_VER_MAJOR__ >= 11)
806template <
typename T,
typename N,
typename FIN,
typename FOUT,
typename TYPE,
807 typename M=std::enable_if_t<std::is_integral_v<N> &&
808 (std::is_same<std::decay_t<TYPE>,Type::Inclusive>::value ||
809 std::is_same<std::decay_t<TYPE>,Type::Exclusive>::value)> >
810T
PrefixSum (N n, FIN
const& fin, FOUT
const& fout, TYPE, RetSum a_ret_sum =
retSum)
812 if (n <= 0) {
return 0; }
813 constexpr int nwarps_per_block = 8;
815 constexpr int nelms_per_thread =
sizeof(T) >= 8 ? 4 : 8;
816 constexpr int nelms_per_block = nthreads * nelms_per_thread;
818 std::numeric_limits<int>::max())*nelms_per_block);
819 int nblocks = (n + nelms_per_block - 1) / nelms_per_block;
823 using ScanTileState = cub::ScanTileState<T>;
824 std::size_t tile_state_size = 0;
825 ScanTileState::AllocationSize(nblocks, tile_state_size);
827 std::size_t nbytes_tile_state =
Arena::align(tile_state_size);
828 auto tile_state_p = (
char*)(
The_Arena()->
alloc(nbytes_tile_state));
830 ScanTileState tile_state;
831 tile_state.Init(nblocks, tile_state_p, tile_state_size);
835 amrex::launch<nthreads>((nblocks+nthreads-1)/nthreads, 0, stream, [=]
AMREX_GPU_DEVICE ()
837 const_cast<ScanTileState&
>(tile_state).InitializeStatus(nblocks);
843 amrex::launch_global<nthreads> <<<nblocks, nthreads, sm, stream>>> (
846 using BlockLoad = cub::BlockLoad<T, nthreads, nelms_per_thread, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
847 using BlockScan = cub::BlockScan<T, nthreads, cub::BLOCK_SCAN_WARP_SCANS>;
848 using BlockExchange = cub::BlockExchange<T, nthreads, nelms_per_thread>;
850#ifdef AMREX_CUDA_CCCL_VER_GE_3
851 using Sum = cuda::std::plus<T>;
853 using Sum = cub::Sum;
855 using TilePrefixCallbackOp = cub::TilePrefixCallbackOp<T, Sum, ScanTileState>;
857 __shared__
union TempStorage
859 typename BlockLoad::TempStorage load;
860 typename BlockExchange::TempStorage exchange;
862 typename BlockScan::TempStorage scan;
863 typename TilePrefixCallbackOp::TempStorage prefix;
868 auto& scan_tile_state =
const_cast<ScanTileState&
>(tile_state);
870 int virtual_block_id = blockIdx.x;
873 N ibegin =
static_cast<N
>(nelms_per_block) * virtual_block_id;
874 N iend =
amrex::min(
static_cast<N
>(ibegin+nelms_per_block), n);
876 auto input_lambda = [&] (N i) -> T {
return fin(i+ibegin); };
877#ifdef AMREX_CUDA_CCCL_VER_GE_3
878 thrust::transform_iterator<
decltype(input_lambda),thrust::counting_iterator<N> >
879 input_begin(thrust::counting_iterator<N>(0), input_lambda);
881 cub::TransformInputIterator<T,
decltype(input_lambda),cub::CountingInputIterator<N> >
882 input_begin(cub::CountingInputIterator<N>(0), input_lambda);
885 T data[nelms_per_thread];
886 if (
static_cast<int>(iend-ibegin) == nelms_per_block) {
887 BlockLoad(temp_storage.load).Load(input_begin, data);
889 BlockLoad(temp_storage.load).Load(input_begin, data, iend-ibegin, 0);
894 constexpr bool is_exclusive = std::is_same<std::decay_t<TYPE>,Type::Exclusive>::value;
896 if (virtual_block_id == 0) {
899 BlockScan(temp_storage.scan_storeage.scan).ExclusiveSum(data, data, block_agg);
901 BlockScan(temp_storage.scan_storeage.scan).InclusiveSum(data, data, block_agg);
903 if (threadIdx.x == 0) {
905 scan_tile_state.SetInclusive(0, block_agg);
906 }
else if (nblocks == 1 && totalsum_p) {
907 *totalsum_p = block_agg;
911 T last = data[nelms_per_thread-1];
913 TilePrefixCallbackOp prefix_op(scan_tile_state, temp_storage.scan_storeage.prefix,
914 Sum{}, virtual_block_id);
916 BlockScan(temp_storage.scan_storeage.scan).ExclusiveSum(data, data, prefix_op);
918 BlockScan(temp_storage.scan_storeage.scan).InclusiveSum(data, data, prefix_op);
921 if (iend == n && threadIdx.x == nthreads-1) {
922 T tsum = data[nelms_per_thread-1];
931 BlockExchange(temp_storage.exchange).BlockedToStriped(data);
933 for (
int i = 0; i < nelms_per_thread; ++i) {
934 N
offset = ibegin + i*nthreads + threadIdx.x;
948 T ret = (a_ret_sum) ? *totalsum_p : T(0);
956template <
typename T,
typename N,
typename FIN,
typename FOUT,
typename TYPE,
957 typename M=std::enable_if_t<std::is_integral_v<N> &&
958 (std::is_same<std::decay_t<TYPE>,Type::Inclusive>::value ||
959 std::is_same<std::decay_t<TYPE>,Type::Exclusive>::value)> >
962 if (n <= 0) {
return 0; }
963 constexpr int nwarps_per_block = 4;
965 constexpr int nchunks = 12;
966 constexpr int nelms_per_block = nthreads * nchunks;
968 std::numeric_limits<int>::max())*nelms_per_block);
969 int nblocks = (
static_cast<Long
>(n) + nelms_per_block - 1) / nelms_per_block;
976 std::size_t nbytes_blockstatus =
Arena::align(
sizeof(BlockStatusT)*nblocks);
977 std::size_t nbytes_blockid =
Arena::align(
sizeof(
unsigned int));
983 unsigned int*
AMREX_RESTRICT virtual_block_id_p = (
unsigned int*)(dp + nbytes_blockstatus);
984 T*
AMREX_RESTRICT totalsum_p = (T*)(dp + nbytes_blockstatus + nbytes_blockid);
987 BlockStatusT& block_status = block_status_p[i];
988 block_status.set_status(
'x');
990 *virtual_block_id_p = 0;
995 amrex::launch<nthreads>(nblocks, sm, stream,
1009 int virtual_block_id = 0;
1010 if (gridDim.x > 1) {
1011 int& virtual_block_id_shared = *((int*)(shared2+nwarps));
1012 if (threadIdx.x == 0) {
1013 unsigned int bid = Gpu::Atomic::Add(virtual_block_id_p, 1u);
1014 virtual_block_id_shared = bid;
1017 virtual_block_id = virtual_block_id_shared;
1021 N ibegin =
static_cast<N
>(nelms_per_block) * virtual_block_id;
1022 N iend =
amrex::min(
static_cast<N
>(ibegin+nelms_per_block), n);
1023 BlockStatusT& block_status = block_status_p[virtual_block_id];
1034 T sum_prev_chunk = 0;
1036 for (
int ichunk = 0; ichunk < nchunks; ++ichunk) {
1037 N
offset = ibegin + ichunk*nthreads;
1038 if (
offset >= iend) {
break; }
1049 T s = __shfl_up_sync(0xffffffff,
x, i); )
1050 if (lane >= i) {
x += s; }
1065#ifdef AMREX_USE_CUDA
1066 if (warp == 0 && lane < nwarps) {
1068 int mask = (1 << nwarps) - 1;
1069 for (
int i = 1; i <= nwarps; i *= 2) {
1070 T s = __shfl_up_sync(
mask,
y, i, nwarps);
1071 if (lane >= i) {
y += s; }
1078 if (lane < nwarps) {
1081 for (
int i = 1; i <= nwarps; i *= 2) {
1082 T s = __shfl_up(
y, i, nwarps);
1083 if (lane >= i) {
y += s; }
1085 if (lane < nwarps) {
1098 T sum_prev_warp = (warp == 0) ? 0 : shared2[warp-1];
1099 tmp_out[ichunk] = sum_prev_warp + sum_prev_chunk +
1101 sum_prev_chunk += shared2[nwarps-1];
1105 if (threadIdx.x == 0 && gridDim.x > 1) {
1106 block_status.write((virtual_block_id == 0) ?
'p' :
'a',
1110 if (virtual_block_id == 0) {
1111 for (
int ichunk = 0; ichunk < nchunks; ++ichunk) {
1112 N
offset = ibegin + ichunk*nthreads + threadIdx.x;
1113 if (
offset >= iend) {
break; }
1114 fout(
offset, tmp_out[ichunk]);
1116 *totalsum_p += tmp_out[ichunk];
1119 }
else if (virtual_block_id > 0) {
1122 T exclusive_prefix = 0;
1123 BlockStatusT
volatile* pbs = block_status_p;
1126 int iblock = iblock0-lane;
1129 stva = pbs[iblock].wait();
1135 unsigned const status_bf = __ballot_sync(0xffffffff, stva.status ==
'p'));
1136 bool stop_lookback = status_bf & 0x1u;
1137 if (stop_lookback ==
false) {
1138 if (status_bf != 0) {
1140 if (lane > 0) {
x = 0; }
1142 unsigned bit_mask = 0x1u);
1145 if (i == lane) {
x =
y; }
1146 if (status_bf & bit_mask) {
1147 stop_lookback =
true;
1155 x += __shfl_down_sync(0xffffffff,
x, i); )
1159 if (lane == 0) { exclusive_prefix +=
x; }
1160 if (stop_lookback) {
break; }
1164 block_status.write(
'p', block_status.get_aggregate() + exclusive_prefix);
1165 shared[0] = exclusive_prefix;
1171 T exclusive_prefix = shared[0];
1173 for (
int ichunk = 0; ichunk < nchunks; ++ichunk) {
1174 N
offset = ibegin + ichunk*nthreads + threadIdx.x;
1175 if (
offset >= iend) {
break; }
1176 T t = tmp_out[ichunk] + exclusive_prefix;
1205template <
typename N,
typename T,
typename M=std::enable_if_t<std::is_
integral_v<N>> >
1208#if defined(AMREX_USE_CUDA) && defined(__CUDACC__) && (__CUDACC_VER_MAJOR__ >= 11)
1209 void* d_temp =
nullptr;
1210 std::size_t temp_bytes = 0;
1224#elif defined(AMREX_USE_HIP)
1225 void* d_temp =
nullptr;
1226 std::size_t temp_bytes = 0;
1240#elif defined(AMREX_USE_SYCL) && defined(AMREX_USE_ONEDPL)
1241 auto policy = oneapi::dpl::execution::make_device_policy(Gpu::Device::streamQueue());
1242 std::inclusive_scan(policy, in, in+n, out, std::plus<T>(), T(0));
1251 if (
static_cast<Long
>(n) <=
static_cast<Long
>(std::numeric_limits<int>::max())) {
1252 return PrefixSum<T>(
static_cast<int>(n),
1255 Type::inclusive, a_ret_sum);
1257 return PrefixSum<T>(n,
1260 Type::inclusive, a_ret_sum);
1266template <
typename N,
typename T,
typename M=std::enable_if_t<std::is_
integral_v<N>> >
1269 if (n <= 0) {
return 0; }
1270#if defined(AMREX_USE_CUDA) && defined(__CUDACC__) && (__CUDACC_VER_MAJOR__ >= 11)
1275 void* d_temp =
nullptr;
1276 std::size_t temp_bytes = 0;
1289 return in_last+out_last;
1290#elif defined(AMREX_USE_HIP)
1295 void* d_temp =
nullptr;
1296 std::size_t temp_bytes = 0;
1309 return in_last+out_last;
1310#elif defined(AMREX_USE_SYCL) && defined(AMREX_USE_ONEDPL)
1315 auto policy = oneapi::dpl::execution::make_device_policy(Gpu::Device::streamQueue());
1316 std::exclusive_scan(policy, in, in+n, out, T(0), std::plus<T>());
1323 return in_last+out_last;
1325 if (
static_cast<Long
>(n) <=
static_cast<Long
>(std::numeric_limits<int>::max())) {
1326 return PrefixSum<T>(
static_cast<int>(n),
1329 Type::exclusive, a_ret_sum);
1331 return PrefixSum<T>(n,
1334 Type::exclusive, a_ret_sum);
1341template <
typename T,
typename N,
typename FIN,
typename FOUT,
typename TYPE,
1342 typename M=std::enable_if_t<std::is_integral_v<N> &&
1343 (std::is_same_v<std::decay_t<TYPE>,Type::Inclusive> ||
1344 std::is_same_v<std::decay_t<TYPE>,Type::Exclusive>)> >
1345T PrefixSum (N n, FIN
const& fin, FOUT
const& fout, TYPE, RetSum = retSum)
1347 if (n <= 0) {
return 0; }
1349 for (N i = 0; i < n; ++i) {
1362template <
typename N,
typename T,
typename M=std::enable_if_t<std::is_
integral_v<N>> >
1363T
InclusiveSum (N n, T
const* in, T * out, RetSum = retSum)
1365#if (__cplusplus >= 201703L) && (!defined(_GLIBCXX_RELEASE) || _GLIBCXX_RELEASE >= 10)
1367 std::inclusive_scan(in, in+n, out);
1369 std::partial_sum(in, in+n, out);
1371 return (n > 0) ? out[n-1] : T(0);
1375template <
typename N,
typename T,
typename M=std::enable_if_t<std::is_
integral_v<N>> >
1376T
ExclusiveSum (N n, T
const* in, T * out, RetSum = retSum)
1378 if (n <= 0) {
return 0; }
1380 auto in_last = in[n-1];
1381#if (__cplusplus >= 201703L) && (!defined(_GLIBCXX_RELEASE) || _GLIBCXX_RELEASE >= 10)
1383 std::exclusive_scan(in, in+n, out, 0);
1386 std::partial_sum(in, in+n-1, out+1);
1388 return in_last + out[n-1];