Source code for joeynmt.transformer_layers
# -*- coding: utf-8 -*-
"""
Transformer layers
"""
import math
from typing import Optional
import torch
from torch import Tensor, nn
from joeynmt.builders import build_activation
[docs]
class MultiHeadedAttention(nn.Module):
"""
Multi-Head Attention module from "Attention is All You Need"
Implementation modified from OpenNMT-py.
https://github.com/OpenNMT/OpenNMT-py
"""
def __init__(self, num_heads: int, size: int, dropout: float = 0.1) -> None:
"""
Create a multi-headed attention layer.
:param num_heads: the number of heads
:param size: hidden size (must be divisible by num_heads)
:param dropout: probability of dropping a unit
"""
super().__init__()
assert size % num_heads == 0
self.head_size = head_size = size // num_heads
self.model_size = size
self.num_heads = num_heads
self.k_layer = nn.Linear(size, num_heads * head_size)
self.v_layer = nn.Linear(size, num_heads * head_size)
self.q_layer = nn.Linear(size, num_heads * head_size)
self.output_layer = nn.Linear(size, size)
self.softmax = nn.Softmax(dim=-1)
self.dropout = nn.Dropout(dropout)
[docs]
def forward(
self,
k: Tensor,
v: Tensor,
q: Tensor,
mask: Optional[Tensor] = None,
return_weights: Optional[bool] = None,
):
"""
Computes multi-headed attention.
:param k: keys [batch_size, seq_len, hidden_size]
:param v: values [batch_size, seq_len, hidden_size]
:param q: query [batch_size, seq_len, hidden_size]
:param mask: optional mask [batch_size, 1, seq_len]
:param return_weights: whether to return the attention weights,
averaged over heads.
:return:
- output [batch_size, query_len, hidden_size]
- attention_weights [batch_size, query_len, key_len]
"""
batch_size = k.size(0)
key_len = k.size(1)
query_len = q.size(1)
# project the queries (q), keys (k), and values (v)
k = self.k_layer(k)
v = self.v_layer(v)
q = self.q_layer(q)
# reshape q, k, v for our computation to
# [batch_size, num_heads, seq_len, head_dim]
k = k.view(batch_size, -1, self.num_heads, self.head_size).transpose(1, 2)
v = v.view(batch_size, -1, self.num_heads, self.head_size).transpose(1, 2)
q = q.view(batch_size, -1, self.num_heads, self.head_size).transpose(1, 2)
# compute scores
q = q / math.sqrt(self.head_size)
# [batch_size, num_heads, query_len, key_len]
scores = torch.matmul(q, k.transpose(2, 3))
# apply the mask (if we have one)
# we add a dimension for the heads to it below: [batch_size, 1, 1, key_len]
if mask is not None:
scores = scores.masked_fill(~mask.unsqueeze(1), float("-inf"))
# apply attention dropout and compute context vectors.
attention_weights = self.softmax(scores)
attention_probs = self.dropout(attention_weights)
# get context vector (select values with attention) and reshape
# back to [batch_size, query_len, hidden_size]
context = torch.matmul(attention_probs, v)
context = context.transpose(1, 2).contiguous().view(
batch_size, -1, self.num_heads * self.head_size
)
output = self.output_layer(context)
if return_weights:
# average attention weights over heads: [batch_size, query_len, key_len]
attention_output_weights = attention_weights.view(
batch_size, self.num_heads, query_len, key_len
)
return output, attention_output_weights.sum(dim=1) / self.num_heads
return output, None
[docs]
class PositionwiseFeedForward(nn.Module):
"""
Position-wise Feed-forward layer
Projects to ff_size and then back down to input_size.
"""
def __init__(
self,
input_size: int,
ff_size: int,
dropout: float = 0.1,
alpha: float = 1.0,
layer_norm: str = "post",
activation: str = "relu",
) -> None:
"""
Initializes position-wise feed-forward layer.
:param input_size: dimensionality of the input.
:param ff_size: dimensionality of intermediate representation
:param dropout: dropout probability
:param alpha: weight factor for residual connection
:param layer_norm: either "pre" or "post"
:param activation: activation function
"""
super().__init__()
activation_fnc = build_activation(activation=activation)
self.layer_norm = nn.LayerNorm(input_size, eps=1e-6)
self.pwff_layer = nn.Sequential(
nn.Linear(input_size, ff_size),
activation_fnc(),
nn.Dropout(dropout),
nn.Linear(ff_size, input_size),
nn.Dropout(dropout),
)
self.alpha = alpha
self._layer_norm_position = layer_norm
assert self._layer_norm_position in {"pre", "post"}
[docs]
def forward(self, x: Tensor) -> Tensor:
residual = x
if self._layer_norm_position == "pre":
x = self.layer_norm(x)
x = self.pwff_layer(x) + self.alpha * residual
if self._layer_norm_position == "post":
x = self.layer_norm(x)
return x
[docs]
class PositionalEncoding(nn.Module):
"""
Pre-compute position encodings (PE).
In forward pass, this adds the position-encodings to the input for as many time
steps as necessary.
Implementation based on OpenNMT-py.
https://github.com/OpenNMT/OpenNMT-py
"""
def __init__(self, size: int = 0, max_len: int = 5000) -> None:
"""
Positional Encoding with maximum length
:param size: embeddings dimension size
:param max_len: maximum sequence length
"""
if size % 2 != 0:
raise ValueError(
f"Cannot use sin/cos positional encoding with odd dim (got dim={size})"
)
pe = torch.zeros(max_len, size)
position = torch.arange(0, max_len).unsqueeze(1)
div_term = torch.exp(
(torch.arange(0, size, 2, dtype=torch.float) * -(math.log(10000.0) / size))
)
pe[:, 0::2] = torch.sin(position.float() * div_term)
pe[:, 1::2] = torch.cos(position.float() * div_term)
pe = pe.unsqueeze(0) # shape: (1, max_len, size)
super().__init__()
self.register_buffer("pe", pe)
self.dim = size
[docs]
def forward(self, emb: Tensor) -> Tensor:
"""
Embed inputs.
:param emb: (Tensor) Sequence of word embeddings vectors
shape (seq_len, batch_size, dim)
:return: positionally encoded word embeddings
"""
# Add position encodings
return emb + self.pe[:, :emb.size(1)]
[docs]
class TransformerEncoderLayer(nn.Module):
"""
One Transformer encoder layer has a Multi-head attention layer plus a position-wise
feed-forward layer.
"""
def __init__(
self,
size: int = 0,
ff_size: int = 0,
num_heads: int = 0,
dropout: float = 0.1,
alpha: float = 1.0,
layer_norm: str = "post",
activation: str = "relu",
) -> None:
"""
A single Transformer encoder layer.
Note: don't change the name or the order of members!
otherwise pretrained models cannot be loaded correctly.
:param size: model dimensionality
:param ff_size: size of the feed-forward intermediate layer
:param num_heads: number of heads
:param dropout: dropout to apply to input
:param alpha: weight factor for residual connection
:param layer_norm: either "pre" or "post"
:param activation: activation function
"""
super().__init__()
self.layer_norm = nn.LayerNorm(size, eps=1e-6)
self.src_src_att = MultiHeadedAttention(num_heads, size, dropout=dropout)
self.feed_forward = PositionwiseFeedForward(
size,
ff_size=ff_size,
dropout=dropout,
alpha=alpha,
layer_norm=layer_norm,
activation=activation,
)
self.dropout = nn.Dropout(dropout)
self.size = size
self.alpha = alpha
self._layer_norm_position = layer_norm
assert self._layer_norm_position in {"pre", "post"}
[docs]
def forward(self, x: Tensor, mask: Tensor) -> Tensor:
"""
Forward pass for a single transformer encoder layer.
First applies self attention, then dropout with residual connection (adding
the input to the result), then layer norm, and then a position-wise
feed-forward layer.
:param x: layer input
:param mask: input mask
:return: output tensor
"""
residual = x
if self._layer_norm_position == "pre":
x = self.layer_norm(x)
x, _ = self.src_src_att(x, x, x, mask)
x = self.dropout(x) + self.alpha * residual
if self._layer_norm_position == "post":
x = self.layer_norm(x)
out = self.feed_forward(x)
return out
[docs]
class TransformerDecoderLayer(nn.Module):
"""
Transformer decoder layer.
Consists of self-attention, source-attention, and feed-forward.
"""
def __init__(
self,
size: int = 0,
ff_size: int = 0,
num_heads: int = 0,
dropout: float = 0.1,
alpha: float = 1.0,
layer_norm: str = "post",
activation: str = "relu",
) -> None:
"""
Represents a single Transformer decoder layer.
It attends to the source representation and the previous decoder states.
Note: don't change the name or the order of members!
otherwise pretrained models cannot be loaded correctly.
:param size: model dimensionality
:param ff_size: size of the feed-forward intermediate layer
:param num_heads: number of heads
:param dropout: dropout to apply to input
:param alpha: weight factor for residual connection
:param layer_norm: either "pre" or "post"
:param activation: activation function
"""
super().__init__()
self.size = size
self.trg_trg_att = MultiHeadedAttention(num_heads, size, dropout=dropout)
self.src_trg_att = MultiHeadedAttention(num_heads, size, dropout=dropout)
self.feed_forward = PositionwiseFeedForward(
size,
ff_size=ff_size,
dropout=dropout,
alpha=alpha,
layer_norm=layer_norm,
activation=activation,
)
self.x_layer_norm = nn.LayerNorm(size, eps=1e-6)
self.dec_layer_norm = nn.LayerNorm(size, eps=1e-6)
self.dropout = nn.Dropout(dropout)
self.alpha = alpha
self._layer_norm_position = layer_norm
assert self._layer_norm_position in {"pre", "post"}
[docs]
def forward(
self,
x: Tensor,
memory: Tensor,
src_mask: Tensor,
trg_mask: Tensor,
return_attention: bool = False,
**kwargs,
) -> Tensor:
"""
Forward pass of a single Transformer decoder layer.
First applies target-target self-attention, dropout with residual connection
(adding the input to the result), and layer norm.
Second computes source-target cross-attention, dropout with residual connection
(adding the self-attention to the result), and layer norm.
Finally goes through a position-wise feed-forward layer.
:param x: inputs
:param memory: source representations
:param src_mask: source mask
:param trg_mask: target mask (so as not to condition on future steps)
:param return_attention: whether to return the attention weights
:return:
- output tensor
- attention weights
"""
# pylint: disable=unused-argument
# 1. target-target self-attention
residual = x
if self._layer_norm_position == "pre":
x = self.x_layer_norm(x)
h1, _ = self.trg_trg_att(x, x, x, mask=trg_mask)
h1 = self.dropout(h1) + self.alpha * residual
if self._layer_norm_position == "post":
h1 = self.x_layer_norm(h1)
# 2. source-target cross-attention
h1_residual = h1
if self._layer_norm_position == "pre":
h1 = self.dec_layer_norm(h1)
h2, att = self.src_trg_att(
memory, memory, h1, mask=src_mask, return_weights=return_attention
)
h2 = self.dropout(h2) + self.alpha * h1_residual
if self._layer_norm_position == "post":
h2 = self.dec_layer_norm(h2)
# 3. final position-wise feed-forward layer
out = self.feed_forward(h2)
if return_attention:
return out, att
return out, None