# coding: utf-8
"""
Vocabulary module
"""
from collections import defaultdict, Counter
from typing import List
import numpy as np
from torchtext.data import Dataset
from joeynmt.constants import UNK_TOKEN, DEFAULT_UNK_ID, \
EOS_TOKEN, BOS_TOKEN, PAD_TOKEN
[docs]class Vocabulary:
""" Vocabulary represents mapping between tokens and indices. """
def __init__(self, tokens: List[str] = None, file: str = None) -> None:
"""
Create vocabulary from list of tokens or file.
Special tokens are added if not already in file or list.
File format: token with index i is in line i.
:param tokens: list of tokens
:param file: file to load vocabulary from
"""
# don't rename stoi and itos since needed for torchtext
# warning: stoi grows with unknown tokens, don't use for saving or size
# special symbols
self.specials = [UNK_TOKEN, PAD_TOKEN, BOS_TOKEN, EOS_TOKEN]
self.stoi = defaultdict(DEFAULT_UNK_ID)
self.itos = []
if tokens is not None:
self._from_list(tokens)
elif file is not None:
self._from_file(file)
def _from_list(self, tokens: List[str] = None) -> None:
"""
Make vocabulary from list of tokens.
Tokens are assumed to be unique and pre-selected.
Special symbols are added if not in list.
:param tokens: list of tokens
"""
self.add_tokens(tokens=self.specials+tokens)
assert len(self.stoi) == len(self.itos)
def _from_file(self, file: str) -> None:
"""
Make vocabulary from contents of file.
File format: token with index i is in line i.
:param file: path to file where the vocabulary is loaded from
"""
tokens = []
with open(file, "r") as open_file:
for line in open_file:
tokens.append(line.strip("\n"))
self._from_list(tokens)
def __str__(self) -> str:
return self.stoi.__str__()
[docs] def to_file(self, file: str) -> None:
"""
Save the vocabulary to a file, by writing token with index i in line i.
:param file: path to file where the vocabulary is written
"""
with open(file, "w") as open_file:
for t in self.itos:
open_file.write("{}\n".format(t))
[docs] def add_tokens(self, tokens: List[str]) -> None:
"""
Add list of tokens to vocabulary
:param tokens: list of tokens to add to the vocabulary
"""
for t in tokens:
new_index = len(self.itos)
# add to vocab if not already there
if t not in self.itos:
self.itos.append(t)
self.stoi[t] = new_index
[docs] def is_unk(self, token: str) -> bool:
"""
Check whether a token is covered by the vocabulary
:param token:
:return: True if covered, False otherwise
"""
return self.stoi[token] == DEFAULT_UNK_ID()
def __len__(self) -> int:
return len(self.itos)
[docs] def array_to_sentence(self, array: np.array, cut_at_eos=True,
skip_pad=True) -> List[str]:
"""
Converts an array of IDs to a sentence, optionally cutting the result
off at the end-of-sequence token.
:param array: 1D array containing indices
:param cut_at_eos: cut the decoded sentences at the first <eos>
:param skip_pad: skip generated <pad> tokens
:return: list of strings (tokens)
"""
sentence = []
for i in array:
s = self.itos[i]
if cut_at_eos and s == EOS_TOKEN:
break
if skip_pad and s == PAD_TOKEN:
continue
sentence.append(s)
return sentence
[docs] def arrays_to_sentences(self, arrays: np.array, cut_at_eos=True,
skip_pad=True) -> List[List[str]]:
"""
Convert multiple arrays containing sequences of token IDs to their
sentences, optionally cutting them off at the end-of-sequence token.
:param arrays: 2D array containing indices
:param cut_at_eos: cut the decoded sentences at the first <eos>
:param skip_pad: skip generated <pad> tokens
:return: list of list of strings (tokens)
"""
sentences = []
for array in arrays:
sentences.append(
self.array_to_sentence(array=array, cut_at_eos=cut_at_eos,
skip_pad=skip_pad))
return sentences
[docs]def build_vocab(field: str, max_size: int, min_freq: int, dataset: Dataset,
vocab_file: str = None) -> Vocabulary:
"""
Builds vocabulary for a torchtext `field` from given`dataset` or
`vocab_file`.
:param field: attribute e.g. "src"
:param max_size: maximum size of vocabulary
:param min_freq: minimum frequency for an item to be included
:param dataset: dataset to load data for field from
:param vocab_file: file to store the vocabulary,
if not None, load vocabulary from here
:return: Vocabulary created from either `dataset` or `vocab_file`
"""
if vocab_file is not None:
# load it from file
vocab = Vocabulary(file=vocab_file)
else:
# create newly
def filter_min(counter: Counter, min_freq: int):
""" Filter counter by min frequency """
filtered_counter = Counter({t: c for t, c in counter.items()
if c >= min_freq})
return filtered_counter
def sort_and_cut(counter: Counter, limit: int):
""" Cut counter to most frequent,
sorted numerically and alphabetically"""
# sort by frequency, then alphabetically
tokens_and_frequencies = sorted(counter.items(),
key=lambda tup: tup[0])
tokens_and_frequencies.sort(key=lambda tup: tup[1], reverse=True)
vocab_tokens = [i[0] for i in tokens_and_frequencies[:limit]]
return vocab_tokens
tokens = []
for i in dataset.examples:
if field == "src":
tokens.extend(i.src)
elif field == "trg":
tokens.extend(i.trg)
counter = Counter(tokens)
if min_freq > -1:
counter = filter_min(counter, min_freq)
vocab_tokens = sort_and_cut(counter, max_size)
assert len(vocab_tokens) <= max_size
vocab = Vocabulary(tokens=vocab_tokens)
assert len(vocab) <= max_size + len(vocab.specials)
assert vocab.itos[DEFAULT_UNK_ID()] == UNK_TOKEN
# check for all except for UNK token whether they are OOVs
for s in vocab.specials[1:]:
assert not vocab.is_unk(s)
return vocab