Source code for joeynmt.batch

# coding: utf-8

"""
Implementation of a mini-batch.
"""
import torch


[docs]class Batch: """Object for holding a batch of data with mask during training. Input is a batch from a torch text iterator. """ # pylint: disable=too-many-instance-attributes def __init__(self, torch_batch, pad_index, use_cuda=False): """ Create a new joey batch from a torch batch. This batch extends torch text's batch attributes with src and trg length, masks, number of non-padded tokens in trg. Furthermore, it can be sorted by src length. :param torch_batch: :param pad_index: :param use_cuda: """ self.src, self.src_length = torch_batch.src self.src_mask = (self.src != pad_index).unsqueeze(1) self.nseqs = self.src.size(0) self.trg_input = None self.trg = None self.trg_mask = None self.trg_length = None self.ntokens = None self.use_cuda = use_cuda self.device = torch.device("cuda" if self.use_cuda else "cpu") if hasattr(torch_batch, "trg"): trg, trg_length = torch_batch.trg # trg_input is used for teacher forcing, last one is cut off self.trg_input = trg[:, :-1] self.trg_length = trg_length # trg is used for loss computation, shifted by one since BOS self.trg = trg[:, 1:] # we exclude the padded areas from the loss computation self.trg_mask = (self.trg_input != pad_index).unsqueeze(1) self.ntokens = (self.trg != pad_index).data.sum().item() if self.use_cuda: self._make_cuda() def _make_cuda(self): """ Move the batch to GPU :return: """ self.src = self.src.to(self.device) self.src_mask = self.src_mask.to(self.device) self.src_length = self.src_length.to(self.device) if self.trg_input is not None: self.trg_input = self.trg_input.to(self.device) self.trg = self.trg.to(self.device) self.trg_mask = self.trg_mask.to(self.device)
[docs] def sort_by_src_length(self): """ Sort by src length (descending) and return index to revert sort :return: """ _, perm_index = self.src_length.sort(0, descending=True) rev_index = [0]*perm_index.size(0) for new_pos, old_pos in enumerate(perm_index.cpu().numpy()): rev_index[old_pos] = new_pos sorted_src_length = self.src_length[perm_index] sorted_src = self.src[perm_index] sorted_src_mask = self.src_mask[perm_index] if self.trg_input is not None: sorted_trg_input = self.trg_input[perm_index] sorted_trg_length = self.trg_length[perm_index] sorted_trg_mask = self.trg_mask[perm_index] sorted_trg = self.trg[perm_index] self.src = sorted_src self.src_length = sorted_src_length self.src_mask = sorted_src_mask if self.trg_input is not None: self.trg_input = sorted_trg_input self.trg_mask = sorted_trg_mask self.trg_length = sorted_trg_length self.trg = sorted_trg if self.use_cuda: self._make_cuda() return rev_index