Chapter 3. Making Transformers Efficient in Production

In the previous chapters, you’ve seen how Transformers can be fine-tuned to produce great results on a wide range of tasks. However, in many situations accuracy (or whatever metric you’re optimizing for) is not enough; your state-of-the-art model is not very useful if it’s too slow or large to meet the business requirements of your application. An obvious alternative is to train a faster and more compact model, but the reduction in model capacity is often accompanied by a degradation in performance. So what can you do when you need a fast, compact, yet highly accurate model?

In this chapter we will explore four complementary techniques that can be used to speed up the predictions and reduce the memory footprint of your Transformer models: knowledge distillation, quantization, pruning, and graph optimization with the Open Neural Network Exchange (ONNX) format and ONNX Runtime (ORT). We’ll also see how some of these techniques can be combined to produce significant performance gains. For example, this was the approach taken by the Roblox engineering team in their article How We Scaled Bert To Serve 1+ Billion Daily Requests on CPUs, who showed in Figure 3-1 that combining knowledge distillation and quantization enabled them to improve the latency and throughput of their BERT classifier by over a factor of 30!

Scaling BERT at Roblox
Figure 3-1. How Roblox scaled BERT with knowledge distillation, dynamic padding, and weight quantization (photo courtesy of Roblox employees Quoc N. Le and Kip Kaehler)

To illustrate the benefits and trade-offs associated with each technique, we’ll use intent detection as a case study since it’s an important component of text-based assistants, where low latencies are critical for maintaining a conversation in real-time. Along the way we’ll learn how to create custom trainers, perform efficient hyperparameter search, and gain a sense for what it takes to implement cutting-edge research with Transformers. Let’s dive in!

Intent Detection as a Case Study

Let’s suppose that we’re trying to build a text-based assistant for our company’s call center so that customers can request the balance of their account or make bookings without needing to speak with a human agent. In order to understand the goals of a customer, our assistant will need to be able to classify a wide variety of natural language text into a set of predefined actions or intents. For example, a customer may send a message about an upcoming trip

Hey, I’d like to rent a vehicle from Nov 1st to Nov 15th in Paris and I need a 15 passenger van

and our intent classifier could automatically categorize this as a Car Rental intent, which then triggers an action and response. To be robust in a production environment, our classifier will also need to be able to handle out-of-scope queries like those shown in the second and third cases of Figure 3-2, where a customer makes a query that doesn’t belong to any of the predefined intents and the system should yield a fallback response. For example, in the second case of Figure 3-2, a customer asks a question about sport which is out-of-scope and the text-assistant mistakenly classifies it as one of the known in-scope intents, which is fed to a downstream component that returns the payday response. In the third case, the text-assistant has been trained to detect out-of-scope queries (usually labelled as a separate class) and informs the customer about which topics they can respond to.

Out of Scope Query
Figure 3-2. Three exchanges between a human (right) and a text-based assistant (left) for personal finance (courtesy of Stefan Larson et al.).

As a baseline we’ve fine-tuned a BERT-base model that achieves around 94% accuracy on the CLINC150 dataset.1 This dataset includes 22,500 in-scope queries across 150 intents and 10 domains like banking and travel, and also includes 1,200 out-of-scope queries that belong to an oos intent class. In practice we would also gather our own in-house dataset, but using public data is a great way to iterate quickly and generate preliminary results.

To get started, let’s download our fine-tuned model from the Hugging Face Hub and wrap it in a pipeline for text classification:

from transformers import (AutoTokenizer, AutoModelForSequenceClassification,
                          TextClassificationPipeline)

bert_ckpt = "lewtun/bert-base-uncased-finetuned-clinc"
bert_tokenizer = AutoTokenizer.from_pretrained(bert_ckpt)
bert_model = (AutoModelForSequenceClassification
              .from_pretrained(bert_ckpt).to("cpu"))
pipe = TextClassificationPipeline(model=bert_model, tokenizer=bert_tokenizer)

Here we’ve set the model’s device to cpu since our text-assistant will need to operate in an environment where queries are processed and responded to in real-time. Although we could use a GPU to run inference, this would be expensive if the machine is idle for extended periods of time or would require us to batch the incoming queries which introduces additional complexity. In general, the choice between running inference on a CPU or GPU depends on the application and a balance between simplicity and cost. Several of the compression techniques covered in this chapter work equally well on a GPU, so you can easily adapt them to your own GPU-powered applications if needed.

Now that we have a pipeline, we can pass a query to get the predicted intent and confidence score from the model:

query = """Hey, I'd like to rent a vehicle from Nov 1st to Nov 15th in
Paris and I need a 15 passenger van"""
pipe(query)
[{'label': 'car_rental', 'score': 0.5490033626556396}]

Great, the car_rental intent makes sense so let’s now look at creating a benchmark that we can use to evaluate the performance of our baseline model.

Creating a Performance Benchmark

Like any other machine learning model, deploying Transformers in production environments involves a trade-off among several constraints, the most common being:2

Model performance

How well does our model perform on a well-crafted test set that reflects production data? This is especially important when the cost of making errors is large (and best mitigated with a human-in-the-loop) or when we need to run inference on millions of examples, and small improvements to the model metrics can translate into large gains in aggregate.

Latency

How fast can our model deliver predictions? We usually care about latency in real-time environments that deal with a lot of traffic, like how Stack Overflow needed a classifier to quickly detect unwelcome comments on their website.

Memory

How can we deploy billion-parameter models like GPT-2 or T5 that require gigabytes of disk storage and RAM? Memory plays an especially important role in mobile or edge devices, where a model has to generate predictions without access to a powerful cloud server.

Failing to address these constraints can have a negative impact on the user experience of your application, or more commonly, lead to ballooning costs from running expensive cloud servers that may only need to handle a few requests. To explore how each of the these constraints can be optimized with various compression techniques, let’s begin by creating a simple benchmark that measures each quantity for a given pipeline and test set. A skeleton of what we’ll need is given by the following class:

class PerformanceBenchmark:
    def __init__(self, pipeline, dataset, optim_type="BERT baseline"):
        self.pipeline = pipeline
        self.dataset = dataset
        self.optim_type = optim_type

    def compute_accuracy(self):
        pass

    def compute_size(self):
        pass

    def time_pipeline(self):
        pass

    def run_benchmark(self):
        metrics = {}
        metrics[self.optim_type] = self.compute_size()
        metrics[self.optim_type].update(self.time_pipeline())
        metrics[self.optim_type].update(self.compute_accuracy())
        return metrics

In this class, we’ve defined the optim_type parameter to keep track of the different optimization techniques that we’ll cover in this chapter. We’ll use the run_benchmark function to collect all the metrics in a dictionary, with keys given by optim_type.

Let’s now put some flesh on this class by computing the model accuracy on the test set. First we need some data to test on, so let’s download the CLINC150 dataset that was used to fine-tune our baseline model. We can get the dataset from the Hub with the Datasets library as follows:

from datasets import load_dataset

clinc = load_dataset("clinc_oos", "plus")
clinc
DatasetDict({
    train: Dataset({
        features: ['text', 'intent'],
        num_rows: 15250
    })
    validation: Dataset({
        features: ['text', 'intent'],
        num_rows: 3100
    })
    test: Dataset({
        features: ['text', 'intent'],
        num_rows: 5500
    })
})

Each example in the CLINC150 dataset consists of a query in the text column and its corresponding intent. We’ll use the test set to benchmark our models, so let’s take a look at one of the dataset’s examples:

clinc["test"][42]
{'intent': 133, 'text': 'transfer $100 from my checking to saving account'}

The intents are provided as IDs, but we can easily get the mapping to strings (and vice versa) by accessing the Dataset.features attribute:

intents = clinc["test"].features["intent"]
intents.int2str(clinc["test"][42]["intent"])
'transfer'

Now that we have a basic understanding of the contents in the CLINC150 dataset, let’s implement the compute_accuracy function. Since the dataset is balanced across the intent classes, we’ll use accuracy as our metric which we can load from Datasets as follows:

from datasets import load_metric

accuracy_score = load_metric('accuracy')
accuracy_score
Metric(name: "accuracy", features: {'predictions': Value(dtype='int32',
 > id=None), 'references': Value(dtype='int32', id=None)}, usage: """
Args:
    predictions: Predicted labels, as returned by a model.
    references: Ground truth labels.
    normalize: If False, return the number of correctly classified samples.
        Otherwise, return the fraction of correctly classified samples.
    sample_weight: Sample weights.
Returns:
    accuracy: Accuracy score.
""", stored examples: 0)

The metric’s description tells us that we need to provide the predictions and references (i.e. the ground truth labels) as integers, so we can use the pipeline to extract the predictions from the text field and then use the ClassLabel.str2int function to map the prediction to its corresponding ID. The following code collects all the predictions and labels in lists before returning the accuracy on the dataset. Let’s also add it to our PerformanceBenchmark class:

def compute_accuracy(self):
    preds, labels = [], []
    for example in self.dataset:
        pred = self.pipeline(example["text"])[0]["label"]
        label = example["intent"]
        preds.append(intents.str2int(pred))
        labels.append(label)
    accuracy = accuracy_score.compute(predictions=preds, references=labels)
    print(f"Accuracy on test set - {accuracy['accuracy']:.3f}")
    return accuracy

PerformanceBenchmark.compute_accuracy = compute_accuracy

Next, let’s compute the size of our model by using the torch.save function from PyTorch to serialize the model to disk. Under the hood, torch.save uses Python’s pickle module and can be used to save anything from models to tensors to ordinary Python objects. In PyTorch, the recommended way to save a model is by using its state_dict, which is a Python dictionary that maps each layer in a model to its learnable parameters (i.e. weights and biases). Let’s see what is stored in the state_dict of our baseline model:

list(pipe.model.state_dict().items())[42]
('bert.encoder.layer.2.attention.self.value.weight',
 tensor([[-1.0526e-02, -3.2215e-02,  2.2097e-02,  ..., -6.0953e-03,
           4.6521e-03,  2.9844e-02],
         [-1.4964e-02, -1.0915e-02,  5.2396e-04,  ...,  3.2047e-05,
          -2.6890e-02, -2.1943e-02],
         [-2.9640e-02, -3.7842e-03, -1.2582e-02,  ..., -1.0917e-02,
           3.1152e-02, -9.7786e-03],
         ...,
         [-1.5116e-02, -3.3226e-02,  4.2063e-02,  ..., -5.2652e-03,
           1.1093e-02,  2.9703e-03],
         [-3.6809e-02,  5.6848e-02, -2.6544e-02,  ..., -4.0114e-02,
           6.7487e-03,  1.0511e-03],
         [-2.4961e-02,  1.4747e-03, -5.4271e-02,  ...,  2.0004e-02,
           2.3981e-02, -4.2880e-02]]))

We can clearly see that each key-value pair corresponds to a specific layer and tensor in BERT. So if we save our model with

torch.save(model.state_dict(), PATH)

we can then use the Path.stat function from Python’s pathlib module to get information about the underlying files. In particular Path(PATH).stat().st_size will give us the model size in bytes, so let’s put this all together in the compute_size function and add it to PerformanceBenchmark:

import torch
from pathlib import Path

def compute_size(self):
    state_dict = self.pipeline.model.state_dict()
    tmp_path = Path("model.pt")
    torch.save(state_dict, tmp_path)
    # Calculate size in megabytes
    size_mb = Path(tmp_path).stat().st_size / (1024 * 1024)
    # Delete temporary file
    tmp_path.unlink()
    print(f"Model size (MB) - {size_mb:.2f}")
    return {"size_mb": size_mb}

PerformanceBenchmark.compute_size = compute_size

Finally let’s implement the time_pipeline function so that we can time the median latency per query. For this application, latency refers to the time it takes to feed a text query to the pipeline and return the predicted intent from the model. Under the hood, the pipeline also tokenizes the text but this is around 1,000 times faster than generating the predictions and thus adds a negligible contribution to the overall latency. A simple way to measure the time of a code snippet is to use the perf_counter function from Python’s time module. This function has a better time resolution than the time.time function and so is well suited for getting precise results.

We can use perf_counter to time our pipeline by passing our test query and calculating the time difference in milliseconds between the start and end:

from time import perf_counter

for _ in range(3):
    start_time = perf_counter()
    _ = pipe(query)
    latency = perf_counter() - start_time
    print(f"Latency (ms) - {1000 * latency:.3f}")
Latency (ms) - 64.923
Latency (ms) - 47.636
Latency (ms) - 47.344

These results exhibit quite some spread in the latencies and suggest that timing a single pass through the pipeline can give wildly different results each time we rerun the code. So instead, we’ll collect the latencies over many runs and then use the resulting distribution to calculate the mean and standard deviation, which will give us an idea about the spread in values. The following code does what we need and includes a phase to warm-up the CPU before performing the actual timed run:

import numpy as np

def time_pipeline(self, query="What is the pin number for my account?"):
    latencies = []
    # Warmup
    for _ in range(10):
        _ = self.pipeline(query)
    # Timed run
    for _ in range(100):
        start_time = perf_counter()
        _ = self.pipeline(query)
        latency = perf_counter() - start_time
        latencies.append(latency)
    # Compute run statistics
    time_avg_ms = 1000 * np.mean(latencies)
    time_std_ms = 1000 * np.std(latencies)
    print(f"Average latency (ms) - {time_avg_ms:.2f} +- {time_std_ms:.2f}")
    return {"time_avg_ms": time_avg_ms, "time_std_ms": time_std_ms}

PerformanceBenchmark.time_pipeline = time_pipeline

Benchmarking Our Baseline Model

Now that our PerformanceBenchmark is complete, let’s give it a spin! For the baseline model we just need to pass the pipeline and dataset we wish to perform the benchmark on, and we’ll collect the results in the perf_metrics dictionary to keep track of each model’s performance:

pb = PerformanceBenchmark(pipe, clinc["test"])
perf_metrics = pb.run_benchmark()
Model size (MB) - 418.17
Average latency (ms) - 46.05 +- 10.13
Accuracy on test set - 0.867

Now that we have a reference point, let’s look at our first compression technique: knowledge distillation.

Note

The average latency values will differ depending on what type of hardware you are running on. For the purposes of this chapter, what’s important is the relative difference in latencies between models. Once we have determined the best performing model we can then explore different backends to reduce the absolute latency if needed.

Making Models Smaller via Knowledge Distillation

Knowledge distillation is a general-purpose method for training a smaller student model to mimic the behavior of a slower, larger, but better performing teacher. Originally introduced in 20063 in the context of ensemble models, it was later popularized in a famous 2015 paper4 by Geoff Hinton, Oriol Vinyals, and Jeff Dean who generalized the method to deep neural networks and applied it to image classification and automatic speech recognition.

Given the trend shown in Figure 3-3 towards pretraining language models with ever-increasing parameter counts (the largest5 at over one trillion parameters at the time of writing this book!), knowledge distillation has also become a popular strategy to compress these huge models and make them more suitable for building practical applications.

Transformer Sizes
Figure 3-3. Parameter counts of several recent pretrained language models.

Knowledge Distillation for Fine-tuning

So how is knowledge actually “distilled” or transferred from the teacher to the student during training? For supervised tasks like fine-tuning, the main idea is to augment the ground truth labels with a distribution of “soft probabilities” from the teacher which provide complementary information for the student to learn from. For example, if our BERT-base classifier assigns high probabilities to multiple intents, then this could be a sign that these intents lie close to each other in the feature space. By training the student to mimic these probabilities, the goal is to distill some of this “dark knowledge”6 that the teacher has learnt; knowledge which is not available from the labels alone.

Mathematically, the way this works is as follows. Suppose we feed an input sequence x to the teacher to generate a vector of logits ?(x) = [z1(x),...,zN(x)]. We can convert these logits into probabilities by applying a softmax function

exp(zi(x))jexp(zi(x)),

but this isn’t quite what we want because in many cases the teacher will assign a high probability to one class, with all other class probabilities close to zero. When that happens, the teacher doesn’t provide much additional information beyond the ground truth labels, so instead we “soften” the probabilities by scaling the logits with a positive temperature hyperparameter T before applying the softmax:

pi(x)=exp(zi(x)/T)jexp(zi(x)/T).

As shown in Figure 3-4, higher values of T produce a softer probability distribution over the classes and reveal much more information about the decision boundary that the teacher has learned for each training example. When T=1 we recover the original softmax distribution.

Soft Probabilities
Figure 3-4. Comparison of a hard label which is one-hot encoded (left), softmax probabilities (middle) and softened class probabilities (right).

Since the student also produces softened probabilities qi(x) of its own we can use the Kullback-Leibler (KL) divergence

DKL(p,q)=ipi(x)logpi(x)qi(x),

to measure the difference between the two probability distributions and thereby define a knowledge distillation loss:

LKD=T2DKL=T2ipi(x)logpi(x)qi(x),

where T2 is a normalization factor to account for the fact that the magnitude of the gradients produced by soft labels scales as 1/T2. For classification tasks, the student loss is then a weighted average of the distillation loss with the usual cross-entropy loss LCE of the ground truth labels:

Lstudent=αLCE+(1-α)LKD,

where α is a hyperparameter that controls the relative strength of each loss. A diagram of the whole process is shown in Figure 3-5 and the temperature is set to 1 at inference time to recover the standard softmax probabilities.

Knowledge distillation
Figure 3-5. Cartoon of the knowledge distillation process.

Knowledge Distillation for Pretraining

Knowledge distillation can also be used during pretraining to create a general-purpose student that can be subsequently fine-tuned on downstream tasks. In this case, the teacher is a pretrained language model like BERT which transfers its knowledge about masked-language-modeling to the student. For example, in the DistilBERT paper,7 the masked-language-modeling loss Lmlm is augmented with a term from knowledge distillation and a cosine embedding loss Lcos=1-cos(hs,ht) to align the directions of the hidden state vectors between the teacher and student:

LDistilBERT=αLmlm+βLKD+γLcos.

Since we already have a fine-tuned BERT-base model, let’s see how we can use knowledge distillation to fine-tune a smaller and faster model. To do that we’ll need a way to augment the cross-entropy loss with a LKD term; fortunately we can do this by creating our own trainer! Let’s take a look at how to do this in the next section.

Creating a Knowledge Distillation Trainer

To implement knowledge distillation we need to add a few things to the Trainer base class:

  • The new hyperparameters α and T which control the relative weight of the distillation loss and how much the probability distribution of the labels should be smoothed.

  • The fine-tuned teacher model, which in our case is BERT-base

  • A new loss function that includes the cross-entropy loss with the knowledge distillation loss.

Adding the new hyperparameters is quite simple since we just need to subclass TrainingArguments and include them as new attributes:

from transformers import TrainingArguments

class DistillationTrainingArguments(TrainingArguments):
    def __init__(self, *args, alpha=0.5, temperature=2.0, **kwargs):
        super().__init__(*args, **kwargs)
        self.alpha = alpha
        self.temperature = temperature

For the trainer itself, we want a new loss function so the way to implement this is by subclassing Trainer and overriding the compute_loss function to include the knowledge distillation loss term LKD:

import torch.nn as nn
import torch.nn.functional as F
from transformers import Trainer

class DistillationTrainer(Trainer):
    def __init__(self, *args, teacher_model=None, **kwargs):
        super().__init__(*args, **kwargs)
        self.teacher_model = teacher_model

    def compute_loss(self, model, inputs):
        outputs_stu = model(**inputs)
        # Extract cross-entropy loss and logits from student
        loss_ce = outputs_stu.loss
        logits_stu = outputs_stu.logits
        # Extract logits from teacher
        with torch.no_grad():
            outputs_tea = self.teacher_model(**inputs)
            logits_tea = outputs_tea.logits
        # Soften probabilities and compute distillation loss
        loss_fct = nn.KLDivLoss(reduction="batchmean")
        loss_kd = self.args.temperature ** 2 * loss_fct(
            F.log_softmax(logits_stu / self.args.temperature, dim=-1),
            F.softmax(logits_tea / self.args.temperature, dim=-1))
        # Return weighted student loss
        return self.args.alpha * loss_ce + (1. - self.args.alpha) * loss_kd

Let’s unpack this code a bit. When we instantiate DistillationTrainer we pass a teacher_model argument with a teacher that has already been fine-tuned on our task. Next, in the compute_loss function we extract the logits from the student and teacher, scale them by the temperature and then normalize them with a softmax before passing them to PyTorch’s nn.KLDivLoss function for computing the KL divergence. Since nn.KLDivLoss expects the inputs in the form of log-probabilities, we’ve used the F.log_softmax function to normalize the student’s logits, while the teacher’s logits are converted to probabilities with a standard softmax. The reduction=batchmean argument in nn.KLDivLoss specifies that we average the losses over the batch dimension.

Choosing a Good Student Initialization

Now that we have our custom trainer, the first question you might have is which pretrained language model should we pick for the student? In general we should pick smaller model for the student to reduce the latency and memory footprint, and a good rule of thumb from the literature8 is that knowledge distillation works best when the teacher and student are of the same model type. One possible reason for this is that different model types, say BERT and RoBERTa, can have different output embedding spaces which hinders the ability of the student to mimic the teacher. In our case study, the teacher is BERT-base so DistilBERT is natural candidate to intitialize the student since it has 40% less parameters and has been shown to achieve strong results on downstream tasks.

First we’ll need to tokenize and encode our queries, so let’s instantiate the tokenizer from DistilBERT and create a simple function to take care of the preprocessing:

student_ckpt = "distilbert-base-uncased"
student_tokenizer = AutoTokenizer.from_pretrained(student_ckpt)

def tokenize_text(batch, tokenizer):
    return tokenizer(batch["text"], truncation=True)

clinc_enc = clinc.map(tokenize_text, batched=True, remove_columns=["text"],
                      fn_kwargs={"tokenizer": student_tokenizer})
clinc_enc.rename_column_("intent", "labels")

Here we’ve removed the text column since we no longer need it and we’ve also used the fn_kwargs argument to specify which tokenizer should be used in the tokenize_text function. We’ve also renamed the intent column to labels so it can be automatically detected by the trainer. Now that our texts are processed, the next thing to do is instantiate DistilBERT for fine-tuning. Since we will be doing multiple runs with the trainer, we’ll use a function to initialize the model with each new run:

import torch
from transformers import AutoConfig

num_labels = intents.num_classes
id2label = bert_model.config.id2label
label2id = bert_model.config.label2id
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

student_config = (AutoConfig
                  .from_pretrained(student_ckpt, num_labels=num_labels,
                                   id2label=id2label, label2id=label2id))

def student_init():
    return (AutoModelForSequenceClassification
            .from_pretrained(student_ckpt, config=student_config).to(device))

Here we’ve also specified the number of classes our model should expect, and used the baseline model’s configuration to provide the mappings id2label and label2id between ID and intent name. Next, we need to define the metrics to track during training. As we did in the performance benchmark, we’ll use accuracy as the main metric so we can reuse our accuracy_score function in the compute_metrics function that we’ll include in the trainer:

def compute_metrics(pred):
    predictions, labels = pred
    predictions = np.argmax(predictions, axis=1)
    return accuracy_score.compute(predictions=predictions, references=labels)

In this function, the predictions from the sequence modeling head come in the form of logits, so we use the np.argmax function to find the most confident class prediction and compare that against the ground truth labels.

Finally, we just need to define the training arguments. To warm-up, we’ll set α=1 to see how well DistilBERT performs without any signal from the teacher:9

batch_size = 48

student_training_args = DistillationTrainingArguments(
    output_dir="checkpoints", evaluation_strategy = "epoch", num_train_epochs=5,
    learning_rate=2e-5, per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size, alpha=1, weight_decay=0.01)

Next we load the teacher model, instantiate the trainer and start fine-tuning:

teacher_checkpoint = "lewtun/bert-base-uncased-finetuned-clinc"
teacher_model = (AutoModelForSequenceClassification
                 .from_pretrained(teacher_checkpoint, num_labels=num_labels)
                 .to(device))

distil_trainer = DistillationTrainer(model_init=student_init,
    teacher_model=teacher_model, args=student_training_args,
    train_dataset=clinc_enc['train'], eval_dataset=clinc_enc['validation'],
    compute_metrics=compute_metrics, tokenizer=student_tokenizer)
distil_trainer.train();

Epoch Training Loss Validation Loss Accuracy Runtime Samples Per Second
1 4.309400 3.318003 0.702903 0.970700 3193.619000
2 2.659500 1.904174 0.843871 0.979200 3165.834000
3 1.573200 1.178305 0.894194 0.988600 3135.649000
4 1.026700 0.873162 0.911613 0.987400 3139.536000
5 0.805600 0.785567 0.917742 1.019300 3041.436000

The accuracy on the validation set looks quite good compared to the 94% that BERT-base teacher achieves. Now that we’ve fine-tuned DistilBERT, we can wrap it in a TextClassificationPipeline and run it through our performance benchmark:

pipe = TextClassificationPipeline(
    model=distil_trainer.model.to("cpu"), tokenizer=distil_trainer.tokenizer)
optim_type = "DistilBERT"
pb = PerformanceBenchmark(pipe, clinc["test"], optim_type=optim_type)
perf_metrics.update(pb.run_benchmark())
Model size (MB) - 255.89
Average latency (ms) - 24.13 +- 10.06
Accuracy on test set - 0.856

To compare these results against our baseline, let’s create a scatter plot of the accuracy against the latency, with the radius of each point corresponding to the size of the model. The following function does what we need and marks the current optimization type as a dashed circle to aid the comparison to previous results:

import pandas as pd

def plot_metrics(perf_metrics, current_optim_type):
    df = pd.DataFrame.from_dict(perf_metrics, orient='index')

    for idx in df.index:
        df_opt = df.loc[idx]
        if idx == current_optim_type:
            plt.scatter(df_opt["time_avg_ms"], df_opt["accuracy"] * 100,
                        alpha=0.5, s=df_opt["size_mb"], label=idx,
                        marker='$u25CC$')
        else:
            plt.scatter(df_opt["time_avg_ms"], df_opt["accuracy"] * 100,
                        s=df_opt["size_mb"], label=idx, alpha=0.5)

    legend = plt.legend(bbox_to_anchor=(1,1))
    for handle in legend.legendHandles:
        handle.set_sizes([20])

    plt.ylim(80,90)
    plt.xlim(5, 53)
    plt.ylabel("Accuracy (%)")
    plt.xlabel("Average latency (ms)")
    plt.show()

plot_metrics(perf_metrics, optim_type)

From the plot we can see that by using a smaller model we’ve managed to decrease the average latency by almost a factor of two. And all this at the price of just over a 1% reduction in accuracy! Let’s see if we can close that last gap by including the distillation loss the teacher and finding good values for α and T.

Finding Good Hyperparameters with Optuna

So what values of α and T should we pick? We could do a grid search over the 2D parameter space but a much better alternative is to use Optuna,10 which is an optimization framework designed for just this type of task. Optuna formulates the search problem in terms of an objective function that is optimized through multiple trials. For example, suppose we wished to minimize Rosenbrock’s “banana function”

f(x,y)=(1-x)2+100(y-x2)2

which is a famous test case for optimization frameworks. As shown in Figure 3-6, the function gets its name from the curved contours and has a global minimum at (x,y)=(1,1). Finding the valley is an easy optimization problem, but converging to the global minimum is not.

A banana plot
Figure 3-6. Plot of the Rosenbrock function of two variables

In Optuna, we can find the minimum of f(x,y) by defining an objective function that returns the value of f(x,y):

def objective(trial):
    x = trial.suggest_float("x", -2, 2)
    y = trial.suggest_float("y", -2, 2)
    return (1 - x) ** 2 + 100 * (y - x ** 2) ** 2

The trial.suggest_float object specifies the parameter ranges to sample uniformly from and Optuna also provides suggest_int and suggest_categorical for integer and categorical parameters respectively. Optuna collects multiple trials as a study so to create one we just pass the objective function to study.optimize as follows:

study = optuna.create_study()
study.optimize(objective, n_trials=1000)

Once the study is completed, we can then find the best parameters as follows:

study.best_params
{'x': 1.0294476378522224, 'y': 1.0595653705812769}

We see that with 1,000 trials, Optuna has managed to find values for x and y that are reasonably close to the global minimum. To use Optuna in Transformers, we use a similar logic by first defining the hyperparameter space that we wish to optimize over. In addition to α and T, we’ll include the number of training epochs as follows:

def hp_space(trial):
    return {"num_train_epochs": trial.suggest_int("num_train_epochs", 5, 10),
        "alpha": trial.suggest_float("alpha", 0, 1),
        "temperature": trial.suggest_int("temperature", 2, 20)}

Running the hyperparameter search with the Trainer is then quite simple; we just need to specify the number of trials to run and a direction to optimize for. Since we want the best possible accuracy, we pick direction="maximize" in the Trainer.hyperparameter_search function and pass the hyperparameter search space as follows:

best_run = distil_trainer.hyperparameter_search(
    n_trials=20, direction="maximize", hp_space=hp_space)

The hyperparameter_search method returns a BestRun object which contains the value of the objective that was maximized (by default the sum of all metrics) and the hyperparameters it used for that run:

best_run
BestRun(run_id='4', objective=3080.872670967742,
 > hyperparameters={'num_train_epochs': 8, 'alpha': 0.31235083318309453,
 > 'temperature': 16})

This value of α tells us that most of the training signal is coming from the knowledge distillation term. Let’s update our trainer with these values and run the final training run:

for k,v in best_run.hyperparameters.items():
    setattr(distil_trainer.args, k, v)

distil_trainer.train();

Epoch Training Loss Validation Loss Accuracy Runtime Samples Per Second
1 1.608300 2.977128 0.714516 0.981500 3158.565000
2 0.904200 1.566405 0.877419 0.981000 3159.929000
3 0.509100 0.881892 0.915806 0.988000 3137.623000
4 0.317700 0.594229 0.932581 1.024500 3025.740000
5 0.230200 0.475622 0.934839 0.995800 3113.162000
6 0.189800 0.419630 0.939032 1.014300 3056.174000
7 0.170300 0.394079 0.943226 1.012700 3061.031000
8 0.161400 0.386891 0.942258 1.009100 3072.173000

Remarkably we’ve been able to train the student to match the accuracy of the teacher, despite having almost half the number of parameters! Let’s save the model for future use:

distil_trainer.save_model("models/distilbert-base-uncased-distilled-clinc")

Benchmarking Our Distilled Model

Now that we have an accurate student, let’s create a pipeline and redo our benchmark to see how we perform on the test set:

pipe = TextClassificationPipeline(
    model=distil_trainer.model.to("cpu"), tokenizer=distil_trainer.tokenizer)
optim_type = "Distillation"
pb = PerformanceBenchmark(pipe, clinc["test"], optim_type=optim_type)
perf_metrics.update(pb.run_benchmark())
Model size (MB) - 255.89
Average latency (ms) - 24.58 +- 7.66
Accuracy on test set - 0.871

To put these results in context, let’s also visualise them with our plot_metrics function:

plot_metrics(perf_metrics, optim_type)

As expected, the model size and latency remain essentially unchanged compared to the DistilBERT benchmark, but the accuracy has improved and even surpassed the performance of the teacher! We can actually compress our distilled model even further using a technique known as quantization. That’s the topic for the next section.

Making Models Faster with Quantization

We’ve now seen that with knowledge distillation we can reduce the computational and memory cost of running inference by transferring the information from a teacher into a smaller student. Quantization takes a different approach; instead of reducing the number of computations, it makes them much more efficient by representing the weights and activations with low-precision data types like 8-bit integer (INT8) instead of the usual 32-bit floating-point (FP32). By reducing the number of bits, the resulting model requires less memory storage, and operations like matrix multiplication can be performed much faster with integer arithmetic. Remarkably, these performance gains can be realized with little to no loss in accuracy!

So what does it mean to quantize the weights or activations of a neural network? The basic idea is that we can “discretize” the floating-point values f in each tensor by mapping their range [fmax,fmin] into a smaller one [qmax,qmin] of fixed-point numbers q, and linearly distributing all values in between. Mathematically, this mapping is described by the following equation

f=fmax-fminqmax-qmin(q-Z)=S(q-Z),

where the scale factor S is a positive floating-point number and the constant Z has the same type as q and is called the zero-point because it corresponds to the quantized value of the floating-point value f=0. Note that the map needs to be affine13 so that we get back floating-point numbers when we dequantize the fixed-point ones. An illustration of the conversion is shown in Figure 3-8.

Mapping floating-point numbers to 8-bit integers
Figure 3-8. Quantizing floating-point numbers as unsigned 8-bit integers (courtesy of Manas Sahni).

Now, one of the main reasons why Transformers (and deep neural networks more generally) are prime candidates for quantization is that the weights and activations tend to take values in relatively small ranges. This means we don’t have to squeeze the whole range of possible FP32 numbers into, say, the 28=256 numbers represented by INT8. To see this, let’s pick out one of the attention weight matrices from our BERT-base model and plot the frequency distribution of the values:

import matplotlib.pyplot as plt

state_dict = bert_model.state_dict()
weights = state_dict["bert.encoder.layer.0.attention.output.dense.weight"]
plt.hist(weights.flatten().numpy(), bins=250, range=(-0.3,0.3));

As we can see, the values of the weights are uniformly distributed in the small range [-0.1,0.1] around zero. Now, suppose we want to quantize this tensor as a signed 8-bit integer. In that case, the range of possible values for our integers is [qmax,qmin] = [-128,127] so the zero-point coincides with the zero of FP32 and the scale factor is calculated according to the previous equation:

zero_point = 0
scale = (weights.max() - weights.min()) / (127 - (-128))

To obtain the quantized tensor, we just need to invert the mapping q=f/S+Z, clamp the values, round them to the nearest integer, and represent the result in the torch.int8 data type using the Tensor.char function:

(weights / scale + zero_point).clamp(-128, 127).round().char()
tensor([[  2,  -1,   1,  ...,  -2,  -6,   9],
        [  7,   2,  -4,  ...,  -3,   5,  -3],
        [-15,  -8,   5,  ...,   3,   0,  -2],
        ...,
        [ 11,  -1,  12,  ...,  -2,   0,  -3],
        [ -2,  -6, -13,  ...,  11,  -3, -10],
        [-12,   5,  -3,  ...,   7,  -3,  -1]], dtype=torch.int8)

Great, we’ve just quantized our first tensor! In PyTorch we can simplify the conversion by using the quantize_per_tensor function together with a quantized data type torch.qint that is optimized for integer arithmetic operations:

from torch import quantize_per_tensor

dtype = torch.qint8
quantized_weights = quantize_per_tensor(weights, scale, zero_point, dtype)
quantized_weights.int_repr()
tensor([[  2,  -1,   1,  ...,  -2,  -6,   9],
        [  7,   2,  -4,  ...,  -3,   5,  -3],
        [-15,  -8,   5,  ...,   3,   0,  -2],
        ...,
        [ 11,  -1,  12,  ...,  -2,   0,  -3],
        [ -2,  -6, -13,  ...,  11,  -3, -10],
        [-12,   5,  -3,  ...,   7,  -3,  -1]], dtype=torch.int8)

If we dequantize this tensor, we can visualize the frequency distribution to see the effect that rounding has had on our original values:

from mpl_toolkits.axes_grid1.inset_locator import zoomed_inset_axes,mark_inset

# Create histogram
fig, ax = plt.subplots()
ax.hist(quantized_weights.dequantize().flatten().numpy(),
         bins=250, range=(-0.3,0.3));
# Create zoom inset
axins = zoomed_inset_axes(ax, 5, loc='upper right')
axins.hist(quantized_weights.dequantize().flatten().numpy(),
         bins=250, range=(-0.3,0.3));
x1, x2, y1, y2 = 0.05, 0.1, 500, 2500
axins.set_xlim(x1, x2)
axins.set_ylim(y1, y2)
axins.axes.xaxis.set_visible(False)
axins.axes.yaxis.set_visible(False)
mark_inset(ax, axins, loc1=2, loc2=4, fc="none", ec="0.5")
plt.show()

This shows very clearly the discretization that’s induced by only mapping some of the weight values precisely and rounding the rest. To round out our little analysis, let’s compare how long it takes to compute the multiplication of two weight tensors with FP32 and INT8 values. For the FP32 tensors we can multiply them using PyTorch’s nifty @ operator:

%%timeit
weights @ weights
9.76 ms ± 207 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

For the quantized tensors we need the QFunctional wrapper class so that we can perform operations with the special torch.qint8 data type:

from torch.nn.quantized import QFunctional

q_fn = QFunctional()

This class supports various elementary operations like addition and in our case we can time the multiplication of our quantized tensors as follows:

%%timeit
q_fn.mul(quantized_weights, quantized_weights)
107 µs ± 7.87 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

Compared to our FP32 computation, using the INT8 tensors is almost 100 times faster! Even larger gains can be obtained by using dedicated backends for running quantized operators efficiently, and as of this book’s writing PyTorch supports:

  • x86 CPUs with AVX2 support or higher

  • ARM CPUs (typically found in mobile/embedded devices)

Since INT8 numbers have four times less bits than FP32, quantization also reduces the memory storage by up to a factor of four. In our simple example we can verify this by comparing the underlying storage size of our weight tensor and quantized cousin by using the Tensor.storage function and the getsizeof function from Python’s sys module:

import sys

sys.getsizeof(weights.storage()) / sys.getsizeof(quantized_weights.storage())
3.999715196311114

For a full-scale Transformer, the actual compression rate depends on which layers are quantized and as we’ll see in the next section it is only the linear layers that typically get quantized.

So what’s the catch with quantization? Changing the precision for all computations in our model introduces small disturbances at each point in the model’s computational graph which can compound and affect the model’s performance. There are several ways to quantize a model which all have pros and cons. In the following section we will briefly introduce them.

Quantization Strategies

Dynamic Quantization

When using dynamic quantization nothing is changed during training and the adaptations are only performed during inference. Like all quantization methods we will discuss, the weights of the model are converted to INT8 ahead of inference time. In addition to the weights, the model’s activations are also quantized. The reason this approach is dynamic is because the quantization happens on-the-fly. This means that all the matrix multiplications can be calculated with highly optimized INT8 functions. Of all the quantization methods discussed here, dynamic quantization is the simplest one. However, with dynamic quantization the activations are written and read to memory in floating-point format. This conversion between integer- and floating-point format can be a performance bottleneck. The next section discusses a method that addresses this issue.

Static Quantization

Instead of computing the quantization of the activations on the fly, one could save the conversion to floating-point if the quantization scheme of the activations were pre-computed. Static quantization achieves this by observing the activations patterns on a representative sample of the data ahead of inference time. The ideal quantization scheme is calculated and then saved. This enables us to skip the conversion between INT8 and FP32 values and produces an additional speed-up of the computations. However, this requires access to a good data sample and introduces an additional step in the pipeline, since we now need to train and determine the quantization scheme before we can perform inference. There is one aspect that also static quantization does not address and this is the discrepancy between the precision during training and inference which leads to a performance drop in the model’s metrics (e.g. accuracy). This can be improved by adapting the training loop as discussed in the next section.

Quantization Aware Training

The affect of quantization can be effectively simulated during training by “fake” quantization of the FP32 values. Instead of using INT8 during training the FP32 values are rounded to mimic the effect of quantization. This is done during both the forward and backward pass and improves performance in terms of model metrics over static and dynamic quantization.

Quantizing Transformers in PyTorch

The main bottleneck for running inference with Transformers is the compute and memory bandwidth associated with the enormous number of weights in these models. For this reason, dynamic quantization is currently the best approach for Transformer-based models in NLP. In smaller computer vision models the limiting factor is the memory bandwidth of the activations which is why static quantization is generally used and quantization aware training in cases where the performance drops are too significant.

Implementing dynamic quantization in PyTorch is quite simple and can be done with a single line of code:

from torch.quantization import quantize_dynamic

model_ckpt = "models/distilbert-base-uncased-distilled-clinc"
tokenizer = AutoTokenizer.from_pretrained(model_ckpt)
model = (AutoModelForSequenceClassification
         .from_pretrained(model_ckpt).to("cpu"))

model_quantized = quantize_dynamic(model, {nn.Linear}, dtype=torch.qint8)

Here we pass to quantize_dynamic the full-precision model and specify the set of PyTorch layer classes in that model that we want to quantize. The dtype argument specifies the target precision and can be fp16 or qint8.

Benchmarking Our Quantized Model

With out model now quantized, let’s pass it through the benchmark and visualise the results:

pipe = TextClassificationPipeline(model=model_quantized, tokenizer=tokenizer)
optim_type = "Distillation + quantization"
pb = PerformanceBenchmark(pipe, clinc["test"], optim_type=optim_type)
perf_metrics.update(pb.run_benchmark())
plot_metrics(perf_metrics, optim_type)

Wow, the quantized model is almost half the size of our distilled one and twice as fast! Let’s see if we can push our optimization to the limit with a powerful framework called ONNX.

Optimizing Inference with ONNX and the ONNX Runtime

ONNX is an open standard that defines a common set of operators and a common file format to represent deep learning models in a wide variety of frameworks, including PyTorch and TensorFlow.14 When a model is exported to the ONNX format, these operators are used to construct a computational graph (often called an intermediate representation) which represents the flow of data through the neural network. An example of such a graph for BERT-base is shown in Figure 3-9, where each node receives some input, applies an operation like “Add” or “Squeeze”, and then feeds the output to the next set of nodes.

Example ONNX graph
Figure 3-9. A section of the ONNX graph for BERT-base, visualized in Netron

By exposing a graph with standardized operators and data types, ONNX makes it easy to switch between frameworks. For example, a model trained in PyTorch can be exported to ONNX format and then imported in TensorFlow (and vice versa).

Where ONNX really shines is when it is coupled with a dedicated accelerator like the ONNX Runtime, or ORT for short. ORT provides tools to optimize the ONNX graph through techniques like operator fusion and constant folding,15 and defines an interface to Execution Providers that allow you to run the model on different types of hardware. This is a powerful abstraction and Figure 3-10 shows the high-level architecture for the ONNX and ORT ecosystem.

Architecture of the ONNX and ONNX Runtime ecosystem
Figure 3-10. Architecture of the ONNX and ONNX Runtime ecosystem (courtesy of the ONNX Runtime team)

To see ORT in action, the first thing we need to do is convert our distilled model into the ONNX format. Transformers has an in-built function called convert_graph_to_onnx.convert that simplifies the process by doing the following steps:

  • Initializes the model as a Pipeline

  • Runs dummy inputs through the pipeline so that ONNX can record the computational graph

  • Defines dynamic axes to handle dynamic sequence lengths

  • Saves the graph with network parameters

To use this function, we first need to set some OpenMP environment variables for ONNX:

from psutil import cpu_count

%env OMP_NUM_THREADS={cpu_count()}
%env OMP_WAIT_POLICY=ACTIVE
env: OMP_NUM_THREADS=8
env: OMP_WAIT_POLICY=ACTIVE

OpenMP is an API designed for developing highly parallelized applications, and the OMP_NUM_THREADS sets the number of threads to use for parallel computations in the ONNX Runtime, while OMP_WAIT_POLICY=ACTIVE specifies that waiting threads should be active (i.e. using CPU processor cycles).

Next, let’s convert our distilled model to the ONNX format. Here we need to specify the argument pipeline_name="sentiment-analysis" since convert wraps the model in a Transformers pipeline during the conversion. We use the sentiment-analysis argument since this is the name of the text classification pipeline in Transformers. In addition to the model_ckpt we also pass the tokenizer to initialize the pipeline:

from transformers.convert_graph_to_onnx import convert

onnx_model_path = Path("onnx/model.onnx")
convert(framework="pt", model=model_ckpt, tokenizer=tokenizer,
        output=onnx_model_path, opset=12, pipeline_name="sentiment-analysis")
ONNX opset version set to: 12
Loading pipeline (model: models/distilbert-base-uncased-distilled-clinc,
 > tokenizer: PreTrainedTokenizerFast(name_or_path='models/distilbert-base-
 > uncased-distilled-clinc', vocab_size=30522, model_max_len=512, is_fast=True,
 > padding_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token':
 > '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token':
 > '[MASK]'}))
Creating folder onnx
Using framework PyTorch: 1.5.0
Found input input_ids with shape: {0: 'batch', 1: 'sequence'}
Found input attention_mask with shape: {0: 'batch', 1: 'sequence'}
Found output output_0 with shape: {0: 'batch'}
Ensuring inputs are in correct order
head_mask is not present in the generated input list.
Generated inputs order: ['input_ids', 'attention_mask']

ONNX uses operator sets to group together immutable operator specifications, so opset=12 corresponds to a specific version of the ONNX library.

Now that we have our model saved, we need to create and inference session to feed inputs to the model:

from onnxruntime import (GraphOptimizationLevel, InferenceSession,
                         SessionOptions)

def create_model_for_provider(model_path, provider="CPUExecutionProvider"):
    options = SessionOptions()
    options.intra_op_num_threads = 1
    options.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL
    session = InferenceSession(str(model_path), options, providers=[provider])
    session.disable_fallback()
    return session
onnx_model = create_model_for_provider(onnx_model_path)

Let’s test this out with an example from the test set. Since the output from the convert function tells us that ONNX expects just the input_ids and attention_mask as inputs, we need to drop the label column from our sample:

inputs = clinc_enc["test"][:1]
del inputs["labels"]
logits_onnx = onnx_model.run(None, inputs)[0]
logits_onnx.shape
(1, 151)

As expected, by specifying the sentiment-analysis pipeline name we get the class logits as the output so we can easily get the predicted label by taking the argmax:

np.argmax(logits_onnx)
61

which indeed agrees with the ground truth label:

clinc_enc["test"][0]["labels"]
61

Since we cannot use the TextClassificationPipeline class to wrap our ONNX model, we’ll create our own class that mimics the core behaviour:

from scipy.special import softmax

class OnnxPipeline:
    def __init__(self, model, tokenizer):
        self.model = model
        self.tokenizer = tokenizer

    def __call__(self, query):
        model_inputs = self.tokenizer(query, return_tensors="pt")
        inputs_onnx = {k: v.cpu().detach().numpy()
                       for k, v in model_inputs.items()}
        logits = self.model.run(None, inputs_onnx)[0][0, :]
        probs = softmax(logits)
        pred_idx = np.argmax(probs).item()
        return [{"label": intents.int2str(pred_idx), "score": probs[pred_idx]}]

We can then test this on our simple query to see if we recover the car_rental intent:

pipe = OnnxPipeline(onnx_model, tokenizer)
pipe(query)
[{'label': 'car_rental', 'score': 0.8440852}]

Great, our pipeline works well so the next step is to create a performance benchmark for ONNX models. Here we can build on the work we did with the PerformanceBenchmark class by simply overriding the compute_size function and leaving the compute_accuracy and time_pipeline functions intact. The reason we need to override the compute_size function is that we cannot rely on the state_dict and torch.save to measure a model’s size since onnx_model is technically an ONNX InferenceSession object which doesn’t have access to the attributes of PyTorch’s nn.Module. In any case, the resulting logic is simple and can be implemented as follows:

class OnnxPerformanceBenchmark(PerformanceBenchmark):
    def __init__(self, *args, model_path, **kwargs):
        super().__init__(*args, **kwargs)
        self.model_path = model_path

    def compute_size(self):
        size_mb = Path(self.model_path).stat().st_size / (1024 * 1024)
        print(f"Model size (MB) - {size_mb:.2f}")
        return {"size_mb": size_mb}

With our new benchmark, let’s see how our distilled model performs when converted to ONNX format:

optim_type = "Distillation + ORT"
pb = OnnxPerformanceBenchmark(pipe, clinc["test"], optim_type,
                              model_path="onnx/model.onnx")
perf_metrics.update(pb.run_benchmark())
Model size (MB) - 255.89
Average latency (ms) - 10.54 +- 2.20
Accuracy on test set - 0.871
plot_metrics(perf_metrics, optim_type)

Remarkably, converting to the ONNX format and using the ONNX runtime has more than halved the average latency of our distilled model (and is almost five times faster than our baseline)! Let’s see if we can squeeze a bit more performance by applying some Transformer-specific optimizations.

Optimizing for Transformer Architectures

We’ve just seen that the ONNX Runtime was very good at optimizing our distilled model out of the box. However, the ONNX Runtime library also offers an optimizer module that contains some Transformer-specific optimizations that we can try to see if the model is fully optimized or not. To use the optimizer module we first need to define some optimization options that are specific to our model. In our case, DistilBERT belongs to the bert model type so we need to use the BertOptimizationOptions class from onnxruntime_tools:

from onnxruntime_tools.transformers.onnx_model_bert import (
    BertOptimizationOptions)

model_type = "bert"
opt_options = BertOptimizationOptions(model_type)
opt_options.enable_embed_layer_norm = False

Here we’ve disabled the norm optimization on the embedding layer to get better model size compression. Now that we’ve specified the model options, we can then run optimizer.optimize_model to optimize the ONNX model specifically for BERT-like architectures:

from onnxruntime_tools import optimizer

opt_model = optimizer.optimize_model(
    "onnx/model.onnx", model_type, num_heads=12, hidden_size=768,
    optimization_options=opt_options)
opt_model.save_model_to_file("onnx/model.opt.onnx")

Here we’ve specified the number of heads and hidden size in our DistilBERT model. The last thing to do is create an inference session for our optimized model, wrap it in a pipeline and run it through our benchmark:

onnx_model_opt = create_model_for_provider("onnx/model.opt.onnx")
pipe = OnnxPipeline(onnx_model_opt, tokenizer)
optim_type = "Distillation + ORT (optimized)"
pb = OnnxPerformanceBenchmark(pipe, clinc["test"], optim_type,
                              model_path="onnx/model.opt.onnx")
perf_metrics.update(pb.run_benchmark())
Model size (MB) - 255.86
Average latency (ms) - 11.22 +- 3.52
Accuracy on test set - 0.871
plot_metrics(perf_metrics, optim_type)

Okay, it seems that our original ORT optimization was already close to the optimal one for this architecture. Let’s now see what happens if we add quantization to the mix. Similar to PyTorch, ORT offers three ways to quantize a model: dynamic, static, and quantization aware training. As we did with PyTorch, we’ll apply dynamic quantization to our distilled model. In ORT, the quantization is applied through the quantize_dynamic function which requires a path to the ONNX model to quantize, a target path to save the quantized model to, and the data type to reduce the weights to:

from onnxruntime.quantization import quantize_dynamic, QuantType

model_input = "onnx/model.onnx"
model_output = "onnx/model.quant.onnx"
quantize_dynamic(model_input, model_output, weight_type=QuantType.QInt8)

Now that the model is quantized, let’s run it through our benchmark:

onnx_quantized_model = create_model_for_provider(model_output)
pipe = OnnxPipeline(onnx_quantized_model, tokenizer)
optim_type = "Distillation + ORT (quantized)"
pb = OnnxPerformanceBenchmark(pipe, clinc["test"], optim_type,
                              model_path=model_output)
perf_metrics.update(pb.run_benchmark())
Model size (MB) - 185.71
Average latency (ms) - 6.95 +- 4.75
Accuracy on test set - 0.875
plot_metrics(perf_metrics, optim_type)

Wow, ORT quantization has reduced the model size and latency by around a factor of two compared to the model obtained from PyTorch quantization (the Distillation + quantization blob). One reason for this is that PyTorch only optimizes the nn.Linear modules, while ONNX quantizes the embedding layer as well. From the plot we can also see that applying ORT quantization to our distilled model has provided an almost 7-fold gain compared to our BERT baseline!

This concludes our analysis of techniques to speed-up Transformers for inference. We have seen that methods such as quantization reduce the model size by reducing the precision of the representation. Another strategy to reduce the size is to remove some weights altogether - this technique is called weight pruning and is the focus of the next section.

Making Models Sparser with Weight Pruning

So far we’ve seen that knowledge distillation and weight quantization are quite effective at producing faster models for inference, but in some cases you might also have strong constraints on the memory footprint of your model. For example, if your product manager suddenly decides that the text-assistant needs to be deployed on a mobile device then we’ll need our intent classifier to take up as little storage space as possible. To round out our survey of compression methods, let’s take a look at how we can shrink the number of parameters in our model by identifying and removing the least important weights in the network.

Sparsity in Deep Neural Networks

As shown in Figure 3-11, the main idea behind pruning is to gradually remove weight connections (and potentially neurons) during training such that the model becomes progressively sparser. The resulting pruned model has a smaller number of non-zero parameters which can then be stored in a compact sparse matrix format. Pruning can be also combined with quantization to obtain further compression.

Network Pruning
Figure 3-11. Weights and neurons before and after pruning. Image from Learning both Weights and Connections for Efficient Neural Networks by S. Han et al (2015).

Weight Pruning Methods

Mathematically, the way most weight pruning methods work is to calculate a matrix ?n×n of importance scores and then select the top-k percent of weights by importance:

Topk(?)ij=1ifSijintopk%0otherwise

In effect, k acts as a new hyperparameter to control the amount of sparsity in the model, that is the proportion of weights that are zero-valued. Lower values of k correspond to sparser matrices. From these scores we can then define a mask matrix ?{0,1}n×n that masks the weights Wij during the forward pass with some input xi and effectively creates a sparse network of activations ai:

ai=k=1nWikMikxk.

As discussed in the tongue-in-cheek Optimal Brain Surgeon paper16, at the heart of each pruning method are a set of questions that need to be considered:

  • Which weights should be eliminated?

  • How should the remaining weights be adjusted for best performance?

  • How can such network pruning be done in a computationally efficient way?

The answers to these questions inform how the score matrix ? is computed, so let’s begin by looking at one of the earliest and most popular pruning methods: magnitude pruning.

Magnitude Pruning

As the name suggests, magnitude pruning calculates the scores according to the magnitude of the weights ?=Wij1j,jn and then derives the masks from ?=Topk(?). In the literature it is common to apply magnitude pruning in an iterative fashion17 by first training the model to learn which connections are important and pruning the weights of least importance. The sparse model is then re-trained and the process repeated until the desired sparsity is reached.

One drawback with this approach is that it is computationally demanding: at every step of pruning we need to train the model to convergence. For this reason it is generally better to gradually increase the initial sparsity si (which is usually zero) to a final value sf after some number of steps N:18

st=sf+(si-sf)1-t-t0NΔt3fort{t0,t0+Δt,...,t0+NΔt}.

Here the idea is to update the binary masks ? every Δt steps to allow masked weights to reactivate during training and recover from any potential loss in accuracy that is induced by the pruning process. As shown in Figure 3-12, the cubic factor implies that the rate of weight pruning is highest in the early phases (when the number of redundant weights is large) and gradually tapers off.

Sparsity scheduler
Figure 3-12. The cubic sparsity scheduler used for pruning.

One problem with magnitude pruning is that it is really designed for pure supervised learning, where the importance of each weight is directly related to the task at hand. By contrast, in transfer learning the importance of the weights is primarily determined by the pretraining phase, so magnitude pruning can remove connections that are important for the fine-tuning task. Recently, an adaptive approach19 called movement pruning has been proposed by the HuggingFace team - let’s take a look.

Movement Pruning

The basic idea behind movement pruning is to gradually remove weights during fine-tuning such that the model becomes progressively sparser. The key novelty is that both the weights and the scores are learned during fine-tuning. So instead of deriving the scores directly from the weights (like magnitude pruning does), the scores in movement pruning are arbitrary, and learned through gradient descent like any other neural network parameter. This implies that in the backward pass, we also track the gradient of the loss L with respect to the scores Sij. We can calculate the gradient from the expression of the activations ai as follows:

LSij=LaiaiSij=LaiWijxj.

Once the scores are learned, it is then straightforward to generate the binary mask using ?=Topk(?). There is also a “soft” version of movement pruning where instead of picking the top-k% of weights, one uses a global threshold τ to define the binary mask: ?=(?>τ).

The intuition behind movement pruning is that the weights which are “moving” the most from zero are the most important ones to keep. To see this, we first note that the gradient of L with respect to the weights Wij is given by

LWij=LaiMijxj,

which can be combined with the expression for L/Sij to yield

LSij=LWijWijMij.

Since the scores are increasing when the gradient L/Sij is negative, we see that this occurs under two scenarios (we can drop ? since it’s a positive matrix):

a)LWij<0andWij>0b)LWij>0andWij<0

In other words, the positive weights increase during fine-tuning and vice versa for the negative weights which is equivalent to saying that the scores increase as the weights move away from zero. As shown in Figure 3-13, this behavior differs from magnitude pruning which selects as the most important weights those which a furthest from zero.

Magnitude vs Movement Pruning
Figure 3-13. Comparison of weights removed (in grey) during magnitude pruning (left) and movement pruning (right).

These differences between the two pruning methods are also evident in the distribution of the remaining weights. As shown in Figure 3-14, magnitude pruning produces two clusters of weights, while movement pruning produces a smoother distribution.

Pruning Distributions
Figure 3-14. Distribution of remaining weights for magnitude pruning (MaP) and movement pruning (MvP)

In this chapter we’ll examine how well movement pruning works with a top-k scorer on our intent classifier. As of this book’s writing, Transformers does not support pruning methods “out of the box”, so we’ll have to implement the main classes ourselves. Fortunately, the code used to produce the results from the movement pruning paper is available in the examples/research_projects/movement-pruning folder of the Transformers repository, so we have a great foundation to work from. Let’s get started!

Creating Masked Transformers

To implement movement pruning, we’ll need a few different ingredients:

  • A Topk operator that we can use to binarize the scores by selecting the top-k% of weights.

  • A way to apply the adaptive masking on-the-fly to our BERT-base model.

  • A cubic sparsity scheduler.

Let’s start by implementing the Topk binarizer. From the definition, we need to calculate a binary mask matrix ? from a real-valued matrix ? if and only if Sij is among the k% highest values of ?. Since the back-propagated gradients will flow through the binary mask, we’ll use the autograd.Function from PyTorch to compute the mask on the forward pass and automatically calculate the gradients in the backward pass:

from torch.autograd import Function

class TopKBinarizer(Function):
    @staticmethod
    def forward(ctx, inputs, threshold):
        # Get threshold from column in validation set
        if not isinstance(threshold, float):
            threshold = threshold[0]
        # Sort the inputs
        mask = inputs.clone()
        _, idx = inputs.flatten().sort(descending=True)
        # Get number of elements above the threshdold
        j = int(threshold * inputs.numel())
        # Zero-out elements below the threshold
        flat_out = mask.flatten()
        flat_out[idx[j:]] = 0
        flat_out[idx[:j]] = 1
        return mask

    @staticmethod
    def backward(ctx, gradOutput):
        return gradOutput, None

Let’s test this out on a random matrix of scores:

from torch.autograd import Variable

torch.manual_seed(123)
dtype = torch.FloatTensor
scores = Variable(torch.randn(2, 2).type(dtype), requires_grad=True)
scores
tensor([[-0.1115,  0.1204],
        [-0.3696, -0.2404]], requires_grad=True)

Now let’s see what happens if we zero-out half of the scores:

topk = TopKBinarizer()
topk.apply(scores, 0.5)
tensor([[1., 1.],
        [0., 0.]], grad_fn=<TopKBinarizerBackward>)

Great, this makes sense since the entries in the top row of scores are the ones with the highest value. Okay, so now that we have a way to binarize the scores, the next step is to implement a fully connected layer that can calculate the binary mask on-the-fly and multiply this by the weight matrix and inputs. As of this book’s writing, there is no simple way to do this beyond extending the various BERT classes in Transformers. We refer the reader to the implementation in the Transformers repository, but note that the main ingredient is a replacement of all nn.Linear layers with a new layer that computes the mask on the forward pass:

from torch.nn import init

class MaskedLinear(nn.Linear):
    def __init__(self, in_features, out_features, bias=True, mask_scale = 0.0):
        super(MaskedLinear, self).__init__(
            in_features=in_features, out_features=out_features, bias=bias)
        self.mask_scale = mask_scale
        self.mask_init = mask_init
        self.mask_scores = nn.Parameter(torch.Tensor(self.weight.size()))
        self.init_mask()

    def init_mask(self):
        init.constant_(self.mask_scores, val=self.mask_scale)

    def forward(self, input, threshold):
        # Get the mask
        mask = TopKBinarizer.apply(self.mask_scores, threshold)
        # Mask weights with computed mask
        weight_thresholded = mask * self.weight
        # Compute output (linear layer) with masked weights
        return F.linear(input, weight_thresholded, self.bias)

This new linear layer can then be used to build up a custom MaskedBertModel, MaskedBertForSequenceClassification and so on by allowing the threshold parameter to be passed along with the inputs. We won’t show the explicit code here but refer the reader to the examples/research_projects/movement-pruning folder of the Transformers repository. These new masked classes work in the same way the ordinary ones, so let’s load the configuration and model_init so we can do multiple fine-pruning runs:

from pruning import MaskedBertConfig, MaskedBertForSequenceClassification

masked_config = MaskedBertConfig(num_labels=num_labels)

def model_init():
    return (MaskedBertForSequenceClassification
            .from_pretrained(bert_ckpt, config=masked_config).to(device))

Creating a Pruning Trainer

Now that we have our masked model, the next step is to implement a custom trainer that we can use for fine-pruning. Similar to knowledge distillation, we’ll need a few ingredients:

  • New hyperparameters like the amount of sparsity to start and end with during the training run. We’ll also need to specify what fraction of steps we use for warmup and cool down which are important.

  • A way to optimize the new learning rate for the scores.

  • A custom loss that can calculate the threshold at each step and feed that to the model to generate the loss.

The new training arguments are simple to include, and again we just subclass TrainingArguments:

class PruningTrainingArguments(TrainingArguments):
    def __init__(self, *args, initial_threshold=1., final_threshold=0.1,
                 initial_warmup=1, final_warmup=2,
                 mask_scores_learning_rate=1e-2, **kwargs):
        super().__init__(*args, **kwargs)
        self.initial_threshold = initial_threshold
        self.final_threshold = final_threshold
        self.initial_warmup = initial_warmup
        self.final_warmup = final_warmup
        self.mask_scores_learning_rate = mask_scores_learning_rate

For the trainer we’ll have to do a bit more work, so let’s look at a skeleton of what we need to override:

class PruningTrainer(Trainer):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.t_total = (len(self.get_train_dataloader()) //
                        self.args.gradient_accumulation_steps
                        * self.args.num_train_epochs)


    def create_optimizer_and_scheduler(self, num_training_steps):
        pass

    def compute_loss(self, model, inputs):
        pass

    def _schedule_threshold(self, step, total_step, warmup_steps,
        initial_threshold, final_threshold, initial_warmup, final_warmup):
        pass

First we need to implement the cubic sparsity scheduler. This is similar to the equation we saw earlier but in movement pruning we also allow for some amount of cool-down steps tf so the definition is as follows:

si0t<tisf+(si-sf)1-t-ti-tfNΔt3tit<T-tfsfotherwise

The following function implements this logic:

def _schedule_threshold(self, step, total_step, warmup_steps,
        initial_threshold, final_threshold, initial_warmup, final_warmup):
        if step <= initial_warmup * warmup_steps:
            threshold = initial_threshold
        elif step > (total_step - final_warmup * warmup_steps):
            threshold = final_threshold
        else:
            spars_warmup_steps = initial_warmup * warmup_steps
            spars_schedu_steps = ((final_warmup + initial_warmup)
                                  * warmup_steps)
            mul_coeff = 1 - ((step - spars_warmup_steps)
                             / (total_step - spars_schedu_steps))
            threshold = final_threshold + (
                (initial_threshold - final_threshold) * (mul_coeff ** 3))
        return threshold

PruningTrainer._schedule_threshold = _schedule_threshold

In addition to the usual inputs, the masked model expects the sparsity threshold produced from the sparsity scheduler. A simple way to provide this information is to overwrite the compute_loss function of the Trainer and extract the threshold at each training step:

def compute_loss(self, model, inputs):
    threshold = self._schedule_threshold(step=self.state.global_step+1,
        total_step=self.t_total, warmup_steps=self.args.warmup_steps,
        final_threshold=self.args.final_threshold,
        initial_threshold=self.args.initial_threshold,
        final_warmup=self.args.final_warmup,
        initial_warmup=self.args.initial_warmup)
    inputs["threshold"] = threshold
    outputs = model(**inputs)
    loss, _ = outputs
    return loss

PruningTrainer.compute_loss = compute_loss

As noted earlier a key property of movement pruning is that the scores are learned during fine-tuning. This means that we need to instruct the Trainer to optimize for the usual weights and these new score parameters. The way this is done in practice is to overwrite the Trainer.create_optimizer_and_scheduler function and indicate which parameters belong to the optimizer’s grouped parameters variable:

from transformers import AdamW, get_linear_schedule_with_warmup

def create_optimizer_and_scheduler(self, num_training_steps: int):
    no_decay = ["bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [{
        "params": [p for n, p in self.model.named_parameters()
                   if "mask_score" in n and p.requires_grad],
        "lr": self.args.mask_scores_learning_rate},
        {"params": [p for n, p in self.model.named_parameters()
                    if "mask_score" not in n and p.requires_grad
                    and not any(nd in n for nd in no_decay)],
         "lr": self.args.learning_rate,
         "weight_decay": self.args.weight_decay},
        {"params": [p for n, p in self.model.named_parameters()
                    if "mask_score" not in n and p.requires_grad
                    and any(nd in n for nd in no_decay)],
         "lr": self.args.learning_rate,
         "weight_decay": 0.0}]

    self.optimizer = AdamW(optimizer_grouped_parameters,
                           lr=self.args.learning_rate,
                           eps=self.args.adam_epsilon)
    self.lr_scheduler = get_linear_schedule_with_warmup(
        self.optimizer, num_warmup_steps=self.args.warmup_steps,
        num_training_steps=self.t_total)

PruningTrainer.create_optimizer_and_scheduler = create_optimizer_and_scheduler

Now that we’ve created out trainer it’s time to give it a spin!

Fine-Pruning With Increasing Sparsity

To evaluate the effect of pruning we’ll fine-tune BERT-base on our dataset at increasing levels of sparsity. We expect some accuracy drop compared to the 94.3% that the unpruned model achieves, but hopefully it is not too much. First we need to define the base training arguments for our runs:

num_train_epochs = 5
logging_steps = len(clinc_enc['train']) // batch_size
warmup_steps = logging_steps * num_train_epochs * 0.1
mask_scores_learning_rate = 1e-2

pruning_training_args = PruningTrainingArguments(
    output_dir="checkpoints", evaluation_strategy="epoch", learning_rate=2e-5,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size, logging_steps=logging_steps,
    warmup_steps=warmup_steps, num_train_epochs=num_train_epochs,
    mask_scores_learning_rate=mask_scores_learning_rate, weight_decay=0.01)

Next, we’ll gradually decrease the threshold parameter in our mask which is equivalent to increasing the sparsity of the weights. Since Transformers are quite robust to sparsity with movement pruning, we’ll start by retaining 30% of the weights and decrease down to 1%. The following loop implements the sparsity increase by updating the training arguments and validation set with the current threshold value, fine-tuning and then saving the models and accuracies for further analysis:

accuracies = {}

for threshold in [0.3, 0.1, 0.05, 0.03, 0.01]:
    model_ckpt = f"prunebert-{int(threshold * 100)}"
    pruning_training_args.final_threshold = threshold
    pruning_training_args.run_name = model_ckpt
    # Include the current sparsity in the validation set
    eval_ds = clinc_enc['validation'].map(lambda x : {'threshold': threshold})

    pruning_trainer = PruningTrainer(model_init=model_init,
    args=pruning_training_args, train_dataset=clinc_enc["train"],
    eval_dataset=eval_ds, tokenizer=bert_tokenizer,
    compute_metrics=compute_metrics)

    pruning_trainer.train()
    pruning_trainer.save_model(f"models/{model_ckpt}")
    preds = pruning_trainer.evaluate()
    accuracies[threshold] = preds["eval_accuracy"]

    wandb.finish()

As shown in Figure 3-15, we can see that pruning only has a small impact on accuracy and only starts to degrade once we start pruning more than 95% of the weights!

Accuracies vs sparsity
Figure 3-15. Effect of removing weights on BERT-base’s accuracy

Since the best performing model appears to have around 5% of the weights, let’s use this one to count the true number of remaining values and convert the model into a format suitable for native model classes in Transformers.

Counting the Number of Pruned Weights

Now that we’ve pruned a model, let’s do a sanity check to count the number of parameters we’ve removed. A simple way to do this is via PyTorch’s state_dict object which we encountered when saving the model to calculate its size. First, let’s load the state_dict associated with our pruned model:

prunebert_model_ckpt = "models/prunebert-5"
prunebert_args = torch.load(
    f"{prunebert_model_ckpt}/training_args.bin", map_location="cpu")
state_dict = torch.load(
    f"{prunebert_model_ckpt}/pytorch_model.bin", map_location="cpu")

To count the number of pruned weights, we’ll iterate throught the dictionary and count the number of elements in every layer that was masked with a layer_name.mask_scores and calculate the sparsity from this relative to the other layers. Since we’ve only pruned the trainable parameters of the model we’ll exclude the embedding parameters from the count, so we arrive at the following code:

# Number of remaining (not pruned) params in the encoder
remaining_count = 0
# Number of params in the encoder
encoder_count = 0
# Fraction of remaining weights
final_threshold = prunebert_args.final_threshold

for name, param in state_dict.items():
    if "encoder" not in name:
        continue

    if "mask_scores" in name:
        mask_ones = TopKBinarizer.apply(param, final_threshold).sum().item()
        remaining_count += mask_ones
    else:
        encoder_count += param.numel()
        if "bias" in name or "LayerNorm" in name:
            remaining_count += param.numel()

print(f"Remaining weights: {100 * remaining_count / encoder_count:.2f}%")
Remaining weights: 5.13%
Note

Movement pruning only eliminates weights in the encoder or decoder stack and task specific head. In particular, the embedding modules are frozen during fine-pruning so the total number of parameters in a fine-pruned model is larger than simply counting the remaining weights.

Pruning Once and For All

Now that we’re confident that we’ve pruned the model as expected, the final step is convert the model back into a form that is suitable for the standard BertForSequenceClassification class. Here we need to collect all the dense tensors and apply the mask to those which have been marked with the mask_scores name. The resulting state_dict can then be saved along with the configuration and tokenizer from our fine-pruned model:

import shutil

model_path = prunebert_model_ckpt
target_path = f"models/bertarized"
model = torch.load(f"{prunebert_model_ckpt}/pytorch_model.bin")
pruned_model = {}

for name, tensor in model.items():
    if "embeddings" in name or "LayerNorm" in name or "pooler" in name:
        pruned_model[name] = tensor
    elif "classifier" in name:
        pruned_model[name] = tensor
    elif "bias" in name:
        pruned_model[name] = tensor
    else:
        if "mask_scores" in name:
            continue
        prefix_ = name[:-6]
        scores = model[f"{prefix_}mask_scores"]
        mask = TopKBinarizer.apply(scores, final_threshold)
        pruned_model[name] = tensor * mask

shutil.copytree(model_path, target_path, dirs_exist_ok=True)
torch.save(pruned_model, f"{target_path}/pytorch_model.bin")

As a sanity check, we can now load our fine-pruned model with the BertConfig and AutoModelForSequenceClassification as follows:

from transformers import BertConfig

config = BertConfig.from_pretrained(target_path, id2label=id2label,
                                    label2id=label2id)
model = (AutoModelForSequenceClassification
                .from_pretrained(target_path, config=config).to('cpu'))

Finally we can run our model through the performance benchmark:

pipe = TextClassificationPipeline(model=model, tokenizer=bert_tokenizer)
PerformanceBenchmark(pipe, clinc["test"]).run_benchmark();
Model size (MB) - 418.13
Average latency (ms) - 90.73 +- 45.57
Accuracy on test set - 0.840

Looking at the at the benchmark you might be surprised: the pruned model is neither smaller nor faster! The reason is that we stored weights as dense matrices which occupy the same space irrespective of how many values are set to zero. Similarly a matrix multiplication does not get faster if more values are zero. Unfortunately, modern frameworks still lack fast sparse operations and hence it is hard to get a speedup from pruning. However, we can store the matrices in a more compact format which we will explore in the next section.

Quantizing and Storing in Sparse Format

As of this book’s writing, pruning in PyTorch or TensorFlow does not lead to improved inferences times or a reduced model size since a dense tensor filled with zeroes is still dense. However, when combined with compression algorithms like gzip, pruning does allow us to reduce the size of the model on disk. To get the most amount of compression, we’ll first apply quantization to our pruned model:

from torch.quantization import quantize_dynamic

quantized_model = quantize_dynamic(model=model, qconfig_spec={torch.nn.Linear},
                                   dtype=torch.qint8)

qtz_st = quantized_model.state_dict()

Next let’s wrap this quantized model in a pipeline so we can get a sense for it’s size:

pipe = TextClassificationPipeline(model=quantized_model,
                                  tokenizer=bert_tokenizer)
PerformanceBenchmark(pipe, clinc["test"]).compute_size();
Model size (MB) - 173.15

The CSR Representation

Great, so naive quantization has reduced our dense model with 418 MB down to 173 MB. To get further compression we can convert the sparse quantized tensors in our model into the Compressed Sparse Row (CSR) representation. In this representation, a sparse matrix is represented by the row and column indices of the non-zero values. Since the other values are zero in sparse matrix, we do not need to keep track of their locations which can be inferred from the non-zero indices. The CSR representation is commonly used in machine learning because it provides better support for matrix operations than other compressed formats. To deepen our intuition, let’s create a sparse matrix by masking most of the elements in a dense one:

X = np.random.uniform(size=(3, 3))
X[X < 0.6] = 0
X
array([[0.        , 0.        , 0.        ],
       [0.60754485, 0.        , 0.        ],
       [0.94888554, 0.96563203, 0.80839735]])

Now we can store the matrix in the CSR format by using SciPy’s csr_matrix function:

from scipy.sparse import csr_matrix

X_csr = csr_matrix(X)
print(X_csr)
  (1, 0)        0.6075448519014384
  (2, 0)        0.9488855372533332
  (2, 1)        0.9656320330745594
  (2, 2)        0.8083973481164611

As expected, we see that each non-zero value is associated with a (row, column) tuple which for large sparse matrices provides a dramatic reduction in memory. Now what we can do is apply this compression to the sparse quantized tensors in our fine-pruned model. To see this, let’s get the first quantized tensor in our state_dict:

for name, param in qtz_st.items():
    if "dtype" not in name and param.is_quantized:
        scale = param.q_scale()
        zero_point = param.q_zero_point()
        print(f"Layer name - {name}")
        print(param)
        break
Layer name - bert.encoder.layer.0.attention.self.query._packed_params.weight
tensor([[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        ...,
        [0.0000, 0.0535, 0.0576,  ..., 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000]],
       size=(768, 768), dtype=torch.qint8,
       quantization_scheme=torch.per_tensor_affine, scale=0.004113025031983852,
       zero_point=0)

We can convert this tensor into the CSR format by first getting its integer representation with the int_repr() function:

param.int_repr()
tensor([[ 0,  0,  0,  ...,  0,  0,  0],
        [ 0,  0,  0,  ...,  0,  0,  0],
        [ 0,  0,  0,  ...,  0,  0,  0],
        ...,
        [ 0, 13, 14,  ...,  0,  0,  0],
        [ 0,  0,  0,  ...,  0,  0,  0],
        [ 0,  0,  0,  ...,  0,  0,  0]], dtype=torch.int8)

So if we wrap this integer-valued tensor with csr_matrix, we should get a compressed representation:

print(csr_matrix(param.int_repr()[:1]))
  (0, 73)       16
  (0, 91)       -2
  (0, 102)      7
  (0, 249)      -9
  (0, 250)      7
  (0, 277)      20
  (0, 374)      -4
  (0, 395)      5
  (0, 490)      7
  (0, 496)      -16
  (0, 520)      11
  (0, 560)      -13
  (0, 575)      -5
  (0, 581)      1
  (0, 638)      20
  (0, 644)      1
  (0, 655)      5
  (0, 719)      1
  (0, 739)      20

Now that we know how to create CSR matrices, let’s loop over our state_dict and convert each sparse quantized tensor into CSR format. We can then store this data in a new state_dict:

elementary_qtz_st = {}

for name, param in qtz_st.items():
    if "dtype" not in name and param.is_quantized:
        scale = param.q_scale()
        zero_point = param.q_zero_point()
        elementary_qtz_st[f"{name}.scale"] = scale
        elementary_qtz_st[f"{name}.zero_point"] = zero_point
        # Convert sparse quantized tensors into CSR format
        int_repr = param.int_repr()
        int_repr_cs = csr_matrix(int_repr)
        elementary_qtz_st[f"{name}.int_repr.data"] = int_repr_cs.data
        elementary_qtz_st[f"{name}.int_repr.indptr"] = int_repr_cs.indptr
        elementary_qtz_st[
            f"{name}.int_repr.indices"
        ] = np.uint16(int_repr_cs.indices)
        elementary_qtz_st[
            f"{name}.int_repr.shape"
        ] = int_repr_cs.shape
    else:
        elementary_qtz_st[name] = param

The final check is to see how much space we’ve save using the CSR format. We can use torch.save and the Linux du command to get the result:

torch.save(elementary_qtz_st, "tmp.pt")
!du -h tmp.pt
110M    tmp.pt

Nice, together with fine-pruning, quantization and the CSR format we have managed to reduce the storage space of our original model from 418 MB to 110 MB! This could be further optimized with the ONNX format, but we leave this as an exercise for the reader.

Conclusion

We’ve seen that optimizing Transformers for deployment in production environments involves compression along two dimensions: latency and memory footprint. Starting from a fine-tuned model we applied distillation, quantization, and optimizations through ORT to reduce the latency and memory by 7 fold. In particular, we found that quantization and conversion in ORT gave the largest gains with minimal effort.

Although pruning is an effective strategy for reducing the storage size of Transformer models, current hardware is not optimized for sparse matrix operations which limits the usefulness of this technique. However, this is an active and rapid area of research and by the time this book hits the shelves many of these limitations may have been resolved.

So where to from here? All of the techniques in this chapter can be adapted to other tasks such as question answering, named entity recognition, or language modeling. If you find yourself struggling to meet the latency requirements or your model is eating up all your compute budget we suggest giving one of these techniques a try.

1 An Evaluation Dataset for Intent Classification and Out-of-Scope Prediction, S. Larson et al. (2019)

2 As described by Emmanuel Ameisen in Building Machine Learning Powered Applications (O’Reilly), business or product metrics are the most important ones to consider; after all, it doesn’t matter how accurate your model is if it doesn’t solve a problem your business cares about. In this chapter we’ll assume that you have already defined the metrics that matter for your application and focus on optimizing the model metrics.

3 Model Compression, C. Bucila, R. Caruana, and A. Niculescu-Mizil (2006)

4 Distilling the Knowledge in a Neural Network, G. Hinton, O. Vinyals, and J. Dean (2015)

5 Switch Transformers: Scaling to Trillion Parameter Models with Simple and Efficient Sparsity, W. Fedus, B. Zoph, and N. Shazeer (2021)

6 Geoff Hinton coined this term in a talk to refer to the observation that softened probabilities reveal the hidden knowledge of the teacher.

7 DistilBERT, a distilled version of BERT: smaller, faster, cheaper and lighter, V. Sanh et al. (2019)

8 FastFormers: Highly Efficient Transformer Models for Natural Language Understanding, Y. Kim and H. Awadalla (2020)

9 This approach of fine-tuning a general-purpose, distilled language model is sometimes referred to as “task-agnostic” distillation

10 Optuna: A Next-generation Hyperparameter Optimization Framework, T. Akiba et al. (2019)

11 Sometimes the significand is also called the mantissa.

12 More precisely, the radix point which applies to all number bases.

13 An affine map is just a fancy name for the y=Ax+b map that you’re familiar with in the linear layers of a neural network.

14 There is a separate standard called ONNX-ML which is designed for traditional machine learning models like Random Forests and frameworks like Scikit-Learn.

15 A fused operation consists of a set of primitive operations that are combined in a composite operator like “Layer Normalization”. Constant folding refers to the process of evaluating constant expressions at compile time instead of runtime.

16 Second order derivatives for network pruning: Optimal Brain Surgeon, B. Hassibi and D. Stork (1993)

17 Learning both Weights and Connections for Efficient Neural Networks, S. Han et al. (2015)

18 To prune, or not to prune: exploring the efficacy of pruning for model compression, M. Zhu and S. Gupta, (2017)

19 Movement Pruning: Adaptive Sparsity by Fine-Tuning, V. Sanh, T. Wolf, and S. Rush (2020)

..................Content has been hidden....................

You can't read the all page of ebook, please click here login for view all page.
Reset