# coding: utf-8
"""
Module to implement training loss
"""
import torch
from torch import nn, Tensor
from torch.autograd import Variable
[docs]class XentLoss(nn.Module):
"""
Cross-Entropy Loss with optional label smoothing
"""
def __init__(self, pad_index: int, smoothing: float = 0.0):
super().__init__()
self.smoothing = smoothing
self.pad_index = pad_index
if self.smoothing <= 0.0:
# standard xent loss
self.criterion = nn.NLLLoss(ignore_index=self.pad_index,
reduction='sum')
else:
# custom label-smoothed loss, computed with KL divergence loss
self.criterion = nn.KLDivLoss(reduction='sum')
def _smooth_targets(self, targets: Tensor, vocab_size: int):
"""
Smooth target distribution. All non-reference words get uniform
probability mass according to "smoothing".
:param targets: target indices, batch*seq_len
:param vocab_size: size of the output vocabulary
:return: smoothed target distributions, batch*seq_len x vocab_size
"""
# batch*seq_len x vocab_size
smooth_dist = targets.new_zeros((targets.size(0), vocab_size)).float()
# fill distribution uniformly with smoothing
smooth_dist.fill_(self.smoothing / (vocab_size - 2))
# assign true label the probability of 1-smoothing ("confidence")
smooth_dist.scatter_(1, targets.unsqueeze(1).data, 1.0-self.smoothing)
# give padding probability of 0 everywhere
smooth_dist[:, self.pad_index] = 0
# masking out padding area (sum of probabilities for padding area = 0)
padding_positions = torch.nonzero(targets.data == self.pad_index,
as_tuple=False)
# pylint: disable=len-as-condition
if len(padding_positions) > 0:
smooth_dist.index_fill_(0, padding_positions.squeeze(), 0.0)
return Variable(smooth_dist, requires_grad=False)
# pylint: disable=arguments-differ
[docs] def forward(self, log_probs, targets):
"""
Compute the cross-entropy between logits and targets.
If label smoothing is used, target distributions are not one-hot, but
"1-smoothing" for the correct target token and the rest of the
probability mass is uniformly spread across the other tokens.
:param log_probs: log probabilities as predicted by model
:param targets: target indices
:return:
"""
if self.smoothing > 0:
targets = self._smooth_targets(
targets=targets.contiguous().view(-1),
vocab_size=log_probs.size(-1))
# targets: distributions with batch*seq_len x vocab_size
assert log_probs.contiguous().view(-1, log_probs.size(-1)).shape \
== targets.shape
else:
# targets: indices with batch*seq_len
targets = targets.contiguous().view(-1)
loss = self.criterion(
log_probs.contiguous().view(-1, log_probs.size(-1)), targets)
return loss