Source code for joeynmt.search

# coding: utf-8
"""
Search module
"""
from typing import List, Tuple

import numpy as np
import torch
import torch.nn.functional as F
from torch import Tensor

from joeynmt.batch import Batch
from joeynmt.decoders import RecurrentDecoder, TransformerDecoder
from joeynmt.helpers import adjust_mask_size, tile
from joeynmt.helpers_for_ddp import ddp_merge
from joeynmt.model import DataParallelWrapper, Model

__all__ = ["greedy", "beam_search", "search"]


[docs] def greedy( src_mask: Tensor, max_output_length: int, model: Model, encoder_output: Tensor, encoder_hidden: Tensor, **kwargs, ) -> Tuple[Tensor, Tensor, Tensor]: """ Greedy decoding. Select the token word highest probability at each time step. This function is a wrapper that calls recurrent_greedy for recurrent decoders and transformer_greedy for transformer decoders. :param src_mask: mask for source inputs, 0 for positions after </s> :param max_output_length: maximum length for the hypotheses :param model: model to use for greedy decoding :param encoder_output: encoder hidden states for attention :param encoder_hidden: encoder last state for decoder initialization :return: - stacked_output: output hypotheses (2d array of indices), - stacked_scores: scores (2d array of token-wise log probabilities), - stacked_attention_scores: attention scores (3d array) """ # pylint: disable=no-else-return if isinstance(model.decoder, TransformerDecoder): return transformer_greedy( src_mask, max_output_length, model, encoder_output, encoder_hidden, **kwargs, ) elif isinstance(model.decoder, RecurrentDecoder): return recurrent_greedy( src_mask, max_output_length, model, encoder_output, encoder_hidden, **kwargs ) else: raise NotImplementedError( f"model.decoder({model.decoder.__class__.__name__}) not supported." )
def recurrent_greedy( src_mask: Tensor, max_output_length: int, model: Model, encoder_output: Tensor, encoder_hidden: Tensor, **kwargs, ) -> Tuple[Tensor, Tensor, Tensor]: """ Greedy decoding: in each step, choose the word that gets highest score. Version for recurrent decoder. :param src_mask: mask for source inputs, 0 for positions after </s> :param max_output_length: maximum length for the hypotheses :param model: model to use for greedy decoding :param encoder_output: encoder hidden states for attention :param encoder_hidden: encoder last state for decoder initialization :return: - stacked_output: output hypotheses (2d array of indices), - stacked_scores: scores (2d array of token-wise log probabilities), - stacked_attention_scores: attention scores (3d array) """ bos_index = model.bos_index eos_index = model.eos_index pad_index = model.pad_index sep_index = model.sep_index unk_index = model.unk_index batch_size = src_mask.size(0) min_output_length: int = kwargs.get("min_output_length", 1) generate_unk: bool = kwargs.get("generate_unk", True) # whether to generate UNK return_prob: bool = kwargs.get("return_prob", "none") == "hyp" prev_y = src_mask.new_full((batch_size, 1), fill_value=bos_index, dtype=torch.long) output = [] scores = [] attention_scores = [] hidden = None prev_att_vector = None finished = src_mask.new_zeros((batch_size, 1)).byte() device = encoder_output.device autocast = kwargs.get("autocast", {"device_type": device.type, "enabled": False}) for step in range(max_output_length): # decode one single step with torch.autocast(**autocast): with torch.no_grad(): out, hidden, att_probs, prev_att_vector = model( return_type="decode", trg_input=prev_y, encoder_output=encoder_output, encoder_hidden=encoder_hidden, src_mask=src_mask, unroll_steps=1, decoder_hidden=hidden, att_vector=prev_att_vector, ) # out: batch x time=1 x vocab (logits) if return_prob: out = F.log_softmax(out, dim=-1) # don't generate BOS, SEP, language tags for forbidden_index in [bos_index, pad_index, sep_index] + model.lang_tags: if forbidden_index is not None and forbidden_index < out.size(2): out[:, :, forbidden_index] = float("-inf") if not generate_unk: out[:, :, unk_index] = float("-inf") # don't generate EOS until we reached min_output_length if step < min_output_length: out[:, :, eos_index] = float("-inf") # greedy decoding: choose arg max over vocabulary in each step prob, next_word = torch.max(out, dim=-1) # batch x time=1 output.append(next_word.squeeze(1).detach().cpu()) if return_prob: scores.append(prob.squeeze(1).detach().cpu()) prev_y = next_word attention_scores.append(att_probs.squeeze(1).detach().cpu()) # shape: (batch_size, max_src_length) # check if previous symbol was <eos> is_eos = torch.eq(next_word, eos_index) finished += is_eos # stop predicting if <eos> reached for all elements in batch if (finished >= 1).sum() == batch_size: break stacked_output = torch.stack(output, dim=1).long() # batch, time stacked_scores = torch.stack(scores, dim=1).float() if return_prob else None stacked_attention_scores = torch.stack(attention_scores, dim=1).float() return stacked_output, stacked_scores, stacked_attention_scores def transformer_greedy( src_mask: Tensor, max_output_length: int, model: Model, encoder_output: Tensor, encoder_hidden: Tensor, **kwargs, ) -> Tuple[Tensor, Tensor, Tensor]: """ Special greedy function for transformer, since it works differently. The transformer remembers all previous states and attends to them. :param src_mask: mask for source inputs, 0 for positions after </s> :param max_output_length: maximum length for the hypotheses :param model: model to use for greedy decoding :param encoder_output: encoder hidden states for attention :param encoder_hidden: encoder final state (unused in Transformer) :return: - stacked_output: output hypotheses (2d array of indices), - stacked_scores: scores (2d array of token-wise log probabilities), - stacked_attention_scores: attention scores (3d array) """ # pylint: disable=unused-argument,too-many-statements bos_index = model.bos_index eos_index = model.eos_index sep_index = model.sep_index unk_index = model.unk_index pad_index = model.pad_index batch_size, _, src_len = src_mask.size() device = encoder_output.device autocast = kwargs.get("autocast", {"device_type": device.type, "enabled": False}) # options to control generation generate_unk: bool = kwargs.get("generate_unk", True) # whether to generate UNK return_attn: bool = kwargs.get("return_attention", False) return_prob: bool = kwargs.get("return_prob", "none") == "hyp" min_output_length: int = kwargs.get("min_output_length", 1) repetition_penalty: float = kwargs.get("repetition_penalty", -1) no_repeat_ngram_size: int = kwargs.get("no_repeat_ngram_size", -1) encoder_input: Tensor = kwargs.get("encoder_input", None) # for repetition blocker decoder_prompt: Tensor = kwargs.get("decoder_prompt", None) # for forced decoding trg_prompt_mask: Tensor = kwargs.get("trg_prompt_mask", None) # for forced decoding compute_softmax: bool = ( return_prob or repetition_penalty > 0 or no_repeat_ngram_size > 0 or encoder_input is not None ) # start with BOS-symbol for each sentence in the batch ys = encoder_output.new_full((batch_size, 1), bos_index, dtype=torch.long) # placeholder for scores yv = ys.new_zeros((batch_size, 1), dtype=torch.float) if return_prob else None # placeholder for attentions yt = ys.new_zeros((batch_size, 1, src_len), dtype=torch.float) \ if return_attn else None # a subsequent mask is intersected with this in decoder forward pass trg_mask = src_mask.new_ones([1, 1, 1]) if isinstance(model, DataParallelWrapper): trg_mask = torch.stack([src_mask.new_ones([1, 1]) for _ in model.device_ids]) finished = src_mask.new_zeros((batch_size, 1)).byte() for step in range(max_output_length): # `forced_word` shape: (batch_size, 1) forced_word = decoder_prompt[:, step + 1].unsqueeze(1) \ if decoder_prompt is not None and decoder_prompt.size(1) > step + 1 \ else ys.new_full((batch_size, 1), pad_index, dtype=torch.long) forced_word_mask = trg_prompt_mask[:, step + 1].unsqueeze(1) \ if trg_prompt_mask is not None and trg_prompt_mask.size(1) > step + 1 \ else ys.new_zeros((batch_size, 1), dtype=torch.long) forced_prob = yv.new_zeros((batch_size, 1)) if return_prob else None forced_att = yt.new_zeros((batch_size, 1, src_len)) if return_attn else None if torch.any(~forced_word_mask).item(): with torch.autocast(**autocast): with torch.no_grad(): log_probs, _, att, _ = model( return_type="decode", trg_input=ys, # model.trg_embed(ys) # embed the previous tokens encoder_output=encoder_output, encoder_hidden=None, src_mask=src_mask, unroll_steps=None, decoder_hidden=None, trg_mask=trg_mask, return_attention=return_attn, trg_prompt_mask=adjust_mask_size( trg_prompt_mask, batch_size, ys.size(1) ), ) log_probs = log_probs[:, -1] # logits if compute_softmax: log_probs = F.log_softmax(log_probs, dim=-1) # ngram blocker if no_repeat_ngram_size > 1: log_probs = block_repeat_ngrams( ys, log_probs, no_repeat_ngram_size, step, src_tokens=encoder_input, exclude_tokens=model.specials + model.lang_tags, ) # repetition penalty if repetition_penalty > 1.0: log_probs = penalize_repetition( ys, log_probs, repetition_penalty, exclude_tokens=model.specials + model.lang_tags, ) if encoder_input is not None: log_probs = penalize_repetition( encoder_input, log_probs, repetition_penalty, exclude_tokens=model.specials + model.lang_tags, ) # don't generate BOS, SEP, language tags for forbidden_index in [bos_index, sep_index] + model.lang_tags: if forbidden_index is not None and forbidden_index < log_probs.size(1): log_probs[:, forbidden_index] = float("-inf") if not generate_unk: log_probs[:, unk_index] = float("-inf") # don't generate EOS until we reached min_output_length if step < min_output_length: log_probs[:, eos_index] = float("-inf") # take the most likely token prob, next_word = torch.max(log_probs, dim=1) next_word = next_word.data.unsqueeze(-1) next_word = torch.where(forced_word_mask.bool(), forced_word, next_word) if return_prob: prob = prob.data.unsqueeze(-1) prob = torch.where(forced_word_mask.bool(), forced_prob, prob) if return_attn: assert att is not None att = att.data[:, -1, :].unsqueeze(1) # take last trg token only att = torch.where( forced_word_mask.expand(-1, src_len).unsqueeze(1).bool(), forced_att, att ) # `att` shape: (batch_size, 1, src_len) else: next_word = forced_word prob = forced_prob if return_prob else None att = forced_att if return_attn else None ys = torch.cat([ys, next_word], dim=1) yv = torch.cat([yv, prob], dim=1) if return_prob else None yt = torch.cat([yt, att], dim=1) if return_attn else None # `yt` shape: (batch_size, trg_len, src_len) # check if we reached EOS is_eos = torch.eq(next_word, eos_index) finished += is_eos # stop predicting if we reached EOS for all elements in batch if (finished >= 1).sum() == batch_size: break # gather # Note: ys.size(0) will become batch_size * world_size in DDP! ys = ddp_merge(ys, pad_index) yv = ddp_merge(yv, 0.0) if return_prob else None yt = ddp_merge(yt, 0.0) if return_attn else None # remove BOS-symbol output = ys[:, 1:].detach().cpu().long() scores = yv[:, 1:].detach().cpu().float() if return_prob else None attention = yt[:, 1:, :].detach().cpu().float() if return_attn else None return output, scores, attention def block_repeat_ngrams( tokens: Tensor, scores: Tensor, no_repeat_ngram_size: int, step: int, **kwargs ) -> Tensor: """ For each hypothesis, check a list of previous ngrams and set associated log probs to -inf. Taken from fairseq's NGramRepeatBlock. :param tokens: target tokens generated so far :param scores: log probabilities of the next token to generate in this time step :param no_repeat_ngram_size: ngram size to prohibit :param step: generation step (= length of hypotheses so far) """ hyp_size = tokens.size(0) banned_batch_tokens = [set([]) for _ in range(hyp_size)] trg_tokens = tokens.cpu().tolist() check_end_pos = step + 2 - no_repeat_ngram_size offset = no_repeat_ngram_size - 1 src_tokens = kwargs.get("src_tokens", None) if src_tokens is not None: src_length = src_tokens.size(-1) assert src_tokens.size(0) == hyp_size, (src_tokens.size(), hyp_size) src_tokens = src_tokens.cpu().tolist() exclude_tokens = kwargs.get("exclude_tokens", []) # get repeated ngrams for hyp_idx in range(hyp_size): if len(trg_tokens[hyp_idx]) > no_repeat_ngram_size: # (n-1) token prefix at the time step # 0 1 2 3 4 <- time step # if tokens[hyp_idx] = [2, 5, 5, 6, 5] at step 4 with ngram_size = 3, # ^ ^ ^ # then ngram_to_check = [6, 5], and set the token in the next position to # -inf, if there are ngrams start with [6, 5]. ngram_to_check = trg_tokens[hyp_idx][-offset:] for i in range(1, check_end_pos): # ignore BOS if ngram_to_check == trg_tokens[hyp_idx][i:i + offset]: banned_batch_tokens[hyp_idx].add(trg_tokens[hyp_idx][i + offset]) # src_tokens if src_tokens is not None: check_end_pos_src = src_length + 1 - no_repeat_ngram_size for i in range(check_end_pos_src): # no BOS in src if ngram_to_check == src_tokens[hyp_idx][i:i + offset]: banned_batch_tokens[hyp_idx].add( src_tokens[hyp_idx][i + offset] ) # set the score of the banned tokens to -inf for i, banned_tokens in enumerate(banned_batch_tokens): banned_tokens = set(banned_tokens) - set(exclude_tokens) scores[i, list(banned_tokens)] = float("-inf") return scores def penalize_repetition( tokens: Tensor, scores: Tensor, penalty: float, exclude_tokens: List[int] = None ) -> Tensor: """ Reduce probability of the given tokens. Taken from Huggingface's RepetitionPenaltyLogitsProcessor. :param tokens: token ids to penalize :param scores: log probabilities of the next token to generate :param penalty: penalty value, bigger value implies less probability :param exclude_tokens: list of token ids to exclude from penalizing """ scores_before = scores if exclude_tokens else None score = torch.gather(scores, 1, tokens) # if score < 0 then repetition penalty has to be multiplied # to reduce the previous token probability score = torch.where(score < 0, score * penalty, score / penalty) scores.scatter_(1, tokens, score) # exclude special tokens if exclude_tokens: for token in exclude_tokens: # pylint: disable=unsubscriptable-object scores[:, token] = scores_before[:, token] return scores