Block-Structured AMR Software Framework
Loading...
Searching...
No Matches
AMReX_CudaGraph.H
Go to the documentation of this file.
1#ifndef AMREX_CUDA_GRAPH_H_
2#define AMREX_CUDA_GRAPH_H_
3#include <AMReX_Config.H>
4
5#if defined(__CUDACC__) && defined(AMREX_USE_CUDA)
6
7#include <AMReX.H>
8#include <AMReX_Array.H>
9#include <AMReX_GpuDevice.H>
10
11namespace amrex {
12
13struct CopyMemory
14{
15 void* src = nullptr;
16 void* dst = nullptr;
17 Dim3 src_begin = {0,0,0};
18 Dim3 src_end = {0,0,0};
19 Dim3 dst_begin = {0,0,0};
20 Dim3 dst_end = {0,0,0};
21 int scomp = 0;
22 int ncomp = 0;
23
24 template <class T>
26 Array4<T const> getSrc () { return Array4<T const>(static_cast<T const*>(src), src_begin, src_end, scomp+ncomp); }
27
28 template <class T>
30 Array4<T> getDst () { return Array4<T>(static_cast<T*>(dst), dst_begin, dst_end, scomp+ncomp); }
31};
32
33template <typename T, typename U>
34CopyMemory
35makeCopyMemory (Array4<T> const& src, Array4<U> const& dst, int scomp, int ncomp)
36{
37 return CopyMemory{ (void*)(src.p), (void*)(dst.p), IntVectShrink<3>(src.begin).dim3(), IntVectShrink<3>(src.end).dim3(),
38 IntVectShrink<3>(dst.begin).dim3(), IntVectShrink<3>(dst.end).dim3(), scomp, ncomp };
39}
40
41// ======================================================================================
42
43template <typename T>
44struct CudaGraph
45{
46 cudaGraphExec_t m_graph;
47 Vector<T> m_parms;
48 T* m_parms_d = nullptr;
49 bool graph_is_ready = false;
50
51 CudaGraph()
52 : m_parms(0)
53 {
54 static_assert(std::is_trivially_copyable<T>::value, "CudaGraph's T must be trivially copyable");
55 }
56 CudaGraph(int num)
57 : m_parms(num)
58 {
59 static_assert(std::is_trivially_copyable<T>::value, "CudaGraph's T must be trivially copyable");
60 m_parms_d = static_cast<T*>( The_Arena()->alloc(sizeof(T)*m_parms.size()) );
61 }
62 ~CudaGraph() {
63 The_Arena()->free(m_parms_d);
64
65 if (graph_is_ready)
66 {
67 AMREX_CUDA_SAFE_CALL(cudaGraphExecDestroy(m_graph));
68 }
69 }
70
71 void resize(Long num) {
72 m_parms.resize(num);
73 if (m_parms_d != nullptr)
74 {
75 The_Arena()->free(m_parms_d);
76 }
77 m_parms_d = static_cast<T*>( The_Arena()->alloc(sizeof(T)*m_parms.size()) );
78 }
79 void setGraph(cudaGraphExec_t const& graph) {
80 m_graph = graph;
81 graph_is_ready = true;
82 }
83 void setParams(int idx, T const& a_parm) { m_parms[idx] = a_parm; }
84
85 T* getHostPtr (int idx) { return (m_parms.data() + idx); }
86 T* getDevicePtr (int idx) const { return (m_parms_d + idx); }
87 bool ready() const { return graph_is_ready; }
88
89 void executeGraph (bool synch = true) const {
90 Gpu::Device::executeGraph(m_graph, synch);
91 }
92};
93
94}
95
96#endif
97#endif
#define AMREX_CUDA_SAFE_CALL(call)
Definition AMReX_GpuError.H:73
#define AMREX_GPU_HOST_DEVICE
Definition AMReX_GpuQualifiers.H:20
virtual void free(void *pt)=0
A pure virtual function for deleting the arena pointed to by pt.
virtual void * alloc(std::size_t sz)=0
Arena * The_Arena()
Definition AMReX_Arena.cpp:783
Definition AMReX_Amr.cpp:49