--- title: "Simple Self-Attention from Scratch" output: rmarkdown::html_vignette vignette: > %\VignetteIndexEntry{Simple Self-Attention from Scratch} %\VignetteEngine{knitr::rmarkdown} %\VignetteEncoding{UTF-8} --- ```{r, include = FALSE} knitr::opts_chunk$set( collapse = TRUE, comment = "#>" ) library(rnn) library(attention) ``` This vignette describes how to implement the [attention mechanism](https://en.m.wikipedia.org/wiki/Attention_(machine_learning)) - which forms the basis of [transformers](https://en.m.wikipedia.org/wiki/Transformer_(machine_learning_model)) - in the [R language](https://en.m.wikipedia.org/wiki/R_(programming_language)). We begin by generating encoder representations of four different words. ```{r} # encoder representations of four different words word_1 = matrix(c(1,0,0), nrow=1) word_2 = matrix(c(0,1,0), nrow=1) word_3 = matrix(c(1,1,0), nrow=1) word_4 = matrix(c(0,0,1), nrow=1) ``` Next, we stack the word embeddings into a single array (in this case a matrix) which we call `words`. ```{r} # stacking the word embeddings into a single array words = rbind(word_1, word_2, word_3, word_4) ``` Let's see what this looks like. ```{r} print(words) ``` Next, we generate random integers on the domain `[0,3]`. ```{r} # initializing the weight matrices (with random values) set.seed(0) W_Q = matrix(floor(runif(9, min=0, max=3)),nrow=3,ncol=3) W_K = matrix(floor(runif(9, min=0, max=3)),nrow=3,ncol=3) W_V = matrix(floor(runif(9, min=0, max=3)),nrow=3,ncol=3) ``` Next, we generate the Queries (`Q`), Keys (`K`), and Values (`V`). The `%*%` operator performs the matrix multiplication. You can view the R help page using `help('%*%')` (or the online [An Introduction to R](https://cran.r-project.org/doc/manuals/r-release/R-intro.html#Multiplication)). ```{r} # generating the queries, keys and values Q = words %*% W_Q K = words %*% W_K V = words %*% W_V ``` Following this, we score the Queries (`Q`) against the Key (`K`) vectors (which are transposed for the multiplation using `t()`, see `help('t')` for more info). ```{r} # scoring the query vectors against all key vectors scores = Q %*% t(K) print(scores) ``` We now generate the `weights` matrix. ```{r} weights = attention::ComputeWeights(scores) ``` Let's have a look at the `weights` matrix. ```{r} print(weights) ``` Finally, we compute the `attention` as a weighted sum of the value vectors (which are combined in the matrix `V`). ```{r} # computing the attention by a weighted sum of the value vectors attention = weights %*% V ``` Now we can view the results using: ```{r} print(attention) ```