NFNets Explained — DeepMind’s New State-Of-The-Art Image Classifier

https://cdn-images-1.medium.com/max/2600/0*BWz48FL_sMBnumFa

Original Source Here

NFNets Explained — DeepMind’s New State-Of-The-Art Image Classifier

Is this the beginning of the end for Batch Normalization?

Introduction

DeepMind has recently released a new family of image classifiers that achieved a new state-of-the-art accuracy on the ImageNet dataset. This new family of image classifiers, named NFNets (short for Normalizer-Free Networks), achieves comparable accuracy to EfficientNet-B7, while having a whopping 8.7x faster train time.

NFNet-F1 trains 8.7x faster than EfficientNet-B7, while achieving comparable accuracy. NFNet-F5 achieves state-of-the-art accuracy, surpassing previous accuracies of the EfficientNet family.

This improvement in training speed was partly achieved by replacing batch normalization with other techniques. This represents an important paradigm shift in the world of image classifiers, which has relied heavily on batch normalization as a key component.

Batch Normalization — The Good

First, let’s understand the benefits that batch normalization brings. With that knowledge, we can then devise alternative methods that recover these benefits.

Batch normalization downscales the residual branch

In ResNet-like architectures (i.e. ResNet, ResNeXt, SE-ResNeXt etc), batch normalization is often applied in the residual branch, which has the effect of reducing the scale of activations on the residual branch (compared to the skip branch). This stabilizes the gradient early in training, enabling the training of significantly deeper networks.

Batch normalization eliminates mean-shift

When using anti-symmetric activation functions such as ReLUs, the mean of the activations are non-zero, resulting in activations that are typically large and positive. This causes the network to predict the same label for all training samples early in training, resulting in unstable training. This phenomenon is known as mean-shift. Batch normalization eliminates mean-shift, by ensuring that the mean activation is zero across each batch.

Batch normalization has a regularizing effect

With the introduction of batch normalization, researchers found that dropout was no longer necessary as applying batch normalization helps to regularize the network. Batch normalization regularizes the network by preventing overfitting on noisy batch statistics.

Batch normalization allows efficient large-batch training

By smoothing the loss landscape, batch normalization allows us to use a larger batch size and training rate without overfitting.

Batch Normalization — The Bad

Even though batch normalization has enabled image classifiers to make substantial gains in recent years, it does have many negative consequences.

Batch normalization is expensive

Computing batch-level statistics is an expensive operation. By eliminating batch normalization, we can train networks faster.

Batch normalization breaks the assumption of data independence

The computation of batch-level statistics breaks the independence between training samples in the mini-batch. This is the root cause of many implementation errors, especially in distributed training. In fact, several well-known results in machine learning research cannot be replicated precisely due to issues caused by batch normalization.

To that end, many researchers have attempted to create normalizer-free networks, while still enjoying the benefits afforded by batch normalization. Most of them have failed to create such normalizer-free networks with competitive accuracies, until now.

NFNets — State-of-the-art normalizer free networks

NFNets are a family of modified ResNets that achieves competitive accuracies without batch normalization. To do so, it applies 3 different techniques:

  • Modified residual branches and convolutions with Scaled Weight Standardization
  • Adaptive Gradient Clipping
  • Architecture optimization for improved accuracy and training speed

These 3 techniques are complex topics on its own, so we’ll go through them individually in the following sections.

Modified residual branches and convolutions

In order to train deep ResNets without normalization, it is crucial to suppress the scale of the activations on the residual branch. To achieve this, NFNets uses 2 scalers (α and β) to scale the activations at the start and end of the residual branch.

α is set to a small constant of 0.2, while β for each block is defined as:

where,

and h refers to the inputs to the block.

In addition, NFNets uses Scaled Weight Standardization to prevent mean-shift. Scaled Weight Standardization normalizes the weights of the convolutional layers in NFNets such that:

where μ and σ refers to the mean and standard deviation of the weights respectively, and N refers to the fan-in of the convolutional layer.

Adaptive Gradient Clipping

Adaptive Gradient Clipping was also used to train NFNets with larger batch sizes and learning rates. Traditionally, gradient clipping is used to restrict gradient magnitudes and to prevent exploding gradients.

Where G refers to the gradient and λ is an arbitrary threshold value. However, the authors found that the training stability of NFNets is extremely sensitive to the choice of λ. Therefore, the authors proposed Adaptive Gradient Clipping, a modified form of gradient clipping.

The intuition behind Adaptive Gradient Clipping is that the gradient-to-weights ratio provides a simple measure of much a single gradient descent step will change the original weights. We expect that training will become unstable when this ratio is large, and we should therefore clip the gradient when this ratio exceeds a certain threshold.

In Adaptive Gradient Clipping, the gradient is clipped unit-wise (i.e. row wise in a matrix) according to:

Note that the ϵ constant is used to prevent zero-initialized parameters from having their gradients being clipped to zero, which helps to stabilize the gradient early in training.

Architecture optimization for improved accuracy and training speed

Even with Adaptive Gradient Clipping, and the modified residual branch and convolutions, normalizer-free networks still could not surpass the accuracies of EfficientNet.

The authors therefore looked into optimizing the architecture of existing ResNet-based models in order to improve model accuracy. The authors used a SE-ResNeXt-D model as a baseline and made the following changes.

First, the group width of the 3×3 convolutions in each bottleneck block was fixed at 128. Group width refers to the number of filters in the 3×3 convolution ÷ cardinality (i.e. number of groups in the grouped convolution) of the bottleneck block.

Next, the depth scaling pattern was also modified. In the original SE-ResNeXt-D models, the number of repeating blocks in the first and last stage is fixed at 3, while the number of repeating blocks in the second and third stage is scaled non-uniformly. For example, SE-ResNeXt-50 has a scaling pattern of [3,4,6,3] while SE-ResNeXt-152 has a scaling pattern of [3,8,36,3].

The authors argued that this scaling strategy is suboptimal, as layers in the early and later stages do not operate with sufficient capacity. Instead, the authors proposed a simpler scaling pattern. The smallest model, NFNet-F0 has a scaling pattern of [1,2,6,3], while NFNet-F1 (the next bigger model) scales 2x from F0 and has a scaling pattern of [2,4,12,6].

Thirdly, the authors modified the default width pattern of the original SE-ResNeXt-D model, where the first stage has 256 channels which are doubled at each subsequent stage, resulting in a width pattern (256, 512, 1024, 2048). Instead, NFNets modifies the width of the third and fourth stage, resulting in a width pattern of (256, 512, 1536, 1536). The authors found that empirically, increasing width in the third stage while reducing width in the fourth stage improves model performance while preserving training speed.

The table below summarizes the difference in architecture between a SE-ResNeXt-50 model and a NFNet-F0 model.

Difference in architecture between SE-ResNeXt-50 and NFNet-F0. Note that batch normalization is applied between each successive convolutional blocks in a SE-ResNeXt-50, while NFNets are normalizer free. The weights of the convolution layers used in NFNets are also scaled using Scaled Weight Standardization.

Implementation of NFNets

DeepMind has released the code for NFNets in their GitHub repository. However, the code is written in JAX.

A PyTorch implementation of NFNets can also be found in the popular timm library.

Conclusion

This paper was an interesting read, and it was an important milestone towards creating competitive normalizer-free networks. Although significant steps were taken towards substituting batch normalization in deep networks, a significant portion of the performance gains in NFNets came from optimization of the existing SE-ResNeXt architecture.

Nevertheless, I believe that more state-of-the-art normalizer free networks will be discovered in the near future, providing faster training speed and higher accuracies.

AI/ML

Trending AI/ML Article Identified & Digested via Granola by Ramsey Elbasheer; a Machine-Driven RSS Bot



via WordPress https://ramseyelbasheer.wordpress.com/2021/03/01/nfnets-explained%e2%80%8a-%e2%80%8adeepminds-new-state-of-the-art-image-classifier/

Popular posts from this blog

I’m Sorry! Evernote Has A New ‘Home’ Now

Jensen Huang: Racism is one flywheel we must stop

5 Best Machine Learning Books for ML Beginners