Deriving LSTM Gradient for Backpropagation
Recurrent Neural Network (RNN) is hot in these past years, especially with the boom of Deep Learning. Just like any deep neural network, RNN can be seen as a (very) deep neural network if we “unroll” the network with respect of the time step. Hence, with all the things that enable vanilla deep network, training RNN become more and more feasible too.
The most popular model for RNN right now is the LSTM (Long Short-Term Memory) network. For the background theory, there are a lot of amazing resources available in Andrej Karpathy’s blog and Chris Olah’s blog.
Using modern Deep Learning libraries like TensorFlow, Torch, or Theano nowadays, building an LSTM model would be a breeze as we don’t need to analytically derive the backpropagation step. However to understand the model better, it’s absolutely a good thing, albeit optional, to try to derive the LSTM net gradient and implement the backpropagation “manually”.
So, here, we will try to first implement the forward computation step according to the LSTM net formula, then we will try to derive the network gradient analytically. Finally, we will implement it using numpy.
We will follow this model for a single LSTM cell:
Let’s implement it!
Above, we’re declaring our LSTM net model. Notice that from the formula above, we’re concatenating the old hidden state
h with current input
x, hence the input for our LSTM net would be
Z = H + D. And because our LSTM layer wants to output
H neurons, each weight matrices’ size would be
ZxH and each bias vectors’ size would be
One difference is for
by. This weight and bias would be used for fully connected layer, which would be fed to a softmax layer. The resulting output should be a probability distribution over all possible items in vocabulary, which would be the size of
Wy’s size must be
by’s size must be
The above code is for the forward step for a single LSTM cell, which identically follows the formula above. The only additions are the one-hot encoding and the hidden-input concatenation process.
Now, we will dive into the main point of this post: LSTM backward computation. We will assume that derivative function for
tanh are already known.
A bit long isn’t it? However, actually it’s easy enough to derive the LSTM gradients if you understand how to take a partial derivative of a function and how to do chain rule, albeit some tricky stuffs are going on here. For this, I would recommend CS231n.
Things that are tricky and not-so-obvious when deriving the LSTM gradients are:
his branched in forward propagation: it was used in
y = h @ Wy + byand the next time step, concatenated with
x. Hence the gradient is split and has to be added here.
dc. Identical reason with above.
dX = dXo + dXc + dXi + dXf. Similar reason with above: X is used in many places so the gradient is split and need to be accumulated back.
dh_nextwhich is the gradient of
X = [h_old, x], then
dh_nextis just a reverse concatenation: split operation on
With the forward and backward computation implementations in hands, we could stitch them together to get a full training step that would be useful for optimization algorithms.
LSTM Training Step
This training step consists of three steps: forward computation, loss calculation, and backward computation.
In the full training step, first we’re do full forward propagation on all items in training set, then store the results which are the softmax probabilities and cache of each timestep into a list, because we are going to use it in backward step.
Next, at each timestep, we can calculate the cross entropy loss (because we’re using softmax). We then accumulate all of those loss in every timestep, then average them.
Lastly, we do backpropagation based on the forward step results. Notice while we’re iterating the data forward in forward step, we’re going the reverse direction here.
Also notice that
dc_next for the first timestep in backward step is zero. Why? This is because at the last timestep in forward propagation,
c won’t be used in the next timestep, as there are no more timestep! So, the gradient of
c in the last timestep are not split and could be derived directly without
With this function in hands, we could plug this to any optimization algorithm like RMSProp, Adam, etc with some modification. Namely, we have to take account on the state of the network. So, the state for the current timestep need to be passed to the next timestep.
And, that’s it. We can train our LSTM net now!
Using Adam to optimize the network, here’s the result when I feed a copy-pasted text about Japan from Wikipedia. Each data is a character in the text. The target is the next character.
After each 100 iterations, the network are sampled.
It works like this:
- Do forward propagation and get the softmax distribution
- Sample the distribution
- Feed the sampled character as the input of next time step
And here’s the snippet of the results:
========================================================================= Iter-100 loss: 4.2125 ========================================================================= best c ehpnpgteHihcpf,M tt" ao tpo Teoe ep S4 Tt5.8"i neai neyoserpiila o rha aapkhMpl rlp pclf5i ========================================================================= ... ========================================================================= Iter-52800 loss: 0.1233 ========================================================================= tary shoguns who ruled in the name of the Uprea wal motrko, the copulation of Japan is a sour the wa =========================================================================
Our network definitely learned something here!
Here, we looked at the general formula for LSTM and implement the forward propagation step based on it, which is very straightforward to do.
Then, we derived the backward computation step. This step was also straightforward, but there were some tricky stuffs that we had to ponder about, especially the recurrency step in
We then stitched the forward and backward step together to build the full training step that can be used with any optimization algorithm.
Lastly, we tried to run the network using some test data and showed that the network was learning by looking at the loss value and the sample of text that are produced by the network.