During the learning of RNN, I undergone a very tough time to deduce the BPTT for vanilla RNN. There are some pieces of work online that try to explain this process(click here or here). However, I found the derivation relevant to hidden state unsatisfying. Specificly, They are mixing up the partial derivative and total derivative, which is a total disaster for beginners to understand it. So I want to write this blog to put down some tough points in the process. And meanwhile, I will also skip the parts that's been explained thoroughly in the posts above. Hopefully, this blog will help you save a bunch of time to wrap you head around this fantastic algorithm. The blog is generally for ones who have already known at least the setting of Neural Network and RNN. The body of this blog is going to be constructed as following:
- Math Recap: the difference between partial derivative and total derivative.
- Paritial derivative for Neural Network
- The essence of BPTT
Math Recap: the difference between partial derivative and total derivative.
I see a lot of people misusing these two terms and in most cases, it won't cause any trouble. But for involved BPTT, I think it will help you understand it if we clarify it beforehand.
In term of math symbols, partial derivative looks like \(\frac{\partial{y}}{\partial{y}}\) and total derivative looks like \(\frac{\mathrm{d}y}{\mathrm{d}x}\). The important discrepancy between them is that when we are doing partial derivative, we assume all other parts are static. They won't change even if it won't be the case in the real world. For example, assume \(F(x)= g(x)f(x)\), \(g(x)\) and \(f(x)\) are both the function of \(x\) and generally will change when the other one changes. If we take the total derivative of \(F(x)\) to \(x\), we've all seen formula like this:
\[ \begin{aligned} \frac{\mathrm{d}F(x)}{\mathrm{d}x} &= \frac{\partial{F(x)}}{\partial{g(x)}} \frac{\mathrm{d}g(x)}{\mathrm{d}x} + \frac{\partial{F(x)}}{\partial{f(x)}} \frac{\mathrm{d}f(x)}{\mathrm{d}x} \\ \frac{\partial{F(x)}}{\partial{g(x)}} &= f(x) \\ \frac{\partial{F(x)}}{\partial{f(x)}} &= g(x) \end{aligned} \]
The partial derivative neglects the change of g(x) while changing f(x). If we want \(\frac{\mathrm{d}F(x)}{\mathrm{d}g(x)}\), it then will depend on the relationship between function \(g\) and \(f\).
Partial derivative for Neural Network
In this part, I want to talk about the backpropagation in vanilla neural networks to enhance your understanding of partial derivatives and more importantly, to elaborate some tricks to use when we are doing the back propagation. If you already have a fully comprehension towards BP in Neural Network, feel free to skip this section.
Without the loss of generality, I will use two-layer neural network as our example(one hidden layer), the model is for a multilabel classification problem:
\[ \begin{aligned} \mathbf{z}^{(1)} &= \mathbf{X}\mathbf{W_1} + \mathbf{b_1} \\ \mathbf{a}^{(1)} &= f^{(1)}(\mathbf{z}^{(1)}) \\ \mathbf{z}^{(2)} &= \mathbf{a}^{(1)}\mathbf{W_2} + \mathbf{b_2} \\ \hat{\mathbf{Y}} &= g(\mathbf{z}^{(2)}) \\ Loss &= \sum^{N}_{i=1} L(y_i, \hat{\mathbf{Y_i}}) \end{aligned} \]
The dimensions of \(\mathbf{X}\), \(\mathbf{W_1}\), \(\mathbf{z}^{(1)}\),\(\mathbf{W_2}\), \(\hat{\mathbf{Y}}\) will be \((N, D)\), \((D, D_1)\), \((N, D_1)\), \((D_1, D_2)\), \((N, C)\). The superscript 1 and 2 indicates the layer, subscript i indicate the corresponding row of data.
When we are trying to calculate the total derivative to the \(W_1\), there will be generally two ways to do this. One formal way and one tricky way:
Formal derivation
The process is straightforward, we only need to write down it as the chain rule suggests.
\[ \frac{\mathrm{d}L}{\mathrm{d}{\mathbf{W_1}}} = \frac{\mathrm{d}{\mathbf{Z}^{(1)}}}{\mathrm{d}{\mathbf{W_1}}} \frac{\partial{\mathbf{Z}^{(2)}}}{\partial{\mathbf{Z}^{(1)}}} \frac{\partial{\mathbf{Y}}}{\partial{\mathbf{Z}^{(2)}}} \frac{\partial{L}}{\partial{\mathbf{Y}}} \]
Although the product of all these things is a simple matrix. However, there are a lot of waste in the process. We have to see that the intermediate matrice are some 4 dimemsional tensors, e.g. \(\frac{\mathrm{d}{\mathbf{Z}^{(1)}}}{\mathrm{d}{\mathbf{W_1}}}\) is a tensor of \((D, D_1, N, D_1)\) dimension. For every element \(W_{1ij}\) among the \((D, D_1)\) elements in \(W_1\). There exists a \((N, D_1)\) matrix of derivative of \(\mathbf{Z}^{(1)}\) to \(W_{1ij}\). And among elements \(N * D_1\) of this matrix, only N elements are none zero(the ones of neuron j in layer 1, namely the column j of matrix \(\mathbf{Z}^{(1)}\)). So as a result, a prevailing tricky one comes into our views.
Tricky derivation
Please Note in the tricky version, I kind of misuse the math symbol derivative \(dx\), it means \(\frac{\partial{L}}{\partial{x}}\) in my case just for the seek of convenience.
Instead of computing the result all in the end, we will try to decrease this problem and conquer it. For example, in the process, we will try to derive \(\frac{\partial{L}}{\partial{\mathbf{Z}^{(2)}}}\) first and go on. In the Python realization, we will note this as a \(\mathrm{d}\mathbf{Z}^{(2)}\), which loos like a total derivative sort of thing although it's a partial derivative. Thus, the partial derivative will be like waterflow and easier to compute. On the stage of this, we will furthur consider every element of next matrix to which we shall take derivative, and try to summarize them into a compact matrix multiplication. Like the example we discuss in the previous method, we will try to derive \(\frac{\mathrm{d}L}{\mathrm{d}{\mathbf{W_1}}}\) with \(\mathrm{d}\mathbf{Z}^{(2)}\). Then, we can just see what we get if we want to have \(\frac{\mathrm{d}L}{\mathrm{d}{\mathbf{W_{1ij}}}}\). With ease we know that it is the inner product of \(\mathrm{d}\mathbf{Z}^{(2)}_{j}\) and \(\mathbf{X}_{i}\), which are both shape of \((N,)\). As a result, the 4 dimensional tensor multiplication in the above will become:
\[ \frac{\mathrm{d}L}{\mathrm{d}{\mathbf{W_1}}} = \mathrm{X}^{T} \mathrm{d}\mathbf{Z}^{(2)} \]
And of course, this tricky derivation is widely used in the algorithm for BP for NN.
The essence of BPTT in vanilla RNN
After previous discussion, we will try to deal with this hardcore problem!
First, we will describe the model and it's parameters:
\[ \begin{aligned} \mathbf{h}^{(t)} &=\tanh\left( \mathbf{b} + \mathbf{h}^{(t-1)} \mathbf{W_h} + \mathbf{x}^{(t)} \mathbf{W_x} \right) \\ \mathbf{o}^{(t)} &= \mathbf{c} +\mathbf{h}^{(t)} \mathbf{V} \\ \mathbf{\hat{y}}^{(t)} &= softmax\left( \mathbf{o}^{(t)} \right) \\ \mathbf{L}^{(t)} &= L(\mathbf{y}^{(t)}, \mathbf{\hat{y}}^{(t)}) \\ L &= \sum^{T}_{t=1} \sum^{N}_{i=1} L^{(t)}_{i} \end{aligned} \]
Here, we also want to take input as a minibatch of N entries. The dimensions of \(\mathbf{x}^{(t)}\), \(\mathbf{h}^{(t-1)}\), \(\mathbf{W_h}\), \(\mathbf{W_x}\), \(\mathbf{V}\), \(\mathbf{L}^{(t)}\) are \((N,D)\), \((N,H)\), \((H,H)\), \((D,H)\), \((D, C)\),\((N,)\) respectively. \(L^{(t)}_{i}\) is the \(i\)-th element of \(\mathbf{L}^{(t)}\). The Loss function is the common cross-entropy loss.
The most headache problem involve the derivative to \(\mathbf{W_h}\), \(\mathbf{W_x}\) and \(\mathbf{b}\). Because they are in all \(\mathbf{h}^{(t)}\) so the derivatives get some sort of recursive fashion. It's quite a thorny thing to deal with. So here, we will cut directly to this part. I will put down the derivation of the derivative of \(L\) to \(\mathbf{W_h}\). Others will leave as exercises for you:)
Goal: Find the total derivative $ $
\[ \frac{\mathrm{d}L}{\mathrm{d}\mathbf{W_h}}= \sum^{T}_{t=1} \frac{\mathrm{d}L^{(t)}}{\mathrm{d}\mathbf{W_h}} \]
Notice the \(L^{(t)}\) is a scalar and it's \(\sum^{N}_{i=1} L^{(t)}_{i}\). We will just look at one of them to have a sense of how it will end up.
\[ \frac{\mathrm{d}L^{(t)}}{\mathrm{d}\mathbf{W_h}} = \frac{\mathrm{d}\mathbf{h}^{(t)}}{\mathrm{d}\mathbf{W_h}} \frac{\partial \mathbf{o}^{(t)}}{\partial \mathbf{h}^{(t)}} \frac{\partial \mathbf{\hat{y}}^{(t)}}{\partial \mathbf{o}^{(t)}} \frac{\partial L^{(t)}}{\partial \mathbf{\hat{y}}^{(t)}} \]
In order to use tricks like BP in the Neural network. We will compute the partial derivative in sequence. Note again the symbol \(d\) is just for notation convenience and has nothing to do with total derivative.
\[ \begin{aligned} d\mathbf{o}^{(t)} &= \frac{\partial L^{(t)}}{\partial \mathbf{o}^{(t)}} = \mathbf{\hat{y}}^{(t)} - \mathbf{1_{y^{(t)}}} \\ d\mathbf{h}^{(t)} &= d\mathbf{o}^{(t)} V^{T} \end{aligned} \]
\(\mathbf{1_{y^{(t)}}}\) is a special matrix of same size as \(\mathbf{\hat{y}}^{(t)}\). For row \(i\) in \(\mathbf{1_{y^{(t)}}}\), the \(y^{(t)}_{i}\)-th element is 1, others are zero.
Then, we will define a intermediate matrix \(\mathbf{a}^{(t)}\) to avoid the use of 4 dimensional tensor. Notice the square of \(d\mathbf{h}^{(t)}\) is a element-wise square operation.
\[ \begin{aligned} \mathbf{a}^{(t)} &=\mathbf{b} + \mathbf{h}^{(t-1)} \mathbf{W_h} + \mathbf{x}^{(t)} \mathbf{W_h} \\ \mathbf{h}^{(t)} &=\tanh\left( \mathbf{a}^{(t)} \right) \\ d\mathbf{h}^{(t)} &= d\mathbf{o}^{(t)} V^{T} \\ d\mathbf{a}^{(t)} &= 1 - (d\mathbf{h}^{(t)})^2 \\ \frac{\mathrm{d}L^{(t)}}{\mathrm{d}\mathbf{W_h}} &= \frac{\mathrm{d}\mathbf{a}^{(t)}}{\mathrm{d}\mathbf{W_h}} d\mathbf{a}^{(t)} \\ &= (\frac{\mathrm{d}\mathbf{W_h}}{\mathrm{d}\mathbf{W_h}} \frac{\partial \mathbf{a}^{(t)}}{\partial \mathbf{W_h}} + \frac{\mathrm{d}\mathbf{h}^{(t-1)}}{\mathrm{d}\mathbf{W_h}} \frac{\partial \mathbf{a}^{(t)}}{\partial \mathbf{h}^{(t-1)}} )d\mathbf{a}^{(t)} \end{aligned} \]
For the last equation, the first part is simply. If we use the result from above. The product will simply be \(d{\mathbf{W_h}^{(t)}} = (\mathbf{h}^{(t-1)})^T d\mathbf{a}^{(t)}\). And then we can furthur integrate the last part of the equation. At last we get:
\[ \begin{aligned} \frac{\mathrm{d}L^{(t)}}{\mathrm{d}\mathbf{W_h}} &= d{\mathbf{W_h}^{(t)}} + \frac{\mathrm{d}\mathbf{h}^{(t-1)}}{\mathrm{d}\mathbf{W_h}} d\mathbf{h}^{(t-1)} \\ d\mathbf{h}^{(t-1)} &= d\mathbf{a}^{(t)} (\mathbf{h}^{(t-1)})^T \end{aligned} \]
And then, we find this can be done recursively afterwards. \(d\mathbf{h}^{(t-1)}\) will be passed down just like the waterflow in the BP in Neural network. At last, after rearrangement, we will have a final version like this. In order to calculate \(d{\mathbf{W_h}^{(1)}}\), there should be a \(\mathbf{h}^{(0)}\) as initialization value for the hidden state here.
\[ \begin{aligned} \frac{\mathrm{d}L^{(t)}}{\mathrm{d}\mathbf{W_x}} &= \sum^{t}_{i=1}d{\mathbf{W_h}^{(i)}} \end{aligned} \]
In the application, when we are trying to calculate the \(\frac{\mathrm{d}L^{(t)}}{\mathrm{d}\mathbf{W_x}}\), for a given timestep \(t=t_0 \in \left[ 1, t-1 \right]\), we may first receive the derivative message from \(t_0+1\) and then we need to pass down the derivative message. The whole process will looks like this.
\[ \begin{aligned} d\mathbf{a}^{(i)} &= 1 - (d\mathbf{h}^{(i)})^2 \\ d{\mathbf{W_h}^{(i)}} &= (\mathbf{h}^{(t-1)})^T d\mathbf{a}^{(t)} \\ d{\mathbf{W_x}^{(i)}} &= \ \dots \\ d{\mathbf{X}^{(i)}} &= \ \dots \\ d{\mathbf{b}^{(i)}} &= \ \dots \\ d\mathbf{h}^{(i-1)} &= d\mathbf{a}^{(i)} (\mathbf{h}^{(i-1)})^T \end{aligned} \]
And then there will be \(T\) timestamps we need to compute corresponding to \(T\) different \(L^{(t)}\).
If you find any mistakes in this passage, feel free to contact me through my email: FDSM_lhn@yahoo.com.
You're welcome to share my passage to other websites or bloggers. But please add my name and link to this post alongside. Thank you!