Source code for joeynmt.datasets

# coding: utf-8
"""
Dataset module
"""
from functools import partial
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import torch
from torch.utils.data import BatchSampler, DataLoader, Dataset, Sampler

from joeynmt.batch import Batch
from joeynmt.config import ConfigurationError
from joeynmt.helpers import read_list_from_file
from joeynmt.helpers_for_ddp import (
    DistributedSubsetSampler,
    RandomSubsetSampler,
    get_logger,
    use_ddp,
)
from joeynmt.tokenizers import BasicTokenizer

logger = get_logger(__name__)
CPU_DEVICE = torch.device("cpu")


[docs] class BaseDataset(Dataset): """ BaseDataset which loads and looks up data. - holds pointer to tokenizers, encoding functions. :param path: path to data directory :param src_lang: source language code, i.e. `en` :param trg_lang: target language code, i.e. `de` :param has_trg: bool indicator if trg exists :param has_prompt: bool indicator if prompt exists :param split: bool indicator for train set or not :param tokenizer: tokenizer objects :param sequence_encoder: encoding functions """ # pylint: disable=too-many-instance-attributes def __init__( self, path: str, src_lang: str, trg_lang: str, split: str = "train", has_trg: bool = False, has_prompt: Dict[str, bool] = None, tokenizer: Dict[str, BasicTokenizer] = None, sequence_encoder: Dict[str, Callable] = None, random_subset: int = -1 ): self.path = path self.src_lang = src_lang self.trg_lang = trg_lang self.has_trg = has_trg self.split = split if self.split == "train": assert self.has_trg self.tokenizer = tokenizer self.sequence_encoder = sequence_encoder self.has_prompt = has_prompt # for random subsampling self.random_subset = random_subset self.indices = None # range(self.__len__()) # Note: self.indices is kept sorted, even if shuffle = True in make_iter() # (Sampler yields permuted indices) self.seed = 1 # random seed for generator
[docs] def reset_indices(self, random_subset: int = None): # should be called after data are loaded. # otherwise self.__len__() is undefined. self.indices = list(range(self.__len__())) if self.__len__() > 0 else [] if random_subset is not None: self.random_subset = random_subset if 0 < self.random_subset: assert (self.split != "test" and self.random_subset < self.__len__()), \ ("Can only subsample from train or dev set " f"larger than {self.random_subset}.")
[docs] def load_data(self, path: Path, **kwargs) -> Any: """ load data - preprocessing (lowercasing etc) is applied here. """ raise NotImplementedError
[docs] def get_item(self, idx: int, lang: str, is_train: bool = None) -> List[str]: """ seek one src/trg item of the given index. - tokenization is applied here. - length-filtering, bpe-dropout etc also triggered if self.split == "train" """ # workaround if tokenizer prepends an extra escape symbol before lang_tang ... def _remove_escape(item): if ( item is not None and self.tokenizer[lang] is not None and item[0] == self.tokenizer[lang].SPACE_ESCAPE and item[1] in self.tokenizer[lang].lang_tags ): return item[1:] return item line, prompt = self.lookup_item(idx, lang) is_train = self.split == "train" if is_train is None else is_train item = _remove_escape(self.tokenizer[lang](line, is_train=is_train)) if self.has_prompt[lang] and prompt is not None: prompt = _remove_escape(self.tokenizer[lang](prompt, is_train=False)) item = item if item is not None else [] max_length = self.tokenizer[lang].max_length if 0 < max_length < len(prompt) + len(item) + 1: # truncate prompt offset = max_length - len(item) - 1 if prompt[0] in self.tokenizer[lang].lang_tags: prompt = [prompt[0]] + prompt[-(offset - 1):] else: prompt = prompt[-offset:] item = prompt + [self.tokenizer[lang].sep_token] + item return item
[docs] def lookup_item(self, idx: int, lang: str) -> Tuple[str, str]: raise NotImplementedError
def __getitem__(self, idx: Union[int, str]) -> Tuple[int, List[str], List[str]]: """ lookup one item pair of the given index. :param idx: index of the instance to lookup :return: - index # needed to recover the original order - tokenized src sentences - tokenized trg sentences """ if idx > self.__len__(): raise KeyError src, trg = None, None src = self.get_item(idx=idx, lang=self.src_lang) if self.has_trg or self.has_prompt[self.trg_lang]: trg = self.get_item(idx=idx, lang=self.trg_lang) if trg is None: src = None return idx, src, trg
[docs] def get_list(self, lang: str, tokenized: bool = False, subsampled: bool = True) -> Union[List[str], List[List[str]]]: """get data column-wise.""" raise NotImplementedError
@property def src(self) -> List[str]: """get detokenized preprocessed data in src language.""" return self.get_list(self.src_lang, tokenized=False, subsampled=True) @property def trg(self) -> List[str]: """get detokenized preprocessed data in trg language.""" return ( self.get_list(self.trg_lang, tokenized=False, subsampled=True) if self.has_trg else [] )
[docs] def collate_fn( self, batch: List[Tuple], pad_index: int, eos_index: int, device: torch.device = CPU_DEVICE, ) -> Batch: """ Custom collate function. See https://pytorch.org/docs/stable/data.html#dataloader-collate-fn for details. Please override the batch class here. (not in TrainManager) :param batch: :param pad_index: :param eos_index: :param device: :return: joeynmt batch object """ idx, src_list, trg_list = zip(*batch) assert len(batch) == len(src_list) == len(trg_list), (len(batch), len(src_list)) assert all(s is not None for s in src_list), src_list src, src_length, src_prompt_mask = self.sequence_encoder[ self.src_lang](src_list, bos=False, eos=True) if self.has_trg or self.has_prompt[self.trg_lang]: if self.has_trg: assert all(t is not None for t in trg_list), trg_list trg, _, trg_prompt_mask = self.sequence_encoder[self.trg_lang]( trg_list, bos=True, eos=self.has_trg ) # no EOS if not self.has_trg else: assert all(t is None for t in trg_list) trg, trg_prompt_mask = None, None # Note: we don't need trg_length! return Batch( src=torch.tensor(src).long(), src_length=torch.tensor(src_length).long(), src_prompt_mask=( torch.tensor(src_prompt_mask).long() if self.has_prompt[self.src_lang] else None ), trg=torch.tensor(trg).long() if trg else None, trg_prompt_mask=( torch.tensor(trg_prompt_mask).long() if self.has_prompt[self.trg_lang] else None ), indices=torch.tensor(idx).long(), device=device, pad_index=pad_index, eos_index=eos_index, is_train=self.split == "train", )
[docs] def make_iter( self, batch_size: int, batch_type: str = "sentence", seed: int = 42, shuffle: bool = False, num_workers: int = 0, pad_index: int = 1, eos_index: int = 3, device: torch.device = CPU_DEVICE, generator_state: torch.Tensor = None, ) -> DataLoader: """ Returns a torch DataLoader for a torch Dataset. (no bucketing) :param batch_size: size of the batches the iterator prepares :param batch_type: measure batch size by sentence count or by token count :param seed: random seed for shuffling :param shuffle: whether to shuffle the order of sequences before each epoch (for testing, no effect even if set to True; generator is still used for random subsampling, but not for permutation!) :param num_workers: number of cpus for multiprocessing :param pad_index: :param eos_index: :param device: :param generator_state: :return: torch DataLoader """ shuffle = shuffle and self.split == "train" # for decoding in DDP, we cannot use TokenBatchSampler if use_ddp() and self.split != "train": assert batch_type == "sentence", self generator = torch.Generator() generator.manual_seed(seed) if generator_state is not None: generator.set_state(generator_state) # define sampler which yields an integer sampler: Sampler[int] if use_ddp(): # use ddp sampler = DistributedSubsetSampler( self, shuffle=shuffle, drop_last=True, generator=generator ) else: sampler = RandomSubsetSampler(self, shuffle=shuffle, generator=generator) # batch sampler which yields a list of integers if batch_type == "sentence": batch_sampler = SentenceBatchSampler( sampler, batch_size=batch_size, drop_last=False, seed=seed ) elif batch_type == "token": batch_sampler = TokenBatchSampler( sampler, batch_size=batch_size, drop_last=False, seed=seed ) else: raise ConfigurationError(f"{batch_type}: Unknown batch type") # initialize generator seed batch_sampler.set_seed(seed) # set seed and resample # ensure that sequence_encoder (padding func) exists assert self.sequence_encoder[self.src_lang] is not None if self.has_trg: assert self.sequence_encoder[self.trg_lang] is not None # data iterator return DataLoader( dataset=self, batch_sampler=batch_sampler, collate_fn=partial( self.collate_fn, eos_index=eos_index, pad_index=pad_index, device=device ), num_workers=num_workers )
def __len__(self) -> int: raise NotImplementedError def __repr__(self) -> str: return ( f"{self.__class__.__name__}(split={self.split}, len={self.__len__()}, " f'src_lang="{self.src_lang}", trg_lang="{self.trg_lang}", ' f"has_trg={self.has_trg}, random_subset={self.random_subset}, " f"has_src_prompt={self.has_prompt[self.src_lang]}, " f"has_trg_prompt={self.has_prompt[self.trg_lang]})" )
[docs] class PlaintextDataset(BaseDataset): """ PlaintextDataset which stores plain text pairs. - used for text file data in the format of one sentence per line. """ def __init__( self, path: str, src_lang: str, trg_lang: str, split: str = "train", has_trg: bool = False, has_prompt: Dict[str, bool] = None, tokenizer: Dict[str, BasicTokenizer] = None, sequence_encoder: Dict[str, Callable] = None, random_subset: int = -1, **kwargs ): super().__init__( path=path, src_lang=src_lang, trg_lang=trg_lang, split=split, has_trg=has_trg, has_prompt=has_prompt, tokenizer=tokenizer, sequence_encoder=sequence_encoder, random_subset=random_subset ) # load data self.data = self.load_data(path, **kwargs) self.reset_indices()
[docs] def load_data(self, path: str, **kwargs) -> Any: def _pre_process(seq, lang): if self.tokenizer[lang] is not None: seq = [self.tokenizer[lang].pre_process(s) for s in seq if len(s) > 0] return seq path = Path(path) src_file = path.with_suffix(f"{path.suffix}.{self.src_lang}") assert src_file.is_file(), f"{src_file} not found. Abort." src_list = read_list_from_file(src_file) data = {self.src_lang: _pre_process(src_list, self.src_lang)} if self.has_trg: trg_file = path.with_suffix(f"{path.suffix}.{self.trg_lang}") assert trg_file.is_file(), f"{trg_file} not found. Abort." trg_list = read_list_from_file(trg_file) data[self.trg_lang] = _pre_process(trg_list, self.trg_lang) assert len(src_list) == len(trg_list) return data
[docs] def lookup_item(self, idx: int, lang: str) -> Tuple[str, str]: try: line = self.data[lang][idx] prompt = ( self.data[f"{lang}_prompt"][idx] if f"{lang}_prompt" in self.data else None ) return line, prompt except Exception as e: logger.error(idx, e) raise ValueError from e
[docs] def get_list(self, lang: str, tokenized: bool = False, subsampled: bool = True) -> Union[List[str], List[List[str]]]: """ Return list of preprocessed sentences in the given language. (not length-filtered, no bpe-dropout) """ indices = self.indices if subsampled else range(self.__len__()) item_list = [] for idx in indices: item, _ = self.lookup_item(idx, lang) if tokenized: item = self.tokenizer[lang](item, is_train=False) item_list.append(item) assert len(indices) == len(item_list), (len(indices), len(item_list)) return item_list
def __len__(self) -> int: return len(self.data[self.src_lang])
[docs] class TsvDataset(BaseDataset): """ TsvDataset which handles data in tsv format. - file_name should be specified without extension `.tsv` - needs src_lang and trg_lang (i.e. `en`, `de`) in header. see: test/data/toy/dev.tsv """ def __init__( self, path: str, src_lang: str, trg_lang: str, split: str = "train", has_trg: bool = False, has_prompt: Dict[str, bool] = None, tokenizer: Dict[str, BasicTokenizer] = None, sequence_encoder: Dict[str, Callable] = None, random_subset: int = -1, **kwargs ): super().__init__( path=path, src_lang=src_lang, trg_lang=trg_lang, split=split, has_trg=has_trg, has_prompt=has_prompt, tokenizer=tokenizer, sequence_encoder=sequence_encoder, random_subset=random_subset ) # load tsv file self.df = self.load_data(path, **kwargs) self.reset_indices()
[docs] def load_data(self, path: str, **kwargs) -> Any: path = Path(path) file_path = path.with_suffix(f"{path.suffix}.tsv") assert file_path.is_file(), f"{file_path} not found. Abort." try: import pandas as pd # pylint: disable=import-outside-toplevel # TODO: use `chunksize` for online data loading. df = pd.read_csv( file_path.as_posix(), sep="\t", header=0, encoding="utf-8", index_col=None ) df = df.dropna() df = df.reset_index() assert self.src_lang in df.columns df[self.src_lang ] = df[self.src_lang].apply(self.tokenizer[self.src_lang].pre_process) if self.trg_lang not in df.columns: self.has_trg = False assert self.split == "test" if self.has_trg: df[self.trg_lang] = df[self.trg_lang].apply( self.tokenizer[self.trg_lang].pre_process ) if f"{self.src_lang}_prompt" in df.columns: self.has_prompt[self.src_lang] = True df[f"{self.src_lang}_prompt"] = df[f"{self.src_lang}_prompt"].apply( self.tokenizer[self.src_lang].pre_process, allow_empty=True ) if f"{self.trg_lang}_prompt" in df.columns: self.has_prompt[self.trg_lang] = True df[f"{self.trg_lang}_prompt"] = df[f"{self.trg_lang}_prompt"].apply( self.tokenizer[self.trg_lang].pre_process, allow_empty=True ) return df except ImportError as e: logger.error(e) raise ImportError from e
[docs] def lookup_item(self, idx: int, lang: str) -> Tuple[str, str]: try: row = self.df.iloc[idx] line = row[lang] prompt = row.get(f"{lang}_prompt", None) return line, prompt except Exception as e: logger.error(idx, e) raise ValueError from e
[docs] def get_list(self, lang: str, tokenized: bool = False, subsampled: bool = True) -> Union[List[str], List[List[str]]]: indices = self.indices if subsampled else range(self.__len__()) df = self.df.iloc[indices] return ( df[lang].apply(self.tokenizer[lang]).to_list() if tokenized else df[lang].to_list() )
def __len__(self) -> int: return len(self.df)
[docs] class StreamDataset(BaseDataset): """ StreamDataset which interacts with stream inputs. - called by `translate()` func in `prediction.py`. """ # pylint: disable=unused-argument def __init__( self, path: str, src_lang: str, trg_lang: str, split: str = "test", has_trg: bool = False, has_prompt: Dict[str, bool] = None, tokenizer: Dict[str, BasicTokenizer] = None, sequence_encoder: Dict[str, Callable] = None, random_subset: int = -1, **kwargs ): super().__init__( path=path, src_lang=src_lang, trg_lang=trg_lang, split=split, has_trg=has_trg, has_prompt=has_prompt, tokenizer=tokenizer, sequence_encoder=sequence_encoder, random_subset=random_subset ) # place holder self.cache = [] def _split_at_sep(self, line: str, prompt: str, lang: str, sep_token: str): """ Split string at sep_token :param line: (non-empty) input string :param prompt: input prompt :param lang: :param sep_token: """ if ( sep_token is not None and line is not None and sep_token in line and prompt is None ): line, prompt = line.split(sep_token) if line: line = self.tokenizer[lang].pre_process(line, allow_empty=False) if prompt: prompt = self.tokenizer[lang].pre_process(prompt, allow_empty=True) self.has_prompt[lang] = True return line, prompt
[docs] def set_item( self, src_line: str, trg_line: Optional[str] = None, src_prompt: Optional[str] = None, trg_prompt: Optional[str] = None ) -> None: """ Set input text to the cache. :param src_line: (non-empty) str :param trg_line: Optional[str] :param src_prompt: Optional[str] :param trg_prompt: Optional[str] """ assert isinstance(src_line, str) and src_line.strip() != "", \ "The input sentence is empty! Please make sure " \ "that you are feeding a valid input." src_line, src_prompt = self._split_at_sep( src_line, src_prompt, self.src_lang, self.tokenizer[self.src_lang].sep_token ) assert src_line is not None trg_line, trg_prompt = self._split_at_sep( trg_line, trg_prompt, self.trg_lang, self.tokenizer[self.trg_lang].sep_token ) if self.has_trg: assert trg_line is not None self.cache.append((src_line, trg_line, src_prompt, trg_prompt)) self.reset_indices()
[docs] def lookup_item(self, idx: int, lang: str) -> Tuple[str, str]: # pylint: disable=no-else-return try: assert lang in [self.src_lang, self.trg_lang] if lang == self.trg_lang: assert self.has_trg or self.has_prompt[lang] src_line, trg_line, src_prompt, trg_prompt = self.cache[idx] if lang == self.src_lang: return src_line, src_prompt elif lang == self.trg_lang: return trg_line, trg_prompt else: raise ValueError except Exception as e: logger.error(idx, e) raise ValueError from e
[docs] def reset_cache(self): self.cache = [] self.reset_indices()
def __len__(self) -> int: return len(self.cache) def __repr__(self) -> str: return ( f"{self.__class__.__name__}(split={self.split}, len={len(self.cache)}, " f'src_lang="{self.src_lang}", trg_lang="{self.trg_lang}", ' f"has_trg={self.has_trg}, random_subset={self.random_subset}, " f"has_src_prompt={self.has_prompt[self.src_lang]}, " f"has_trg_prompt={self.has_prompt[self.trg_lang]})" )
[docs] class BaseHuggingfaceDataset(BaseDataset): """ Wrapper for Huggingface's dataset object cf.) https://huggingface.co/docs/datasets """ COLUMN_NAME = "sentence" # dummy column name. should be overriden. def __init__( self, path: str, src_lang: str, trg_lang: str, has_trg: bool = True, has_prompt: Dict[str, bool] = None, tokenizer: Dict[str, BasicTokenizer] = None, sequence_encoder: Dict[str, Callable] = None, random_subset: int = -1, **kwargs, ): super().__init__( path=path, src_lang=src_lang, trg_lang=trg_lang, split=kwargs["split"], has_trg=has_trg, has_prompt=has_prompt, tokenizer=tokenizer, sequence_encoder=sequence_encoder, random_subset=random_subset, ) # load data self.dataset = self.load_data(path, **kwargs) self._kwargs = kwargs # should contain arguments passed to `load_dataset()` self.reset_indices()
[docs] def load_data(self, path: str, **kwargs) -> Any: # pylint: disable=import-outside-toplevel try: from datasets import Dataset as Dataset_hf from datasets import DatasetDict, config, load_dataset, load_from_disk if Path(path, config.DATASET_STATE_JSON_FILENAME).exists() \ or Path(path, config.DATASETDICT_JSON_FILENAME).exists(): hf_dataset = load_from_disk(path) if isinstance(hf_dataset, DatasetDict): assert kwargs["split"] in hf_dataset hf_dataset = hf_dataset[kwargs["split"]] else: hf_dataset = load_dataset(path, **kwargs) assert isinstance(hf_dataset, Dataset_hf) assert self.COLUMN_NAME in hf_dataset.features return hf_dataset except ImportError as e: logger.error(e) raise ImportError from e
[docs] def lookup_item(self, idx: int, lang: str) -> Tuple[str, str]: try: line = self.dataset[idx] assert lang in line[self.COLUMN_NAME], (line, lang) prompt = line.get(f"{lang}_prompt", None) return line[self.COLUMN_NAME][lang], prompt except Exception as e: logger.error(idx, e) raise ValueError from e
[docs] def get_list(self, lang: str, tokenized: bool = False, subsampled: bool = True) -> Union[List[str], List[List[str]]]: dataset = self.dataset if subsampled: dataset = dataset.filter( lambda x, idx: idx in self.indices, with_indices=True ) assert len(dataset) == len(self.indices), (len(dataset), len(self.indices)) if tokenized: def _tok(item): item[f'tok_{lang}'] = self.tokenizer[lang](item[self.COLUMN_NAME][lang]) return item return dataset.map(_tok, desc=f"Tokenizing {lang}...")[f'tok_{lang}'] return dataset.flatten()[f'{self.COLUMN_NAME}.{lang}']
def __len__(self) -> int: return self.dataset.num_rows def __repr__(self) -> str: ret = ( f"{self.__class__.__name__}(len={self.__len__()}, " f'src_lang="{self.src_lang}", trg_lang="{self.trg_lang}", ' f"has_trg={self.has_trg}, random_subset={self.random_subset}, " f"has_src_prompt={self.has_prompt[self.src_lang]}, " f"has_trg_prompt={self.has_prompt[self.trg_lang]}" ) for k, v in self._kwargs.items(): ret += f", {k}={v}" ret += ")" return ret
[docs] class HuggingfaceTranslationDataset(BaseHuggingfaceDataset): """ Wrapper for Huggingface's `datasets.features.Translation` class cf.) https://github.com/huggingface/datasets/blob/master/src/datasets/features/translation.py """ # noqa COLUMN_NAME = "translation"
[docs] def load_data(self, path: str, **kwargs) -> Any: dataset = super().load_data(path=path, **kwargs) # pylint: disable=import-outside-toplevel try: from datasets.features import Translation as Translation_hf assert isinstance(dataset.features[self.COLUMN_NAME], Translation_hf), \ f"Data type mismatch. Please cast `{self.COLUMN_NAME}` column to " \ "datasets.features.Translation class." assert self.src_lang in dataset.features[self.COLUMN_NAME].languages if self.has_trg: assert self.trg_lang in dataset.features[self.COLUMN_NAME].languages except ImportError as e: logger.error(e) raise ImportError from e # preprocess (lowercase, pretokenize, etc.) + validity check def _pre_process(item): sl = self.src_lang tl = self.trg_lang item[self.COLUMN_NAME][sl] = self.tokenizer[sl].pre_process( item[self.COLUMN_NAME][sl] ) if self.has_trg: item[self.COLUMN_NAME][tl] = self.tokenizer[tl].pre_process( item[self.COLUMN_NAME][tl] ) if self.has_prompt[sl]: item[f"{sl}_prompt"] = self.tokenizer[sl].pre_process( item[f"{sl}_prompt"], allow_empty=True ) if self.has_prompt[tl]: item[f"{tl}_prompt"] = self.tokenizer[tl].pre_process( item[f"{tl}_prompt"], allow_empty=True ) return item def _drop_nan(item): src_item = item[self.COLUMN_NAME][self.src_lang] is_src_valid = src_item is not None and len(src_item) > 0 if self.has_trg: trg_item = item[self.COLUMN_NAME][self.trg_lang] is_trg_valid = trg_item is not None and len(trg_item) > 0 return is_src_valid and is_trg_valid return is_src_valid dataset = dataset.filter(_drop_nan, desc="Dropping NaN...") dataset = dataset.map(_pre_process, desc="Preprocessing...") return dataset
[docs] def build_dataset( dataset_type: str, path: str, src_lang: str, trg_lang: str, split: str, tokenizer: Dict = None, sequence_encoder: Dict = None, has_prompt: Dict = None, random_subset: int = -1, **kwargs, ): """ Builds a dataset. :param dataset_type: (str) one of {`plain`, `tsv`, `stream`, `huggingface`} :param path: (str) either a local file name or dataset name to download from remote :param src_lang: (str) language code for source :param trg_lang: (str) language code for target :param split: (str) one of {`train`, `dev`, `test`} :param tokenizer: tokenizer objects for both source and target :param sequence_encoder: encoding functions for both source and target :param has_prompt: prompt indicators :param random_subset: (int) number of random subset; -1 means no subsampling :return: loaded Dataset """ dataset = None has_trg = True # by default, we expect src-trg pairs _placeholder = {src_lang: None, trg_lang: None} tokenizer = _placeholder if tokenizer is None else tokenizer sequence_encoder = _placeholder if sequence_encoder is None else sequence_encoder has_prompt = _placeholder if has_prompt is None else has_prompt if dataset_type == "plain": if not Path(path).with_suffix(f"{Path(path).suffix}.{trg_lang}").is_file(): has_trg = False # no target is given -> create dataset from src only dataset = PlaintextDataset( path=path, src_lang=src_lang, trg_lang=trg_lang, split=split, has_trg=has_trg, has_prompt=has_prompt, tokenizer=tokenizer, sequence_encoder=sequence_encoder, random_subset=random_subset, **kwargs, ) elif dataset_type == "tsv": dataset = TsvDataset( path=path, src_lang=src_lang, trg_lang=trg_lang, split=split, has_trg=has_trg, has_prompt=has_prompt, tokenizer=tokenizer, sequence_encoder=sequence_encoder, random_subset=random_subset, **kwargs, ) elif dataset_type == "stream": dataset = StreamDataset( path=path, src_lang=src_lang, trg_lang=trg_lang, split="test", has_trg=False, has_prompt=has_prompt, tokenizer=tokenizer, sequence_encoder=sequence_encoder, random_subset=-1, **kwargs, ) elif dataset_type == "huggingface": # "split" should be specified in kwargs if "split" not in kwargs: kwargs["split"] = "validation" if split == "dev" else split dataset = HuggingfaceTranslationDataset( path=path, src_lang=src_lang, trg_lang=trg_lang, has_trg=has_trg, has_prompt=has_prompt, tokenizer=tokenizer, sequence_encoder=sequence_encoder, random_subset=random_subset, **kwargs, ) else: raise ConfigurationError(f"{dataset_type}: Unknown dataset type.") return dataset
[docs] class SentenceBatchSampler(BatchSampler): """ Wraps another sampler to yield a mini-batch of indices based on num of instances. An instance longer than dataset.max_len will be filtered out. :param sampler: Base sampler. Can be any iterable object :param batch_size: Size of mini-batch. :param drop_last: If `True`, the sampler will drop the last batch if its size would be less than `batch_size` """ def __init__(self, sampler: Sampler, batch_size: int, drop_last: bool, seed: int): super().__init__(sampler, batch_size, drop_last) self.seed = seed @property def num_samples(self) -> int: """ Returns number of samples in the dataset. This may change during sampling. Note: len(dataset) won't change during sampling. Use len(dataset) instead, to retrieve the original dataset length. """ assert self.sampler.data_source.indices is not None try: return len(self.sampler) except NotImplementedError as e: # pylint: disable=unused-variable # noqa: F841 return len(self.sampler.data_source.indices) def __iter__(self): batch = [] d = self.sampler.data_source for idx in self.sampler: _, src, trg = d[idx] # pylint: disable=unused-variable if src is not None: # otherwise drop instance batch.append(idx) if len(batch) >= self.batch_size: yield batch batch = [] if len(batch) > 0: if not self.drop_last: yield batch else: logger.warning(f"Drop indices {batch}.") def __len__(self) -> int: # pylint: disable=no-else-return if self.drop_last: return self.num_samples // self.batch_size else: return (self.num_samples + self.batch_size - 1) // self.batch_size
[docs] def set_seed(self, seed: int) -> None: assert seed is not None, seed self.sampler.data_source.seed = seed if hasattr(self.sampler, 'set_seed'): self.sampler.set_seed(seed) # set seed and resample elif hasattr(self.sampler, 'generator'): self.sampler.generator.manual_seed(seed) if self.num_samples < len(self.sampler.data_source): logger.info( "Sample random subset from %s data: n=%d, seed=%d", self.sampler.data_source.split, self.num_samples, seed )
[docs] def reset(self) -> None: if hasattr(self.sampler, 'reset'): self.sampler.reset()
[docs] def get_state(self): if hasattr(self.sampler, 'generator'): return self.sampler.generator.get_state() return None
[docs] def set_state(self, state) -> None: if hasattr(self.sampler, 'generator'): self.sampler.generator.set_state(state)
[docs] class TokenBatchSampler(SentenceBatchSampler): """ Wraps another sampler to yield a mini-batch of indices based on num of tokens (incl. padding). An instance longer than dataset.max_len or shorter than dataset.min_len will be filtered out. * no bucketing implemented .. warning:: In DDP, we shouldn't use TokenBatchSampler for prediction, because we cannot ensure that the data points will be distributed evenly across devices. `ddp_merge()` (`dist.all_gather()`) called in `predict()` can get stuck. :param sampler: Base sampler. Can be any iterable object :param batch_size: Size of mini-batch. :param drop_last: If `True`, the sampler will drop the last batch if its size would be less than `batch_size` """ def __iter__(self): """yields list of indices""" batch = [] max_tokens = 0 d = self.sampler.data_source for idx in self.sampler: _, src, trg = d[idx] # call __getitem__() if src is not None: # otherwise drop instance src_len = 0 if src is None else len(src) trg_len = 0 if trg is None else len(trg) n_tokens = 0 if src_len == 0 else max(src_len + 1, trg_len + 1) batch.append(idx) if n_tokens > max_tokens: max_tokens = n_tokens if max_tokens * len(batch) >= self.batch_size: yield batch batch = [] max_tokens = 0 if len(batch) > 0: if not self.drop_last: yield batch else: logger.warning(f"Drop indices {batch}.") def __len__(self): raise NotImplementedError