Breaking the Softmax Bottleneck: A High-Rank RNN Language Model
WHY?
This paper first prove that the expresiveness of a language model is restricted by softmax and suggest a way to overcome this limit.
WHAT?
The last part of language models usuallt consist of softmax layer applied on a product of context vector(h) and a word embedding w.
P_{\theta}(x|c) = \frac{\exp\mathbf{h}_c^t\mathbf{w}_x}{\sum_{x'}\exp \mathbf{h}_c^t\mathbf{w}_x}
This paper formulate language modeling as language factorization problem. To do this, three matrices can be defined: context vectors, word embedding, and log probabilities of the true data distribution.
\mathbf{H}_{\theta} =
\begin{bmatrix}
\mathbf{h}^T_{C_1}\\
\mathbf{h}^T_{C_2}\\
...\\
\mathbf{h}^T_{C_N}\\
\end{bmatrix}\\
\mathbf{W}_{\theta} =
\begin{bmatrix}
\mathbf{w}^T_{x_1}\\
\mathbf{w}^T_{x_2}\\
...\\
\mathbf{w}^T_{x_M}\\
\end{bmatrix}\\
\mathbf{A} =
\begin{bmatrix}
\log P*(x_1|c_1) & \log P*(x_2|c_1) & ... & \log P*(x_M|c_1) \\
\log P*(x_1|c_2) & \log P*(x_2|c_2) & ... & \log P*(x_M|c_2) \\
... & ... & ... & ... \\
\log P*(x_1|c_N) & \log P*(x_2|c_N) & ... & \log P*(x_M|c_N) \\
\end{bmatrix}\\
Also, we can define a set of matrices formed by applying row-wise shift to A.
F(\mathbf{A}) = \{\mathbf{A} + \mathbf{\Lambda}\mathbf{J}_{N,M}\mathbf{\Lambda} is diagonal and \mathbf{\Lambda} \in \mathbb{R}^{N \times N}\}
We can derive two property of this set: F(A) is all possible logits of the true data distribution, and all matrices in F(A) have similar rank, with the maximum rank difference being 1. If we want HW to be in F(A), HW must have rank as large as A. However, the rank of HW is strictly upperbounded by the embedding size d.
\mathbf{H}_{\theta}\mathbf{W}_{\theta}^T = \mathbf{A}'\\
d \geq min_{\mathbf{A}'\in F(\mathbf{A})}rank(\mathbf{A}')
This proves the softmax bottleneck meaning softmax layer does not have the capacity to express the true data if the dimension d is too small.
To solve this problem, this paper suggest mixture of softmaxes which have improved expresiveness. Since the matrix is a nonlinear function of context vectors and word embeddings, the rank of the matrix is not restricted to d.
P_{\theta}(x|c) = \sum^K_{k=1}\pi_{c,k}\frac{\exp\mathbf{h}_{c,k}^T\mathbf{w}_x}{\sum_{x'}\exp \mathbf{h}_{c,k}^T\mathbf{w}_x}; s.t. \sum^K_{k=1}\pi_{c,k} = 1\\
\pi_{c,k} = \frac{\exp\mathbf{w}^T_{\pi, k}\mathbf{g}_t}{\sum_{k'=1}^K \exp\mathbf{w}^T_{\pi, k}\mathbf{g}_t}\\
\mathbf{h}_{c+t, k} = tanh(\mathbf{W}_{h,k}\mathbf{g}_t)\\
\hat{\mathbf{A}}_{MoS} = \log \sum_{k=1}^K \Pi_k \exp(\mathbf{H}_{\theta, k}\mathbf{W}_{\theta}^T)
So?
The perplexity of MoS on Penn Treebank, WikiText, and 1B word dataset showed clearly improved performance than softmax. Even though the computation time of MoS is 2~3 times slower, MoS was better at making context dependent prediction.
Critic
Mixture of softmax seems incredible, but there may be computationally efficient way of doing this.