Source code for joeynmt.search

# coding: utf-8
import torch
import torch.nn.functional as F
from torch import Tensor
import numpy as np

from joeynmt.decoders import TransformerDecoder
from joeynmt.model import Model
from joeynmt.batch import Batch
from joeynmt.helpers import tile

__all__ = ["greedy", "transformer_greedy", "beam_search", "run_batch"]


[docs]def greedy(src_mask: Tensor, max_output_length: int, model: Model, encoder_output: Tensor, encoder_hidden: Tensor) \ -> (np.array, np.array): """ Greedy decoding. Select the token word highest probability at each time step. This function is a wrapper that calls recurrent_greedy for recurrent decoders and transformer_greedy for transformer decoders. :param src_mask: mask for source inputs, 0 for positions after </s> :param max_output_length: maximum length for the hypotheses :param model: model to use for greedy decoding :param encoder_output: encoder hidden states for attention :param encoder_hidden: encoder last state for decoder initialization :return: """ if isinstance(model.decoder, TransformerDecoder): # Transformer greedy decoding greedy_fun = transformer_greedy else: # Recurrent greedy decoding greedy_fun = recurrent_greedy return greedy_fun( src_mask, max_output_length, model, encoder_output, encoder_hidden)
def recurrent_greedy( src_mask: Tensor, max_output_length: int, model: Model, encoder_output: Tensor, encoder_hidden: Tensor) -> (np.array, np.array): """ Greedy decoding: in each step, choose the word that gets highest score. Version for recurrent decoder. :param src_mask: mask for source inputs, 0 for positions after </s> :param max_output_length: maximum length for the hypotheses :param model: model to use for greedy decoding :param encoder_output: encoder hidden states for attention :param encoder_hidden: encoder last state for decoder initialization :return: - stacked_output: output hypotheses (2d array of indices), - stacked_attention_scores: attention scores (3d array) """ bos_index = model.bos_index eos_index = model.eos_index batch_size = src_mask.size(0) prev_y = src_mask.new_full(size=[batch_size, 1], fill_value=bos_index, dtype=torch.long) output = [] attention_scores = [] hidden = None prev_att_vector = None finished = src_mask.new_zeros((batch_size, 1)).byte() # pylint: disable=unused-variable for t in range(max_output_length): # decode one single step with torch.no_grad(): logits, hidden, att_probs, prev_att_vector = model( return_type="decode", trg_input=prev_y, encoder_output=encoder_output, encoder_hidden=encoder_hidden, src_mask=src_mask, unroll_steps=1, decoder_hidden=hidden, att_vector=prev_att_vector) # logits: batch x time=1 x vocab (logits) # greedy decoding: choose arg max over vocabulary in each step next_word = torch.argmax(logits, dim=-1) # batch x time=1 output.append(next_word.squeeze(1).detach().cpu().numpy()) prev_y = next_word attention_scores.append(att_probs.squeeze(1).detach().cpu().numpy()) # batch, max_src_length # check if previous symbol was <eos> is_eos = torch.eq(next_word, eos_index) finished += is_eos # stop predicting if <eos> reached for all elements in batch if (finished >= 1).sum() == batch_size: break stacked_output = np.stack(output, axis=1) # batch, time stacked_attention_scores = np.stack(attention_scores, axis=1) return stacked_output, stacked_attention_scores # pylint: disable=unused-argument
[docs]def transformer_greedy( src_mask: Tensor, max_output_length: int, model: Model, encoder_output: Tensor, encoder_hidden: Tensor) -> (np.array, None): """ Special greedy function for transformer, since it works differently. The transformer remembers all previous states and attends to them. :param src_mask: mask for source inputs, 0 for positions after </s> :param max_output_length: maximum length for the hypotheses :param model: model to use for greedy decoding :param encoder_output: encoder hidden states for attention :param encoder_hidden: encoder final state (unused in Transformer) :return: - stacked_output: output hypotheses (2d array of indices), - stacked_attention_scores: attention scores (3d array) """ bos_index = model.bos_index eos_index = model.eos_index batch_size = src_mask.size(0) # start with BOS-symbol for each sentence in the batch ys = encoder_output.new_full([batch_size, 1], bos_index, dtype=torch.long) # a subsequent mask is intersected with this in decoder forward pass trg_mask = src_mask.new_ones([1, 1, 1]) if isinstance(model, torch.nn.DataParallel): trg_mask = torch.stack( [src_mask.new_ones([1, 1]) for _ in model.device_ids]) finished = src_mask.new_zeros(batch_size).byte() for _ in range(max_output_length): # pylint: disable=unused-variable with torch.no_grad(): logits, _, _, _ = model( return_type="decode", trg_input=ys, # model.trg_embed(ys) # embed the previous tokens encoder_output=encoder_output, encoder_hidden=None, src_mask=src_mask, unroll_steps=None, decoder_hidden=None, trg_mask=trg_mask ) logits = logits[:, -1] _, next_word = torch.max(logits, dim=1) next_word = next_word.data ys = torch.cat([ys, next_word.unsqueeze(-1)], dim=1) # check if previous symbol was <eos> is_eos = torch.eq(next_word, eos_index) finished += is_eos # stop predicting if <eos> reached for all elements in batch if (finished >= 1).sum() == batch_size: break ys = ys[:, 1:] # remove BOS-symbol return ys.detach().cpu().numpy(), None
# pylint: disable=too-many-statements,too-many-branches
[docs]def run_batch(model: Model, batch: Batch, max_output_length: int, beam_size: int, beam_alpha: float, n_best: int = 1) -> (np.array, np.array): """ Get outputs and attentions scores for a given batch :param model: Model class :param batch: batch to generate hypotheses for :param max_output_length: maximum length of hypotheses :param beam_size: size of the beam for beam search, if 0 use greedy :param beam_alpha: alpha value for beam search :param n_best: candidates to return :return: stacked_output: hypotheses for batch, stacked_attention_scores: attention scores for batch """ with torch.no_grad(): encoder_output, encoder_hidden, _, _ = model( return_type="encode", **vars(batch)) # if maximum output length is not globally specified, adapt to src len if max_output_length is None: max_output_length = int(max(batch.src_length.cpu().numpy()) * 1.5) # greedy decoding if beam_size < 2: stacked_output, stacked_attention_scores = greedy( src_mask=batch.src_mask, max_output_length=max_output_length, model=model, encoder_output=encoder_output, encoder_hidden=encoder_hidden) # batch, time, max_src_length else: # beam search stacked_output, stacked_attention_scores = beam_search( model=model, size=beam_size, encoder_output=encoder_output, encoder_hidden=encoder_hidden, src_mask=batch.src_mask, max_output_length=max_output_length, alpha=beam_alpha, n_best=n_best) return stacked_output, stacked_attention_scores