Block-Structured AMR Software Framework
 
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages
Loading...
Searching...
No Matches
AMReX_CudaGraph.H
Go to the documentation of this file.
1#ifndef AMREX_CUDA_GRAPH_H_
2#define AMREX_CUDA_GRAPH_H_
3#include <AMReX_Config.H>
4
5#if defined(__CUDACC__) && defined(AMREX_USE_CUDA)
6
7#include <AMReX.H>
8#include <AMReX_Array.H>
9#include <AMReX_GpuDevice.H>
10
11namespace amrex {
12
13struct CopyMemory
14{
15 void* src = nullptr;
16 void* dst = nullptr;
17 Dim3 src_begin = {0,0,0};
18 Dim3 src_end = {0,0,0};
19 Dim3 dst_begin = {0,0,0};
20 Dim3 dst_end = {0,0,0};
21 int scomp = 0;
22 int ncomp = 0;
23
24 template <class T>
26 Array4<T const> getSrc () { return Array4<T const>(static_cast<T const*>(src), src_begin, src_end, scomp+ncomp); }
27
28 template <class T>
30 Array4<T> getDst () { return Array4<T>(static_cast<T*>(dst), dst_begin, dst_end, scomp+ncomp); }
31};
32
33template <typename T, typename U>
34CopyMemory
35makeCopyMemory (Array4<T> const& src, Array4<U> const& dst, int scomp, int ncomp)
36{
37#if __cplusplus < 201402L
38 CopyMemory mem;
39 mem.src = (void*)(src.p);
40 mem.dst = (void*)(dst.p);
41 mem.src_begin = src.begin;
42 mem.src_end = src.end;
43 mem.dst_begin = dst.begin;
44 mem.dst_end = dst.end;
45 mem.scomp = scomp;
46 mem.ncomp = ncomp;
47 return mem;
48
49#else
50
51 return CopyMemory{ (void*)(src.p), (void*)(dst.p), src.begin, src.end, dst.begin, dst.end, scomp, ncomp };
52#endif
53}
54
55// ======================================================================================
56
57template <typename T>
58struct CudaGraph
59{
60 cudaGraphExec_t m_graph;
61 Vector<T> m_parms;
62 T* m_parms_d = nullptr;
63 bool graph_is_ready = false;
64
65 CudaGraph()
66 : m_parms(0)
67 {
68 static_assert(std::is_trivially_copyable<T>::value, "CudaGraph's T must be trivially copyable");
69 }
70 CudaGraph(int num)
71 : m_parms(num)
72 {
73 static_assert(std::is_trivially_copyable<T>::value, "CudaGraph's T must be trivially copyable");
74 m_parms_d = static_cast<T*>( The_Arena()->alloc(sizeof(T)*m_parms.size()) );
75 }
76 ~CudaGraph() {
77 The_Arena()->free(m_parms_d);
78
79 if (graph_is_ready)
80 {
81 AMREX_CUDA_SAFE_CALL(cudaGraphExecDestroy(m_graph));
82 }
83 }
84
85 void resize(Long num) {
86 m_parms.resize(num);
87 if (m_parms_d != nullptr)
88 {
89 The_Arena()->free(m_parms_d);
90 }
91 m_parms_d = static_cast<T*>( The_Arena()->alloc(sizeof(T)*m_parms.size()) );
92 }
93 void setGraph(cudaGraphExec_t const& graph) {
94 m_graph = graph;
95 graph_is_ready = true;
96 }
97 void setParams(int idx, T const& a_parm) { m_parms[idx] = a_parm; }
98
99 T* getHostPtr (int idx) { return (m_parms.data() + idx); }
100 T* getDevicePtr (int idx) const { return (m_parms_d + idx); }
101 bool ready() const { return graph_is_ready; }
102
103 void executeGraph (bool synch = true) const {
104 Gpu::Device::executeGraph(m_graph, synch);
105 }
106};
107
108}
109
110#endif
111#endif
#define AMREX_CUDA_SAFE_CALL(call)
Definition AMReX_GpuError.H:73
#define AMREX_GPU_HOST_DEVICE
Definition AMReX_GpuQualifiers.H:20
virtual void free(void *pt)=0
A pure virtual function for deleting the arena pointed to by pt.
virtual void * alloc(std::size_t sz)=0
Definition AMReX_Amr.cpp:49
Arena * The_Arena()
Definition AMReX_Arena.cpp:616