Using PyTorch archive objects

The second approach is to use an object of the torch::serialize::OutputArchive type and write the parameters we want to save into it. The following code shows how to implement the SaveWeights method for our model. This method writes all the parameters and buffers that exist in our module to the archive object, and then it uses the save_to method to write them in a file:

 void NetImpl::SaveWeights(const std::string& file_name) {
torch::serialize::OutputArchive archive;
auto parameters = named_parameters(true /*recurse*/);
auto buffers = named_buffers(true /*recurse*/);
for (const auto& param : parameters) {
if (param.value().defined()) {
archive.write(param.key(), param.value());
}
}
for (const auto& buffer : buffers) {
if (buffer.value().defined()) {
archive.write(buffer.key(), buffer.value(), /*is_buffer*/ true);
}
}
archive.save_to(file_name);
}

It is important to save buffers tensors too. Buffers can be retrieved from a module with the named_buffers module's method. These objects represent the intermediate values that are used to evaluate different modules. For example, we can be running mean and standard deviation values for the batch normalization module. We need them to continue being trained if we used serialization to save the intermediate steps and if our training process was stopped for some reason.

To load parameters that have been saved this way, we can use the torch::serialize::InputArchive object. The following code shows how to implement the LoadWeights method for our model:

 void NetImpl::LoadWeights(const std::string& file_name) {
torch::serialize::InputArchive archive;
archive.load_from(file_name);
torch::NoGradGuard no_grad;
auto parameters = named_parameters(true /*recurse*/);
auto buffers = named_buffers(true /*recurse*/);
for (auto& param : parameters) {
archive.read(param.key(), param.value());
}
for (auto& buffer : buffers) {
archive.read(buffer.key(), buffer.value(), /*is_buffer*/ true);
}
}

This method uses the load_from method of the archive object to load parameters from the file. Then, we took the parameters and buffers from our module with the named_parameters and named_buffers methods and incrementally filled in their values with the read method of the archive object. Notice that we used an instance of the torch::NoGradGuard class to tell the PyTorch library that we don't perform any model calculation and graph-related operations. It's essential to do this because the PyTorch construct calculation graph and any unrelated operations can lead to errors.

Now, we can use the new instance of our model_loaded model with load parameters to evaluate the model on some test data. Note that we need to switch the model to the evaluation model with the eval method. Generated test data values should also be converted into tensor objects with the torch::tensor function and moved to the same computational device that our model uses. The following code shows how we can implement this:

 model_loaded->to(device);
model_loaded->eval();
std::cout << "Test: ";
for (int i = 0; i < 5; ++i) {
auto x_val = static_cast<float>(i) + 0.1f;
auto tx = torch::tensor(x_val, torch::dtype(torch::kFloat).device(device));
tx = (tx - x_mean) / x_std;

auto ty = torch::tensor(func(x_val), torch::dtype(torch::kFloat).device(device));

torch::Tensor prediction = model_loaded->forward(tx);

std::cout << "Target:" << ty << std::endl;
std::cout << "Prediction:" << prediction << std::endl;
}

In this section, we looked at two types of serialization in the PyTorch library. The first approach was using the torch::save and torch::load functions, which easily save and load all the model parameters, respectively. The second approach was using objects of the torch::serialize::InputArchive and torch::serialize::OutputArchive types so that we can select what parameters we want to save and load.

In the next section, we will discuss the ONNX file format, which allows us to share our ML model architecture and model parameters among different frameworks.

..................Content has been hidden....................

You can't read the all page of ebook, please click here login for view all page.
Reset
18.118.163.250