Batch normalization, often referred to as BatchNorm, is a technique that helps to accelerate the training process, improve the performance, and increase the stability of neural networks. It involves standardizing the inputs within a neural network, ensuring consistent distribution of data throughout the network.
In this topic, we will look at the motivation, the general working mechanism, and the effects of BatchNorm.
The problem setup
It’s well-known that feature scaling in general, and input standardization in particular accelerates the convergence of gradient descent. Standardization refers to transforming the features to have a mean of zero and a standard deviation of 1.
The issue with regular input standardization is that it's only applied in the very beginning, and when we are talking about deep networks, we have many inputs for each layer. As we move forward through the layers, the initial standardization will fade out, and the inputs for each consecutive layer will likely have different distributions because they are affected by the parameters in all preceding layers in the network (thus, small changes in the parameters might accumulate as the network depth increases). This phenomena of inputs changing their distribution over the course of training is often referred to as internal covariate shift, often referred abbreviated to ICS (later, we will see whether BatchNorm actually addresses the internal covariate shift problem and whether ICS has a negative impact at all, but the original motivation for BatchNorm is the reduction of the internal covariate shift).
Now, one might ask: "Why is it undesirable to have a change in the input distribution for each layer?". Basically, the distribution shift means that the layers have to constantly re-adjust to the new distribution, which is not very effective. So it's helpful to keep the distribution of the inputs consistent throughout training.
And this is where BatchNorm is introduced.
The definition
Before delving into the details, let's look at BatchNorm on a high level. The goal of BatchNorm is to shift the mean and standard deviation of the batch to the ones "desired" by the network. To do that, it performs two steps:
Make the batch's mean and variance equal to 0 and 1 respectively;
Rescale the batch to make the mean and variance equal to and respectively.
We consider a small example of a neural network for illustration. Suppose that there is a mini-batch net input of a given training example (where , lower index of 1 in this case is the activations of the previous layer) associated with the activation in the second hidden layer (red bounding box in the illustration below):
At first, the mean and the standard deviation for each individual feature in the mini-batch are computed. Formally, we can describe this for the input as follows (note that we ignore the layer index because it does not change from layer to layer):
where corresponds to the feature index. Then, we perform the standardization (scale to have a zero mean and unit variance, the same as the StandardScaler transformation in scikit-learn):
is a small value to avoid zero division.
Next, we perform pre-activation scaling:
where and are two trainable parameters updated via backpropagation. is a scaling parameter, and is a shift parameter (which controls the mean). Essentially, pre-activation scaling allows the network to decide what mean and standard deviation to have (it might or might not be 0 and 1, respectively, but BatchNorm can also learn to "undo" the standardization or learn other more flexible distribution that does not have a and ). also eliminates the need to train the biases separately. Then, is fed into the activation function itself, and the output is produced.
During inference, the batch normalization layer uses the average and variance over the whole training set (instead of using the batch statistic). The learned parameters beta (for shifting) and gamma (for scaling) are then used in the normalization process.
The effects of BatchNorm
In this section, we will describe some of the effects that BatchNorm has on the training process. The first and probably the most important effect of BatchNorm is that it makes the loss surface smoother, leading to more stable training.
In a non-BatchNorm'ed neural network, the loss is typically highly non-convex and has flat regions (which correspond to vanishing gradients) and many sharp local minima (which are associated with exploding gradients). BatchNorm makes the loss change at the smaller rate, and thus, the gradients also have a smaller magnitude. This makes the loss surface more predictable, allowing to use higher learning rates and reducing the need to carefully tune the hyperparameters in general.
Also, what BatchNorm essentially does is introduce noise into the training process, since the normalization depends on the batch statistic, which will be slightly different at each iteration.
As the name suggests, BatchNorm is impacted by the choice of the mini-batch size. Smaller batch sizes (<32) will lead to higher variance and higher error (because the estimation of the batch statistic is less accurate), and larger batch sizes are more close to the population statistic.
Addressing the internal covariate shift reduction
BatchNorm produces all these desirable effects, but the specific reason for its success is not that straightforward. As previously mentioned, the initial motivation for BatchNorm is the reduction of the internal covariate shift.
In 2018, a paper called How Does Batch Normalization Help Optimization? came out, which claims that a) internal covariate shift does not really lead to degraded performance, and b) BatchNorm does not reduce the internal covariate shift as it was thought.
To demonstrate the first point, the authors implement the injection of random noise following the BatchNorm layers. Each activation in each sample of the batch is specifically disrupted using independent and identically distributed noise. This noise originates from a distribution with a non-zero mean and non-unit variance. It is important to highlight that this specific noise distribution varies at every time step. Such injection of noise creates significant covariate shifts which distort activations at each time step. As a result, all units in a layer encounter different input distributions at each time step.
However, that noise injection does not lead to a significant difference in the performance (Source):
The second point revolves around a more general, optimization-based definition of ICS. The authors show that BatchNorm does lead to better performance in terms of accuracy and loss, but does not really suppress the ICS (and sometimes even amplify it):
Conclusion
As a result, you are now familiar with the following:
BatchNorm is a way to extend input standardization beyond the first layer;
BatchNorm introduces two trainable parameter, namely, and , which help to learn a more suitable distribution of the data than simple standardization;
Although originally proposed to reduce the internal covariate shift, the effectiveness of BatchNorm has little to do with the ICS reduction;
The main effect of BatchNorm is the smoothing of the loss surface, which enables the usage of higher learning rates in particular and less careful hyperparameter initialization in general.