LSTM units

Feed forward pass

ft = sigmoid(ht − 1 * Wf + xt * Uf)

it = sigmoid(ht − 1 * Wi + xt * Ui)

gt = tanh(ht − 1 * Wg + xt * Ug)

ot = sigmoid(ht − 1 * Wo + xt * Uo)

Ct = Ct − 1 ⋅ ft + it ⋅ gt

ht = yt = tanh(Ct) ⋅ ot

Back propagation pass

To perform the BPTT with a LSTM unit, we have the eror comming from the top layer (δ1), the future cell (δ4), the future hidden state (δ2). Also, we have stored during the feed forward the states at each step of the feeding. In the case of the future layer, this error is just set to zero if not calculated yet. For convention, correspond to point wise multiplication, while * correspond to matrix multiplication.

The rules on how to back prpagate come from this post.

δ3 = δ1 + δ2

δ5 = δ3 ⋅ 6 = δ3 ⋅ ot

δ6 = δ3 ⋅ 5 = δ3 ⋅ tanh(ct)

δ7 = δ5 ⋅ f′(5) = δ5 ⋅ tanh′(tanh(ct))

δ8 = δ7 ⋅ δ4

δ9 = δ8 ⋅ 10 = δ8 ⋅ it

δ10 = δ8 ⋅ 9 = δ8 ⋅ gt

δ11 = δ8 ⋅ 12 = δ8 ⋅ ft

δ12 = δ8 ⋅ 11 = δ8 ⋅ ct − 1

δ13 = δ6 ⋅ f′(6) = δ6 ⋅ sigmoid′(ot) δ14 = δ9 ⋅ f′(9) = δ9 ⋅ tanh′(gt) δ15 = δ10 ⋅ f′(10) = δ10 ⋅ sigmoid′(it) δ16 = δ12 ⋅ f′(12) = δ12 ⋅ sigmoid′(ft)

δ17 = δ13 * UoT δ19 = δ14 * UgT δ21 = δ15 * UiT δ23 = δ16 * WfT δ18 = δ13 * WoT δ20 = δ14 * WgT δ22 = δ16 * WiT δ24 = δ16 * WfT

δ25 = δ18 + δ20 + δ22 + δ24 δ26 = δ17 + δ19 + δ21 + δ23

The error δ11, δ25 and δ26 are used for the next layers. Once all those errors are available, it is possible to calculate the weight update.

δWf = δWf + ht − 1T * δ16 δUf = δUf + xtT * δ16

δWi = δWi + ht − 1T * δ15 δUi = δUi + xtT * δ15

δWg = δWg + ht − 1T * δ14 δUg = δUg + xtT * δ14

δWo = δWo + ht − 1T * δ13 δUo = δUo + xtT * δ13