# coding: utf-8
"""
Training module
"""
import heapq
import math
import time
from collections import OrderedDict
from contextlib import nullcontext
from pathlib import Path
from typing import Dict, List, Tuple
import torch
from torch.utils.data import Dataset
from torch.utils.tensorboard import SummaryWriter
from joeynmt.batch import Batch
from joeynmt.builders import build_gradient_clipper, build_optimizer, build_scheduler
from joeynmt.config import (
TestConfig,
TrainConfig,
log_config,
parse_global_args,
set_validation_args,
)
from joeynmt.helpers import (
delete_ckpt,
load_checkpoint,
store_attention_plots,
symlink_update,
write_list_to_file,
)
from joeynmt.helpers_for_ddp import (
ddp_cleanup,
ddp_reduce,
ddp_setup,
ddp_synchronize,
get_logger,
use_ddp,
)
from joeynmt.model import Model
from joeynmt.prediction import predict, prepare, test
logger = get_logger(__name__)
[docs]
class TrainManager:
"""
Manages training loop, validations, learning rate scheduling
and early stopping.
"""
# pylint: disable=too-many-instance-attributes
def __init__(
self,
rank: int,
model: Model,
model_dir: Path,
device: torch.device,
n_gpu: int = 0,
num_workers: int = 0,
autocast: Dict = None,
seed: int = 42,
train_args: TrainConfig = None,
dev_args: TestConfig = None
) -> None:
"""
Creates a new TrainManager for a model, specified as in configuration.
:param model: torch module defining the model
:param model_dir: directory to save ckpts
:param device: torch device
:param n_gpu: number of gpus. 0 if cpu.
:param num_workers: number of multiprocess workers.
:param autocast: autocast context
:param seed: random seed
:param train_args: config args for training
:param dev_args: config args for validation
"""
self.rank = rank
self.args = train_args # config for training
self.dev_cfg = dev_args # config for geedy decoding
self.seed = seed
self.model_dir = model_dir
# access these variables only when rank == 0
if self.rank == 0:
# tensorboard
self.tb_writer = SummaryWriter(
log_dir=(model_dir / "tensorboard").as_posix()
)
# save/delete checkpoints
self.ckpt_queue: List[Tuple[float, Path]] = [] # heap queue
# model
self.model = model
# CPU / GPU
self.device = device
self.n_gpu = n_gpu
self.num_workers = num_workers
# optimization
self.clip_grad_fun = build_gradient_clipper(cfg=self.args._asdict())
self.optimizer = build_optimizer(
cfg=self.args._asdict(), parameters=self.model.parameters()
)
# fp16
self.scaler = torch.cuda.amp.GradScaler(enabled=autocast["enabled"]) \
if self.device.type == "cuda" else None
self.autocast = autocast
# learning rate scheduling
self.scheduler, self.scheduler_step_at = build_scheduler(
cfg=self.args._asdict(),
scheduler_mode="min" if self.args.minimize_metric else "max",
optimizer=self.optimizer,
hidden_size=self.model.encoder._output_size,
)
# Placeholder so that we can use the train_iter in other functions.
self.train_iter, self.train_iter_state = None, None
# initialize training statistics
self.stats = self.TrainStatistics(minimize_metric=self.args.minimize_metric)
# load model parameters
if self.args.load_model is not None:
self.init_from_checkpoint(
self.args.load_model,
reset_best_ckpt=self.args.reset_best_ckpt,
reset_scheduler=self.args.reset_scheduler,
reset_optimizer=self.args.reset_optimizer,
reset_iter_state=self.args.reset_iter_state,
)
for layer_name, load_path in [
("encoder", self.args.load_encoder),
("decoder", self.args.load_decoder),
]:
if load_path is not None:
self.init_layers(path=load_path, layer=layer_name)
def _save_checkpoint(self, new_best: bool, score: float) -> None:
"""
Save the model's current parameters and the training state to a checkpoint.
The training state contains the total number of training steps, the total number
of training tokens, the best checkpoint score and iteration so far, and
optimizer and scheduler states.
:param new_best: This boolean signals which symlink we will use for the new
checkpoint. If it is true, we update best.ckpt.
:param score: Validation score which is used as key of heap queue. if score is
float('nan'), the queue won't be updated.
"""
assert self.rank == 0, self.rank # execute this func only in the master process
model_path = Path(self.model_dir) / f"{self.stats.steps}.ckpt"
state = {
"model_state": self.model.state_dict(),
"optimizer_state": self.optimizer.state_dict(),
"scaler_state": (
self.scaler.state_dict() if self.scaler is not None else None
),
"scheduler_state": (
self.scheduler.state_dict() if self.scheduler is not None else None
),
"train_iter_state": self.train_iter.batch_sampler.get_state(),
"stats_state": self.stats.state_dict(),
}
torch.save(state, model_path.as_posix())
logger.info("Checkpoint saved in %s.", model_path)
# update symlink
symlink_target = Path(f"{self.stats.steps}.ckpt")
# last symlink
last_path = Path(self.model_dir) / "latest.ckpt"
prev_path = symlink_update(symlink_target, last_path) # update always
# best symlink
best_path = Path(self.model_dir) / "best.ckpt"
if new_best:
prev_path = symlink_update(symlink_target, best_path)
assert best_path.resolve().stem == str(self.stats.best_ckpt_iter)
# push to and pop from the heap queue
to_delete = None
if not math.isnan(score) and self.args.keep_best_ckpts > 0:
if len(self.ckpt_queue) < self.args.keep_best_ckpts: # no pop, push only
heapq.heappush(self.ckpt_queue, (score, model_path))
else: # push + pop the worst one in the queue
if self.args.minimize_metric:
# pylint: disable=protected-access
heapq._heapify_max(self.ckpt_queue)
to_delete = heapq._heappop_max(self.ckpt_queue)
heapq.heappush(self.ckpt_queue, (score, model_path))
# pylint: enable=protected-access
else:
to_delete = heapq.heappushpop(self.ckpt_queue, (score, model_path))
if to_delete is not None:
assert to_delete[1] != model_path # don't delete the last ckpt
if to_delete[1].stem != best_path.resolve().stem:
delete_ckpt(to_delete[1]) # don't delete the best ckpt
assert len(self.ckpt_queue) <= self.args.keep_best_ckpts
# remove old symlink target if not in queue after push/pop
if prev_path is not None and prev_path.stem not in [
c[1].stem for c in self.ckpt_queue
]:
delete_ckpt(prev_path)
[docs]
def init_from_checkpoint(
self,
path: Path,
reset_best_ckpt: bool = False,
reset_scheduler: bool = False,
reset_optimizer: bool = False,
reset_iter_state: bool = False,
) -> None:
"""
Initialize the trainer from a given checkpoint file.
This checkpoint file contains not only model parameters, but also
scheduler and optimizer states, see `self._save_checkpoint`.
:param path: path to checkpoint
:param reset_best_ckpt: reset tracking of the best checkpoint,
use for domain adaptation with a new dev set.
:param reset_scheduler: reset the learning rate scheduler, and do not
use the one stored in the checkpoint.
:param reset_optimizer: reset the optimizer, and do not use the one
stored in the checkpoint.
:param reset_iter_state: reset the sampler's internal state and do not
use the one stored in the checkpoint.
"""
logger.info("Loading model from %s", path)
map_location = {"cuda:0": f"cuda:{self.rank}"} if use_ddp() else self.device
model_checkpoint = load_checkpoint(path=path, map_location=map_location)
# restore model and optimizer parameters
self.model.load_state_dict(model_checkpoint["model_state"])
if not reset_optimizer:
self.optimizer.load_state_dict(model_checkpoint["optimizer_state"])
if "scaler_state" in model_checkpoint and self.scaler is not None:
self.scaler.load_state_dict(model_checkpoint["scaler_state"])
else:
logger.info("Reset optimizer.")
if not reset_scheduler:
if (
model_checkpoint["scheduler_state"] is not None
and self.scheduler is not None
):
self.scheduler.load_state_dict(model_checkpoint["scheduler_state"])
else:
logger.info("Reset scheduler.")
if not reset_best_ckpt:
# for backwards compatibility:
stats_state = model_checkpoint.get(
"stats_state", {
"epochs": model_checkpoint.get("epochs", 1),
"steps": model_checkpoint.get("steps", 0),
"total_tokens": model_checkpoint.get("total_tokens", 0),
"total_correct": model_checkpoint.get("total_correct", 0),
"best_ckpt_score": model_checkpoint.get("best_ckpt_score", 0.0),
"best_ckpt_iter": model_checkpoint.get("best_ckpt_iteration", 0),
}
)
# restore TrainStatistics
self.stats.load_state_dict(stats_state)
else:
logger.info("Reset tracking of the best checkpoint.")
if not reset_iter_state:
# restore counters
assert "train_iter_state" in model_checkpoint
self.train_iter_state = model_checkpoint["train_iter_state"].cpu()
# TODO: save / restore epoch no, since DistirbutedSampler resets generator
# based on the initial seed + epoch no
else:
# reset counters if explicitly 'train_iter_state: True' in config
logger.info("Reset data iterator (random seed: {%d}).", self.seed)
[docs]
def init_layers(self, path: Path, layer: str) -> None:
"""
Initialize encoder decoder layers from a given checkpoint file.
:param path: path to checkpoint
:param layer: layer name; 'encoder' or 'decoder' expected
"""
assert path is not None
layer_state_dict = OrderedDict()
logger.info("Loading %s laysers from %s", layer, path)
map_location = {"cuda:0": f"cuda:{self.rank}"} if use_ddp() else self.device
ckpt = load_checkpoint(path=path, map_location=map_location)
for k, v in ckpt["model_state"].items():
if k.startswith(layer):
layer_state_dict[k] = v
self.model.load_state_dict(layer_state_dict, strict=False)
[docs]
def train_and_validate(self, train_data: Dataset, valid_data: Dataset) -> None:
"""
Train the model and validate it from time to time on the validation set.
:param train_data: training data
:param valid_data: validation data
"""
# pylint: disable=too-many-branches,too-many-statements
# dataloader
self.train_iter = train_data.make_iter(
batch_size=self.args.batch_size,
batch_type=self.args.batch_type,
seed=self.seed,
shuffle=self.args.shuffle,
num_workers=self.num_workers,
device=self.device,
eos_index=self.model.eos_index,
pad_index=self.model.pad_index,
)
# TODO: set iter state in the first epoch loop (?)
if self.train_iter_state is not None:
self.train_iter.batch_sampler.set_state(self.train_iter_state)
# training config
batch_size_per_device = self.args.batch_size
effective_batch_size = self.args.batch_size * self.args.batch_multiplier
if use_ddp():
effective_batch_size = effective_batch_size * self.n_gpu
elif self.n_gpu > 1:
batch_size_per_device = batch_size_per_device // self.n_gpu
logger.info(
"Train config:\n"
"\tdevice: %s\n"
"\tn_gpu: %d\n"
"\tddp training: %r\n"
"\t16-bits training: %r\n"
"\tgradient accumulation: %d\n"
"\tbatch size per device: %d\n"
"\teffective batch size (w. parallel & accumulation): %d",
self.device.type,
self.n_gpu,
use_ddp(),
self.autocast["enabled"],
self.args.batch_multiplier,
batch_size_per_device,
effective_batch_size,
)
#################################################################
# simplify accumulation logic:
#################################################################
# for epoch in range(epochs):
# model.zero_grad()
# epoch_loss = 0.0
# batch_loss = 0.0
# for i, batch in enumerate(train_iter):
#
# # gradient accumulation:
# # loss.backward() will be called in _train_step()
# batch_loss += _train_step(batch)
#
# if (i + 1) % args.batch_multiplier == 0:
# optimizer.step() # update!
# model.zero_grad() # reset gradients
# steps += 1 # increment counter
#
# epoch_loss += batch_loss # accumulate batch loss
# batch_loss = 0 # reset batch loss
#
# # leftovers are just ignored.
# # (see `drop_last` arg in train_iter.batch_sampler)
#################################################################
# pylint: disable=too-many-nested-blocks
try:
for epoch_no in range(self.stats.epochs, self.args.epochs + 1, 1):
ddp_synchronize() # wait for all processes before starting an epoch
logger.info("EPOCH %d", epoch_no)
self.stats.epochs = epoch_no
if self.scheduler_step_at == "epoch":
self.scheduler.step(epoch=epoch_no)
# resuffle/resample data
train_data.seed = self.seed + epoch_no
valid_data.seed = self.seed + epoch_no
self.train_iter.batch_sampler.set_seed(self.seed + epoch_no)
self.model.train()
self.model.zero_grad()
# reset statistics for each epoch
# Note: access these variables only when rank == 0
if self.rank == 0:
start_tokens = self.stats.total_tokens
start_correct = self.stats.total_correct
epoch_nseqs, epoch_ntokens, epoch_loss = 0, 0, 0
total_train_duration, total_valid_duration = 0, 0
total_batch_loss = 0
start = time.time()
batch: Batch # yield a joeynmt Batch object
for i, batch in enumerate(self.train_iter):
# get batch loss
batch_loss, correct_tokens = self._train_step(batch)
batch_nseqs = ddp_reduce(batch.nseqs, self.device, torch.long)
batch_ntokens = ddp_reduce(batch.ntokens, self.device, torch.long)
# increment token counter
if self.rank == 0:
total_batch_loss += batch_loss
epoch_nseqs += batch_nseqs.item()
epoch_ntokens += batch_ntokens.item()
self.stats.total_tokens += batch_ntokens.item()
self.stats.total_correct += correct_tokens
# update the model parameters!
if (i + 1) % self.args.batch_multiplier == 0:
# clip gradients (in-place)
if self.clip_grad_fun is not None:
self.clip_grad_fun(parameters=self.model.parameters())
# make gradient step
if self.scaler is None:
self.optimizer.step()
else:
self.scaler.step(self.optimizer)
self.scaler.update()
# decay lr
if self.scheduler_step_at == "step":
self.scheduler.step(self.stats.steps)
# reset gradients
self.model.zero_grad()
# increment step counter
self.stats.steps += 1
if self.stats.steps >= self.args.max_updates:
self.stats.is_max_update = True
# log learning progress (only in the master process)
if (
self.stats.steps % self.args.logging_freq == 0
and self.rank == 0
):
elapsed_time = time.time() - start - total_valid_duration
total_train_duration += elapsed_time
self._log_scores(
epoch_no, elapsed_time, start_tokens, start_correct,
total_batch_loss
)
# restart counter
start = time.time()
start_tokens = self.stats.total_tokens
start_correct = self.stats.total_correct
total_valid_duration = 0
# increment statistics
if self.rank == 0:
epoch_loss += total_batch_loss # accumulate loss
total_batch_loss = 0 # reset batch loss
# validate on the entire dev set
if self.stats.steps % self.args.validation_freq == 0:
# valid_duration includes time for model saving etc.
if self.rank == 0:
valid_start_time = time.time()
valid_data.seed = self.seed + self.stats.steps # set seed
self._validate(valid_data)
if self.rank == 0:
total_valid_duration += time.time() - valid_start_time
if ((self.stats.is_min_lr or self.stats.is_max_update)
and self.rank == 0): # noqa: E129
break # break batch loop (TODO: DDP?)
if ((self.stats.is_min_lr or self.stats.is_max_update)
and self.rank == 0): # noqa: E129
log_str = (
f"minimum lr {self.args.learning_rate_min}"
if self.stats.is_min_lr else
f"maximum num. of updates {self.args.max_updates}"
)
logger.info("Training ended since %s was reached.", log_str)
break # break epoch loop (TODO: DDP?)
if self.rank == 0:
total_train_duration += time.time() - start - total_valid_duration
logger.info(
"Epoch %3d, total training loss: %.2f, "
"num. of seqs: %d, num. of tokens: %d, %.4f[sec]", epoch_no,
epoch_loss, epoch_nseqs, epoch_ntokens, total_train_duration
)
else:
# epoch loop did not encounter a `break` (neither min_lr nor max_update)
logger.info("Training ended after %3d epochs.", epoch_no)
except KeyboardInterrupt:
logger.info("Interrupt at epoch %d, step %d.", epoch_no, self.stats.steps)
# TODO: in DDP, need to kill all processes (?)
else:
# KeyboardInterrupt not triggered
logger.info(
"Best validation result (greedy) at step %8d: %6.2f %s.",
self.stats.best_ckpt_iter, self.stats.best_ckpt_score,
self.args.early_stopping_metric
)
finally:
ddp_synchronize() # wait for all processes
if self.rank == 0:
self._save_checkpoint(False, float("nan")) # save current weights
self.tb_writer.close() # close Tensorboard writer
def _train_step(self, batch: Batch) -> Tuple[float, int]:
"""
Train the model on one batch: Compute the loss.
:param batch: training batch
:return:
- losses for batch (sum)
- number of correct tokens for batch (sum)
"""
# reactivate training
self.model.train()
# sort batch now by src length
batch.sort_by_src_length()
with torch.autocast(**self.autocast):
# get loss (run as during training with teacher forcing)
batch_loss, _, _, correct_tokens = self.model(
return_type="loss", **vars(batch)
)
# normalize batch loss
norm_batch_loss = batch.normalize(
batch_loss,
normalization=self.args.normalization,
n_gpu=self.n_gpu,
n_accumulation=self.args.batch_multiplier
)
# sum over multiple gpus
sum_correct_tokens = batch.normalize(correct_tokens, "sum", self.n_gpu)
# accumulate gradients
with self.model.no_sync() if use_ddp() else nullcontext():
if self.scaler is None:
norm_batch_loss.backward()
else:
self.scaler.scale(norm_batch_loss).backward()
# gather
norm_batch_loss = ddp_reduce(norm_batch_loss).item()
sum_correct_tokens = ddp_reduce(sum_correct_tokens).item()
return norm_batch_loss, sum_correct_tokens
def _validate(self, valid_data: Dataset) -> None:
"""Validation"""
# run prediction on dev set
prediction = predict(
model=self.model,
data=valid_data,
compute_loss=True,
device=self.device,
rank=self.rank,
n_gpu=self.n_gpu,
normalization=self.args.normalization,
args=self.dev_cfg,
autocast=self.autocast,
)
if self.rank == 0: # in the master process
assert prediction is not None
(
valid_scores,
valid_references,
valid_hypotheses,
valid_hypotheses_raw,
_, # valid_sequence_scores,
valid_attention_scores,
) = prediction
# write in tensorboard
for eval_metric, score in valid_scores.items():
if not math.isnan(score):
self.tb_writer.add_scalar(
f"valid/{eval_metric}", score, self.stats.steps
)
ckpt_score = valid_scores[self.args.early_stopping_metric]
# TODO: need to update scheduler even when rank != 0 (?)
if self.scheduler_step_at == "validation":
self.scheduler.step(metrics=ckpt_score)
# update new best
new_best = self.stats.is_best(ckpt_score)
if new_best:
self.stats.best_ckpt_score = ckpt_score
self.stats.best_ckpt_iter = self.stats.steps
logger.info(
"Hooray! New best validation result [%s]!",
self.args.early_stopping_metric,
)
# save checkpoints
is_better = (
self.stats.is_better(ckpt_score, self.ckpt_queue)
if len(self.ckpt_queue) > 0 else True
)
if self.args.keep_best_ckpts < 0 or is_better:
self._save_checkpoint(new_best, ckpt_score)
# append to validation report
self._add_report(valid_scores=valid_scores, new_best=new_best)
# log sample sentences
self._log_examples(
references=valid_references,
hypotheses=valid_hypotheses,
hypotheses_raw=valid_hypotheses_raw,
data=valid_data,
)
# store validation set outputs
write_list_to_file(
self.model_dir / f"{self.stats.steps}.hyps", valid_hypotheses
)
# store attention plots for selected valid sentences
if valid_attention_scores:
store_attention_plots(
attentions=valid_attention_scores,
targets=valid_hypotheses_raw,
sources=valid_data.get_list(
lang=valid_data.src_lang, tokenized=True, subsampled=True
),
indices=self.args.print_valid_sents,
output_prefix=(self.model_dir /
f"att.{self.stats.steps}").as_posix(),
tb_writer=self.tb_writer,
steps=self.stats.steps,
)
def _add_report(self, valid_scores: dict, new_best: bool = False) -> None:
"""
Append a one-line report to validation logging file.
:param valid_scores: validation evaluation score [eval_metric]
:param new_best: whether this is a new best model
"""
current_lr = self.optimizer.param_groups[0]["lr"]
valid_file = self.model_dir / "validations.txt"
with valid_file.open("a", encoding="utf-8") as opened_file:
score_str = "\t".join([f"Steps: {self.stats.steps}"] + [
f"{eval_metric}: {score:.5f}"
for eval_metric, score in valid_scores.items() if not math.isnan(score)
] + [f"LR: {current_lr:.8f}", "*" if new_best else ""])
opened_file.write(f"{score_str}\n")
def _log_examples(
self,
hypotheses: List[str],
references: List[str],
hypotheses_raw: List[List[str]],
data: Dataset,
) -> None:
"""
Log the `print_valid_sents` sentences from given examples.
:param hypotheses: decoded hypotheses (list of strings)
:param references: decoded references (list of strings)
:param hypotheses_raw: raw hypotheses (list of list of tokens)
:param data: dev Dataset
"""
for p in self.args.print_valid_sents:
if p >= len(hypotheses):
continue
logger.info("Example #%d", p)
# detokenized text
detokenized_src = data.tokenizer[data.src_lang].post_process(data.src[p])
logger.info("\tSource: %s", detokenized_src)
logger.info("\tReference: %s", references[p])
logger.info("\tHypothesis: %s", hypotheses[p])
# tokenized text
tokenized_src = data.tokenizer[data.src_lang](data.src[p], is_train=False)
tokenized_trg = data.tokenizer[data.trg_lang](data.trg[p], is_train=False)
logger.debug("\tTokenized source: %s", tokenized_src)
logger.debug("\tTokenized reference: %s", tokenized_trg)
logger.debug("\tTokenized hypothesis: %s", hypotheses_raw[p][:-1])
def _log_scores(
self, epoch_no: int, elapsed_time: float, start_tokens: int, start_correct: int,
total_batch_loss: float
) -> None:
"""Log training progress"""
elapsed_tok = self.stats.total_tokens - start_tokens
elapsed_correct = self.stats.total_correct - start_correct
steps = self.stats.steps
self.tb_writer.add_scalar("train/batch_loss", total_batch_loss, steps)
self.tb_writer.add_scalar(
"train/batch_acc", elapsed_correct / elapsed_tok, steps
)
# check current_lr
current_lr = self.optimizer.param_groups[0]["lr"]
if current_lr < self.args.learning_rate_min:
self.stats.is_min_lr = True
self.tb_writer.add_scalar("train/learning_rate", current_lr, steps)
logger.info(
"Epoch %3d, Step: %8d, Batch Loss: %12.6f, Batch Acc: %.6f, "
"Tokens per Sec: %8.0f, Lr: %.6f", epoch_no, steps, total_batch_loss,
elapsed_correct / elapsed_tok, elapsed_tok / elapsed_time, current_lr
)
[docs]
class TrainStatistics:
"""
Train Statistics
:param epochs: epoch counter
:param steps: global update step counter
:param is_min_lr: stop by reaching learning rate minimum
:param is_max_update: stop by reaching max num of updates
:param total_tokens: number of total tokens seen so far
:param best_ckpt_iter: store iteration point of best ckpt
:param minimize_metric: minimize or maximize score
:param total_correct: number of correct tokens seen so far
"""
def __init__(self, minimize_metric: bool = True) -> None:
self.epochs = 1
self.steps = 0
self.is_min_lr = False
self.is_max_update = False
self.total_tokens = 0
self.best_ckpt_iter = 0
self.minimize_metric = minimize_metric
self.best_ckpt_score = float('inf') if minimize_metric else float('-inf')
self.total_correct = 0
[docs]
def is_best(self, score) -> bool:
if self.minimize_metric:
is_best = score < self.best_ckpt_score
else:
is_best = score > self.best_ckpt_score
return is_best
[docs]
def is_better(self, score: float, heap_queue: list) -> bool:
assert len(heap_queue) > 0
if self.minimize_metric:
is_better = score < heapq.nlargest(1, heap_queue)[0][0]
else:
is_better = score > heapq.nsmallest(1, heap_queue)[0][0]
return is_better
[docs]
def state_dict(self) -> Dict:
"""Returns a dictionary of values necessary to reconstruct stats"""
return {
"epochs": self.epochs,
"steps": self.steps,
"total_tokens": self.total_tokens,
"total_correct": self.total_correct,
"best_ckpt_score": self.best_ckpt_score,
"best_ckpt_iter": self.best_ckpt_iter,
}
[docs]
def load_state_dict(self, state_dict: Dict) -> None:
"""Given a state_dict, this function reconstruct the state"""
self.epochs = state_dict["epochs"]
self.steps = state_dict["steps"]
self.total_tokens = state_dict["total_tokens"]
self.total_correct = state_dict["total_correct"]
self.best_ckpt_score = state_dict["best_ckpt_score"]
self.best_ckpt_iter = state_dict["best_ckpt_iter"]
[docs]
def train(rank: int, world_size: int, cfg: Dict, skip_test: bool = False) -> None:
"""
Main training function. After training, also test on test data if given.
:param rank: ddp local rank
:param world_size: ddp world size
:param cfg: configuration dict
:param skip_test: whether a test should be run or not after training
"""
if cfg.pop("use_ddp", False):
# initialize ddp
# TODO: make `master_addr` and `master_port` configurable
ddp_setup(rank, world_size, master_addr="localhost", master_port=12355)
# need to assign file handlers again, after multi-processes are spawned...
get_logger(__name__, log_file=Path(cfg["model_dir"]) / "train.log")
# write all entries of config to the log
log_config(cfg)
# parse args
args = parse_global_args(cfg, rank=rank, mode="train")
# prepare model and datasets
model, train_data, dev_data, test_data = prepare(args, rank=rank, mode="train")
dev_args = set_validation_args(args.test)
# for training management, e.g. early stopping and model selection
trainer = TrainManager(
model=model,
model_dir=args.model_dir,
device=args.device,
n_gpu=args.n_gpu,
rank=rank,
num_workers=args.num_workers,
autocast=args.autocast,
seed=args.seed,
train_args=args.train,
dev_args=dev_args
)
# train the model
trainer.train_and_validate(train_data=train_data, valid_data=dev_data)
if not skip_test:
ddp_synchronize() # wait for all processes
# predict with the best model on validation and test
# (if test data is available)
# load model checkpoint
ckpt = args.model_dir / "best.ckpt" # or from args.test.load_model?
map_location = {"cuda:0": f"cuda:{rank}"} if use_ddp() else args.device
model_checkpoint = load_checkpoint(ckpt, map_location=map_location)
model.load_state_dict(model_checkpoint["model_state"])
prepared = {"dev": dev_data, "test": test_data, "model": model}
test(
cfg=cfg,
output_path=(args.model_dir / f"{ckpt.stem}.hyps").as_posix(),
prepared=prepared,
)
else:
logger.info("Skipping test after training.")
ddp_cleanup()