1 #ifndef AMREX_CUDA_GRAPH_H_
2 #define AMREX_CUDA_GRAPH_H_
3 #include <AMReX_Config.H>
5 #if defined(__CUDACC__) && defined(AMREX_USE_CUDA)
17 Dim3 src_begin = {0,0,0};
18 Dim3 src_end = {0,0,0};
19 Dim3 dst_begin = {0,0,0};
20 Dim3 dst_end = {0,0,0};
26 Array4<T const> getSrc () {
return Array4<T const>(
static_cast<T const*
>(src), src_begin, src_end, scomp+ncomp); }
30 Array4<T> getDst () {
return Array4<T>(
static_cast<T*
>(dst), dst_begin, dst_end, scomp+ncomp); }
33 template <
typename T,
typename U>
35 makeCopyMemory (Array4<T>
const& src, Array4<U>
const& dst,
int scomp,
int ncomp)
37 #if __cplusplus < 201402L
39 mem.src = (
void*)(src.p);
40 mem.dst = (
void*)(dst.p);
41 mem.src_begin = src.begin;
42 mem.src_end = src.end;
43 mem.dst_begin = dst.begin;
44 mem.dst_end = dst.end;
51 return CopyMemory{ (
void*)(src.p), (
void*)(dst.p), src.begin, src.end, dst.begin, dst.end, scomp, ncomp };
60 cudaGraphExec_t m_graph;
62 T* m_parms_d =
nullptr;
63 bool graph_is_ready =
false;
68 static_assert(std::is_trivially_copyable<T>::value,
"CudaGraph's T must be trivially copyable");
73 static_assert(std::is_trivially_copyable<T>::value,
"CudaGraph's T must be trivially copyable");
74 m_parms_d =
static_cast<T*
>(
The_Arena()->
alloc(
sizeof(T)*m_parms.size()) );
85 void resize(Long num) {
87 if (m_parms_d !=
nullptr)
91 m_parms_d =
static_cast<T*
>(
The_Arena()->
alloc(
sizeof(T)*m_parms.size()) );
93 void setGraph(cudaGraphExec_t
const& graph) {
95 graph_is_ready =
true;
97 void setParams(
int idx, T
const& a_parm) { m_parms[idx] = a_parm; }
99 T* getHostPtr (
int idx) {
return (m_parms.data() + idx); }
100 T* getDevicePtr (
int idx)
const {
return (m_parms_d + idx); }
101 bool ready()
const {
return graph_is_ready; }
103 void executeGraph (
bool synch =
true)
const {
104 Gpu::Device::executeGraph(m_graph, synch);
#define AMREX_CUDA_SAFE_CALL(call)
Definition: AMReX_GpuError.H:73
#define AMREX_GPU_HOST_DEVICE
Definition: AMReX_GpuQualifiers.H:20
virtual void free(void *pt)=0
A pure virtual function for deleting the arena pointed to by pt.
virtual void * alloc(std::size_t sz)=0
Definition: AMReX_Amr.cpp:49
Arena * The_Arena()
Definition: AMReX_Arena.cpp:609