How to deploy PyTorch Lightning models to production
A complete guide to serving PyTorch Lightning models at scale.
By Caleb Kaiser, Cortex Labs
Looking at the machine learning landscape, one of the major trends is the proliferation of projects focused on applying software engineering principles to machine learning. Cortex, for example, recreates the experience of deploying serverless functions, but with inference pipelines. DVC, similarly, implements modern version control and CI/CD pipelines, but for ML.
PyTorch Lightning has a similar philosophy, only applied to training. The frameworks provides a Python wrapper for PyTorch that lets data scientists and engineers write clean, manageable, and performant training code.
As people who built an entire deployment platform in part because we hated writing boilerplate, we’re huge fans of PyTorch Lightning. In that spirit, I’ve put together this guide to deploying PyTorch Lightning models to production. In the process, we’re going to look at a few different options for exporting PyTorch Lightning models for inclusion in your inference pipelines.
Every way to deploy a PyTorch Lightning model for inference
There are three ways to export a PyTorch Lightning model for serving:
- Saving the model as a PyTorch checkpoint
- Converting the model to ONNX
- Exporting the model to Torchscript
We can serve all three with Cortex.
1. Package and deploy PyTorch Lightning modules directly
Starting with the simplest approach, let’s deploy a PyTorch Lightning model without any conversion steps.
The PyTorch Lightning Trainer, a class which abstracts boilerplate training code (think training and validation steps), has a builtin save_checkpoint() function which will save your model as a .ckpt file. To save your model as a checkpoint, simply add this code to your training script:
Now, before we get into serving this checkpoint, it’s important to note that while I keep saying “PyTorch Lightning model,” PyTorch Lightning is a wrapper around PyTorch — the project’s README literally says “PyTorch Lightning is just organized PyTorch.” The exported model, therefore, is a normal PyTorch model, and can be served accordingly.
With a saved checkpoint, we can serve the model pretty easily in Cortex. If you’re unfamiliar with Cortex, you can familiarize yourself quickly here, but the simple overview of the deployment process with Cortex is:
- We write a prediction API for our model in Python
- We define our APIs infrastructure and behavior in YAML
- We deploy the API with a command from the CLI
Our prediction API will use Cortex’s Python Predictor class to define an init() function to initialize our API and load the model, and a predict() function to serve predictions when queried:
Pretty simple. We repurpose some code from our training code, add a little inference logic, and that’s it. One thing to note is that if you upload your model to S3 (recommended), you’ll need to add some logic for accessing it.
Next, we configure our infrastructure in YAML:
Again, simple. We give our API a name, tell Cortex where our prediction API is, and allocate some CPU.
Next, we deploy it:
Note that we can also deploy to a cluster, spun up and managed by Cortex:
With all deployments, Cortex containerizes our API and exposes it as a web service. With cloud deployments, Cortex configures load balancing, autoscaling, monitoring, updating, and many other infrastructure features.
And that’s it! We now have a live web API serving predictions from our model on request.
2. Export to ONNX and serve via ONNX Runtime
Now that we’ve deployed a vanilla PyTorch checkpoint, lets complicate things a bit.
PyTorch Lightning recently added a convenient abstraction for exporting models to ONNX (previously, you could use PyTorch’s built-in conversion functions, though they required a bit more boilerplate). To export your model to ONNX, just add this bit of code to your training script:
Note that your input sample should mimic the shape of your actual model input.
Once you’ve exported an ONNX model, you can serve it using Cortex’s ONNX Predictor. The code will basically look the same, and the process is identical. For example, this is an ONNX prediction API:
Basically the same. The only difference is that instead of initializing the model directly, we access it through the onnx_client, which is an ONNX Runtime container Cortex spins up for serving our model.
Our YAML also looks pretty similar:
I added a monitoring flag here just to show how easy it is to configure, and there are some ONNX specific fields, but otherwise it’s the same YAML.
Finally, we deploy by using the same $ cortex deploy command as before, and our ONNX API is live.
3. Serialize with Torchscript’s JIT compiler
For a final deployment, we’re going to export our PyTorch Lightning model to Torchscript and serve it using PyTorch’s JIT compiler. To export the model, simply add this to your training script:
The Python API for this is nearly identical to the vanilla PyTorch example:
The YAML stays the same as before, and the CLI command of course is consistent. If we want, we can actually update our previous PyTorch API to use the new model by simply replacing our old predictor.py script with the new one, and running$ cortex deploy again:
Cortex automatically performs a rolling update here, in which a new API is spun up and then swapped with the old API, preventing any downtime between model updates.
And that’s all there is to it. Now you have a fully operational prediction API for realtime inference, serving predictions from a Torchscript model.
So, which method should you use?
The obvious question here is which method performs best. The truth is that there isn’t a straightforward answer here, as it depends on your model.
For Transformer models like BERT and GPT-2, ONNX can offer incredible optimizations (we measured a 40x improvement in throughput on CPUs). For other models, Torchscript likely performs better than vanilla PyTorch — though that too comes with some caveats, as not all models export to Torchscript cleanly.
Fortunately, with how easy it is to deploy using any option, you can test all three in parallel and see which performs best for your particular API.
Original. Reposted with permission.
Top Stories Past 30 Days