Weight Normalization: A Simple Reparameterization to Accelerate Training of Deep Neural Networks


WHY?

Batch normalization is known as a good method to stablize the optimization of neural network by reducing internal covariate shift. However, batch normalization inheritantly depends on minibatch which impeding the use in recurrent models.

WHAT?

The core idea of weight normalization is to reparameterize the weight to decompose it into scale parameter and direction parameter. And then, perform gradient descent with respect to each of parameter. Weight normalization is appied by a neuron unit. y = \phi(\mathbb{w}\cdot \mathbb{x} + b)\\ \mathbb{w} = \frac{g}{\|\mathbb{v}\|}\mathbb{v}\\ The gradient with respect to each of parameter can be attained with minor modification to formal gradient. M is a projection matrix that projects on to the complement of the w vector. \nabla_g L = \frac{\nabla_{\mathbb{w}} L\cdot \mathbb{v}}{\|\mathbb{v}\|}\\ \nabla_{\mathbb{v}} L = \frac{g}{\|\mathbb{v}\|}\nabla_{\mathbb{w}}L - \frac{g\nabla_g L}{\|\mathbb{v}\|^2}\mathbb{v}\\ = \frac{g}{\|\mathbb{v}\|} M_{\mathbb{w}}\nabla_{\mathbb{w}}L,\\ M_{\mathbb{w}} = I - \frac{\mathbb{w}\mathbb{w}'}{\|\mathbb{w}\|}^2 We can see that weight normalization scale the weight, and project away from current weight vector. These effects help covariance matix of the gradient closer to identity. Unlike batch normalization, weight normalization does not directly scale the features. Therefore, proper initialization of parameters is needed. Vector v is sampled from Gaussian with mean zero and std 0.05. And then, we can initalize g and b using a single minibatch of data. t = \frac{\mathbb{v}\cdot\mathbb{x}}{\|\mathbb{v}\|}\\ g \leftarrow \frac{1}{\sigma[t]}\\ b \leftarrow \frac{-\mu[t]}{\sigma[t]} Mean-only batch normalization can help scaling the mean of output. t = \mathbb{w}\cdot \mathbb{x}\\ \tilde{t} = t - \mu[t] + b\\ y = \phi(\tilde{t})

So?

WN + mean BN showed to improve the performance of supervised classification(CIFAR10), generative modelling(convolutional VAE, DRAW) and reinforcement learning(DQN).

Critic

Implementing both WN and mean BN sound little cumbersome… I’m not sure they are worth it.

Salimans, Tim, and Diederik P. Kingma. “Weight normalization: A simple reparameterization to accelerate training of deep neural networks.” Advances in Neural Information Processing Systems. 2016.



© 2017. by isme2n

Powered by aiden