Source code for joeynmt.encoders
# coding: utf-8
"""
Various encoders
"""
from typing import Tuple
import torch
from torch import Tensor, nn
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
from joeynmt.helpers import freeze_params
from joeynmt.transformer_layers import PositionalEncoding, TransformerEncoderLayer
[docs]
class Encoder(nn.Module):
"""
Base encoder class
"""
# pylint: disable=abstract-method
@property
def output_size(self):
"""
Return the output size
:return:
"""
return self._output_size
[docs]
class RecurrentEncoder(Encoder):
"""Encodes a sequence of word embeddings"""
# pylint: disable=unused-argument
def __init__(
self,
rnn_type: str = "gru",
hidden_size: int = 1,
emb_size: int = 1,
num_layers: int = 1,
dropout: float = 0.0,
emb_dropout: float = 0.0,
bidirectional: bool = True,
freeze: bool = False,
**kwargs,
) -> None:
"""
Create a new recurrent encoder.
:param rnn_type: RNN type: `gru` or `lstm`.
:param hidden_size: Size of each RNN.
:param emb_size: Size of the word embeddings.
:param num_layers: Number of encoder RNN layers.
:param dropout: Is applied between RNN layers.
:param emb_dropout: Is applied to the RNN input (word embeddings).
:param bidirectional: Use a bi-directional RNN.
:param freeze: freeze the parameters of the encoder during training
:param kwargs:
"""
super().__init__()
self.emb_dropout = torch.nn.Dropout(p=emb_dropout, inplace=False)
self.type = rnn_type
self.emb_size = emb_size
rnn = nn.GRU if rnn_type == "gru" else nn.LSTM
self.rnn = rnn(
emb_size,
hidden_size,
num_layers,
batch_first=True,
bidirectional=bidirectional,
dropout=dropout if num_layers > 1 else 0.0,
)
self._output_size = 2 * hidden_size if bidirectional else hidden_size
if freeze:
freeze_params(self)
def _check_shapes_input_forward(
self, src_embed: Tensor, src_length: Tensor, mask: Tensor
) -> None:
"""
Make sure the shape of the inputs to `self.forward` are correct.
Same input semantics as `self.forward`.
:param src_embed: embedded source tokens
:param src_length: source length
:param mask: source mask
"""
# pylint: disable=unused-argument
assert src_embed.shape[0] == src_length.shape[0]
assert src_embed.shape[2] == self.emb_size
# assert mask.shape == src_embed.shape
assert len(src_length.shape) == 1
[docs]
def forward(self, src_embed: Tensor, src_length: Tensor, mask: Tensor,
**kwargs) -> Tuple[Tensor, Tensor, Tensor]:
"""
Applies a bidirectional RNN to sequence of embeddings x.
The input mini-batch x needs to be sorted by src length.
x and mask should have the same dimensions [batch, time, dim].
:param src_embed: embedded src inputs,
shape (batch_size, src_len, embed_size)
:param src_length: length of src inputs
(counting tokens before padding), shape (batch_size)
:param mask: indicates padding areas (zeros where padding), shape
(batch_size, src_len, embed_size)
:param kwargs:
:return:
- output: hidden states with
shape (batch_size, max_length, directions*hidden),
- hidden_concat: last hidden state with
shape (batch_size, directions*hidden)
"""
self._check_shapes_input_forward(
src_embed=src_embed, src_length=src_length, mask=mask
)
total_length = src_embed.size(1)
# apply dropout to the rnn input
src_embed = self.emb_dropout(src_embed)
packed = pack_padded_sequence(src_embed, src_length.cpu(), batch_first=True)
output, hidden = self.rnn(packed)
if isinstance(hidden, tuple):
hidden, memory_cell = hidden # pylint: disable=unused-variable
output, _ = pad_packed_sequence(
output, batch_first=True, total_length=total_length
)
# hidden: dir*layers x batch x hidden
# output: batch x max_length x directions*hidden
batch_size = hidden.size()[1]
# separate final hidden states by layer and direction
hidden_layerwise = hidden.view(
self.rnn.num_layers,
2 if self.rnn.bidirectional else 1,
batch_size,
self.rnn.hidden_size,
)
# final_layers: layers x directions x batch x hidden
# concatenate the final states of the last layer for each directions
# thanks to pack_padded_sequence final states don't include padding
fwd_hidden_last = hidden_layerwise[-1:, 0]
bwd_hidden_last = hidden_layerwise[-1:, 1]
# only feed the final state of the top-most layer to the decoder
# pylint: disable=no-member
hidden_concat = torch.cat([fwd_hidden_last, bwd_hidden_last], dim=2).squeeze(0)
# final: batch x directions*hidden
assert hidden_concat.size(0) == output.size(0), (
hidden_concat.size(),
output.size(),
)
return output, hidden_concat
def __repr__(self):
return f"{self.__class__.__name__}(rnn={self.rnn})"
[docs]
class TransformerEncoder(Encoder):
"""
Transformer Encoder
"""
def __init__(
self,
hidden_size: int = 512,
ff_size: int = 2048,
num_layers: int = 8,
num_heads: int = 4,
dropout: float = 0.1,
emb_dropout: float = 0.1,
freeze: bool = False,
**kwargs,
):
"""
Initializes the Transformer.
:param hidden_size: hidden size and size of embeddings
:param ff_size: position-wise feed-forward layer size.
(Typically this is 2*hidden_size.)
:param num_layers: number of layers
:param num_heads: number of heads for multi-headed attention
:param dropout: dropout probability for Transformer layers
:param emb_dropout: Is applied to the input (word embeddings).
:param freeze: freeze the parameters of the encoder during training
:param kwargs:
"""
super().__init__()
self._output_size = hidden_size
# build all (num_layers) layers
self.layers = nn.ModuleList([
TransformerEncoderLayer(
size=hidden_size,
ff_size=ff_size,
num_heads=num_heads,
dropout=dropout,
alpha=kwargs.get("alpha", 1.0),
layer_norm=kwargs.get("layer_norm", "pre"),
activation=kwargs.get("activation", "relu"),
) for _ in range(num_layers)
])
self.pe = PositionalEncoding(hidden_size)
self.emb_dropout = nn.Dropout(p=emb_dropout)
self.layer_norm = (
nn.LayerNorm(hidden_size, eps=1e-6)
if kwargs.get("layer_norm", "post") == "pre" else None
)
if freeze:
freeze_params(self)
[docs]
def forward(
self,
src_embed: Tensor,
src_length: Tensor, # unused
mask: Tensor = None,
**kwargs,
) -> Tuple[Tensor, Tensor]:
"""
Pass the input (and mask) through each layer in turn.
Applies a Transformer encoder to sequence of embeddings x.
The input mini-batch x needs to be sorted by src length.
x and mask should have the same dimensions [batch, time, dim].
:param src_embed: embedded src inputs,
shape (batch_size, src_len, embed_size)
:param src_length: length of src inputs
(counting tokens before padding), shape (batch_size)
:param mask: indicates padding areas (zeros where padding), shape
(batch_size, 1, src_len)
:param kwargs:
:return:
- output: hidden states with shape (batch_size, max_length, hidden)
- None
"""
# pylint: disable=unused-argument
x = self.pe(src_embed) # add position encoding to word embeddings
if kwargs.get("src_prompt_mask", None) is not None: # add src_prompt_mask
x = x + kwargs["src_prompt_mask"]
x = self.emb_dropout(x)
for layer in self.layers:
x = layer(x, mask)
if self.layer_norm is not None:
x = self.layer_norm(x)
return x, None
def __repr__(self):
return (
f"{self.__class__.__name__}(num_layers={len(self.layers)}, "
f"num_heads={self.layers[0].src_src_att.num_heads}, "
f"alpha={self.layers[0].alpha}, "
f'layer_norm="{self.layers[0]._layer_norm_position}", '
f"activation={self.layers[0].feed_forward.pwff_layer[1]})"
)