Source code for joeynmt.encoders

# coding: utf-8
"""
Various encoders
"""
from typing import Tuple

import torch
from torch import Tensor, nn
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence

from joeynmt.helpers import freeze_params
from joeynmt.transformer_layers import PositionalEncoding, TransformerEncoderLayer


[docs] class Encoder(nn.Module): """ Base encoder class """ # pylint: disable=abstract-method @property def output_size(self): """ Return the output size :return: """ return self._output_size
[docs] class RecurrentEncoder(Encoder): """Encodes a sequence of word embeddings""" # pylint: disable=unused-argument def __init__( self, rnn_type: str = "gru", hidden_size: int = 1, emb_size: int = 1, num_layers: int = 1, dropout: float = 0.0, emb_dropout: float = 0.0, bidirectional: bool = True, freeze: bool = False, **kwargs, ) -> None: """ Create a new recurrent encoder. :param rnn_type: RNN type: `gru` or `lstm`. :param hidden_size: Size of each RNN. :param emb_size: Size of the word embeddings. :param num_layers: Number of encoder RNN layers. :param dropout: Is applied between RNN layers. :param emb_dropout: Is applied to the RNN input (word embeddings). :param bidirectional: Use a bi-directional RNN. :param freeze: freeze the parameters of the encoder during training :param kwargs: """ super().__init__() self.emb_dropout = torch.nn.Dropout(p=emb_dropout, inplace=False) self.type = rnn_type self.emb_size = emb_size rnn = nn.GRU if rnn_type == "gru" else nn.LSTM self.rnn = rnn( emb_size, hidden_size, num_layers, batch_first=True, bidirectional=bidirectional, dropout=dropout if num_layers > 1 else 0.0, ) self._output_size = 2 * hidden_size if bidirectional else hidden_size if freeze: freeze_params(self) def _check_shapes_input_forward( self, src_embed: Tensor, src_length: Tensor, mask: Tensor ) -> None: """ Make sure the shape of the inputs to `self.forward` are correct. Same input semantics as `self.forward`. :param src_embed: embedded source tokens :param src_length: source length :param mask: source mask """ # pylint: disable=unused-argument assert src_embed.shape[0] == src_length.shape[0] assert src_embed.shape[2] == self.emb_size # assert mask.shape == src_embed.shape assert len(src_length.shape) == 1
[docs] def forward(self, src_embed: Tensor, src_length: Tensor, mask: Tensor, **kwargs) -> Tuple[Tensor, Tensor, Tensor]: """ Applies a bidirectional RNN to sequence of embeddings x. The input mini-batch x needs to be sorted by src length. x and mask should have the same dimensions [batch, time, dim]. :param src_embed: embedded src inputs, shape (batch_size, src_len, embed_size) :param src_length: length of src inputs (counting tokens before padding), shape (batch_size) :param mask: indicates padding areas (zeros where padding), shape (batch_size, src_len, embed_size) :param kwargs: :return: - output: hidden states with shape (batch_size, max_length, directions*hidden), - hidden_concat: last hidden state with shape (batch_size, directions*hidden) """ self._check_shapes_input_forward( src_embed=src_embed, src_length=src_length, mask=mask ) total_length = src_embed.size(1) # apply dropout to the rnn input src_embed = self.emb_dropout(src_embed) packed = pack_padded_sequence(src_embed, src_length.cpu(), batch_first=True) output, hidden = self.rnn(packed) if isinstance(hidden, tuple): hidden, memory_cell = hidden # pylint: disable=unused-variable output, _ = pad_packed_sequence( output, batch_first=True, total_length=total_length ) # hidden: dir*layers x batch x hidden # output: batch x max_length x directions*hidden batch_size = hidden.size()[1] # separate final hidden states by layer and direction hidden_layerwise = hidden.view( self.rnn.num_layers, 2 if self.rnn.bidirectional else 1, batch_size, self.rnn.hidden_size, ) # final_layers: layers x directions x batch x hidden # concatenate the final states of the last layer for each directions # thanks to pack_padded_sequence final states don't include padding fwd_hidden_last = hidden_layerwise[-1:, 0] bwd_hidden_last = hidden_layerwise[-1:, 1] # only feed the final state of the top-most layer to the decoder # pylint: disable=no-member hidden_concat = torch.cat([fwd_hidden_last, bwd_hidden_last], dim=2).squeeze(0) # final: batch x directions*hidden assert hidden_concat.size(0) == output.size(0), ( hidden_concat.size(), output.size(), ) return output, hidden_concat
def __repr__(self): return f"{self.__class__.__name__}(rnn={self.rnn})"
[docs] class TransformerEncoder(Encoder): """ Transformer Encoder """ def __init__( self, hidden_size: int = 512, ff_size: int = 2048, num_layers: int = 8, num_heads: int = 4, dropout: float = 0.1, emb_dropout: float = 0.1, freeze: bool = False, **kwargs, ): """ Initializes the Transformer. :param hidden_size: hidden size and size of embeddings :param ff_size: position-wise feed-forward layer size. (Typically this is 2*hidden_size.) :param num_layers: number of layers :param num_heads: number of heads for multi-headed attention :param dropout: dropout probability for Transformer layers :param emb_dropout: Is applied to the input (word embeddings). :param freeze: freeze the parameters of the encoder during training :param kwargs: """ super().__init__() self._output_size = hidden_size # build all (num_layers) layers self.layers = nn.ModuleList([ TransformerEncoderLayer( size=hidden_size, ff_size=ff_size, num_heads=num_heads, dropout=dropout, alpha=kwargs.get("alpha", 1.0), layer_norm=kwargs.get("layer_norm", "pre"), activation=kwargs.get("activation", "relu"), ) for _ in range(num_layers) ]) self.pe = PositionalEncoding(hidden_size) self.emb_dropout = nn.Dropout(p=emb_dropout) self.layer_norm = ( nn.LayerNorm(hidden_size, eps=1e-6) if kwargs.get("layer_norm", "post") == "pre" else None ) if freeze: freeze_params(self)
[docs] def forward( self, src_embed: Tensor, src_length: Tensor, # unused mask: Tensor = None, **kwargs, ) -> Tuple[Tensor, Tensor]: """ Pass the input (and mask) through each layer in turn. Applies a Transformer encoder to sequence of embeddings x. The input mini-batch x needs to be sorted by src length. x and mask should have the same dimensions [batch, time, dim]. :param src_embed: embedded src inputs, shape (batch_size, src_len, embed_size) :param src_length: length of src inputs (counting tokens before padding), shape (batch_size) :param mask: indicates padding areas (zeros where padding), shape (batch_size, 1, src_len) :param kwargs: :return: - output: hidden states with shape (batch_size, max_length, hidden) - None """ # pylint: disable=unused-argument x = self.pe(src_embed) # add position encoding to word embeddings if kwargs.get("src_prompt_mask", None) is not None: # add src_prompt_mask x = x + kwargs["src_prompt_mask"] x = self.emb_dropout(x) for layer in self.layers: x = layer(x, mask) if self.layer_norm is not None: x = self.layer_norm(x) return x, None
def __repr__(self): return ( f"{self.__class__.__name__}(num_layers={len(self.layers)}, " f"num_heads={self.layers[0].src_src_att.num_heads}, " f"alpha={self.layers[0].alpha}, " f'layer_norm="{self.layers[0]._layer_norm_position}", ' f"activation={self.layers[0].feed_forward.pwff_layer[1]})" )