Source code for joeynmt.data

# coding: utf-8
"""
Data module
"""
import sys
import random
import os
import os.path
from typing import Optional
import logging

from torchtext.datasets import TranslationDataset
from torchtext import data
from torchtext.data import Dataset, Iterator, Field

from joeynmt.constants import UNK_TOKEN, EOS_TOKEN, BOS_TOKEN, PAD_TOKEN
from joeynmt.vocabulary import build_vocab, Vocabulary

logger = logging.getLogger(__name__)


[docs]def load_data(data_cfg: dict, datasets: list = None)\ -> (Dataset, Dataset, Optional[Dataset], Vocabulary, Vocabulary): """ Load train, dev and optionally test data as specified in configuration. Vocabularies are created from the training set with a limit of `voc_limit` tokens and a minimum token frequency of `voc_min_freq` (specified in the configuration dictionary). The training data is filtered to include sentences up to `max_sent_length` on source and target side. If you set ``random_train_subset``, a random selection of this size is used from the training set instead of the full training set. :param data_cfg: configuration dictionary for data ("data" part of configuation file) :param datasets: list of dataset names to load :return: - train_data: training dataset - dev_data: development dataset - test_data: testdata set if given, otherwise None - src_vocab: source vocabulary extracted from training data - trg_vocab: target vocabulary extracted from training data """ if datasets is None: datasets = ["train", "dev", "test"] # load data from files src_lang = data_cfg["src"] trg_lang = data_cfg["trg"] train_path = data_cfg.get("train", None) dev_path = data_cfg.get("dev", None) test_path = data_cfg.get("test", None) if train_path is None and dev_path is None and test_path is None: raise ValueError('Please specify at least one data source path.') level = data_cfg["level"] lowercase = data_cfg["lowercase"] max_sent_length = data_cfg["max_sent_length"] tok_fun = lambda s: list(s) if level == "char" else s.split() src_field = data.Field(init_token=None, eos_token=EOS_TOKEN, pad_token=PAD_TOKEN, tokenize=tok_fun, batch_first=True, lower=lowercase, unk_token=UNK_TOKEN, include_lengths=True) trg_field = data.Field(init_token=BOS_TOKEN, eos_token=EOS_TOKEN, pad_token=PAD_TOKEN, tokenize=tok_fun, unk_token=UNK_TOKEN, batch_first=True, lower=lowercase, include_lengths=True) train_data = None if "train" in datasets and train_path is not None: logger.info("Loading training data...") train_data = TranslationDataset(path=train_path, exts=("." + src_lang, "." + trg_lang), fields=(src_field, trg_field), filter_pred= lambda x: len(vars(x)['src']) <= max_sent_length and len(vars(x)['trg']) <= max_sent_length) random_train_subset = data_cfg.get("random_train_subset", -1) if random_train_subset > -1: # select this many training examples randomly and discard the rest keep_ratio = random_train_subset / len(train_data) keep, _ = train_data.split( split_ratio=[keep_ratio, 1 - keep_ratio], random_state=random.getstate()) train_data = keep src_max_size = data_cfg.get("src_voc_limit", sys.maxsize) src_min_freq = data_cfg.get("src_voc_min_freq", 1) trg_max_size = data_cfg.get("trg_voc_limit", sys.maxsize) trg_min_freq = data_cfg.get("trg_voc_min_freq", 1) src_vocab_file = data_cfg.get("src_vocab", None) trg_vocab_file = data_cfg.get("trg_vocab", None) assert (train_data is not None) or (src_vocab_file is not None) assert (train_data is not None) or (trg_vocab_file is not None) logger.info("Building vocabulary...") src_vocab = build_vocab(field="src", min_freq=src_min_freq, max_size=src_max_size, dataset=train_data, vocab_file=src_vocab_file) trg_vocab = build_vocab(field="trg", min_freq=trg_min_freq, max_size=trg_max_size, dataset=train_data, vocab_file=trg_vocab_file) dev_data = None if "dev" in datasets and dev_path is not None: logger.info("Loading dev data...") dev_data = TranslationDataset(path=dev_path, exts=("." + src_lang, "." + trg_lang), fields=(src_field, trg_field)) test_data = None if "test" in datasets and test_path is not None: logger.info("Loading test data...") # check if target exists if os.path.isfile(test_path + "." + trg_lang): test_data = TranslationDataset( path=test_path, exts=("." + src_lang, "." + trg_lang), fields=(src_field, trg_field)) else: # no target is given -> create dataset from src only test_data = MonoDataset(path=test_path, ext="." + src_lang, field=src_field) src_field.vocab = src_vocab trg_field.vocab = trg_vocab logger.info("Data loaded.") return train_data, dev_data, test_data, src_vocab, trg_vocab
# pylint: disable=global-at-module-level global max_src_in_batch, max_tgt_in_batch # pylint: disable=unused-argument,global-variable-undefined
[docs]def token_batch_size_fn(new, count, sofar): """Compute batch size based on number of tokens (+padding).""" global max_src_in_batch, max_tgt_in_batch if count == 1: max_src_in_batch = 0 max_tgt_in_batch = 0 max_src_in_batch = max(max_src_in_batch, len(new.src)) src_elements = count * max_src_in_batch if hasattr(new, 'trg'): # for monolingual data sets ("translate" mode) max_tgt_in_batch = max(max_tgt_in_batch, len(new.trg) + 2) tgt_elements = count * max_tgt_in_batch else: tgt_elements = 0 return max(src_elements, tgt_elements)
[docs]def make_data_iter(dataset: Dataset, batch_size: int, batch_type: str = "sentence", train: bool = False, shuffle: bool = False) -> Iterator: """ Returns a torchtext iterator for a torchtext dataset. :param dataset: torchtext dataset containing src and optionally trg :param batch_size: size of the batches the iterator prepares :param batch_type: measure batch size by sentence count or by token count :param train: whether it's training time, when turned off, bucketing, sorting within batches and shuffling is disabled :param shuffle: whether to shuffle the data before each epoch (no effect if set to True for testing) :return: torchtext iterator """ batch_size_fn = token_batch_size_fn if batch_type == "token" else None if train: # optionally shuffle and sort during training data_iter = data.BucketIterator( repeat=False, sort=False, dataset=dataset, batch_size=batch_size, batch_size_fn=batch_size_fn, train=True, sort_within_batch=True, sort_key=lambda x: len(x.src), shuffle=shuffle) else: # don't sort/shuffle for validation/inference data_iter = data.BucketIterator( repeat=False, dataset=dataset, batch_size=batch_size, batch_size_fn=batch_size_fn, train=False, sort=False) return data_iter
[docs]class MonoDataset(Dataset): """Defines a dataset for machine translation without targets."""
[docs] @staticmethod def sort_key(ex): return len(ex.src)
def __init__(self, path: str, ext: str, field: Field, **kwargs) -> None: """ Create a monolingual dataset (=only sources) given path and field. :param path: Prefix of path to the data file :param ext: Containing the extension to path for this language. :param field: Containing the fields that will be used for data. :param kwargs: Passed to the constructor of data.Dataset. """ fields = [('src', field)] if hasattr(path, "readline"): # special usage: stdin src_file = path else: src_path = os.path.expanduser(path + ext) src_file = open(src_path) examples = [] for src_line in src_file: src_line = src_line.strip() if src_line != '': examples.append(data.Example.fromlist( [src_line], fields)) src_file.close() super().__init__(examples, fields, **kwargs)