Computer scienceData scienceMachine learningComputer visionConvolutional neural networks

ResNet

1 minute read

ResNet, also known as the residual network, is a popular DL architecture introduced in the paper "Deep Residual Learning for Image Recognition" in 2015. ResNet's innovation is the introduction of residual (also known as skip) connections, the usage of which allowed to build deeper networks with improved performance.

In this topic, we will look at the problem of plain deep networks and see how the residual connection makes ResNet handle more layers than its predecessors.

The problem with very deep networks

The predecessors of ResNet (e.g., Inception or VGG) demonstrated that deeper networks perform better and brought forward the consideration of greater depth as a part of the architectural design. Then the question was: given the computational resources, does stacking more layers generally bring better performance (a.k.a., how deep can you actually go)?

So, the first problem arises: the gradients vanish or explode for the deep networks. The argument can be presented as follows: during the initialization, all layers have random weights, so when you pass the input through the layers, the activations are multiplied by the random weights. Doing this many times in the case of a large number of layers will make the input lose the signal and turn it into almost random noise when the output layer is reached. We ended up with random noise at the end of the forward pass, computed the loss, and then propagated it back. The loss at the deeper layers of the network won't give much info: the inputs were random noise, and when the gradients are passed, each layer multiplied the deltas by the weights, so by the time the early layers are reached, the gradients also become like random noise (due to the large number of multiplications).

At that point, this was addressed by normalized initialization (think of Xavier or He) and mostly by the introduction of intermediate normalization layers (BatchNorm), such that deeper networks will converge with SGD.

But there was another, more peculiar issue:

What this graph shows is that contrary to the assumption that more layers result in better performance, more layers, beyond a certain point, lead to the degradation problem: in deeper models, the accuracy saturates and rapidly drops, and more layers lead to a higher training (and thus, test) error compared to a smaller number of layers. Looking at this graph, this degradation problem is not caused by overfitting (you would see a hike in the test error in that case).

The degradation leads to an assumption: we can make the model deeper by copying the layers from a shallower model and adding multiple layers that learn the identity function, and this new deeper model with the added identity layers learns the same function as its shallower counterpart and should at least show the same performance in terms of accuracy.

The comparison between the regular approach and what the ResNet does

This is where the residual connection comes in.

Introducing the residual connection

The residual connection is typically illustrated as follows:

Instead of having consecutive weight layers, there are skip (or residual) connections present. These residual connections, instead of learning the H(x)=F(x)+xH(x) = F(x)+x function (the output, H(x)H(x), can be an arbitrarily complex function) directly, learn F(x)F(x), and F(x)F(x) represents what needs to be changed from the input (xx) to arrive at the output. The two weight layers here learn what makes the output different from the input (F(x)F(x)). Here, the residual connection performs the identity mapping, and its output is added to the outputs of the stacked layers.

Essentially, now, the job of the layers is not to find everything important that needs to be passed forward, but instead to figure out what information has to be added to the input to arrive at the output, which turns out to be easier.

But why do the residual connections work? The reasoning can be as follows: it comes down to how the weights are initialized (usually to very small random numbers from a distribution, e.g., Gaussian, with a mean of zero), and the regularization (e.g., weight decay makes the weights biased towards zero). Thus, in the plain network, the default function being learned is the zero (or randomly initialized) function (which in turn leads to the degraded performance of deeper plain networks), when we want to learn the identity from the get-go (but learning the identity is non-trivial by itself). And what the residual connection allows you to do is to pass the meaningful input from before even if your Wx+b’s are essentially canceled out into zeros at the current layer.

This identity residual connection does not add any trainable parameters or additional costs, and this component alone introduces more depth into the otherwise shallower network and allows for more gradient flow through the network.

The architectural setup

The ResNet-34 (34 corresponds to the number of layers) architecture looks like this:

There are multiple variants of ResNet, namely, ResNet-18, ResNet-34, ResNet-50, ResNet-101, and ResNet-152, which were introduced in the same paper (with ResNet-152 showing the best performance). The residual connection can be used when the input and output are of the same dimensions (solid line shortcuts in the illustration). When the dimensions differ (dotted line shortcuts), either zero padding is used to keep the dimensions to the same value, or 1x1 convolutions are used.

The residual block from one of the early intermediate layers is given as follows:

ResNet was trained with momentum SGD and weight decay, ReLU non-linearity, and BatchNorm is present after each convolution and before the activations. The learning rate starts from 0.1 and is divided by 10 when the error plateaus. Also, the He weight initialization is used. Dropout is not present.

For the deeper ResNets (e.g., starting from ResNet-50), the bottleneck residual blocks are used to reduce the number of parameters with three convolutions: it's a 1x1 filter that initially reduces the number of channels, followed by a 3x3 filter, followed by another 1x1 to return to the initial number of channels before the block. The order of operations in the regular residual block is depicted in a) and the bottleneck residual block is depicted in b) in the illustration below:

The deeper variants of ResNet are similar to the original design (here, ResNet), except that the bottleneck residual blocks are used in place of the regular residual blocks:

Conclusion

As a result, you are now familiar with:

  • The issue with deeper plain networks — the degradation problem;
  • How the residual (or skip) connections are introduced to pass better signals by propagating the inputs of the residual block forward;
  • The basic setup for the ResNet architecture;
  • The usage of bottleneck residual blocks, which help the deeper layers be more parameter-efficient by stacking multiple 1x1 filters.
How did you like the theory?
Report a typo