NFNets Explained — DeepMind’s New State-Of-The-Art Image Classifier
https://cdn-images-1.medium.com/max/2600/0*BWz48FL_sMBnumFa
Original Source Here
NFNets Explained — DeepMind’s New State-Of-The-Art Image Classifier
Is this the beginning of the end for Batch Normalization?
Introduction
DeepMind has recently released a new family of image classifiers that achieved a new state-of-the-art accuracy on the ImageNet dataset. This new family of image classifiers, named NFNets (short for Normalizer-Free Networks), achieves comparable accuracy to EfficientNet-B7, while having a whopping 8.7x faster train time.
This improvement in training speed was partly achieved by replacing batch normalization with other techniques. This represents an important paradigm shift in the world of image classifiers, which has relied heavily on batch normalization as a key component.
Batch Normalization — The Good
First, let’s understand the benefits that batch normalization brings. With that knowledge, we can then devise alternative methods that recover these benefits.
Batch normalization downscales the residual branch
In ResNet-like architectures (i.e. ResNet, ResNeXt, SE-ResNeXt etc), batch normalization is often applied in the residual branch, which has the effect of reducing the scale of activations on the residual branch (compared to the skip branch). This stabilizes the gradient early in training, enabling the training of significantly deeper networks.
Batch normalization eliminates mean-shift
When using anti-symmetric activation functions such as ReLUs, the mean of the activations are non-zero, resulting in activations that are typically large and positive. This causes the network to predict the same label for all training samples early in training, resulting in unstable training. This phenomenon is known as mean-shift. Batch normalization eliminates mean-shift, by ensuring that the mean activation is zero across each batch.
Batch normalization has a regularizing effect
With the introduction of batch normalization, researchers found that dropout was no longer necessary as applying batch normalization helps to regularize the network. Batch normalization regularizes the network by preventing overfitting on noisy batch statistics.
Batch normalization allows efficient large-batch training
By smoothing the loss landscape, batch normalization allows us to use a larger batch size and training rate without overfitting.
Batch Normalization — The Bad
Even though batch normalization has enabled image classifiers to make substantial gains in recent years, it does have many negative consequences.
Batch normalization is expensive
Computing batch-level statistics is an expensive operation. By eliminating batch normalization, we can train networks faster.
Batch normalization breaks the assumption of data independence
The computation of batch-level statistics breaks the independence between training samples in the mini-batch. This is the root cause of many implementation errors, especially in distributed training. In fact, several well-known results in machine learning research cannot be replicated precisely due to issues caused by batch normalization.
To that end, many researchers have attempted to create normalizer-free networks, while still enjoying the benefits afforded by batch normalization. Most of them have failed to create such normalizer-free networks with competitive accuracies, until now.
NFNets — State-of-the-art normalizer free networks
NFNets are a family of modified ResNets that achieves competitive accuracies without batch normalization. To do so, it applies 3 different techniques:
- Modified residual branches and convolutions with Scaled Weight Standardization
- Adaptive Gradient Clipping
- Architecture optimization for improved accuracy and training speed
These 3 techniques are complex topics on its own, so we’ll go through them individually in the following sections.
Modified residual branches and convolutions
In order to train deep ResNets without normalization, it is crucial to suppress the scale of the activations on the residual branch. To achieve this, NFNets uses 2 scalers (α and β) to scale the activations at the start and end of the residual branch.
α is set to a small constant of 0.2, while β for each block is defined as:
where,
and h refers to the inputs to the block.
In addition, NFNets uses Scaled Weight Standardization to prevent mean-shift. Scaled Weight Standardization normalizes the weights of the convolutional layers in NFNets such that:
where μ and σ refers to the mean and standard deviation of the weights respectively, and N refers to the fan-in of the convolutional layer.
Adaptive Gradient Clipping
Adaptive Gradient Clipping was also used to train NFNets with larger batch sizes and learning rates. Traditionally, gradient clipping is used to restrict gradient magnitudes and to prevent exploding gradients.
Where G refers to the gradient and λ is an arbitrary threshold value. However, the authors found that the training stability of NFNets is extremely sensitive to the choice of λ. Therefore, the authors proposed Adaptive Gradient Clipping, a modified form of gradient clipping.
The intuition behind Adaptive Gradient Clipping is that the gradient-to-weights ratio provides a simple measure of much a single gradient descent step will change the original weights. We expect that training will become unstable when this ratio is large, and we should therefore clip the gradient when this ratio exceeds a certain threshold.
In Adaptive Gradient Clipping, the gradient is clipped unit-wise (i.e. row wise in a matrix) according to:
Note that the ϵ constant is used to prevent zero-initialized parameters from having their gradients being clipped to zero, which helps to stabilize the gradient early in training.
Architecture optimization for improved accuracy and training speed
Even with Adaptive Gradient Clipping, and the modified residual branch and convolutions, normalizer-free networks still could not surpass the accuracies of EfficientNet.
The authors therefore looked into optimizing the architecture of existing ResNet-based models in order to improve model accuracy. The authors used a SE-ResNeXt-D model as a baseline and made the following changes.
First, the group width of the 3×3 convolutions in each bottleneck block was fixed at 128. Group width refers to the number of filters in the 3×3 convolution ÷ cardinality (i.e. number of groups in the grouped convolution) of the bottleneck block.
Next, the depth scaling pattern was also modified. In the original SE-ResNeXt-D models, the number of repeating blocks in the first and last stage is fixed at 3, while the number of repeating blocks in the second and third stage is scaled non-uniformly. For example, SE-ResNeXt-50 has a scaling pattern of [3,4,6,3] while SE-ResNeXt-152 has a scaling pattern of [3,8,36,3].
The authors argued that this scaling strategy is suboptimal, as layers in the early and later stages do not operate with sufficient capacity. Instead, the authors proposed a simpler scaling pattern. The smallest model, NFNet-F0 has a scaling pattern of [1,2,6,3], while NFNet-F1 (the next bigger model) scales 2x from F0 and has a scaling pattern of [2,4,12,6].
Thirdly, the authors modified the default width pattern of the original SE-ResNeXt-D model, where the first stage has 256 channels which are doubled at each subsequent stage, resulting in a width pattern (256, 512, 1024, 2048). Instead, NFNets modifies the width of the third and fourth stage, resulting in a width pattern of (256, 512, 1536, 1536). The authors found that empirically, increasing width in the third stage while reducing width in the fourth stage improves model performance while preserving training speed.
The table below summarizes the difference in architecture between a SE-ResNeXt-50 model and a NFNet-F0 model.
Implementation of NFNets
DeepMind has released the code for NFNets in their GitHub repository. However, the code is written in JAX.
A PyTorch implementation of NFNets can also be found in the popular timm
library.
Conclusion
This paper was an interesting read, and it was an important milestone towards creating competitive normalizer-free networks. Although significant steps were taken towards substituting batch normalization in deep networks, a significant portion of the performance gains in NFNets came from optimization of the existing SE-ResNeXt architecture.
Nevertheless, I believe that more state-of-the-art normalizer free networks will be discovered in the near future, providing faster training speed and higher accuracies.
AI/ML
Trending AI/ML Article Identified & Digested via Granola by Ramsey Elbasheer; a Machine-Driven RSS Bot
via WordPress https://ramseyelbasheer.wordpress.com/2021/03/01/nfnets-explained%e2%80%8a-%e2%80%8adeepminds-new-state-of-the-art-image-classifier/