Source code for joeynmt.data

# coding: utf-8
"""
Data module
"""
from typing import Dict, Optional, Tuple

from joeynmt.datasets import BaseDataset, build_dataset
from joeynmt.helpers_for_ddp import get_logger
from joeynmt.tokenizers import build_tokenizer
from joeynmt.vocabulary import Vocabulary, build_vocab

logger = get_logger(__name__)


[docs] def load_data(cfg: Dict, datasets: list = None) \ -> Tuple[Vocabulary, Vocabulary, Optional[BaseDataset], Optional[BaseDataset], Optional[BaseDataset]]: """ Load train, dev and optionally test data as specified in configuration. Vocabularies are created from the training set with a limit of `voc_limit` tokens and a minimum token frequency of `voc_min_freq` (specified in the configuration dictionary). The training data is filtered to include sentences up to `max_length` on source and target side. If you set `random_{train|dev}_subset`, a random selection of this size is used from the {train|development} set instead of the full {train|development} set. :param cfg: configuration dictionary for data ("data" part of config file) :param datasets: list of dataset names to load :returns: - src_vocab: source vocabulary - trg_vocab: target vocabulary - train_data: training dataset - dev_data: development dataset - test_data: test dataset if given, otherwise None """ assert len(datasets) > 0, datasets src_cfg = cfg["src"] trg_cfg = cfg["trg"] # load data from files src_lang = src_cfg["lang"] trg_lang = trg_cfg["lang"] train_path = cfg.get("train", None) dev_path = cfg.get("dev", None) test_path = cfg.get("test", None) if train_path is None and dev_path is None and test_path is None: raise ValueError("Please specify at least one data source path.") # build tokenizer logger.info("Building tokenizer...") tokenizer = build_tokenizer(cfg) dataset_type = cfg.get("dataset_type", "plain") dataset_cfg = cfg.get("dataset_cfg", {}) has_prompt = { src_lang: src_cfg.get("has_prompt", False), trg_lang: trg_cfg.get("has_prompt", False), } # train data train_data = None if "train" in datasets and train_path is not None: train_subset = cfg.get("sample_train_subset", -1) if "random_train_subset" in cfg: logger.warning( "`random_train_subset` option is obsolete. " "Please use `sample_train_subset` instead." ) train_subset = cfg.get("random_train_subset", train_subset) logger.info("Loading train set...") train_data = build_dataset( dataset_type=dataset_type, path=train_path, src_lang=src_lang, trg_lang=trg_lang, split="train", tokenizer=tokenizer, has_prompt=has_prompt, random_subset=train_subset, **dataset_cfg, ) # build vocab logger.info("Building vocabulary...") src_vocab, trg_vocab = build_vocab(cfg, dataset=train_data) # set vocab to tokenizer tokenizer[src_lang].set_vocab(src_vocab) tokenizer[trg_lang].set_vocab(trg_vocab) # encoding func sequence_encoder = { src_lang: src_vocab.sentences_to_ids, trg_lang: trg_vocab.sentences_to_ids, } if train_data is not None: train_data.sequence_encoder = sequence_encoder # dev data dev_data = None if "dev" in datasets and dev_path is not None: dev_subset = cfg.get("sample_dev_subset", -1) if "random_dev_subset" in cfg: logger.warning( "`random_dev_subset` option is obsolete. " "Please use `sample_dev_subset` instead." ) dev_subset = cfg.get("random_dev_subset", dev_subset) logger.info("Loading dev set...") dev_data = build_dataset( dataset_type=dataset_type, path=dev_path, src_lang=src_lang, trg_lang=trg_lang, split="dev", tokenizer=tokenizer, sequence_encoder=sequence_encoder, has_prompt=has_prompt, random_subset=dev_subset, **dataset_cfg, ) # test data test_data = None if "test" in datasets and test_path is not None: logger.info("Loading test set...") test_data = build_dataset( dataset_type=dataset_type, path=test_path, src_lang=src_lang, trg_lang=trg_lang, split="test", tokenizer=tokenizer, sequence_encoder=sequence_encoder, has_prompt=has_prompt, random_subset=-1, # no subsampling for test **dataset_cfg, ) if "stream" in datasets: test_data = build_dataset( dataset_type="stream", path=None, src_lang=src_lang, trg_lang=trg_lang, split="test", tokenizer=tokenizer, sequence_encoder=sequence_encoder, has_prompt=has_prompt, ) logger.info("Data loaded.") # Log statistics of data and vocabulary logger.info("Train dataset: %s", train_data) logger.info("Valid dataset: %s", dev_data) logger.info(" Test dataset: %s", test_data) if train_data: src = "\n\t[SRC] " + " ".join( train_data.get_item(idx=0, lang=train_data.src_lang, is_train=False) ) trg = "\n\t[TRG] " + " ".join( train_data.get_item(idx=0, lang=train_data.trg_lang, is_train=False) ) logger.info("First training example:%s%s", src, trg) logger.info("First 10 Src tokens: %s", src_vocab.log_vocab(10)) logger.info("First 10 Trg tokens: %s", trg_vocab.log_vocab(10)) logger.info("Number of unique Src tokens (vocab_size): %d", len(src_vocab)) logger.info("Number of unique Trg tokens (vocab_size): %d", len(trg_vocab)) return src_vocab, trg_vocab, train_data, dev_data, test_data