Block-Structured AMR Software Framework
AMReX_CudaGraph.H
Go to the documentation of this file.
1 #ifndef AMREX_CUDA_GRAPH_H_
2 #define AMREX_CUDA_GRAPH_H_
3 #include <AMReX_Config.H>
4 
5 #if defined(__CUDACC__) && defined(AMREX_USE_CUDA)
6 
7 #include <AMReX.H>
8 #include <AMReX_Array.H>
9 #include <AMReX_GpuDevice.H>
10 
11 namespace amrex {
12 
13 struct CopyMemory
14 {
15  void* src = nullptr;
16  void* dst = nullptr;
17  Dim3 src_begin = {0,0,0};
18  Dim3 src_end = {0,0,0};
19  Dim3 dst_begin = {0,0,0};
20  Dim3 dst_end = {0,0,0};
21  int scomp = 0;
22  int ncomp = 0;
23 
24  template <class T>
26  Array4<T const> getSrc () { return Array4<T const>(static_cast<T const*>(src), src_begin, src_end, scomp+ncomp); }
27 
28  template <class T>
30  Array4<T> getDst () { return Array4<T>(static_cast<T*>(dst), dst_begin, dst_end, scomp+ncomp); }
31 };
32 
33 template <typename T, typename U>
34 CopyMemory
35 makeCopyMemory (Array4<T> const& src, Array4<U> const& dst, int scomp, int ncomp)
36 {
37 #if __cplusplus < 201402L
38  CopyMemory mem;
39  mem.src = (void*)(src.p);
40  mem.dst = (void*)(dst.p);
41  mem.src_begin = src.begin;
42  mem.src_end = src.end;
43  mem.dst_begin = dst.begin;
44  mem.dst_end = dst.end;
45  mem.scomp = scomp;
46  mem.ncomp = ncomp;
47  return mem;
48 
49 #else
50 
51  return CopyMemory{ (void*)(src.p), (void*)(dst.p), src.begin, src.end, dst.begin, dst.end, scomp, ncomp };
52 #endif
53 }
54 
55 // ======================================================================================
56 
57 template <typename T>
58 struct CudaGraph
59 {
60  cudaGraphExec_t m_graph;
61  Vector<T> m_parms;
62  T* m_parms_d = nullptr;
63  bool graph_is_ready = false;
64 
65  CudaGraph()
66  : m_parms(0)
67  {
68  static_assert(std::is_trivially_copyable<T>::value, "CudaGraph's T must be trivially copyable");
69  }
70  CudaGraph(int num)
71  : m_parms(num)
72  {
73  static_assert(std::is_trivially_copyable<T>::value, "CudaGraph's T must be trivially copyable");
74  m_parms_d = static_cast<T*>( The_Arena()->alloc(sizeof(T)*m_parms.size()) );
75  }
76  ~CudaGraph() {
77  The_Arena()->free(m_parms_d);
78 
79  if (graph_is_ready)
80  {
81  AMREX_CUDA_SAFE_CALL(cudaGraphExecDestroy(m_graph));
82  }
83  }
84 
85  void resize(Long num) {
86  m_parms.resize(num);
87  if (m_parms_d != nullptr)
88  {
89  The_Arena()->free(m_parms_d);
90  }
91  m_parms_d = static_cast<T*>( The_Arena()->alloc(sizeof(T)*m_parms.size()) );
92  }
93  void setGraph(cudaGraphExec_t const& graph) {
94  m_graph = graph;
95  graph_is_ready = true;
96  }
97  void setParams(int idx, T const& a_parm) { m_parms[idx] = a_parm; }
98 
99  T* getHostPtr (int idx) { return (m_parms.data() + idx); }
100  T* getDevicePtr (int idx) const { return (m_parms_d + idx); }
101  bool ready() const { return graph_is_ready; }
102 
103  void executeGraph (bool synch = true) const {
104  Gpu::Device::executeGraph(m_graph, synch);
105  }
106 };
107 
108 }
109 
110 #endif
111 #endif
#define AMREX_CUDA_SAFE_CALL(call)
Definition: AMReX_GpuError.H:73
#define AMREX_GPU_HOST_DEVICE
Definition: AMReX_GpuQualifiers.H:20
virtual void free(void *pt)=0
A pure virtual function for deleting the arena pointed to by pt.
virtual void * alloc(std::size_t sz)=0
Definition: AMReX_Amr.cpp:49
Arena * The_Arena()
Definition: AMReX_Arena.cpp:594