ML/PYTORCH
The overall goal of this machine learning tutorial is to accelerate computationally expensive point-wise kernels/routines within an AMReX simulation.
This tutorial demonstrates how to interface a pre-trained PyTorch machine learning model to an AMReX simulation by querying inputs from and supplying outputs to an AMReX MultiFab.
PyTorch is a commonly used machine learning package with a C++ API library called LibTorch.
Located in the directory amrex-tutorials/ExampleCodes/ML/PYTORCH
, this example uses a machine learning model to solve a radioactive beta decay problem.
Here we use a 1-input, 2-output model to illustrate the interface between the PyTorch model and a MultiFab.
Beta Decay Reaction
In this example, the machine learning model is a regression model pre-trained to solve a two-component ODE system describing beta decay.
In the context of the pytorch model, the input is a time step dt
and output is the two-component solution of the ODE system at time t = dt
.
Pre-trained Model
The TorchScript model that is included in this example is located at ML/PYTORCH/Exec/model.pt
.
If you wish to change the model, edit the model_file
parameter in inputs
.
Running an AMReX application with a PyTorch model
To begin, we initialize a MultiFab full of data representing different dt
values, then copy this data into a PyTorch tensor, then call the pre-trained model to compute the outputs, and finally load the result back into a MultiFab.
The model can be evaluated on the CPU or GPU.
Below is a step-by-step guide to successfully run an AMReX program that uses a PyTorch model. It will require the model to have been saved as a TorchScript. In this example the TorchScript file is model.pt
. For more information on TorchScript, please see their intro tutorial.
Before compiling, either a CPU or CUDA version of LibTorch (PyTorch C++ library) must be downloaded into
ML/PYTORCH/
. To download the CPU-only version oflibtorch
and rename it tolibtorch_cpu
:wget https://download.pytorch.org/libtorch/cpu/libtorch-cxx11-abi-shared-with-deps-2.0.1%2Bcpu.zip unzip libtorch-cxx11-abi-shared-with-deps-2.0.1+cpu.zip mv libtorch libtorch_cpuSimilarly, the CUDA 11.8 version of
libtorch
can be downloaded and renamed tolibtorch_cuda
:wget https://download.pytorch.org/libtorch/cu118/libtorch-cxx11-abi-shared-with-deps-2.0.1%2Bcu118.zip unzip libtorch-cxx11-abi-shared-with-deps-2.0.1+cu118.zip mv libtorch libtorch_cudaYou can also check the website, PyTorch to download the latest version of LibTorch.
Go to
ML/PYTORCH/Exec
to compile the executable. Runmake
and optionallyUSE_CUDA=TRUE
and it should result in an executable named, e.g.,main2d.gnu.MPI.CUDA.ex
Then you can run the example, e.g.,
./main2d.gnu.MPI.CUDA.ex inputs
ormpiexec -n 4 ./main2d.gnu.MPI.ex inputs
. There will be two plotfiles,plt_inputs
(containingdt
) andplt_outputs
(containingX_0
andX_1
at the final time).