# coding: utf-8
"""
This modules holds methods for generating predictions from a model.
"""
import os
import sys
from typing import List, Optional
import logging
import numpy as np
import torch
from torchtext.data import Dataset, Field
from joeynmt.helpers import bpe_postprocess, load_config, make_logger,\
get_latest_checkpoint, load_checkpoint, store_attention_plots
from joeynmt.metrics import bleu, chrf, token_accuracy, sequence_accuracy
from joeynmt.model import build_model, Model, _DataParallel
from joeynmt.search import run_batch
from joeynmt.batch import Batch
from joeynmt.data import load_data, make_data_iter, MonoDataset
from joeynmt.constants import UNK_TOKEN, PAD_TOKEN, EOS_TOKEN
from joeynmt.vocabulary import Vocabulary
logger = logging.getLogger(__name__)
# pylint: disable=too-many-arguments,too-many-locals,no-member,too-many-branches
[docs]def validate_on_data(model: Model, data: Dataset,
batch_size: int,
use_cuda: bool, max_output_length: int,
level: str, eval_metric: Optional[str],
n_gpu: int,
batch_class: Batch = Batch,
compute_loss: bool = False,
beam_size: int = 1, beam_alpha: int = -1,
batch_type: str = "sentence",
postprocess: bool = True,
bpe_type: str = "subword-nmt",
sacrebleu: dict = None,
n_best: int = 1) \
-> (float, float, float, List[str], List[List[str]], List[str],
List[str], List[List[str]], List[np.array]):
"""
Generate translations for the given data.
If `compute_loss` is True and references are given,
also compute the loss.
:param model: model module
:param data: dataset for validation
:param batch_size: validation batch size
:param batch_class: class type of batch
:param use_cuda: if True, use CUDA
:param max_output_length: maximum length for generated hypotheses
:param level: segmentation level, one of "char", "bpe", "word"
:param eval_metric: evaluation metric, e.g. "bleu"
:param n_gpu: number of GPUs
:param compute_loss: whether to computes a scalar loss
for given inputs and targets
:param beam_size: beam size for validation.
If <2 then greedy decoding (default).
:param beam_alpha: beam search alpha for length penalty,
disabled if set to -1 (default).
:param batch_type: validation batch type (sentence or token)
:param postprocess: if True, remove BPE segmentation from translations
:param bpe_type: bpe type, one of {"subword-nmt", "sentencepiece"}
:param sacrebleu: sacrebleu options
:param n_best: Amount of candidates to return
:return:
- current_valid_score: current validation score [eval_metric],
- valid_loss: validation loss,
- valid_ppl:, validation perplexity,
- valid_sources: validation sources,
- valid_sources_raw: raw validation sources (before post-processing),
- valid_references: validation references,
- valid_hypotheses: validation_hypotheses,
- decoded_valid: raw validation hypotheses (before post-processing),
- valid_attention_scores: attention scores for validation hypotheses
"""
assert batch_size >= n_gpu, "batch_size must be bigger than n_gpu."
if sacrebleu is None: # assign default value
sacrebleu = {"remove_whitespace": True, "tokenize": "13a"}
if batch_size > 1000 and batch_type == "sentence":
logger.warning(
"WARNING: Are you sure you meant to work on huge batches like "
"this? 'batch_size' is > 1000 for sentence-batching. "
"Consider decreasing it or switching to"
" 'eval_batch_type: token'.")
valid_iter = make_data_iter(
dataset=data, batch_size=batch_size, batch_type=batch_type,
shuffle=False, train=False)
valid_sources_raw = data.src
pad_index = model.src_vocab.stoi[PAD_TOKEN]
# disable dropout
model.eval()
# don't track gradients during validation
with torch.no_grad():
all_outputs = []
valid_attention_scores = []
total_loss = 0
total_ntokens = 0
total_nseqs = 0
for valid_batch in iter(valid_iter):
# run as during training to get validation loss (e.g. xent)
batch = batch_class(valid_batch, pad_index, use_cuda=use_cuda)
# sort batch now by src length and keep track of order
reverse_indexes = batch.sort_by_src_length()
sort_reverse_index = [[] for _ in range(len(reverse_indexes))]
for i, ix in enumerate(reverse_indexes):
for n in range(0, n_best):
sort_reverse_index[i].append(ix + n)
assert len(sort_reverse_index) == len(data)
# run as during training with teacher forcing
if compute_loss and batch.trg is not None:
batch_loss, _, _, _ = model(return_type="loss", **vars(batch))
if n_gpu > 1:
batch_loss = batch_loss.mean() # average on multi-gpu
total_loss += batch_loss
total_ntokens += batch.ntokens
total_nseqs += batch.nseqs
# run as during inference to produce translations
output, attention_scores = run_batch(
model=model, batch=batch, beam_size=beam_size,
beam_alpha=beam_alpha, max_output_length=max_output_length,
n_best=n_best)
# sort outputs back to original order
for reverse_index in sort_reverse_index:
all_outputs.append(output[reverse_index])
valid_attention_scores.append(
attention_scores[reverse_index]
if attention_scores is not None else [])
assert len(all_outputs) == len(data)
if compute_loss and total_ntokens > 0:
# total validation loss
valid_loss = total_loss
# exponent of token-level negative log prob
valid_ppl = torch.exp(total_loss / total_ntokens)
else:
valid_loss = -1
valid_ppl = -1
# decode back to symbols
decoded_valid = model.trg_vocab.arrays_to_sentences(
arrays=[output for output_group in all_outputs
for output in output_group],
cut_at_eos=True
)
# evaluate with metric on full dataset
join_char = " " if level in ["word", "bpe"] else ""
valid_sources = [join_char.join(s) for s in data.src]
valid_references = [join_char.join(t) for t in data.trg]
valid_hypotheses = [join_char.join(t) for t in decoded_valid]
# post-process
if level == "bpe" and postprocess:
valid_sources = [bpe_postprocess(s, bpe_type=bpe_type)
for s in valid_sources]
valid_references = [bpe_postprocess(v, bpe_type=bpe_type)
for v in valid_references]
valid_hypotheses = [bpe_postprocess(v, bpe_type=bpe_type)
for v in valid_hypotheses]
# if references are given, evaluate against them
if valid_references:
assert len(valid_hypotheses) == len(valid_references)
current_valid_score = 0
if eval_metric.lower() == 'bleu':
# this version does not use any tokenization
current_valid_score = bleu(
valid_hypotheses, valid_references,
tokenize=sacrebleu["tokenize"])
elif eval_metric.lower() == 'chrf':
current_valid_score = chrf(valid_hypotheses, valid_references,
remove_whitespace=sacrebleu["remove_whitespace"])
elif eval_metric.lower() == 'token_accuracy':
current_valid_score = token_accuracy( # supply List[List[str]]
list(decoded_valid), list(data.trg))
elif eval_metric.lower() == 'sequence_accuracy':
current_valid_score = sequence_accuracy(
valid_hypotheses, valid_references)
else:
current_valid_score = -1
return current_valid_score, valid_loss, valid_ppl, valid_sources, \
valid_sources_raw, valid_references, valid_hypotheses, \
decoded_valid, valid_attention_scores
[docs]def parse_test_args(cfg, mode="test"):
"""
parse test args
:param cfg: config object
:param mode: 'test' or 'translate'
:return:
"""
if "test" not in cfg["data"].keys():
raise ValueError("Test data must be specified in config.")
batch_size = cfg["training"].get(
"eval_batch_size", cfg["training"].get("batch_size", 1))
batch_type = cfg["training"].get(
"eval_batch_type", cfg["training"].get("batch_type", "sentence"))
use_cuda = (cfg["training"].get("use_cuda", False)
and torch.cuda.is_available())
device = torch.device("cuda" if use_cuda else "cpu")
if mode == 'test':
n_gpu = torch.cuda.device_count() if use_cuda else 0
k = cfg["testing"].get("beam_size", 1)
batch_per_device = batch_size*k // n_gpu if n_gpu > 1 else batch_size*k
logger.info("Process device: %s, n_gpu: %d, "
"batch_size per device: %d (with beam_size)",
device, n_gpu, batch_per_device)
eval_metric = cfg["training"]["eval_metric"]
elif mode == 'translate':
# in multi-gpu, batch_size must be bigger than n_gpu!
n_gpu = 1 if use_cuda else 0
logger.debug("Process device: %s, n_gpu: %d", device, n_gpu)
eval_metric = ""
level = cfg["data"]["level"]
max_output_length = cfg["training"].get("max_output_length", None)
# whether to use beam search for decoding, 0: greedy decoding
if "testing" in cfg.keys():
beam_size = cfg["testing"].get("beam_size", 1)
beam_alpha = cfg["testing"].get("alpha", -1)
postprocess = cfg["testing"].get("postprocess", True)
bpe_type = cfg["testing"].get("bpe_type", "subword-nmt")
sacrebleu = {"remove_whitespace": True, "tokenize": "13a"}
if "sacrebleu" in cfg["testing"].keys():
sacrebleu["remove_whitespace"] = cfg["testing"]["sacrebleu"] \
.get("remove_whitespace", True)
sacrebleu["tokenize"] = cfg["testing"]["sacrebleu"] \
.get("tokenize", "13a")
else:
beam_size = 1
beam_alpha = -1
postprocess = True
bpe_type = "subword-nmt"
sacrebleu = {"remove_whitespace": True, "tokenize": "13a"}
decoding_description = "Greedy decoding" if beam_size < 2 else \
"Beam search decoding with beam size = {} and alpha = {}". \
format(beam_size, beam_alpha)
tokenizer_info = f"[{sacrebleu['tokenize']}]" \
if eval_metric == "bleu" else ""
return batch_size, batch_type, use_cuda, device, n_gpu, level, \
eval_metric, max_output_length, beam_size, beam_alpha, \
postprocess, bpe_type, sacrebleu, decoding_description, \
tokenizer_info
# pylint: disable-msg=logging-too-many-args
[docs]def test(cfg_file,
ckpt: str,
batch_class: Batch = Batch,
output_path: str = None,
save_attention: bool = False,
datasets: dict = None) -> None:
"""
Main test function. Handles loading a model from checkpoint, generating
translations and storing them and attention plots.
:param cfg_file: path to configuration file
:param ckpt: path to checkpoint to load
:param batch_class: class type of batch
:param output_path: path to output
:param datasets: datasets to predict
:param save_attention: whether to save the computed attention weights
"""
cfg = load_config(cfg_file)
model_dir = cfg["training"]["model_dir"]
if len(logger.handlers) == 0:
_ = make_logger(model_dir, mode="test") # version string returned
# when checkpoint is not specified, take latest (best) from model dir
if ckpt is None:
ckpt = get_latest_checkpoint(model_dir)
try:
step = ckpt.split(model_dir+"/")[1].split(".ckpt")[0]
except IndexError:
step = "best"
# load the data
if datasets is None:
_, dev_data, test_data, src_vocab, trg_vocab = load_data(
data_cfg=cfg["data"], datasets=["dev", "test"])
data_to_predict = {"dev": dev_data, "test": test_data}
else: # avoid to load data again
data_to_predict = {"dev": datasets["dev"], "test": datasets["test"]}
src_vocab = datasets["src_vocab"]
trg_vocab = datasets["trg_vocab"]
# parse test args
batch_size, batch_type, use_cuda, device, n_gpu, level, eval_metric, \
max_output_length, beam_size, beam_alpha, postprocess, \
bpe_type, sacrebleu, decoding_description, tokenizer_info \
= parse_test_args(cfg, mode="test")
# load model state from disk
model_checkpoint = load_checkpoint(ckpt, use_cuda=use_cuda)
# build model and load parameters into it
model = build_model(cfg["model"], src_vocab=src_vocab, trg_vocab=trg_vocab)
model.load_state_dict(model_checkpoint["model_state"])
if use_cuda:
model.to(device)
# multi-gpu eval
if n_gpu > 1 and not isinstance(model, torch.nn.DataParallel):
model = _DataParallel(model)
for data_set_name, data_set in data_to_predict.items():
if data_set is None:
continue
dataset_file = cfg["data"][data_set_name] + "." + cfg["data"]["trg"]
logger.info("Decoding on %s set (%s)...", data_set_name, dataset_file)
#pylint: disable=unused-variable
score, loss, ppl, sources, sources_raw, references, hypotheses, \
hypotheses_raw, attention_scores = validate_on_data(
model, data=data_set, batch_size=batch_size,
batch_class=batch_class, batch_type=batch_type, level=level,
max_output_length=max_output_length, eval_metric=eval_metric,
use_cuda=use_cuda, compute_loss=False, beam_size=beam_size,
beam_alpha=beam_alpha, postprocess=postprocess,
bpe_type=bpe_type, sacrebleu=sacrebleu, n_gpu=n_gpu)
#pylint: enable=unused-variable
if "trg" in data_set.fields:
logger.info("%4s %s%s: %6.2f [%s]",
data_set_name, eval_metric, tokenizer_info,
score, decoding_description)
else:
logger.info("No references given for %s -> no evaluation.",
data_set_name)
if save_attention:
if attention_scores:
attention_name = "{}.{}.att".format(data_set_name, step)
attention_path = os.path.join(model_dir, attention_name)
logger.info("Saving attention plots. This might take a while..")
store_attention_plots(attentions=attention_scores,
targets=hypotheses_raw,
sources=data_set.src,
indices=range(len(hypotheses)),
output_prefix=attention_path)
logger.info("Attention plots saved to: %s", attention_path)
else:
logger.warning("Attention scores could not be saved. "
"Note that attention scores are not available "
"when using beam search. "
"Set beam_size to 1 for greedy decoding.")
if output_path is not None:
output_path_set = "{}.{}".format(output_path, data_set_name)
with open(output_path_set, mode="w", encoding="utf-8") as out_file:
for hyp in hypotheses:
out_file.write(hyp + "\n")
logger.info("Translations saved to: %s", output_path_set)
[docs]def translate(cfg_file: str,
ckpt: str,
output_path: str = None,
batch_class: Batch = Batch,
n_best: int = 1) -> None:
"""
Interactive translation function.
Loads model from checkpoint and translates either the stdin input or
asks for input to translate interactively.
The input has to be pre-processed according to the data that the model
was trained on, i.e. tokenized or split into subwords.
Translations are printed to stdout.
:param cfg_file: path to configuration file
:param ckpt: path to checkpoint to load
:param output_path: path to output file
:param batch_class: class type of batch
:param n_best: amount of candidates to display
"""
def _load_line_as_data(line):
""" Create a dataset from one line via a temporary file. """
# write src input to temporary file
tmp_name = "tmp"
tmp_suffix = ".src"
tmp_filename = tmp_name+tmp_suffix
with open(tmp_filename, "w") as tmp_file:
tmp_file.write("{}\n".format(line))
test_data = MonoDataset(path=tmp_name, ext=tmp_suffix,
field=src_field)
# remove temporary file
if os.path.exists(tmp_filename):
os.remove(tmp_filename)
return test_data
def _translate_data(test_data):
""" Translates given dataset, using parameters from outer scope. """
# pylint: disable=unused-variable
score, loss, ppl, sources, sources_raw, references, hypotheses, \
hypotheses_raw, attention_scores = validate_on_data(
model, data=test_data, batch_size=batch_size,
batch_class=batch_class, batch_type=batch_type, level=level,
max_output_length=max_output_length, eval_metric="",
use_cuda=use_cuda, compute_loss=False, beam_size=beam_size,
beam_alpha=beam_alpha, postprocess=postprocess,
bpe_type=bpe_type, sacrebleu=sacrebleu, n_gpu=n_gpu, n_best=n_best)
return hypotheses
cfg = load_config(cfg_file)
model_dir = cfg["training"]["model_dir"]
_ = make_logger(model_dir, mode="translate")
# version string returned
# when checkpoint is not specified, take oldest from model dir
if ckpt is None:
ckpt = get_latest_checkpoint(model_dir)
# read vocabs
src_vocab_file = cfg["data"].get("src_vocab", model_dir + "/src_vocab.txt")
trg_vocab_file = cfg["data"].get("trg_vocab", model_dir + "/trg_vocab.txt")
src_vocab = Vocabulary(file=src_vocab_file)
trg_vocab = Vocabulary(file=trg_vocab_file)
data_cfg = cfg["data"]
level = data_cfg["level"]
lowercase = data_cfg["lowercase"]
tok_fun = lambda s: list(s) if level == "char" else s.split()
src_field = Field(init_token=None, eos_token=EOS_TOKEN,
pad_token=PAD_TOKEN, tokenize=tok_fun,
batch_first=True, lower=lowercase,
unk_token=UNK_TOKEN,
include_lengths=True)
src_field.vocab = src_vocab
# parse test args
batch_size, batch_type, use_cuda, device, n_gpu, level, _, \
max_output_length, beam_size, beam_alpha, postprocess, \
bpe_type, sacrebleu, _, _ = parse_test_args(cfg, mode="translate")
# load model state from disk
model_checkpoint = load_checkpoint(ckpt, use_cuda=use_cuda)
# build model and load parameters into it
model = build_model(cfg["model"], src_vocab=src_vocab, trg_vocab=trg_vocab)
model.load_state_dict(model_checkpoint["model_state"])
if use_cuda:
model.to(device)
if not sys.stdin.isatty():
# input file given
test_data = MonoDataset(path=sys.stdin, ext="", field=src_field)
all_hypotheses = _translate_data(test_data)
if output_path is not None:
# write to outputfile if given
def write_to_file(output_path_set, hypotheses):
with open(output_path_set, mode="w", encoding="utf-8") \
as out_file:
for hyp in hypotheses:
out_file.write(hyp + "\n")
logger.info("Translations saved to: %s.", output_path_set)
if n_best > 1:
for n in range(n_best):
file_name, file_extension = os.path.splitext(output_path)
write_to_file(
"{}-{}{}".format(
file_name, n,
file_extension if file_extension else ""
),
[all_hypotheses[i]
for i in range(n, len(all_hypotheses), n_best)]
)
else:
write_to_file("{}".format(output_path), all_hypotheses)
else:
# print to stdout
for hyp in all_hypotheses:
print(hyp)
else:
# enter interactive mode
batch_size = 1
batch_type = "sentence"
while True:
try:
src_input = input("\nPlease enter a source sentence "
"(pre-processed): \n")
if not src_input.strip():
break
# every line has to be made into dataset
test_data = _load_line_as_data(line=src_input)
hypotheses = _translate_data(test_data)
print("JoeyNMT: Hypotheses ranked by score")
for i, hyp in enumerate(hypotheses):
print("JoeyNMT #{}: {}".format(i + 1, hyp))
except (KeyboardInterrupt, EOFError):
print("\nBye.")
break