Source code for joeynmt.model

# coding: utf-8
"""
Module to represents whole models
"""
from pathlib import Path
from typing import Dict, Tuple

import numpy as np
import torch
import torch.nn.functional as F
from torch import Tensor, nn

from joeynmt.config import ConfigurationError
from joeynmt.decoders import Decoder, RecurrentDecoder, TransformerDecoder
from joeynmt.embeddings import Embeddings
from joeynmt.encoders import Encoder, RecurrentEncoder, TransformerEncoder
from joeynmt.helpers_for_ddp import get_logger
from joeynmt.initialization import initialize_model
from joeynmt.loss import XentLoss
from joeynmt.vocabulary import Vocabulary

logger = get_logger(__name__)


[docs] class Model(nn.Module): """ Base Model class """ # pylint: disable=too-many-instance-attributes def __init__( self, encoder: Encoder, decoder: Decoder, src_embed: Embeddings, trg_embed: Embeddings, src_vocab: Vocabulary, trg_vocab: Vocabulary, ) -> None: """ Create a new encoder-decoder model :param encoder: encoder :param decoder: decoder :param src_embed: source embedding :param trg_embed: target embedding :param src_vocab: source vocabulary :param trg_vocab: target vocabulary """ super().__init__() self.src_embed = src_embed self.trg_embed = trg_embed self.encoder = encoder self.decoder = decoder self.src_vocab = src_vocab self.trg_vocab = trg_vocab self.pad_index = self.trg_vocab.pad_index self.bos_index = self.trg_vocab.bos_index self.eos_index = self.trg_vocab.eos_index self.sep_index = self.trg_vocab.sep_index self.unk_index = self.trg_vocab.unk_index self.specials = [self.trg_vocab.lookup(t) for t in self.trg_vocab.specials] self.lang_tags = [self.trg_vocab.lookup(t) for t in self.trg_vocab.lang_tags] self._loss_function = None # set by `prepare()` func in prediction.py @property def loss_function(self): return self._loss_function @loss_function.setter def loss_function(self, cfg: Tuple): loss_type, label_smoothing = cfg assert loss_type == "crossentropy" self._loss_function = XentLoss( pad_index=self.pad_index, smoothing=label_smoothing )
[docs] def forward(self, return_type: str = None, **kwargs) -> Tuple[Tensor, Tensor, Tensor, Tensor]: """Interface for multi-gpu For DataParallel, We need to encapsulate all model call: `model.encode()`, `model.decode()`, and `model.encode_decode()` by `model.__call__()`. `model.__call__()` triggers model.forward() together with pre hooks and post hooks, which takes care of multi-gpu distribution. :param return_type: one of {"loss", "encode", "decode"} """ if return_type is None: raise ValueError( "Please specify return_type: " "{`loss`, `encode`, `decode`}." ) if return_type == "loss": assert self.loss_function is not None assert "trg" in kwargs and "trg_mask" in kwargs # need trg to compute loss out, _, att_probs, _ = self._encode_decode(**kwargs) # compute log probs log_probs = F.log_softmax(out, dim=-1) # compute batch loss # pylint: disable=not-callable batch_loss = self.loss_function(log_probs, **kwargs) # count correct tokens before decoding (for accuracy) trg_mask = kwargs["trg_mask"].squeeze(1) assert kwargs["trg"].size() == trg_mask.size() n_correct = torch.sum( log_probs.argmax(-1).masked_select(trg_mask).eq( kwargs["trg"].masked_select(trg_mask) ) ) # return batch loss # = sum over all elements in batch that are not pad return_tuple = (batch_loss, log_probs, att_probs, n_correct) elif return_type == "encode": kwargs["pad"] = True # TODO: only if multi-gpu encoder_output, encoder_hidden = self._encode(**kwargs) # return encoder outputs return_tuple = (encoder_output, encoder_hidden, None, None) elif return_type == "decode": outputs, hidden, att_probs, att_vectors = self._decode(**kwargs) # return decoder outputs return_tuple = (outputs, hidden, att_probs, att_vectors) return tuple(return_tuple)
def _encode_decode( self, src: Tensor, trg_input: Tensor, src_mask: Tensor, src_length: Tensor, trg_mask: Tensor = None, **kwargs, ) -> Tensor: """ First encodes the source sentence. Then produces the target one word at a time. :param src: source input :param trg_input: target input :param src_mask: source mask :param src_length: length of source inputs :param trg_mask: target mask :return: decoder outputs """ encoder_output, encoder_hidden = self._encode( src=src, src_length=src_length, src_mask=src_mask, **kwargs ) unroll_steps = trg_input.size(1) return self._decode( encoder_output=encoder_output, encoder_hidden=encoder_hidden, src_mask=src_mask, trg_input=trg_input, unroll_steps=unroll_steps, trg_mask=trg_mask, **kwargs, ) def _encode(self, src: Tensor, src_length: Tensor, src_mask: Tensor, **_kwargs) -> Tuple[Tensor, Tensor, Tensor]: """ Encodes the source sentence. :param src: :param src_length: :param src_mask: :return: - encoder_outputs - hidden_concat - src_mask """ # embed src prompts if given if ( _kwargs.get("src_prompt_mask", None) is not None and isinstance(self.encoder, TransformerEncoder) ): assert self.sep_index is not None and self.sep_index in self.specials, \ (f"Prompt marker {self.sep_index} not found." "This model doesn't support prompting!") assert src.size(1) == _kwargs["src_prompt_mask"].size(1) _kwargs["src_prompt_mask"] = self.src_embed(_kwargs["src_prompt_mask"]) return self.encoder(self.src_embed(src), src_length, src_mask, **_kwargs) def _decode( self, encoder_output: Tensor, encoder_hidden: Tensor, src_mask: Tensor, trg_input: Tensor, unroll_steps: int, decoder_hidden: Tensor = None, att_vector: Tensor = None, trg_mask: Tensor = None, **_kwargs, ) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]: """ Decode, given an encoded source sentence. :param encoder_output: encoder states for attention computation :param encoder_hidden: last encoder state for decoder initialization :param src_mask: source mask, 1 at valid tokens :param trg_input: target inputs :param unroll_steps: number of steps to unrol the decoder for :param decoder_hidden: decoder hidden state (optional) :param att_vector: previous attention vector (optional) :param trg_mask: mask for target steps :return: decoder outputs - decoder_output - decoder_hidden - att_prob - att_vector """ # embed trg prompts if given if ( _kwargs.get("trg_prompt_mask", None) is not None and isinstance(self.decoder, TransformerDecoder) ): assert self.sep_index is not None and self.sep_index in self.specials, \ (f"Prompt marker {self.sep_index} not found." "This model doesn't support prompting!") assert trg_input.size(1) == _kwargs["trg_prompt_mask"].size(1), ( trg_input.size(1), _kwargs["trg_prompt_mask"].size(1) ) _kwargs["trg_prompt_mask"] = self.trg_embed(_kwargs["trg_prompt_mask"]) return self.decoder( trg_embed=self.trg_embed(trg_input), encoder_output=encoder_output, encoder_hidden=encoder_hidden, src_mask=src_mask, unroll_steps=unroll_steps, hidden=decoder_hidden, prev_att_vector=att_vector, trg_mask=trg_mask, **_kwargs, ) def __repr__(self) -> str: """ String representation: a description of encoder, decoder and embeddings :return: string representation """ return ( f"{self.__class__.__name__}(\n" f"\tencoder={self.encoder},\n" f"\tdecoder={self.decoder},\n" f"\tsrc_embed={self.src_embed},\n" f"\ttrg_embed={self.trg_embed},\n" f"\tloss_function={self.loss_function})" )
[docs] def log_parameters_list(self) -> None: """ Write all model parameters (name, shape) to the log. """ model_parameters = filter(lambda p: p.requires_grad, self.parameters()) n_params = sum([np.prod(p.size()) for p in model_parameters]) logger.info("Total params: %d", n_params) trainable_params = [n for (n, p) in self.named_parameters() if p.requires_grad] logger.debug("Trainable parameters: %s", sorted(trainable_params)) assert trainable_params
[docs] class DataParallelWrapper(nn.Module): """ DataParallel wrapper to pass through the model attributes ex. 1) for DataParallel >>> from torch.nn import DataParallel as DP >>> model = DataParallelWrapper(DP(model)) ex. 2) for DistributedDataParallel >>> from torch.nn.parallel import DistributedDataParallel as DDP >>> model = DataParallelWrapper(DDP(model)) """ def __init__(self, module: nn.Module): super().__init__() assert hasattr(module, "module") self.module = module def __getattr__(self, name): """Forward missing attributes to twice-wrapped module.""" try: # defer to nn.Module's logic return super().__getattr__(name) except AttributeError: try: # forward to the once-wrapped module return getattr(self.module, name) except AttributeError: # forward to the twice-wrapped module return getattr(self.module.module, name)
[docs] def state_dict(self, *args, **kwargs): """saving the twice-wrapped module.""" return self.module.module.state_dict(*args, **kwargs)
[docs] def load_state_dict(self, *args, **kwargs): """loading the twice-wrapped module.""" self.module.module.load_state_dict(*args, **kwargs)
[docs] def forward(self, *args, **kwargs): return self.module(*args, **kwargs)
[docs] def build_model( cfg: Dict = None, src_vocab: Vocabulary = None, trg_vocab: Vocabulary = None ) -> Model: """ Build and initialize the model according to the configuration. :param cfg: dictionary configuration containing model specifications :param src_vocab: source vocabulary :param trg_vocab: target vocabulary :return: built and initialized model """ logger.info("Building an encoder-decoder model...") enc_cfg = cfg["encoder"] dec_cfg = cfg["decoder"] src_pad_index = src_vocab.pad_index trg_pad_index = trg_vocab.pad_index src_embed = Embeddings( **enc_cfg["embeddings"], vocab_size=len(src_vocab), padding_idx=src_pad_index, ) # this ties source and target embeddings for softmax layer tying, see further below if cfg.get("tied_embeddings", False): if src_vocab == trg_vocab: trg_embed = src_embed # share embeddings for src and trg else: raise ConfigurationError( "Embedding cannot be tied since vocabularies differ." ) else: trg_embed = Embeddings( **dec_cfg["embeddings"], vocab_size=len(trg_vocab), padding_idx=trg_pad_index, ) # build encoder enc_dropout = enc_cfg.get("dropout", 0.0) enc_emb_dropout = enc_cfg["embeddings"].get("dropout", enc_dropout) if enc_cfg.get("type", "recurrent") == "transformer": assert enc_cfg["embeddings"]["embedding_dim"] == enc_cfg["hidden_size"], ( "for transformer, emb_size must be " "the same as hidden_size" ) emb_size = src_embed.embedding_dim encoder = TransformerEncoder( **enc_cfg, emb_size=emb_size, emb_dropout=enc_emb_dropout, pad_index=src_pad_index, ) else: encoder = RecurrentEncoder( **enc_cfg, emb_size=src_embed.embedding_dim, emb_dropout=enc_emb_dropout, ) # build decoder dec_dropout = dec_cfg.get("dropout", 0.0) dec_emb_dropout = dec_cfg["embeddings"].get("dropout", dec_dropout) if dec_cfg.get("type", "transformer") == "transformer": decoder = TransformerDecoder( **dec_cfg, encoder=encoder, vocab_size=len(trg_vocab), emb_size=trg_embed.embedding_dim, emb_dropout=dec_emb_dropout, ) else: decoder = RecurrentDecoder( **dec_cfg, encoder=encoder, vocab_size=len(trg_vocab), emb_size=trg_embed.embedding_dim, emb_dropout=dec_emb_dropout, ) model = Model( encoder=encoder, decoder=decoder, src_embed=src_embed, trg_embed=trg_embed, src_vocab=src_vocab, trg_vocab=trg_vocab, ) # tie softmax layer with trg embeddings if cfg.get("tied_softmax", False): if trg_embed.lut.weight.shape == model.decoder.output_layer.weight.shape: # (also) share trg embeddings and softmax layer: model.decoder.output_layer.weight = trg_embed.lut.weight else: raise ConfigurationError( "For tied_softmax, the decoder embedding_dim and decoder hidden_size " "must be the same. The decoder must be a Transformer." ) # custom initialization of model parameters initialize_model(model, cfg, src_pad_index, trg_pad_index) # initialize embeddings from file enc_embed_path = enc_cfg["embeddings"].get("load_pretrained", None) dec_embed_path = dec_cfg["embeddings"].get("load_pretrained", None) if enc_embed_path: logger.info("Loading pretrained src embeddings...") model.src_embed.load_from_file(Path(enc_embed_path), src_vocab) if dec_embed_path and not cfg.get("tied_embeddings", False): logger.info("Loading pretrained trg embeddings...") model.trg_embed.load_from_file(Path(dec_embed_path), trg_vocab) logger.info("Enc-dec model built.") return model