Source code for joeynmt.attention

# coding: utf-8
"""
Attention modules
"""

import torch
from torch import Tensor
import torch.nn as nn
import torch.nn.functional as F


[docs]class AttentionMechanism(nn.Module): """ Base attention class """
[docs] def forward(self, *inputs): raise NotImplementedError("Implement this.")
[docs]class BahdanauAttention(AttentionMechanism): """ Implements Bahdanau (MLP) attention Section A.1.2 in https://arxiv.org/pdf/1409.0473.pdf. """ def __init__(self, hidden_size=1, key_size=1, query_size=1): """ Creates attention mechanism. :param hidden_size: size of the projection for query and key :param key_size: size of the attention input keys :param query_size: size of the query """ super().__init__() self.key_layer = nn.Linear(key_size, hidden_size, bias=False) self.query_layer = nn.Linear(query_size, hidden_size, bias=False) self.energy_layer = nn.Linear(hidden_size, 1, bias=False) self.proj_keys = None # to store projected keys self.proj_query = None # projected query #pylint: disable=arguments-differ
[docs] def forward(self, query: Tensor = None, mask: Tensor = None, values: Tensor = None): """ Bahdanau MLP attention forward pass. :param query: the item (decoder state) to compare with the keys/memory, shape (batch_size, 1, decoder.hidden_size) :param mask: mask out keys position (0 in invalid positions, 1 else), shape (batch_size, 1, src_length) :param values: values (encoder states), shape (batch_size, src_length, encoder.hidden_size) :return: context vector of shape (batch_size, 1, value_size), attention probabilities of shape (batch_size, 1, src_length) """ self._check_input_shapes_forward(query=query, mask=mask, values=values) assert mask is not None, "mask is required" assert self.proj_keys is not None,\ "projection keys have to get pre-computed" # We first project the query (the decoder state). # The projected keys (the encoder states) were already pre-computated. self.compute_proj_query(query) # Calculate scores. # proj_keys: batch x src_len x hidden_size # proj_query: batch x 1 x hidden_size scores = self.energy_layer(torch.tanh(self.proj_query + self.proj_keys)) # scores: batch x src_len x 1 scores = scores.squeeze(2).unsqueeze(1) # scores: batch x 1 x time # mask out invalid positions by filling the masked out parts with -inf scores = torch.where(mask, scores, scores.new_full([1], float('-inf'))) # turn scores to probabilities alphas = F.softmax(scores, dim=-1) # batch x 1 x time # the context vector is the weighted sum of the values context = alphas @ values # batch x 1 x value_size return context, alphas
[docs] def compute_proj_keys(self, keys: Tensor): """ Compute the projection of the keys. Is efficient if pre-computed before receiving individual queries. :param keys: :return: """ self.proj_keys = self.key_layer(keys)
[docs] def compute_proj_query(self, query: Tensor): """ Compute the projection of the query. :param query: :return: """ self.proj_query = self.query_layer(query)
def _check_input_shapes_forward(self, query: torch.Tensor, mask: torch.Tensor, values: torch.Tensor): """ Make sure that inputs to `self.forward` are of correct shape. Same input semantics as for `self.forward`. :param query: :param mask: :param values: :return: """ assert query.shape[0] == values.shape[0] == mask.shape[0] assert query.shape[1] == 1 == mask.shape[1] assert query.shape[2] == self.query_layer.in_features assert values.shape[2] == self.key_layer.in_features assert mask.shape[2] == values.shape[1] def __repr__(self): return "BahdanauAttention"
[docs]class LuongAttention(AttentionMechanism): """ Implements Luong (bilinear / multiplicative) attention. Eq. 8 ("general") in http://aclweb.org/anthology/D15-1166. """ def __init__(self, hidden_size: int = 1, key_size: int = 1): """ Creates attention mechanism. :param hidden_size: size of the key projection layer, has to be equal to decoder hidden size :param key_size: size of the attention input keys """ super().__init__() self.key_layer = nn.Linear(in_features=key_size, out_features=hidden_size, bias=False) self.proj_keys = None # projected keys # pylint: disable=arguments-differ
[docs] def forward(self, query: torch.Tensor = None, mask: torch.Tensor = None, values: torch.Tensor = None): """ Luong (multiplicative / bilinear) attention forward pass. Computes context vectors and attention scores for a given query and all masked values and returns them. :param query: the item (decoder state) to compare with the keys/memory, shape (batch_size, 1, decoder.hidden_size) :param mask: mask out keys position (0 in invalid positions, 1 else), shape (batch_size, 1, src_length) :param values: values (encoder states), shape (batch_size, src_length, encoder.hidden_size) :return: context vector of shape (batch_size, 1, value_size), attention probabilities of shape (batch_size, 1, src_length) """ self._check_input_shapes_forward(query=query, mask=mask, values=values) assert self.proj_keys is not None,\ "projection keys have to get pre-computed" assert mask is not None, "mask is required" # scores: batch_size x 1 x src_length scores = query @ self.proj_keys.transpose(1, 2) # mask out invalid positions by filling the masked out parts with -inf scores = torch.where(mask, scores, scores.new_full([1], float('-inf'))) # turn scores to probabilities alphas = F.softmax(scores, dim=-1) # batch x 1 x src_len # the context vector is the weighted sum of the values context = alphas @ values # batch x 1 x values_size return context, alphas
[docs] def compute_proj_keys(self, keys: Tensor): """ Compute the projection of the keys and assign them to `self.proj_keys`. This pre-computation is efficiently done for all keys before receiving individual queries. :param keys: shape (batch_size, src_length, encoder.hidden_size) """ # proj_keys: batch x src_len x hidden_size self.proj_keys = self.key_layer(keys)
def _check_input_shapes_forward(self, query: torch.Tensor, mask: torch.Tensor, values: torch.Tensor): """ Make sure that inputs to `self.forward` are of correct shape. Same input semantics as for `self.forward`. :param query: :param mask: :param values: :return: """ assert query.shape[0] == values.shape[0] == mask.shape[0] assert query.shape[1] == 1 == mask.shape[1] assert query.shape[2] == self.key_layer.out_features assert values.shape[2] == self.key_layer.in_features assert mask.shape[2] == values.shape[1] def __repr__(self): return "LuongAttention"