Research Articles

Paper summaries and personal insights

View the Project on GitHub epikadith/research-articles

Understanding Batch Normalization

Paper: Ioffe, S., & Szegedy, C. (2015). Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift. ICML. Read on arXiv

Problem

The change in the distribution of activations from each layer of a neural network is detrimental to its training. It causes a need for lower learning rates and careful parameter initialization. This is known as internal covariate shift, and is common in deep neural networks.

Solution

To solve internal covariate shift, activations are normalized before entering the next layer. To prevent changing the representability of a layer, the normalized values are scaled and shifted using learned parameters. This transformation is applied to each mini-batch during training.

Method

For a given mini-batch, the following transformation is applied to each input feature during training-

Calculation of mean:

\[\mu_{\mathcal{B}} \leftarrow \frac{1}{m} \sum_{i=1}^m x_i\]

Calculation of variance:

\[\sigma_{\mathcal{B}}^2 \leftarrow \frac{1}{m} \sum_{i=1}^m (x_i - \mu_{\mathcal{B}})^2\]

Normalization:

\[\hat{x}_i \leftarrow \frac{x_i - \mu_{\mathcal{B}}}{\sqrt{\sigma_{\mathcal{B}}^2 + \epsilon}}\]

where $\epsilon$ is a small constant added for numerical stability

Scaling and shifting:

\[y_i \leftarrow \gamma \hat{x}_i + \beta\]

During inference, the network uses the moving average of the mean and variance (collected during training) rather than the batch statistics. This ensures a deterministic output for a given input.

The gradients for the backward pass are-

Gradient for scale parameter:

\[\frac{\partial \mathcal{L}}{\partial \gamma} = \sum_{i=1}^m \frac{\partial \mathcal{L}}{\partial y_i} \cdot \hat{x}_i\]

Gradient for shift parameter:

\[\frac{\partial \mathcal{L}}{\partial \beta} = \sum_{i=1}^m \frac{\partial \mathcal{L}}{\partial y_i}\]

Gradient for input:

\[\frac{\partial \mathcal{L}}{\partial x_i} = \frac{1}{m \sqrt{\sigma_{\mathcal{B}}^2 + \epsilon}} \left( m \frac{\partial \mathcal{L}}{\partial \hat{x}_i} - \sum_{j=1}^m \frac{\partial \mathcal{L}}{\partial \hat{x}_j} - \hat{x}_i \sum_{j=1}^m \frac{\partial \mathcal{L}}{\partial \hat{x}_j} \hat{x}_j \right)\]

where \(\frac{\partial \mathcal{L}}{\partial \hat{x}_i} = \frac{\partial \mathcal{L}}{\partial y_i} \cdot \gamma\)

Implementation

Results

MNIST network training

As shown above, batch normalization helps reduce the number of steps to reach high generalisation accuracy

Inception network training

Inception network results

As seen above, batch normalization speeds up convergence as well as improves accuracy in certain scenarios

Comparison with previous state-of-the-art results

Strengths

Batch normalization significantly improves training speeds in neural networks. It also allows other modifications that help speed up training.

Weaknesses

The new parameters introduced increase the complexity of the network, and require calculation of more gradients.

My Thoughts

Batch normalization is now ubiquitous in deep learning pipelines, so it is remarkable to see how it came about. In practice it seems pretty simple, but the theory and foundation behind it is rather complex.