$$ \newcommand{\dint}{\mathrm{d}} \newcommand{\vphi}{\boldsymbol{\phi}} \newcommand{\vpi}{\boldsymbol{\pi}} \newcommand{\vpsi}{\boldsymbol{\psi}} \newcommand{\vomg}{\boldsymbol{\omega}} \newcommand{\vsigma}{\boldsymbol{\sigma}} \newcommand{\vzeta}{\boldsymbol{\zeta}} \renewcommand{\vx}{\mathbf{x}} \renewcommand{\vy}{\mathbf{y}} \renewcommand{\vz}{\mathbf{z}} \renewcommand{\vh}{\mathbf{h}} \renewcommand{\b}{\mathbf} \renewcommand{\vec}{\mathrm{vec}} \newcommand{\vecemph}{\mathrm{vec}} \newcommand{\mvn}{\mathcal{MN}} \newcommand{\G}{\mathcal{G}} \newcommand{\M}{\mathcal{M}} \newcommand{\N}{\mathcal{N}} \newcommand{\S}{\mathcal{S}} \newcommand{\I}{\mathcal{I}} \newcommand{\diag}[1]{\mathrm{diag}(#1)} \newcommand{\diagemph}[1]{\mathrm{diag}(#1)} \newcommand{\tr}[1]{\text{tr}(#1)} \renewcommand{\C}{\mathbb{C}} \renewcommand{\R}{\mathbb{R}} \renewcommand{\E}{\mathbb{E}} \newcommand{\D}{\mathcal{D}} \newcommand{\inner}[1]{\langle #1 \rangle} \newcommand{\innerbig}[1]{\left \langle #1 \right \rangle} \newcommand{\abs}[1]{\lvert #1 \rvert} \newcommand{\norm}[1]{\lVert #1 \rVert} \newcommand{\two}{\mathrm{II}} \newcommand{\GL}{\mathrm{GL}} \newcommand{\Id}{\mathrm{Id}} \newcommand{\grad}[1]{\mathrm{grad} \, #1} \newcommand{\gradat}[2]{\mathrm{grad} \, #1 \, \vert_{#2}} \newcommand{\Hess}[1]{\mathrm{Hess} \, #1} \newcommand{\T}{\text{T}} \newcommand{\dim}[1]{\mathrm{dim} \, #1} \newcommand{\partder}[2]{\frac{\partial #1}{\partial #2}} \newcommand{\rank}[1]{\mathrm{rank} \, #1} \newcommand{\inv}1 \newcommand{\map}{\text{MAP}} \newcommand{\L}{\mathcal{L}} \DeclareMathOperator*{\argmax}{arg\,max} \DeclareMathOperator*{\argmin}{arg\,min} $$

Natural Gradient Descent

Previously, we looked at the Fisher Information Matrix. We saw that it is equal to the negative expected Hessian of log likelihood. Thus, the immediate application of Fisher Information Matrix is as drop-in replacement of Hessian in second order optimization algorithm. In this article, we will look deeper at the intuition on what excatly is the Fisher Information Matrix represents and what is the interpretation of it.

Distribution Space

As per previous article, we have a probabilistic model represented by its likelihood \( p(x \vert \theta) \). We want to maximize this likelihood function to find the most likely parameter \( \theta \). Equivalent formulation would be to minimize the loss function \( \mathcal{L}(\theta) \), which is the negative log likelihood.

Usual way to solve this optimization is to use gradient descent. In this case, we are taking step which direction is given by \( -\nabla_\theta \mathcal{L}(\theta) \). This is the steepest descent direction around the local neighbourhood of the current value of \( \theta \) in the parameter space. Formally, we have

\[\frac{-\nabla_\theta \mathcal{L}(\theta)}{\lVert \nabla_\theta \mathcal{L}(\theta) \rVert} = \lim_{\epsilon \to 0} \frac{1}{\epsilon} \mathop{\text{arg min}}_{d \text{ s.t. } \lVert d \rVert \leq \epsilon} \mathcal{L}(\theta + d) \, .\]

The above expression is saying that the steepest descent direction in parameter space is to pick a vector \( d \), such that the new parameter \( \theta + d \) is within the \( \epsilon \)-neighbourhood of the current parameter \( \theta \), and we pick \( d \) that minimize the loss. Notice the way we express this neighbourhood is by the means of Euclidean norm. Thus, the optimization in gradient descent is dependent to the Euclidean geometry of the parameter space.

Meanwhile, if our objective is to minimize the loss function (maximizing the likelihood), then it is natural that we taking step in the space of all possible likelihood, realizable by parameter \( \theta \). As the likelihood function itself is a probability distribution, we call this space distribution space. Thus it makes sense to take the steepest descent direction in this distribution space instead of parameter space.

Which metric/distance then do we need to use in this space? A popular choice would be KL-divergence. KL-divergence measure the “closeness” of two distributions. Although as KL-divergence is non-symmetric and thus not a true metric, we can use it anyway. This is because as \( d \) goes to zero, KL-divergence is asymptotically symmetric. So, within a local neighbourhood, KL-divergence is approximately symmetric [1].

We can see the problem when using only Euclidean metric in parameter space from the illustrations below. Consider a Gaussian parameterized by only its mean and keep the variance fixed to 2 and 0.5 for the first and second image respectively:

Param1

Param2

In both images, the distance of those Gaussians are the same, i.e. 4, according to Euclidean metric (red line). However, clearly in distribution space, i.e. when we are taking into account the shape of the Gaussians, the distance is different in the first and second image. In the first image, the KL-divergence should be lower as there is more overlap between those Gaussians. Therefore, if we only work in parameter space, we cannot take into account this information about the distribution realized by the parameter.

Fisher Information and KL-divergence

One question still needs to be answered is what exactly is the connection between Fisher Information Matrix and KL-divergence? It turns out, Fisher Information Matrix defines the local curvature in distribution space for which KL-divergence is the metric.

Claim: Fisher Information Matrix \( \text{F} \) is the Hessian of KL-divergence between two distributions \( p(x \vert \theta) \) and \( p(x \vert \theta’) \), with respect to \( \theta’ \), evaluated at \( \theta’ = \theta \).

Proof.    KL-divergence can be decomposed into entropy and cross-entropy term, i.e.:

\[\text{KL} [p(x \vert \theta) \, \Vert \, p(x \vert \theta')] = \mathop{\mathbb{E}}_{p(x \vert \theta)} [ \log p(x \vert \theta) ] - \mathop{\mathbb{E}}_{p(x \vert \theta)} [ \log p(x \vert \theta') ] \, .\]

The first derivative wrt. \( \theta’ \) is:

\[\begin{align} \nabla_{\theta'} \text{KL}[p(x \vert \theta) \, \Vert \, p(x \vert \theta')] &= \nabla_{\theta'} \mathop{\mathbb{E}}_{p(x \vert \theta)} [ \log p(x \vert \theta) ] - \nabla_{\theta'} \mathop{\mathbb{E}}_{p(x \vert \theta)} [ \log p(x \vert \theta') ] \\[5pt] &= - \mathop{\mathbb{E}}_{p(x \vert \theta)} [ \nabla_{\theta'} \log p(x \vert \theta') ] \\[5pt] &= - \int p(x \vert \theta) \nabla_{\theta'} \log p(x \vert \theta') \, \text{d}x \, . \end{align}\]

The second derivative is:

\[\begin{align} \nabla_{\theta'}^2 \, \text{KL}[p(x \vert \theta) \, \Vert \, p(x \vert \theta')] &= - \int p(x \vert \theta) \, \nabla_{\theta'}^2 \log p(x \vert \theta') \, \text{d}x \\[5pt] \end{align}\]

Thus, the Hessian wrt. \( \theta’ \) evaluated at \( \theta’ = \theta \) is:

\[\begin{align} \text{H}_{\text{KL}[p(x \vert \theta) \, \Vert \, p(x \vert \theta')]} &= - \int p(x \vert \theta) \, \left. \nabla_{\theta'}^2 \log p(x \vert \theta') \right\vert_{\theta' = \theta} \, \text{d}x \\[5pt] &= - \int p(x \vert \theta) \, \text{H}_{\log p(x \vert \theta)} \, \text{d}x \\[5pt] &= - \mathop{\mathbb{E}}_{p(x \vert \theta)} [\text{H}_{\log p(x \vert \theta)}] \\[5pt] &= \text{F} \, . \end{align}\]

The last line follows from the previous article about Fisher Information Matrix, in which we showed that the negative expected Hessian of log likelihood is the Fisher Information Matrix.

\( \square \)

Steepest Descent in Distribution Space

Now we are ready to use the Fisher Information Matrix to enhance the gradient descent. But first, we need to derive the Taylor series expansion for KL-divergence around \( \theta \).

Claim: Let \( d \to 0 \). The second order Taylor series expansion of KL-divergence is \( \text{KL}[p(x \vert \theta) \, \Vert \, p(x \vert \theta + d)] \approx \frac{1}{2} d^\text{T} \text{F} d \).

Proof.    We will use \( p_{\theta} \) as a notational shortcut for \( p(x \vert \theta) \). By definition, the second order Taylor series expansion of KL-divergence is:

\[\begin{align} \text{KL}[p_{\theta} \, \Vert \, p_{\theta + d}] &\approx \text{KL}[p_{\theta} \, \Vert \, p_{\theta}] + (\left. \nabla_{\theta'} \text{KL}[p_{\theta} \, \Vert \, p_{\theta'}] \right\vert_{\theta' = \theta})^\text{T} d + \frac{1}{2} d^\text{T} \text{F} d \\[5pt] &= \text{KL}[p_{\theta} \, \Vert \, p_{\theta}] - \mathop{\mathbb{E}}_{p(x \vert \theta)} [ \nabla_\theta \log p(x \vert \theta) ]^\text{T} d + \frac{1}{2} d^\text{T} \text{F} d \\[5pt] \end{align}\]

Notice that the first term is zero as it is the same distribution. Furthermore, from the previous article, we saw that the expected value of the gradient of log likelihood, which is exactly the gradient of KL-divergence as shown in the previous proof, is also zero. Thus the only thing left is:

\[\begin{align} \text{KL}[p(x \vert \theta) \, \Vert \, p(x \vert \theta + d)] &\approx \frac{1}{2} d^\text{T} \text{F} d \, . \end{align}\]

\( \square \)

Now, we would like to know what is update vector \( d \) that minimizes the loss function \( \mathcal{L} (\theta) \) in distribution space, so that we know in which direction decreases the KL-divergence the most. This is analogous to the method of steepest descent, but in distribution space with KL-divergence as metric, instead of the usual parameter space with Euclidean metric. For that, we do this minimization:

\[d^* = \mathop{\text{arg min}}_{d \text{ s.t. } \text{KL}[p_\theta \Vert p_{\theta + d}] = c} \mathcal{L} (\theta + d) \, ,\]

where \( c \) is some constant. The purpose of fixing the KL-divergence to some constant is to make sure that we move along the space with constant speed, regardless the curvature. Further benefit is that this makes the algorithm more robust to the reparametrization of the model, i.e. the algorithm does not care how the model is parametrized, it only cares about the distribution induced by the parameter [3].

If we write the above minimization in Lagrangian form, with constraint KL-divergence approximated by its second order Taylor series expansion and approximate \( \mathcal{L}(\theta + d) \) with its first order Taylor series expansion, we get:

\[\begin{align} d^* &= \mathop{\text{arg min}}_d \, \mathcal{L} (\theta + d) + \lambda \, (\text{KL}[p_\theta \Vert p_{\theta + d}] - c) \\ &\approx \mathop{\text{arg min}}_d \, \mathcal{L}(\theta) + \nabla_\theta \mathcal{L}(\theta)^\text{T} d + \frac{1}{2} \lambda \, d^\text{T} \text{F} d - \lambda c \, . \end{align}\]

To solve this minimization, we set its derivative wrt. \( d \) to zero:

\[\begin{align} 0 &= \frac{\partial}{\partial d} \mathcal{L}(\theta) + \nabla_\theta \mathcal{L}(\theta)^\text{T} d + \frac{1}{2} \lambda \, d^\text{T} \text{F} d - \lambda c \\[5pt] &= \nabla_\theta \mathcal{L}(\theta) + \lambda \, \text{F} d \\[5pt] \lambda \, \text{F} d &= -\nabla_\theta \mathcal{L}(\theta) \\[5pt] d &= -\frac{1}{\lambda} \text{F}^{-1} \nabla_\theta \mathcal{L}(\theta) \\[5pt] \end{align}\]

Up to constant factor of \( \frac{1}{\lambda} \), we get the optimal descent direction, i.e. the opposite direction of gradient while taking into account the local curvature in distribution space defined by \( \text{F}^{-1} \). We can absorb this constant factor into the learning rate.

Definition: Natural gradient is defined as

\[\tilde{\nabla}_\theta \mathcal{L}(\theta) = \text{F}^{-1} \nabla_\theta \mathcal{L}(\theta) \, .\]

\( \square \)

As corollary, we have the following algorithm:

Algorithm: Natural Gradient Descent

  1. Repeat:
    1. Do forward pass on our model and compute loss \( \mathcal{L}(\theta) \).
    2. Compute the gradient \( \nabla_\theta \mathcal{L}(\theta) \).
    3. Compute the Fisher Information Matrix \( \text{F} \), or its empirical version (wrt. our training data).
    4. Compute the natural gradient \( \tilde{\nabla}_\theta \mathcal{L}(\theta) = \text{F}^{-1} \nabla_\theta \mathcal{L}(\theta) \).
    5. Update the parameter: \( \theta = \theta - \alpha \, \tilde{\nabla}_\theta \mathcal{L}(\theta) \), where \( \alpha \) is the learning rate.
  2. Until convergence.

Discussion

In the above very simple model with low amount of data, we saw that we can implement natural gradient descent easily. But how easy is it to do this in the real world? As we know, the number of parameters in deep learning models is very large, within millions of parameters. The Fisher Information Matrix for these kind of models is then infeasible to compute, store, or invert. This is the same problem as why second order optimization methods are not popular in deep learning.

One way to get around this problem is to approximate the Fisher/Hessian instead. Method like ADAM [4] computes the running average of first and second moment of the gradient. First moment can be seen as momentum which is not our interest in this article. The second moment is approximating the Fisher Information Matrix, but constrainting it to be diagonal matrix. Thus in ADAM, we only need \( O(n) \) space to store (the approximation of) \( \text{F} \) instead of \( O(n^2) \) and the inversion can be done in \( O(n) \) instead of \( O(n^3) \). In practice ADAM works really well and is currently the de facto standard for optimizing deep neural networks.

References

  1. Martens, James. “New insights and perspectives on the natural gradient method.” arXiv preprint arXiv:1412.1193 (2014).
  2. Ly, Alexander, et al. “A tutorial on Fisher information.” Journal of Mathematical Psychology 80 (2017): 40-55.
  3. Pascanu, Razvan, and Yoshua Bengio. “Revisiting natural gradient for deep networks.” arXiv preprint arXiv:1301.3584 (2013).
  4. Kingma, Diederik P., and Jimmy Ba. “Adam: A method for stochastic optimization.” arXiv preprint arXiv:1412.6980 (2014).