1#ifndef AMREX_CUDA_GRAPH_H_
2#define AMREX_CUDA_GRAPH_H_
3#include <AMReX_Config.H>
5#if defined(__CUDACC__) && defined(AMREX_USE_CUDA)
17 Dim3 src_begin = {0,0,0};
18 Dim3 src_end = {0,0,0};
19 Dim3 dst_begin = {0,0,0};
20 Dim3 dst_end = {0,0,0};
26 Array4<T const> getSrc () {
return Array4<T const>(
static_cast<T const*
>(src), src_begin, src_end, scomp+ncomp); }
30 Array4<T> getDst () {
return Array4<T>(
static_cast<T*
>(dst), dst_begin, dst_end, scomp+ncomp); }
33template <
typename T,
typename U>
35makeCopyMemory (Array4<T>
const& src, Array4<U>
const& dst,
int scomp,
int ncomp)
37 return CopyMemory{ (
void*)(src.p), (
void*)(dst.p), IntVectShrink<3>(src.begin).dim3(), IntVectShrink<3>(src.end).dim3(),
38 IntVectShrink<3>(dst.begin).dim3(), IntVectShrink<3>(dst.end).dim3(), scomp, ncomp };
46 cudaGraphExec_t m_graph;
48 T* m_parms_d =
nullptr;
49 bool graph_is_ready =
false;
54 static_assert(std::is_trivially_copyable<T>::value,
"CudaGraph's T must be trivially copyable");
59 static_assert(std::is_trivially_copyable<T>::value,
"CudaGraph's T must be trivially copyable");
60 m_parms_d =
static_cast<T*
>(
The_Arena()->
alloc(
sizeof(T)*m_parms.size()) );
71 void resize(Long num) {
73 if (m_parms_d !=
nullptr)
77 m_parms_d =
static_cast<T*
>(
The_Arena()->
alloc(
sizeof(T)*m_parms.size()) );
79 void setGraph(cudaGraphExec_t
const& graph) {
81 graph_is_ready =
true;
83 void setParams(
int idx, T
const& a_parm) { m_parms[idx] = a_parm; }
85 T* getHostPtr (
int idx) {
return (m_parms.data() + idx); }
86 T* getDevicePtr (
int idx)
const {
return (m_parms_d + idx); }
87 bool ready()
const {
return graph_is_ready; }
89 void executeGraph (
bool synch =
true)
const {
90 Gpu::Device::executeGraph(m_graph, synch);
#define AMREX_CUDA_SAFE_CALL(call)
Definition AMReX_GpuError.H:73
#define AMREX_GPU_HOST_DEVICE
Definition AMReX_GpuQualifiers.H:20
virtual void free(void *pt)=0
A pure virtual function for deleting the arena pointed to by pt.
virtual void * alloc(std::size_t sz)=0
Arena * The_Arena()
Definition AMReX_Arena.cpp:783
Definition AMReX_Amr.cpp:49