import math
import torch
import torch.nn as nn
45 Transformer
def masked_X(X, valid_lens, value=-1e6):
'''
Calculate the masked version of `X` based on `valid_lens`. Usually `X` is
the scores used in dot product attention function with shape (batch_size,
num_queries, num_keys). Consider the i-th sequence in the batch. To specify
the j-th query should not attend to the keys after index k, we set
`valid_lens[i][j] = k`. If `valid_lens` is 1d tensor, then it is treated as
`valid_lens[i][j] = valid_lens[i]` for all `j`s
Args:
X (torch.Tensor): The tensor to be masked. Should have shape of
(batch_size, num_queries, num_keys).
valid_lens (torch.Tensor): Specify what entries in X are removed.
Tensor of shape (batch_size, ) or (batch_size, num_queries).
If shape (batch_size, ), set `X[i, :, valid_lens[i]:] = value` for
every `i in range(len(valid_lens))`.
If (batch_size, num_queries), set `X[i, j, valid_lens[i][j]:] = value`
for every `i in range(len(valid_lens))` and `j in range(len(valid_lens[i]))`.
value (float): The value to set for the removed entries in `X`.
Returns:
(torch.Tensor): Masked score of shape (batch_size, num_queries, num_keys).
'''
# Repeat or reshape `valid_lens` to have length batch_size * num_queries.
# Each number in `valid_lens` corresponds to the number of valid tokens in
# each query sequence.
if valid_lens.dim() == 1:
= torch.repeat_interleave(valid_lens, X.shape[1])
valid_lens else:
= valid_lens.reshape(-1)
valid_lens
# Create a 1d range array [0, ..., num_keys]
= torch.arange((X.shape[-1]), dtype=torch.float32, device=X.device)
seq_len_range
# Create masks by broadcasting
= seq_len_range[None, :] < valid_lens[:, None]
masks
= X.reshape(-1, X.shape[-1])
masked_X ~masks] = value
masked_X[= masked_X.reshape(X.shape)
masked_X
return masked_X
= torch.rand(2, 2, 4)
X = torch.tensor([2, 3])
valid_lens print(masked_X(X, valid_lens))
tensor([[[ 9.7588e-01, 2.7506e-01, -1.0000e+06, -1.0000e+06],
[ 4.3251e-01, 6.0360e-01, -1.0000e+06, -1.0000e+06]],
[[ 2.2331e-02, 6.7170e-01, 5.6041e-01, -1.0000e+06],
[ 6.1315e-02, 4.9694e-02, 4.6259e-01, -1.0000e+06]]])
Scaled Dot Product Attention
The attention function in transformers is a mechanism that calculates a weighted combination of input values to capture dependencies between tokens in a sequence, regardless of their distance.
The most commonly used attention function is scaled dot-product attention:
\mathrm{Attention} (\mathbf{Q}, \mathbf{K}, \mathbf{V}) = \text{softmax} \left( \frac{ \mathbf{Q} \mathbf{K}^{\top} }{ \sqrt{d_k} } \right) \mathbf{V},
where:
\mathbf{Q}: Query matrix. Each row in \mathbf{Q} is a query token.
\mathbf{K}: Key matrix. Each row in \mathbf{K} is a key token.
\mathbf{V}: Value matrix. Each row in \mathbf{V} is a value token.
d_{k}: Dimensionality of the key tokens.
Explanations:
Compute dot products between \mathbf{Q} and \mathbf{K} to get a similarity score. Each row in matrix \mathbf{Q} \mathbf{K}^{\top} contains the similarity scores (dot product) between the i-th query token \mathbf{Q} and all key tokens in \mathbf{K}.
Scale the scores by \sqrt{d_k} to prevent large gradients.
Apply softmax to convert scores into attention weights. The softmax is applied to each row of matrix \frac{\mathbf{Q} \mathbf{K}^{\top}}{\sqrt{d_{k}}}, i.e. each row in \mathrm{softmax} \left( \frac{\mathbf{Q} \mathbf{K}^{\top}}{\sqrt{d_{k}}} \right) sums up 1.
Multiply the weights by \mathbf{V} to produce the attention output.
class ScaledDotProductAttention(nn.Module):
def __init__(self, dropout=0.1):
super().__init__()
self.dropout = nn.Dropout(dropout)
def forward(self, queries, keys, values, valid_lens=None):
'''
queries: (batch_size, num_queries, query_dim)
keys: (batch_size, num_keys, query_dim)
values: (batch_size, num_keys, value_dim)
valid_lens: (batch_size, ) or (batch_size, num_queries)
'''
# scores: (batch_size, num_queries, num_keys)
= torch.bmm(queries, keys.transpose(1, 2)) / math.sqrt(queries.shape[-1])
scores if valid_lens is not None:
= masked_X(scores, valid_lens)
scores
# attention_weights: (batch_size, num_queries, num_keys)
self.attention_weights = torch.softmax(scores, dim=-1)
# attentions: (batch_size, num_queries, value_dim)
= torch.bmm(self.dropout(self.attention_weights), values)
attentions
return attentions
= torch.normal(0, 1, (2, 4, 2))
queries = torch.normal(0, 1, (2, 6, 2))
keys = torch.normal(0, 1, (2, 6, 4))
values = torch.tensor([2, 4])
valid_lens
= ScaledDotProductAttention(dropout=0.1)
attention eval()
attention.print(attention(queries, keys, values, valid_lens))
print(attention.attention_weights)
tensor([[[ 0.1430, -0.1285, -0.1106, 0.7467],
[ 0.2903, -0.1678, 0.3139, 0.9371],
[ 0.0870, -0.1135, -0.2719, 0.6743],
[ 0.0252, -0.0970, -0.4500, 0.5944]],
[[-0.2120, -0.1235, -0.6992, 0.7657],
[-0.2176, -0.1362, -0.6724, 0.7451],
[-0.0597, -0.1562, -0.7385, 0.7999],
[-0.1990, -0.2225, -0.5296, 0.6328]]])
tensor([[[0.4356, 0.5644, 0.0000, 0.0000, 0.0000, 0.0000],
[0.1318, 0.8682, 0.0000, 0.0000, 0.0000, 0.0000],
[0.5511, 0.4489, 0.0000, 0.0000, 0.0000, 0.0000],
[0.6785, 0.3215, 0.0000, 0.0000, 0.0000, 0.0000]],
[[0.2134, 0.2993, 0.2441, 0.2431, 0.0000, 0.0000],
[0.2087, 0.2950, 0.2389, 0.2573, 0.0000, 0.0000],
[0.2731, 0.2140, 0.2473, 0.2657, 0.0000, 0.0000],
[0.1981, 0.2416, 0.2108, 0.3494, 0.0000, 0.0000]]])
Multi-Head Attention
In practice, given the same set of queries, keys, and values, we may want our model to combine different aspects of the knowledge in the data, which can be mathematically represented using different subspaces of the same data.
Given n tokens as rows of a matrix \mathbf{X}, they can be projected into vectors in another subspace by using a linear transformation matrix \mathbf{W}
\mathbf{X}' = \mathbf{X} \mathbf{W}.
The idea of multi-head attention is to perform the same attention mechanism on different learnable subspaces of the same set of queries, keys, and values, whose results are then concatenated and linear transformed again to give the information from different aspects of the data
\mathrm{MultiHead} (\mathbf{Q}, \mathbf{K}, \mathbf{V}) = \begin{bmatrix} \mathrm{head}_{1}, \dots, \mathrm{head}_{h} \end{bmatrix} \mathbf{W}_{O},
where
\mathrm{head}_{i} = \mathrm{Attention} (\mathbf{Q} \mathbf{W}_{Q}, \mathbf{K} \mathbf{W}_{K}, \mathbf{V} \mathbf{W}_{V}).
Parallel Implementation
The linear transformations of inputs for each attention layer can be implemented using fully connected (linear) layers. A vanilla implementation of the multi-head attention layer is to create a separate set of linear layers for each attention head and use a for loop to get the outputs from the h heads.
multi_head_outputs = []
for i in range(h):
W_q = linear(head_hidden_dim)
W_k = linear(head_hidden_dim)
W_v = linear(head_hidden_dim)
output = Attention(W_q(Q), W_k(K), W_q(V))
multi_head_outputs.append(output)
A parallel version of the same operation can be implemented using a single set of linear layers. To understand this, first observe the following facts.
Batch processing of the inputs is supported in
Attention()
. That is, the inputX
inAttention(X, X, X)
has a shape of(batch_size, seq_len, hidden_dim)
.The h different linear transformations \mathbf{W}_{1}, \dots, \mathbf{W}_{h} \in \mathbb{R}^{n \times d} of the same input \mathbf{X} can be grouped and replaced by a single linear transformation matrix \mathbf{W} \in \mathbb{R}^{n \times dh},
\begin{bmatrix} | & & | \\ \mathbf{X} \mathbf{W}_{1} & \dots & \mathbf{X} \mathbf{W}_{h} \\ | & & | \end{bmatrix} = \mathbf{X} \begin{bmatrix} | & & | \\ \mathbf{W}_{1} & \dots & \mathbf{W}_{h} \\ | & & | \end{bmatrix} = \mathbf{X} \mathbf{W}.
class MultiHeadAttention(nn.Module):
def __init__(self, num_heads, hidden_dim, dropout=0.1, bias=False):
super().__init__()
self.num_heads = num_heads
self.attention = ScaledDotProductAttention(dropout)
# Here we set `hidden_dim / num_heads` as the input embedding dim of
# queries, keys, and values in each head.
# The reason we use `hidden_dim` instead of `hidden_dim / num_heads` as
# the output dim for `W_q`, `W_k`, and `W_v` is to enable the parallel
# computation for all heads e.g. the i-th set of `hidden_dim / num_heads`
# outputs is for the i-th head.
self.W_q = nn.LazyLinear(hidden_dim, bias=bias)
self.W_k = nn.LazyLinear(hidden_dim, bias=bias)
self.W_v = nn.LazyLinear(hidden_dim, bias=bias)
# The input for `W_o` layer is still `hidden_dim` as the `num_heads`
# number of heads are concatenated before feeding into `W_o` layer.
self.W_o = nn.LazyLinear(hidden_dim, bias=bias)
def forward(self, queries, keys, values, valid_lens):
'''
The forward computaiton of a transformer attention layer.
Args:
queries (batch_size, num_queries, query_dim)
keys (batch_size, num_keys, query_dim)
values (batch_size, num_keys, value_dim)
valid_lens (batch_size, ) or (batch_size, num_queries)
Returns:
(batch_size, num_queries, hidden_dim)
'''
# multi_queries: (batch_size * num_heads, num_queries, hidden_dim / num_heads)
= self.transpose_qkv(self.W_q(queries))
multi_queries # multi_keys: (batch_size * num_heads, num_keys, hidden_dim / num_heads)
= self.transpose_qkv(self.W_k(keys))
multi_keys # multi_values: (batch_size * num_heads, num_keys, hidden_dim / num_heads)
= self.transpose_qkv(self.W_v(values))
multi_values
# Repeat each element in `valid_lens` `num_heads` times to align with
# the shape of `multi_*queries`.
if valid_lens is not None:
# multi_valid_lens: (batch_size * num_heads, ) or (batch_size * num_heads, num_queries)
= valid_lens.repeat_interleave(
multi_valid_lens self.num_heads, dim=0
)
# multi_output: (batch_size * num_heads, num_queries, hidden_dim / num_heads)
= self.attention(
multi_output =multi_valid_lens
multi_queries, multi_keys, multi_values, valid_lens
)
# output: (batch_size, num_queries, hidden_dim)
= self.W_o(self.transpose_output(multi_output))
output
return output
def transpose_qkv(self, X):
'''
Reshape X for parallel computation of multiple attention heads. Assume
`X` has shape (batch_size, seq_len, hidden_dim) and `hidden_dim` is
divisible by `num_heads`. We want to make it `(batch_size * num_heads,
seq_len, hidden_dim / num_heads)`, so that the self-attention performed
later is done on `hidden_dim / num_heads` dimension.
Args:
X (batch_size, seq_len, hidden_dim)
Returns:
(batch_size * num_heads, seq_len, hidden_dim / num_heads)
'''
# X: (batch_size, seq_len, num_heads, hidden_dim / num_heads)
= X.reshape(X.shape[0], X.shape[1], self.num_heads, -1)
X # X: (batch_size, num_heads, seq_len, hidden_dim / num_heads)
= X.permute(0, 2, 1, 3)
X # X: (batch_size * num_heads, seq_len, hidden_dim / num_heads)
= X.reshape(-1, X.shape[2], X.shape[3])
X
return X
def transpose_output(self, X):
'''
Reverse the operation of `transpose_qkv`.
Args:
X (batch_size * num_heads, seq_len, hidden_dim / num_heads)
Returns:
(batch_size, seq_len, hidden_dim)
'''
# X: (batch_size, num_heads, seq_len, hidden_dim / num_heads)
= X.reshape(-1, self.num_heads, X.shape[1], X.shape[2])
X # X: (batch_size, seq_len, num_heads, hidden_dim / num_heads)
= X.permute(0, 2, 1, 3)
X # X: (batch_size, seq_len, hidden_dim)
= X.reshape(X.shape[0], X.shape[1], -1)
X
return X
= MultiHeadAttention(2, 4, 0)
attention = torch.ones((2, 4, 4))
queries = torch.ones((2, 6, 4))
keys = torch.ones((2, 6, 6))
values = torch.tensor([3, 2])
valid_lens print(attention(queries, keys, values, valid_lens).shape)
torch.Size([2, 4, 4])
Positional Encoding
Given a sequence of n tokens as rows of a matrix \mathbf{X} \in \mathbb{R}^{n \times d}, the positional encoding will inject positional information into \mathbf{X} by generating a new matrix \mathbf{X}'
\mathbf{X}' = \mathbf{X} + \mathbf{P}
where \mathbf{P} \in \mathbb{R}^{n \times d} is a positional encoding matrix with each row being a positional encoding vector for each token for \mathbf{X}.
Usually \mathbf{P} should provide two types of positional information.
Absolute positional information. This type requires the encoding to provide the positional information that is unique across the entire sequence.
Relative positional information. This type requires the encoding to provide the positional information that encodes the relative order of the tokens.
Sinusoidal Positional Encoding
In sinusoidal positional encoding, the positional encoding matrix \mathbf{P} has sine and cosine functions with different periods at odd and even columns, respectively.
Each element p_{i, j} at the i-th row and j-th column in \mathbf{P} is calculated as
p_{i, j} = \begin{aligned} \begin{cases} \sin (\omega_{j} i) & \quad \text{when } j \text{ is even} \\ \cos (\omega_{j} i) & \quad \text{when } j \text{ is odd} \end{cases} \end{aligned}
where
\omega_{j} = \begin{aligned} \begin{cases} 1 \mathbin{/} \left( 10000^{j \mathbin{/} d} \right) & \quad \text{when } j \text{ is even} \\ 1 \mathbin{/} \left( 10000^{(j - 1) \mathbin{/} d} \right) & \quad \text{when } j \text{ is odd}. \end{cases} \end{aligned}
Encoding absolute positional information
The number of unique encodings that sinusoidal positional encoding can represent depends on d.
If d = 1, the encodings for the single column is \sin(i) for i = 1, \dots, n and it has a period of \lambda = 2 \pi. Since the sine function will repeat after each period, the positional encoding using a single sine function can represent at most \lfloor \lambda \rfloor = 6 number of tokens.
If d = 2, the encodings for the 1st and 2nd columns are \sin(i) and \cos(i) for i = 1, \dots, n, which have the same period \lambda = 2\pi. Since the corresponding the sine and cosine functions for the even and odd columns always have the same period, the number of unique tokens it can represent with an odd d is the same as that with the corresponding even d.
If d > 2, the number of unique positions that the sinusoidal positional encoding can achieve is the least common multiples of d \mathbin{/} 2 different periods, which is quite large for a reasonable d.
Encoding relative positional information
For any fixed offset \delta, the encodings at position i + \delta can be expressed as a linear transformation of the encodings at position i. To see this, we can use trigonometric sum identities to rewrite the encodings at position i + \delta:
\sin(\omega_{j} (i + \delta)) = \sin(\omega_{j} i) \cos(\omega_{j} \delta) + \cos(\omega_{j} i) \sin(\omega_j \delta),
\cos(\omega_{j} (i + \delta)) = \cos(\omega_{j} i) \cos(\omega_{j} \delta) - \sin(\omega_{j} i) \sin(\omega_j \delta),
which can be represented using the matrix multiplication
\begin{bmatrix} \sin(\omega_{j} (i + \delta)) \\ \cos(\omega_{j} (i + \delta)) \end{bmatrix} = \begin{bmatrix} \cos(\omega_{j} \delta) & \sin(\omega_{j} \delta) \\ \cos(\omega_{j} \delta) & - \sin(\omega_{j} \delta) \end{bmatrix} \begin{bmatrix} \sin(\omega_{j} i) \\ \cos(\omega_{j} i) \end{bmatrix}.
The positional encoding at i + \delta can be obtained by multiplying the encoding at i with a 2 \times 2 rotation matrix whose values do not depend on the position of the token i, which shows that the encodings at different positions are linearly dependant.
class SinusoidalPositionalEncoding(nn.Module):
def __init__(self, hidden_dim, dropout, max_len=1000):
super().__init__()
self.dropout = nn.Dropout(dropout)
# i: (max_len, 1)
= torch.arange(max_len).reshape(-1, 1)
i # two_j: (hidden_dim / 2, )
= torch.arange(0, hidden_dim, 2)
two_j # X: (max_len, hidden_dim / 2)
= i / torch.pow(10000, two_j / hidden_dim)
X
# P is the same for each sequence matrix in the mini-batch.
# The shape of (1, max_len, hidden_dim) for `P` can be directly
# broadcasted to (batch_size, max_len, hidden_dim) when added to `X`.
self.P = torch.zeros((1, max_len, hidden_dim))
self.P[:, :, 0::2] = torch.sin(X)
self.P[:, :, 1::2] = torch.cos(X)
def forward(self, X):
'''
X (batch_size, seq_len, hidden_dim)
'''
# Since `X` will have a seq_len less than `max_len`, we want to take the
# first seq_len of `P` when added to `X`.
= X + self.P[:, :X.shape[1], :].to(X.device)
X
return self.dropout(X)
= 32, 60
encoding_dim, num_steps = SinusoidalPositionalEncoding(encoding_dim, 0)
pos_encoding = torch.zeros((10, num_steps, encoding_dim))
X = pos_encoding(X)
X_p = pos_encoding.P[:, :X.shape[1], :]
P print(P.shape)
torch.Size([1, 60, 32])
Transformer Encoder
class TransformerEncoderBlock(nn.Module):
def __init__(self, num_heads, hidden_dim, ffn_hidden_dim, dropout, bias=False):
super().__init__()
self.attention = MultiHeadAttention(num_heads, hidden_dim, dropout, bias)
self.norm1 = nn.LayerNorm(hidden_dim)
self.ffn = nn.Sequential(
=bias),
nn.LazyLinear(ffn_hidden_dim, bias
nn.ReLU(),=bias)
nn.LazyLinear(hidden_dim, bias
)self.norm2 = nn.LayerNorm(hidden_dim)
def forward(self, X, valid_lens):
'''
X (batch_size, seq_len, hidden_dim)
valid_lens (batch_size, ) or (batch_size, seq_len)
'''
# Y: (batch_size, seq_len, hidden_dim)
= self.norm1(X + self.attention(X, X, X, valid_lens))
Y # Z: (batch_size, seq_len, hidden_dim)
= self.norm2(Y + self.ffn(Y))
Z
return Z
class TransformerEncoder(nn.Module):
def __init__(self, num_blocks, num_heads, hidden_dim, ffn_hidden_dim, dropout, bias=False):
super().__init__()
self.positional_encoding = SinusoidalPositionalEncoding(hidden_dim, dropout)
self.attention_blocks = []
for i in range(num_blocks):
self.attention_blocks.append(
TransformerEncoderBlock(num_heads, hidden_dim, ffn_hidden_dim, dropout, bias)
)
def forward(self, X, valid_lens):
'''
X (batch_size, seq_len, hidden_dim)
valid_lens (batch_size, ) or (batch_size, seq_len)
'''
# X: (batch_size, seq_len, hidden_dim)
= self.positional_encoding(X)
X # X: (batch_size, seq_len, hidden_dim)
for block in self.attention_blocks:
= block(X, valid_lens)
X
return X
= TransformerEncoder(2, 4, 8, 16, 0.5)
encoder = torch.ones((2, 4, 8))
X = torch.tensor([2, 3])
valid_lens print(encoder(X, valid_lens).shape)
torch.Size([2, 4, 8])
class TransformerDecoderBlock(nn.Module):
def __init__(self, block_index, num_heads, hidden_dim, ffn_hidden_dim, dropout, bias=False):
super().__init__()
self.block_index = block_index
self.attention1 = MultiHeadAttention(num_heads, hidden_dim, dropout, bias)
self.norm1 = nn.LayerNorm(hidden_dim)
self.attention2 = MultiHeadAttention(num_heads, hidden_dim, dropout, bias)
self.norm2 = nn.LayerNorm(hidden_dim)
self.ffn = nn.Sequential(
=bias),
nn.LazyLinear(ffn_hidden_dim, bias
nn.ReLU(),=bias)
nn.LazyLinear(hidden_dim, bias
)self.norm3 = nn.LayerNorm(hidden_dim)
def forward(self, X, state):
'''
Args:
- X (torch.Tensor): (batch_size, seq_len, hidden_dim)
- state (tuple): A tuple of 3 Tensors (enc_outputs, enc_valid_lens,
dec_outputs). enc_outputs: (batch_size, seq_len, hidden_dim) is
enc_valid_lens: (batch_size, ) or (batch_size, seq_len)
dec_outputs (num_blocks, batch_size, num_tokens_so_far, hidden_dim):
'''
# enc_output: (batch_size, seq_len, hidden_dim)
# enc_valid_lens: (batch_size, ) or (batch_size, seq_len)
# dec_outputs: (num_blocks, batch_size, n, hidden_dim)
= state
enc_outputs, enc_valid_lens, dec_outputs
# During training, all tokens in any sequence are available, so `X` has
# shape of (batch_size, seq_len, hidden_dim). The mask is used to ensure
# that attention is performed on the previous tokens.
if self.training:
# dec_valid_lens: (batch_size, seq_len)
# Each row in `dec_valid_lens` is [1, ..., seq_len], which masks out
# the upper right diagonal of each score matrix in the batch.
= torch.arange(
dec_valid_lens 1, X.shape[1] + 1, device=X.device
0], 1)
).repeat(X.shape[= self.norm1(X + self.attention1(X, X, X, dec_valid_lens))
Y # During prediction, one token is available per call of the function, so
# `X` has shape of (batch_size, 1, hidden_dim).
else:
# prev_X: (batch_size, num_tokens_so_far, hidden_dim)
if prev_X is None:
= X
prev_X else:
= torch.cat((dec_outputs[self.block_index], X), dim=1)
prev_X 2][self.block_index] = prev_X
state[= self.norm1(X + self.attention1(X, prev_X, prev_X))
Y
= self.norm2(self.attention2(Y, enc_outputs, enc_outputs, enc_valid_lens))
Z return self.norm3(Z + self.ffn(Z)), state
def TransformerDecoder(nn.Module):
def __init__(self, num_blocks, num_heads, hidden_dim, ffn_hidden_dim, dropout, bias=False):
super().__init__()
self.positional_encoding = SinusoidalPositionalEncoding(hidden_dim, dropout=dropout)
self.num_blocks = num_blocks
self.decoder_blocks = []
for i in range(num_blocks):
self.decoder_blocks.append(
=dropout)
TransformerDecoderBlock(i, num_heads, hidden_dim, ffn_hidden, dropout
)
self.dense = nn.LazyLinear(hidden_dim)
def forward(self, X, state):
= self.positional_encoding(X)
X for block in self.decoder_blocks:
= block(X, state)
X, state
return self.dense(X), state