Prelude

Required Background

This document assumes that you have some familiarity with neural networks and training them (and all the machine learning, linear algebra and vector calculus that entails). Terms like “gradient”, “parameters”, “softmax”, and “regularization” will be tossed around without accompanying explanation. Other than that, I am basically assuming that you haven’t heard of a Transformer, and am building intuition from the ground up, starting with self-attention. If you don’t have this background, I’d suggest getting up to speed on that first before reading this.

Word Embeddings

Implicit in all following discussion of machine translation and language modeling is that you can feed “words” to a neural network. Of course, neural networks speak in numbers, not words, so words have to be represented as a word embedding vector in $\mathbb{R}^d$. There is a fixed vocabulary, and each word in the vocabulary corresponds to a d-dimensional vector in a lookup table. In the past, these might have been learned separately (via a procedure like GloVe or Word2Vec), and then taken as fixed inputs to the model. In more recent research, including the Transformer, embeddings are initialized randomly, and learned during training to minimize loss.

For open-vocabulary tasks (where we want the model to be able to understand words it may not have seen during training), sub-word units—chunks of words, characters, or Unicode code-points—are used instead of words. Whether using words or sub-words, an important preprocessing step is to break pieces of text into tokens (words or sub-words), then map these tokens to integer indices into the vocabulary. It is these integer indices, rather than passages of text, that are fed directly to the neural network.

Transformer Architecture & Attention

The Transformer model originates from the 2017 paper “Attention is All You Need” (Vaswani et al.). The Transformer was conceived as a method for sequence-to-sequence mapping for machine translation, but the attention-based encoder and decoder introduced in the paper have spread like wildfire, and have been used to great effect for all sorts of language-modeling tasks. Pretty much all recent state-of-the-art work in large language models (BERT, GPT-3, etc.) is derived from this attention-based architecture. It has also caught on in computer vision, and more recently, even reinforcement learning! In this section, I’ll give a high-level overview of what attention is, and how transformers use it to model sequences in the natural language setting.

The Idea of Attention

N.B. — for a more detailed/visual explanation, see this article or this video.

Attention in Recurrent Neural Networks

Before the advent of transformers, neural machine translation (translating from language to another using neural networks) was based on recurrent models, which processed the input sentence one word at a time using shared weights, generating a new hidden state $h_t$ for each position $t$ in the input sentence (based on $h_{t-1}$ and the $t$-th word in the sentence). At the end of the sequence, the final hidden state would be used as the representation of the whole sentence. This was limiting, because all the information from the whole sentence had to be compressed into that one state. It rendered models unable to deal effectively with longer sentences.

Recurrent neural network encoding a sequence.

Recurrent neural network encoding a sequence.

Attention was initially popularized as a solution to this problem: instead of just using the final hidden state to represent the sentence, Bahdanau et al. proposed using all of the hidden states (one for each input word) to represent the sentence. At each step of decoding (where the representation of the input sentence is “decoded” into the target language one word at a time), the hidden states are averaged, weighted by their relevance to the current step of decoding. In this way, the decoder can assign more weight (“pay attention”) to hidden states that are most relevant to the current decoding step. The weights, i.e. the relevance of one vector to another, can be determined most simply by a dot product. This (roughly) works because vectors that point in similar directions tend to have more positive dot products, and vectors that are opposite or orthogonal have smaller or negative dot products. The raw scores of the dot product are then normalized by a softmax function.

Dot-product attention weights. Source: Chris Manning’s CS224N slides.

Dot-product attention weights. Source: Chris Manning’s CS224N slides.

Self-Attention Without RNNs

Attention caught on as a way to improve recurrent models for machine translation, as described in the previous section. But it really took off with the seminal paper, “Attention is All You Need,” by Vaswani et al. This paper dispenses with the recurrent, “one-word-at-a-time” architecture in favor of what it calls the Transformer, which processes an entire sequence in parallel. In order to capture dependencies between words (or tokens) in a sequence, the authors use self-attention, which is like the attention mechanism discussed before, but applied reflexively to compute the relevance of words in a sequence $S$ with respect to each other (rather than with respect to a separate decoder state).

This can be hard to grasp at first, but it helps to think about attention as a function. This function takes a sequence $S$ and a query $x$, and outputs a weighted average of the elements of $S$, where the weights for each $s_i$ are determined by their relevance to $x$ (for now, measured by a dot product).