GRU units

Feed forward pass

rt = sigmoid(ht − 1 * Wr + xt * Ur)

zt = sigmoid(ht − 1 * Wz + xt * Uz)

gt = tanh(Wg * (ht − 1 ⋅ rt) + xt * Ug)

ht = yt = ht − 1 ⋅ (1 − zt) + (zt ⋅ gt)

Back propagation pass

To perform the BPTT with a GRU unit, we have the eror comming from the top layer (δ1), the future hidden states (δ2). Also, we have stored during the feed forward the states at each step of the feeding. In the case of the future layer, this error is just set to zero if not calculated yet. For convention, correspond to point wise multiplication, while * correspond to matrix multiplication.

The rules on how to back prpagate come from this post.

δ3 = δ1 + δ2

δ4 = (1 − zt) ⋅ δ3

δ5 = δ3 ⋅ ht − 1

δ6 = 1 − δ5

δ7 = δ3 ⋅ gt

δ8 = δ3 ⋅ zt

δ9 = δ7 + δ8

δ10 = δ8 ⋅ tanh′(gt)

δ11 = δ9 ⋅ sigmoid′(zt)

δ12 = δ10 * WgT δ13 = δ10 * UgT δ14 = δ11 * WzT δ15 = δ11 * UzT

δ16 = δ13 ⋅ ht − 1 δ17 = δ13 ⋅ rt

δ18 = δ17 ⋅ sigmoid′(rt)

δ19 = δ17 + δ4

δ20 = δ18 * WrT δ21 = δ18 * UrT

δ22 = δ21 + δ15

δ23 = δ19 + δ22

δ24 = δ12 + δ14 + δ20

The error δ23 and δ24 are used for the next layers. Once all those errors are available, it is possible to calculate the weight update.

δWr = δWf + ht − 1T * δ10 δUr = δUf + xtT * δ10

δWz = δWi + ht − 1T * δ11 δUz = δUi + xtT * δ11

δWg = δWg + (ht − 1T ⋅ rt) * δ18 δUg = δUg + xtT * δ18