Using TorchScript for a model snapshot

In this section, we will discuss how to get the model snapshot file so that we can use it in our mobile application. In the previous chapters, we discussed how to save and load model parameters and how to use ONNX format to share models between frameworks. When we use the PyTorch framework, there is another method we can use to share models between the Python API and C++ API called TorchScript.

This method uses real-time model tracing to get a special type of model definition that can be executed by the PyTorch engine, regardless of API. For PyTorch 1.2, only the Python API can create such definitions, but we can use the C++ API to load the model and execute it. Also, the mobile version of the PyTorch framework still doesn't allow us to program neural networks with a full-featured C++ API; only the ATen library is available.

So, in this example, we are going to use the TorchScript model to perform image classification. To get this model, we need to use the Python API to load the pre-trained model, trace it, and save the model snapshot. The following code shows how to do this with Python:

import torch
import urllib
from PIL import Image
from torchvision import transforms

# Download pretrained model
model = torch.hub.load('pytorch/vision', 'resnet18', pretrained=True)
model.eval()

# Download an example image from the pytorch website
url, filename = ("https://github.com/pytorch/hub/raw/master/dog.jpg", "dog.jpg")

try:
urllib.URLopener().retrieve(url, filename)
except:
urllib.request.urlretrieve(url, filename)

# sample execution
input_image = Image.open(filename)
preprocess = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224,
0.225]),
])
input_tensor = preprocess(input_image)

# create a mini-batch as expected by the model
input_batch = input_tensor.unsqueeze(0)

traced_script_module = torch.jit.trace(model, input_batch)

traced_script_module.save("model.pt")

In this programming sample, we performed the following steps:

  1. We downloaded a pre-trained model with the torch.hub.load() function.
  2. Then, we downloaded an input image with the urllib module.
  3. With the input image acquired, we used the PIL library to resize and normalize it.
  4. Using the unsqueeze() function, we added a batch size dimension to the input tensor.
  5. Then, we used the torch.jit.trace() function to run the loaded model and trace it into a script.
  6. Finally, we simply saved the script module into a file with the save() method.

Now that we have saved the script module, we can start creating an Android application that will use it for image classification.

..................Content has been hidden....................

You can't read the all page of ebook, please click here login for view all page.
Reset
13.58.209.201