Source code for joeynmt.builders

# coding: utf-8
"""
Collection of builder functions
"""
from functools import partial
from typing import Callable, Dict, Generator, Optional

import torch
from torch import nn
from torch.optim import Optimizer
from torch.optim.lr_scheduler import (
    ExponentialLR,
    ReduceLROnPlateau,
    StepLR,
    _LRScheduler,
)

from joeynmt.config import ConfigurationError
from joeynmt.helpers_for_ddp import get_logger

logger = get_logger(__name__)


[docs] def build_activation(activation: str = "relu") -> Callable: """ Returns the activation function """ # pylint: disable=no-else-return if activation == "relu": return nn.ReLU elif activation == "gelu": return nn.GELU elif activation == "tanh": return torch.tanh elif activation == "swish": return nn.SiLU else: raise ConfigurationError( "Invalid activation function. Valid options: " "'relu', 'gelu', 'tanh', 'swish'." )
[docs] def build_gradient_clipper(cfg: Dict) -> Optional[Callable]: """ Define the function for gradient clipping as specified in configuration. If not specified, returns None. Current options: - "clip_grad_val": clip the gradients if they exceed this value, see `torch.nn.utils.clip_grad_value_` - "clip_grad_norm": clip the gradients if their norm exceeds this value, see `torch.nn.utils.clip_grad_norm_` :param cfg: dictionary with training configurations :return: clipping function (in-place) or None if no gradient clipping """ if cfg["clip_grad_val"] is not None and cfg["clip_grad_norm"] is not None: raise ConfigurationError( "You can only specify either clip_grad_val or clip_grad_norm." ) clip_grad_fun = None if cfg["clip_grad_val"] is not None: clip_grad_fun = partial( nn.utils.clip_grad_value_, clip_value=cfg["clip_grad_val"] ) elif cfg["clip_grad_norm"] is not None: clip_grad_fun = partial( nn.utils.clip_grad_norm_, max_norm=cfg["clip_grad_norm"] ) return clip_grad_fun
[docs] def build_optimizer(cfg: Dict, parameters: Generator) -> Optimizer: """ Create an optimizer for the given parameters as specified in config. Except for the weight decay and initial learning rate, default optimizer settings are used. Currently supported configuration settings for "optimizer": - "sgd" (default): see `torch.optim.SGD` - "adam": see `torch.optim.adam` - "adamw": see `torch.optim.adamw` - "adagrad": see `torch.optim.adagrad` - "adadelta": see `torch.optim.adadelta` - "rmsprop": see `torch.optim.RMSprop` The initial learning rate is set according to "learning_rate" in the config. The weight decay is set according to "weight_decay" in the config. If they are not specified, the initial learning rate is set to 3.0e-4, the weight decay to 0. Note that the scheduler state is saved in the checkpoint, so if you load a model for further training you have to use the same type of scheduler. :param cfg: configuration dictionary :param parameters: :return: optimizer """ optimizer_name = cfg.get("optimizer", "sgd").lower() kwargs = { "lr": cfg.get("learning_rate", 3.0e-4), "weight_decay": cfg.get("weight_decay", 0), } if optimizer_name == "adam": kwargs["betas"] = cfg.get("adam_betas", (0.9, 0.999)) optimizer = torch.optim.Adam(parameters, **kwargs) elif optimizer_name == "adamw": kwargs["betas"] = cfg.get("adam_betas", (0.0, 0.999)) optimizer = torch.optim.AdamW(parameters, **kwargs) elif optimizer_name == "adagrad": optimizer = torch.optim.Adagrad(parameters, **kwargs) elif optimizer_name == "adadelta": optimizer = torch.optim.Adadelta(parameters, **kwargs) elif optimizer_name == "rmsprop": optimizer = torch.optim.RMSprop(parameters, **kwargs) elif optimizer_name == "sgd": # default kwargs["momentum"] = cfg.get("momentum", 0.0) optimizer = torch.optim.SGD(parameters, **kwargs) else: raise ConfigurationError( "Invalid optimizer. Valid options: 'adam', " "'adamw', 'adagrad', 'adadelta', 'rmsprop', 'sgd'." ) logger.info( "%s(%s)", optimizer.__class__.__name__, ", ".join([f"{k}={v}" for k, v in kwargs.items()]), ) return optimizer
[docs] def build_scheduler( cfg: Dict, optimizer: Optimizer, scheduler_mode: str, hidden_size: int = 0, ) -> (Optional[_LRScheduler], Optional[str]): """ Create a learning rate scheduler if specified in config and determine when a scheduler step should be executed. Current options: - "plateau": see `torch.optim.lr_scheduler.ReduceLROnPlateau` - "decaying": see `torch.optim.lr_scheduler.StepLR` - "exponential": see `torch.optim.lr_scheduler.ExponentialLR` - "noam": see `joeynmt.builders.NoamScheduler` - "warmupexponentialdecay": see `joeynmt.builders.WarmupExponentialDecayScheduler` - "warmupinversesquareroot": see `joeynmt.builders.WarmupInverseSquareRootScheduler` If no scheduler is specified, returns (None, None) which will result in a constant learning rate. :param cfg: training configuration :param optimizer: optimizer for the scheduler, determines the set of parameters which the scheduler sets the learning rate for :param scheduler_mode: "min" or "max", depending on whether the validation score should be minimized or maximized. Only relevant for "plateau". :param hidden_size: encoder hidden size (required for NoamScheduler) :return: - scheduler: scheduler object, - scheduler_step_at: either "validation", "epoch", "step" or "none" """ scheduler, scheduler_step_at = None, None scheduler_name = cfg.get("scheduling", None) kwargs = {} if scheduler_name == "plateau": # learning rate scheduler kwargs = { "mode": scheduler_mode, "verbose": False, "threshold_mode": "abs", "eps": 0.0, "factor": cfg.get("decrease_factor", 0.1), "patience": cfg.get("patience", 10), } scheduler = ReduceLROnPlateau(optimizer=optimizer, **kwargs) # scheduler step is executed after every validation scheduler_step_at = "validation" elif scheduler_name == "decaying": kwargs = {"step_size": cfg.get("decaying_step_size", 1)} scheduler = StepLR(optimizer=optimizer, **kwargs) # scheduler step is executed after every epoch scheduler_step_at = "epoch" elif scheduler_name == "exponential": kwargs = {"gamma": cfg.get("decrease_factor", 0.99)} scheduler = ExponentialLR(optimizer=optimizer, **kwargs) # scheduler step is executed after every epoch scheduler_step_at = "epoch" elif scheduler_name == "noam": scheduler = NoamScheduler( optimizer=optimizer, hidden_size=hidden_size, factor=cfg.get("learning_rate_factor", 1), warmup=cfg.get("learning_rate_warmup", 4000), ) scheduler_step_at = "step" elif scheduler_name == "warmupexponentialdecay": scheduler = WarmupExponentialDecayScheduler( min_rate=cfg.get("learning_rate_min", 1.0e-5), decay_rate=cfg.get("learning_rate_decay", 0.1), warmup=cfg.get("learning_rate_warmup", 4000), peak_rate=cfg.get("learning_rate_peak", 1.0e-3), decay_length=cfg.get("learning_rate_decay_length", 10000), ) scheduler_step_at = "step" elif scheduler_name == "warmupinversesquareroot": lr = cfg.get("learning_rate", 1.0e-3) peak_rate = cfg.get("learning_rate_peak", lr) scheduler = WarmupInverseSquareRootScheduler( optimizer=optimizer, peak_rate=peak_rate, min_rate=cfg.get("learning_rate_min", 1.0e-5), warmup=cfg.get("learning_rate_warmup", 10000), ) scheduler_step_at = "step" else: raise ConfigurationError( "Invalid scheduler. Valid options: 'plateau', " "'decaying', 'exponential', 'noam', " "'warmupexponentialdecay', 'warmupinversesquareroot'." ) if scheduler is None: scheduler_step_at = "none" else: assert scheduler_step_at in {"validation", "epoch", "step", "none"} # print log if scheduler_name in [ "noam", "warmupexponentialdecay", "warmupinversesquareroot", ]: logger.info(scheduler) else: logger.info( "%s(%s)", scheduler.__class__.__name__, ", ".join([f"{k}={v}" for k, v in kwargs.items()]), ) return scheduler, scheduler_step_at
[docs] class BaseScheduler: """Base LR Scheduler decay at "step" """ def __init__(self, optimizer: torch.optim.Optimizer): """ :param optimizer: """ self.optimizer = optimizer self._step = 0 self._rate = 0 self._state_dict = {"step": self._step, "rate": self._rate}
[docs] def state_dict(self): """Returns dictionary of values necessary to reconstruct scheduler""" self._state_dict["step"] = self._step self._state_dict["rate"] = self._rate return self._state_dict
[docs] def load_state_dict(self, state_dict): """Given a state_dict, this function loads scheduler's state""" self._step = state_dict["step"] self._rate = state_dict["rate"]
[docs] def step(self, step): """Update parameters and rate""" self._step = step + 1 # sync with trainer.stats.steps rate = self._compute_rate() for p in self.optimizer.param_groups: p["lr"] = rate self._rate = rate
def _compute_rate(self): raise NotImplementedError
[docs] class NoamScheduler(BaseScheduler): """ The Noam learning rate scheduler used in "Attention is all you need" See Eq. 3 in https://arxiv.org/abs/1706.03762 """ def __init__( self, hidden_size: int, optimizer: torch.optim.Optimizer, factor: float = 1.0, warmup: int = 4000, ): """ Warm-up, followed by learning rate decay. :param hidden_size: :param optimizer: :param factor: decay factor :param warmup: number of warmup steps """ super().__init__(optimizer) self.warmup = warmup self.factor = factor self.hidden_size = hidden_size def _compute_rate(self): """Implement `lrate` above""" step = self._step upper_bound = min(step**(-0.5), step * self.warmup**(-1.5)) return self.factor * (self.hidden_size**(-0.5) * upper_bound)
[docs] def state_dict(self): """Returns dictionary of values necessary to reconstruct scheduler""" super().state_dict() self._state_dict["warmup"] = self.warmup self._state_dict["factor"] = self.factor self._state_dict["hidden_size"] = self.hidden_size return self._state_dict
[docs] def load_state_dict(self, state_dict): """Given a state_dict, this function loads scheduler's state""" super().load_state_dict(state_dict) self.warmup = state_dict["warmup"] self.factor = state_dict["factor"] self.hidden_size = state_dict["hidden_size"]
def __repr__(self): return ( f"{self.__class__.__name__}(warmup={self.warmup}, " f"factor={self.factor}, hidden_size={self.hidden_size})" )
[docs] class WarmupExponentialDecayScheduler(BaseScheduler): """ A learning rate scheduler similar to Noam, but modified: Keep the warm up period but make it so that the decay rate can be tuneable. The decay is exponential up to a given minimum rate. """ def __init__( self, optimizer: torch.optim.Optimizer, peak_rate: float = 1.0e-3, decay_length: int = 10000, warmup: int = 4000, decay_rate: float = 0.5, min_rate: float = 1.0e-5, ): """ Warm-up, followed by exponential learning rate decay. :param peak_rate: maximum learning rate at peak after warmup :param optimizer: :param decay_length: decay length after warmup :param decay_rate: decay rate after warmup :param warmup: number of warmup steps :param min_rate: minimum learning rate """ super().__init__(optimizer) self.warmup = warmup self.decay_length = decay_length self.peak_rate = peak_rate self.decay_rate = decay_rate self.min_rate = min_rate def _compute_rate(self): """Implement `lrate` above""" step = self._step warmup = self.warmup if step < warmup: rate = step * self.peak_rate / warmup else: exponent = (step - warmup) / self.decay_length rate = self.peak_rate * (self.decay_rate**exponent) return max(rate, self.min_rate)
[docs] def state_dict(self): """Returns dictionary of values necessary to reconstruct scheduler""" super().state_dict() self._state_dict["warmup"] = self.warmup self._state_dict["decay_length"] = self.decay_length self._state_dict["peak_rate"] = self.peak_rate self._state_dict["decay_rate"] = self.decay_rate self._state_dict["min_rate"] = self.min_rate return self._state_dict
[docs] def load_state_dict(self, state_dict): """Given a state_dict, this function loads scheduler's state""" super().load_state_dict(state_dict) self.warmup = state_dict["warmup"] self.decay_length = state_dict["decay_length"] self.peak_rate = state_dict["peak_rate"] self.decay_rate = state_dict["decay_rate"] self.min_rate = state_dict["min_rate"]
def __repr__(self): return ( f"{self.__class__.__name__}(warmup={self.warmup}, " f"decay_length={self.decay_length}, " f"decay_rate={self.decay_rate}, " f"peak_rate={self.peak_rate}, " f"min_rate={self.min_rate})" )
[docs] class WarmupInverseSquareRootScheduler(BaseScheduler): """ Decay the LR based on the inverse square root of the update number. In the warmup phase, we linearly increase the learning rate. After warmup, we decrease the learning rate as follows: ``` decay_factor = peak_rate * sqrt(warmup) # constant value lr = decay_factor / sqrt(step) ``` cf.) https://github.com/pytorch/fairseq/blob/main/fairseq/optim/lr_scheduler/inverse_square_root_schedule.py """ # noqa def __init__( self, optimizer: torch.optim.Optimizer, peak_rate: float = 1.0e-3, warmup: int = 10000, min_rate: float = 1.0e-5, ): """ Warm-up, followed by inverse square root learning rate decay. :param optimizer: :param peak_rate: maximum learning rate at peak after warmup :param warmup: number of warmup steps :param min_rate: minimum learning rate """ super().__init__(optimizer) self.warmup = warmup self.min_rate = min_rate self.peak_rate = peak_rate self.decay_rate = peak_rate * (warmup**0.5) # constant value def _compute_rate(self): """Implement `lrate` above""" step = self._step warmup = self.warmup if step < warmup: # linear warmup rate = step * self.peak_rate / warmup else: # decay prop. to the inverse square root of the update number rate = self.decay_rate * (step**-0.5) return max(rate, self.min_rate)
[docs] def state_dict(self): """Returns dictionary of values necessary to reconstruct scheduler""" super().state_dict() self._state_dict["warmup"] = self.warmup self._state_dict["peak_rate"] = self.peak_rate self._state_dict["decay_rate"] = self.decay_rate self._state_dict["min_rate"] = self.min_rate return self._state_dict
[docs] def load_state_dict(self, state_dict): """Given a state_dict, this function loads scheduler's state""" super().load_state_dict(state_dict) self.warmup = state_dict["warmup"] self.decay_rate = state_dict["decay_rate"] self.peak_rate = state_dict["peak_rate"] self.min_rate = state_dict["min_rate"]
def __repr__(self): return ( f"{self.__class__.__name__}(warmup={self.warmup}, " f"decay_rate={self.decay_rate:.6f}, peak_rate={self.peak_rate}, " f"min_rate={self.min_rate})" )