Source code for joeynmt.attention

# coding: utf-8
"""
Attention modules
"""
from typing import Tuple

import numpy as np
import torch
import torch.nn.functional as F
from torch import Tensor, nn


[docs] class AttentionMechanism(nn.Module): """ Base attention class """
[docs] def forward(self, *inputs): raise NotImplementedError("Implement this.")
[docs] class BahdanauAttention(AttentionMechanism): """ Implements Bahdanau (MLP) attention Section A.1.2 in https://arxiv.org/abs/1409.0473. """ def __init__(self, hidden_size: int = 1, key_size: int = 1, query_size: int = 1): """ Creates attention mechanism. :param hidden_size: size of the projection for query and key :param key_size: size of the attention input keys :param query_size: size of the query """ super().__init__() self.key_layer = nn.Linear(key_size, hidden_size, bias=False) self.query_layer = nn.Linear(query_size, hidden_size, bias=False) self.energy_layer = nn.Linear(hidden_size, 1, bias=False) self.proj_keys = None # to store projected keys self.proj_query = None # projected query
[docs] def forward(self, query: Tensor, mask: Tensor, values: Tensor) -> Tuple[Tensor, Tensor]: """ Bahdanau MLP attention forward pass. :param query: the item (decoder state) to compare with the keys/memory, shape (batch_size, 1, decoder.hidden_size) :param mask: mask out keys position (0 in invalid positions, 1 else), shape (batch_size, 1, src_length) :param values: values (encoder states), shape (batch_size, src_length, encoder.hidden_size) :return: - context vector of shape (batch_size, 1, value_size), - attention probabilities of shape (batch_size, 1, src_length) """ # pylint: disable=arguments-differ self._check_input_shapes_forward(query=query, mask=mask, values=values) assert mask is not None, "mask is required" assert self.proj_keys is not None, "projection keys have to get pre-computed" # We first project the query (the decoder state). # The projected keys (the encoder states) were already pre-computed. self.compute_proj_query(query) # Calculate scores. # proj_keys: batch x src_len x hidden_size # proj_query: batch x 1 x hidden_size scores = self.energy_layer(torch.tanh(self.proj_query + self.proj_keys)) # scores: batch x src_len x 1 scores = scores.squeeze(2).unsqueeze(1) # scores: batch x 1 x time # mask out invalid positions by filling the masked out parts with -inf scores = torch.where(mask > 0, scores, scores.new_full([1], -np.inf)) # turn scores to probabilities alphas = F.softmax(scores, dim=-1) # batch x 1 x time # the context vector is the weighted sum of the values context = alphas @ values # batch x 1 x value_size return context, alphas
[docs] def compute_proj_keys(self, keys: Tensor) -> None: """ Compute the projection of the keys. Is efficient if pre-computed before receiving individual queries. :param keys: :return: """ self.proj_keys = self.key_layer(keys)
[docs] def compute_proj_query(self, query: Tensor): """ Compute the projection of the query. :param query: :return: """ self.proj_query = self.query_layer(query)
def _check_input_shapes_forward( self, query: Tensor, mask: Tensor, values: Tensor ) -> None: """ Make sure that inputs to `self.forward` are of correct shape. Same input semantics as for `self.forward`. :param query: :param mask: :param values: :return: """ assert query.shape[0] == values.shape[0] == mask.shape[0] assert query.shape[1] == 1 == mask.shape[1] assert query.shape[2] == self.query_layer.in_features assert values.shape[2] == self.key_layer.in_features assert mask.shape[2] == values.shape[1] def __repr__(self): return "BahdanauAttention"
[docs] class LuongAttention(AttentionMechanism): """ Implements Luong (bilinear / multiplicative) attention. Eq. 8 ("general") in http://aclweb.org/anthology/D15-1166. """ def __init__(self, hidden_size: int = 1, key_size: int = 1): """ Creates attention mechanism. :param hidden_size: size of the key projection layer, has to be equal to decoder hidden size :param key_size: size of the attention input keys """ super().__init__() self.key_layer = nn.Linear( in_features=key_size, out_features=hidden_size, bias=False ) self.proj_keys = None # projected keys
[docs] def forward(self, query: Tensor, mask: Tensor, values: Tensor) -> Tuple[Tensor, Tensor]: """ Luong (multiplicative / bilinear) attention forward pass. Computes context vectors and attention scores for a given query and all masked values and returns them. :param query: the item (decoder state) to compare with the keys/memory, shape (batch_size, 1, decoder.hidden_size) :param mask: mask out keys position (0 in invalid positions, 1 else), shape (batch_size, 1, src_length) :param values: values (encoder states), shape (batch_size, src_length, encoder.hidden_size) :return: - context vector of shape (batch_size, 1, value_size), - attention probabilities of shape (batch_size, 1, src_length) """ # pylint: disable=arguments-differ self._check_input_shapes_forward(query=query, mask=mask, values=values) assert self.proj_keys is not None, "projection keys have to get pre-computed" assert mask is not None, "mask is required" # scores: batch_size x 1 x src_length scores = query @ self.proj_keys.transpose(1, 2) # mask out invalid positions by filling the masked out parts with -inf scores = torch.where(mask > 0, scores, scores.new_full([1], -np.inf)) # turn scores to probabilities alphas = F.softmax(scores, dim=-1) # batch x 1 x src_len # the context vector is the weighted sum of the values context = alphas @ values # batch x 1 x values_size return context, alphas
[docs] def compute_proj_keys(self, keys: Tensor) -> None: """ Compute the projection of the keys and assign them to `self.proj_keys`. This pre-computation is efficiently done for all keys before receiving individual queries. :param keys: shape (batch_size, src_length, encoder.hidden_size) """ # proj_keys: batch x src_len x hidden_size self.proj_keys = self.key_layer(keys)
def _check_input_shapes_forward( self, query: Tensor, mask: Tensor, values: Tensor ) -> None: """ Make sure that inputs to `self.forward` are of correct shape. Same input semantics as for `self.forward`. :param query: :param mask: :param values: :return: """ assert query.shape[0] == values.shape[0] == mask.shape[0] assert query.shape[1] == 1 == mask.shape[1] assert query.shape[2] == self.key_layer.out_features assert values.shape[2] == self.key_layer.in_features assert mask.shape[2] == values.shape[1] def __repr__(self): return "LuongAttention"