Training deep learning models tends to be computationally heavy, and good models (in terms of performance) typically have a lot of weights. The problem gets worse when it comes to actually serving the model and applying it to some task. The purpose of minimizing the models is to retain as much performance of the original model as possible while making the model more lightweight, and thus, more easily available to be utilized in a down-stream application.
This topic covers three popular techniques (namely, knowledge distillation, quantization, and pruning) that help to minimize larger models for easier deployment.
Knowledge distillation
Knowledge distillation is a method of parameter reduction where a smaller (student) model is trained to approximate the behavior of a more complex (teacher) model or an ensemble of models.
For the teacher model in the multi-class classification setting, the objective is to maximize the average log probability of making the correct prediction, and the trained model assigns probabilities to all incorrect classes, and while these probabilities are small, some of them are larger than other. These incorrect probabilities provide information on how the larger model generalizes.
Models usually aim to optimize training data performance instead of focusing on new data generalization, due to lack of the latter's specific information. However, in distillation, we can train a student model to generalize similarly to its teacher counterpart. If the large model, being the average of a diverse ensemble, generalizes effectively, a small model trained similarly will outperform one trained using the standard method on the same training set. This, in turn, suggests that in order to transfer the generalization from a larger model into a smaller one, one can use the described class probabilities as soft targets for training the small model.
More formally, to transfer the knowledge from the teacher model to the student model is to make the student model minimize the loss where the target is the distribution of class probabilities predicted by the teacher model. This teacher model probability distribution usually has high probabilities for the correct classes, and other classes are close to zero, which is somewhat too similar to the ground truths already available. To provide more info to the student model, the softmax temperature is introduced:
where
is the probability of class ;
is the logit;
is the temperature (the original paper uses values in the [1, 20] range).
When the increases, the resulting probability distribution gets softer, providing the smaller model more information on the classes that the larger model found to be similar to the predicted class (essentially, this is the knowledge being transferred). corresponds to the standard softmax function. Distillation loss is the measure of discrepancy between the predictions of the student and the teacher model.
The general loss function comprises of the distillation loss and the loss between the ground truths and the student model's predictions (this is the more 'standard' loss, sometimes referred to as the 'student loss', the for which is set to 1):
where
is the input;
are the student model parameters;
is the ground truth label;
is the cross-entropy loss function,
is the softmax function parameterized by the temperature T;
and are coefficients (discussed later on);
and are the logits of the student and teacher respectively.
There are 3 hyperparameters: , , and . 's are temperatures (again, in the [1, 20] range), the original paper proposes that when the student model is much smaller than the teacher model, lower 's work better. and are the weighted averages of the student and the teacher loss (), and they can be configured in different settings (e.g., set to 0.5, or set to 1 and is left tunable.
Quantization
Quantization is pretty straight-forward: the main numerical format for the majority of the models is 32-bit floating point (FP32). However, since the goal of minimization is to reduce the computational complexity and bandwidth while preserving as much accuracy as possible, there have been attempts to lower the precision to 8-bit integers (although even lower precision could be used, e.g., INT4, which is sometimes referred to as 'aggressive quantization').
Pruning
Pruning is a process of making the weight or the activation matricies sparser (increasing the number of 0's) by setting the elements that fall under some predetermined threshold to 0. The weights and the activations are the most commonly pruned, but there is research on pruning channels in the convolutional neural networks, full layers in the residual networks, etc, but we won't cover them here.
The intuition is that if an element is small enough (by some metric), it's contribution is not that significant to keep it around for the performance sake, so they can be safely removed. Another aspect is that realistically, the models are over-parametrized, because introducing more parameters usually ends up in better generalization, but also brings along more redundant features, and the weights of these features can be set to zero because they are not that informative (think of l1 regularization for feature selection as an analogy).
The pruning process in a very simple case can be illustrated as follows:
There are two modes for pruning: one-shot and incremental. One-shot pruning refers to the scenario where the model is trained and pruned only once. The problem is that the weights do become sparser, but one can go even further with it by doing incremental pruning. Incremental pruning is done by training, then pruning, and then fine-tuning the pruned model, and the (pruning, fine-tuning) sequence is applied multiple times, allowing to gradually extract the most important weights and make the non-zeroed weights more informative.
How to actually decide which weights are not important? The most common approach is looking at the absolute value of the weight that is considered insignificant if it falls under some threshold (also known as magnitude-based pruning). The underlying assumption is that the absolute value of the weight corresponds to it's relative importance towards the accuracy in the trained network. Choosing the threshold value itself can be done in the following manner: for every parametrized layer in the network, set multiple pruning levels as percentages (aka, the sparsity), prune, and log the accuracy on the test set for each layer and every pruning level. This is referred to as sensitivity analysis, which shows how the layers are affected by pruning.
Conclusion
As a result, you are now familiar with the following:
Knowledge distillation is a method where a smaller model is trained to approximate the behavior of a more complex model or an ensemble of models;
Quantization refers to changing the datatype from the default F32 to a smaller one;
Pruning sets certain weight to zero if their contribution is not significant based on some predetermined metric.