Source code for sccross.models.utils

r"""
Probability distributions
"""


import torch.distributions as D
import torch.nn.functional as F


import functools
import os

import numpy as np
import pynvml
from torch.nn.modules.batchnorm import _NormBase

import pathlib
import shutil

import ignite.contrib.handlers.tensorboard_logger as tb
import parse

from torch.optim.lr_scheduler import ReduceLROnPlateau
import tempfile

from typing import Any, Iterable, List, Mapping, Optional
import dill
import ignite
import torch

from ..utils import DelayedKeyboardInterrupt, config, logged, EPS
from abc import abstractmethod

EPOCH_STARTED = ignite.engine.Events.EPOCH_STARTED
EPOCH_COMPLETED = ignite.engine.Events.EPOCH_COMPLETED
ITERATION_COMPLETED = ignite.engine.Events.ITERATION_COMPLETED
EXCEPTION_RAISED = ignite.engine.Events.EXCEPTION_RAISED
COMPLETED = ignite.engine.Events.COMPLETED


EPOCH_COMPLETED = ignite.engine.Events.EPOCH_COMPLETED
TERMINATE = ignite.engine.Events.TERMINATE
COMPLETED = ignite.engine.Events.COMPLETED


[docs]@logged class Trainer: r""" Abstract trainer class Parameters ---------- net Network module to be trained Note ---- Subclasses should populate ``required_losses``, and additionally define optimizers here. """ def __init__(self, net: torch.nn.Module) -> None: self.net = net self.required_losses: List[str] = []
[docs] @abstractmethod def train_step( self, engine: ignite.engine.Engine, data: List[torch.Tensor] ) -> Mapping[str, torch.Tensor]: r""" A single training step Parameters ---------- engine Training engine data Data of the training step Returns ------- loss_dict Dict containing training loss values """ raise NotImplementedError # pragma: no cover
[docs] @abstractmethod def val_step( self, engine: ignite.engine.Engine, data: List[torch.Tensor] ) -> Mapping[str, torch.Tensor]: r""" A single validation step Parameters ---------- engine Validation engine data Data of the validation step Returns ------- loss_dict Dict containing validation loss values """ raise NotImplementedError # pragma: no cover
[docs] def report_metrics( self, train_state: ignite.engine.State, val_state: Optional[ignite.engine.State] ) -> None: r""" Report loss values during training Parameters ---------- train_state Training engine state val_state Validation engine state """ if train_state.epoch % config.PRINT_LOSS_INTERVAL: return train_metrics = { key: float(f"{val:.3f}") for key, val in train_state.metrics.items() } val_metrics = { key: float(f"{val:.3f}") for key, val in val_state.metrics.items() } if val_state else None self.logger.info( "[Epoch %d] train=%s, val=%s, %.1fs elapsed", train_state.epoch, train_metrics, val_metrics, train_state.times["EPOCH_COMPLETED"] # Also includes validator time )
[docs] def fit( self, train_loader: Iterable, val_loader: Optional[Iterable] = None, max_epochs: int = 100, random_seed: int = 0, directory: Optional[os.PathLike] = None, plugins: Optional[List["TrainingPlugin"]] = None ) -> None: r""" Fit network Parameters ---------- train_loader Training data loader val_loader Validation data loader max_epochs Maximal number of epochs random_seed Random seed directory Training directory plugins Optional list of training plugins """ interrupt_delayer = DelayedKeyboardInterrupt() directory = pathlib.Path(directory or tempfile.mkdtemp(prefix=config.TMP_PREFIX)) self.logger.info("Using training directory: \"%s\"", directory) # Construct engines train_engine = ignite.engine.Engine(self.train_step) val_engine = ignite.engine.Engine(self.val_step) if val_loader else None delay_interrupt = interrupt_delayer.__enter__ train_engine.add_event_handler(EPOCH_STARTED, delay_interrupt) train_engine.add_event_handler(COMPLETED, delay_interrupt) # Exception handling train_engine.add_event_handler(ITERATION_COMPLETED, ignite.handlers.TerminateOnNan()) @train_engine.on(EXCEPTION_RAISED) def _handle_exception(engine, e): if isinstance(e, KeyboardInterrupt) and config.ALLOW_TRAINING_INTERRUPTION: self.logger.info("Stopping training due to user interrupt...") engine.terminate() else: raise e # Compute metrics for item in self.required_losses: ignite.metrics.Average( output_transform=lambda output, item=item: output[item] ).attach(train_engine, item) if val_engine: ignite.metrics.Average( output_transform=lambda output, item=item: output[item] ).attach(val_engine, item) if val_engine: @train_engine.on(EPOCH_COMPLETED) def _validate(engine): val_engine.run( val_loader, max_epochs=engine.state.epoch ) # Bumps max_epochs by 1 per training epoch, so validator resumes for 1 epoch @train_engine.on(EPOCH_COMPLETED) def _report_metrics(engine): self.report_metrics(engine.state, val_engine.state if val_engine else None) for plugin in plugins or []: plugin.attach( net=self.net, trainer=self, train_engine=train_engine, val_engine=val_engine, train_loader=train_loader, val_loader=val_loader, directory=directory ) restore_interrupt = lambda: interrupt_delayer.__exit__(None, None, None) train_engine.add_event_handler(EPOCH_COMPLETED, restore_interrupt) train_engine.add_event_handler(COMPLETED, restore_interrupt) # Start engines torch.manual_seed(random_seed) train_engine.run(train_loader, max_epochs=max_epochs) torch.cuda.empty_cache() # Works even if GPU is unavailable
[docs] def get_losses(self, loader: Iterable) -> Mapping[str, float]: r""" Get loss values for given data Parameters ---------- loader Data loader Returns ------- loss_dict Dict containing loss values """ engine = ignite.engine.Engine(self.val_step) for item in self.required_losses: ignite.metrics.Average( output_transform=lambda output, item=item: output[item] ).attach(engine, item) engine.run(loader, max_epochs=1) torch.cuda.empty_cache() # Works even if GPU is unavailable return engine.state.metrics
[docs] def state_dict(self) -> Mapping[str, Any]: r""" State dict Returns ------- state_dict State dict """ return {}
[docs] def load_state_dict(self, state_dict: Mapping[str, Any]) -> None: r""" Load state from a state dict Parameters ---------- state_dict State dict """
[docs]@logged class Model: r""" Abstract model class Parameters ---------- net Network type *args Positional arguments are passed to the network constructor **kwargs Keyword arguments are passed to the network constructor Note ---- Subclasses may override arguments for API definition. """ NET_TYPE = torch.nn.Module TRAINER_TYPE = Trainer def __init__(self, *args, **kwargs) -> None: self._net = self.NET_TYPE(*args, **kwargs) self._trainer: Optional[Trainer] = None # Constructed upon compile @property def net(self) -> torch.nn.Module: r""" Neural network module in the model (read-only) """ return self._net @property def trainer(self) -> Trainer: r""" Trainer of the neural network module (read-only) """ if self._trainer is None: raise RuntimeError( "No trainer has been registered! " "Please call `.compile()` first." ) return self._trainer
[docs] def compile(self, *args, **kwargs) -> None: r""" Prepare model for training Parameters ---------- trainer Trainer type *args Positional arguments are passed to the trainer constructor **kwargs Keyword arguments are passed to the trainer constructor Note ---- Subclasses may override arguments for API definition. """ if self._trainer: self.logger.warning( "`compile` has already been called. " "Previous trainer will be overwritten!" ) self._trainer = self.TRAINER_TYPE(self.net, *args, **kwargs)
[docs] def fit(self, *args, **kwargs) -> None: r""" Alias of ``.trainer.fit``. Parameters ---------- *args Positional arguments are passed to the ``.trainer.fit`` method **kwargs Keyword arguments are passed to the ``.trainer.fit`` method Note ---- Subclasses may override arguments for API definition. """ self.trainer.fit(*args, **kwargs)
[docs] def get_losses(self, *args, **kwargs) -> Mapping[str, float]: r""" Alias of ``.trainer.get_losses``. Parameters ---------- *args Positional arguments are passed to the ``.trainer.get_losses`` method **kwargs Keyword arguments are passed to the ``.trainer.get_losses`` method Returns ------- loss_dict Dict containing loss values """ return self.trainer.get_losses(*args, **kwargs)
[docs] def save(self, fname: os.PathLike) -> None: r""" Save model to file Parameters ---------- file Specifies path to the file Note ---- Only the network is saved but not the trainer """ fname = pathlib.Path(fname) trainer_backup, self._trainer = self._trainer, None device_backup, self.net.device = self.net.device, torch.device("cpu") with fname.open("wb") as f: dill.dump(self, f, protocol=4, byref=False, recurse=True) self.net.device = device_backup self._trainer = trainer_backup
[docs] @staticmethod def load(fname: os.PathLike) -> "Model": r""" Load model from file Parameters ---------- fname Specifies path to the file Returns ------- model Loaded model """ fname = pathlib.Path(fname) with fname.open("rb") as f: model = dill.load(f) model.net.device = autodevice() return model
[docs]@logged class TrainingPlugin: r""" Plugin used to extend the training process with certain functions """
[docs] @abstractmethod def attach( self, net: torch.nn.Module, trainer: Trainer, train_engine: ignite.engine.Engine, val_engine: ignite.engine.Engine, train_loader: Iterable, val_loader: Optional[Iterable], directory: pathlib.Path ) -> None: r""" Attach custom handlers to training or validation engine Parameters ---------- net Network module trainer Trainer object train_engine Training engine val_engine Validation engine train_loader Training data loader val_loader Validation data loader directory Training directory """ raise NotImplementedError # pragma: no cover
#----------------------------- Utility functions -------------------------------
[docs]def freeze_running_stats(m: torch.nn.Module) -> None: r""" Selectively stops normalization layers from updating running stats Parameters ---------- m Network module """ if isinstance(m, _NormBase): m.eval()
[docs]def get_default_numpy_dtype() -> type: r""" Get numpy dtype matching that of the pytorch default dtype Returns ------- dtype Default numpy dtype """ return getattr(np, str(torch.get_default_dtype()).replace("torch.", ""))
[docs]@logged @functools.lru_cache(maxsize=1) def autodevice() -> torch.device: r""" Get torch computation device automatically based on GPU availability and memory usage Returns ------- device Computation device """ used_device = -1 if not config.CPU_ONLY: try: pynvml.nvmlInit() free_mems = np.array([ pynvml.nvmlDeviceGetMemoryInfo( pynvml.nvmlDeviceGetHandleByIndex(i) ).free for i in range(pynvml.nvmlDeviceGetCount()) ]) for item in config.MASKED_GPUS: free_mems[item] = -1 best_devices = np.where(free_mems == free_mems.max())[0] used_device = np.random.choice(best_devices, 1)[0] if free_mems[used_device] < 0: used_device = -1 except pynvml.NVMLError: pass if used_device == -1: autodevice.logger.info("Using CPU as computation device.") return torch.device("cpu") autodevice.logger.info("Using GPU %d as computation device.", used_device) os.environ["CUDA_VISIBLE_DEVICES"] = str(used_device) return torch.device("cuda")
[docs]class Tensorboard(TrainingPlugin): r""" Training logging via tensorboard """
[docs] def attach( self, net: torch.nn.Module, trainer: Trainer, train_engine: ignite.engine.Engine, val_engine: ignite.engine.Engine, train_loader: Iterable, val_loader: Optional[Iterable], directory: pathlib.Path ) -> None: tb_directory = directory / "tensorboard" if tb_directory.exists(): shutil.rmtree(tb_directory) tb_logger = tb.TensorboardLogger( log_dir=tb_directory, flush_secs=config.TENSORBOARD_FLUSH_SECS ) tb_logger.attach( train_engine, log_handler=tb.OutputHandler( tag="train", metric_names=trainer.required_losses ), event_name=EPOCH_COMPLETED ) if val_engine: tb_logger.attach( val_engine, log_handler=tb.OutputHandler( tag="val", metric_names=trainer.required_losses ), event_name=EPOCH_COMPLETED ) train_engine.add_event_handler(COMPLETED, tb_logger.close)
[docs]@logged class EarlyStopping(TrainingPlugin): r""" Early stop model training when loss no longer decreases Parameters ---------- monitor Loss to monitor patience Patience to stop early burnin Burn-in epochs to skip before initializing early stopping wait_n_lrs Wait n learning rate scheduling events before starting early stopping """ def __init__( self, monitor: str, patience: int, burnin: int = 0, wait_n_lrs: int = 0 ) -> None: super().__init__() self.monitor = monitor self.patience = patience self.burnin = burnin self.wait_n_lrs = wait_n_lrs
[docs] def attach( self, net: torch.nn.Module, trainer: Trainer, train_engine: ignite.engine.Engine, val_engine: ignite.engine.Engine, train_loader: Iterable, val_loader: Optional[Iterable], directory: pathlib.Path ) -> None: for item in directory.glob("checkpoint_*.pt"): item.unlink() score_engine = val_engine if val_engine else train_engine score_function = lambda engine: -score_engine.state.metrics[self.monitor] event_filter = ( lambda engine, event: event > self.burnin and engine.state.n_lrs >= self.wait_n_lrs ) if self.wait_n_lrs else ( lambda engine, event: event > self.burnin ) event = EPOCH_COMPLETED(event_filter=event_filter) # pylint: disable=not-callable train_engine.add_event_handler( event, ignite.handlers.Checkpoint( {"net": net, "trainer": trainer}, ignite.handlers.DiskSaver( directory, atomic=True, create_dir=True, require_empty=False ), score_function=score_function, filename_pattern="checkpoint_{global_step}.pt", n_saved=config.CHECKPOINT_SAVE_NUMBERS, global_step_transform=ignite.handlers.global_step_from_engine(train_engine) ) ) train_engine.add_event_handler( event, ignite.handlers.EarlyStopping( patience=self.patience, score_function=score_function, trainer=train_engine ) ) @train_engine.on(COMPLETED | TERMINATE) def _(engine): nan_flag = any( not bool(torch.isfinite(item).all()) for item in (engine.state.output or {}).values() ) ckpts = sorted([ parse.parse("checkpoint_{epoch:d}.pt", item.name).named["epoch"] for item in directory.glob("checkpoint_*.pt") ], reverse=True) if ckpts and nan_flag and train_engine.state.epoch == ckpts[0]: self.logger.warning( "The most recent checkpoint \"%d\" can be corrupted by NaNs, " "will thus be discarded.", ckpts[0] ) ckpts = ckpts[1:] if ckpts: self.logger.info("Restoring checkpoint \"%d\"...", ckpts[0]) loaded = torch.load(directory / f"checkpoint_{ckpts[0]}.pt") net.load_state_dict(loaded["net"]) trainer.load_state_dict(loaded["trainer"]) else: self.logger.info( "No usable checkpoint found. " "Skipping checkpoint restoration." )
[docs]@logged class LRScheduler(TrainingPlugin): r""" Reduce learning rate on loss plateau Parameters ---------- *optims Optimizers monitor Loss to monitor patience Patience to reduce learning rate burnin Burn-in epochs to skip before initializing learning rate scheduling """ def __init__( self, *optims: torch.optim.Optimizer, monitor: str = None, patience: int = None, burnin: int = 0 ) -> None: super().__init__() if monitor is None: raise ValueError("`monitor` must be specified!") self.monitor = monitor if patience is None: raise ValueError("`patience` must be specified!") self.schedulers = [ ReduceLROnPlateau(optim, patience=patience, verbose=True) for optim in optims ] self.burnin = burnin
[docs] def attach( self, net: torch.nn.Module, trainer: Trainer, train_engine: ignite.engine.Engine, val_engine: ignite.engine.Engine, train_loader: Iterable, val_loader: Optional[Iterable], directory: pathlib.Path ) -> None: score_engine = val_engine if val_engine else train_engine event_filter = lambda engine, event: event > self.burnin for scheduler in self.schedulers: scheduler.last_epoch = self.burnin train_engine.state.n_lrs = 0 @train_engine.on(EPOCH_COMPLETED(event_filter=event_filter)) # pylint: disable=not-callable def _(): update_flags = set() for scheduler in self.schedulers: old_lr = scheduler.optimizer.param_groups[0]["lr"] scheduler.step(score_engine.state.metrics[self.monitor]) new_lr = scheduler.optimizer.param_groups[0]["lr"] update_flags.add(new_lr != old_lr) if len(update_flags) != 1: raise RuntimeError("Learning rates are out of sync!") if update_flags.pop(): train_engine.state.n_lrs += 1 self.logger.info("Learning rate reduction: step %d", train_engine.state.n_lrs)
#-------------------------------- Distributions --------------------------------
[docs]class MSE(D.Distribution): r""" A "sham" distribution that outputs negative MSE on ``log_prob`` Parameters ---------- loc Mean of the distribution """ def __init__(self, loc: torch.Tensor) -> None: super().__init__(validate_args=False) self.loc = loc
[docs] def log_prob(self, value: torch.Tensor) -> None: return -F.mse_loss(self.loc, value)
@property def mean(self) -> torch.Tensor: return self.loc
[docs]class RMSE(MSE): r""" A "sham" distribution that outputs negative RMSE on ``log_prob`` Parameters ---------- loc Mean of the distribution """
[docs] def log_prob(self, value: torch.Tensor) -> None: return -F.mse_loss(self.loc, value).sqrt()
[docs]class ZIN(D.Normal): r""" Zero-inflated normal distribution with subsetting support Parameters ---------- zi_logits Zero-inflation logits loc Location of the normal distribution scale Scale of the normal distribution """ def __init__( self, zi_logits: torch.Tensor, loc: torch.Tensor, scale: torch.Tensor ) -> None: super().__init__(loc, scale) self.zi_logits = zi_logits
[docs] def log_prob(self, value: torch.Tensor) -> torch.Tensor: raw_log_prob = super().log_prob(value) zi_log_prob = torch.empty_like(raw_log_prob) z_mask = value.abs() < EPS z_zi_logits, nz_zi_logits = self.zi_logits[z_mask], self.zi_logits[~z_mask] zi_log_prob[z_mask] = ( raw_log_prob[z_mask].exp() + z_zi_logits.exp() + EPS ).log() - F.softplus(z_zi_logits) zi_log_prob[~z_mask] = raw_log_prob[~z_mask] - F.softplus(nz_zi_logits) return zi_log_prob
[docs]class ZILN(D.LogNormal): r""" Zero-inflated log-normal distribution with subsetting support Parameters ---------- zi_logits Zero-inflation logits loc Location of the log-normal distribution scale Scale of the log-normal distribution """ def __init__( self, zi_logits: torch.Tensor, loc: torch.Tensor, scale: torch.Tensor ) -> None: super().__init__(loc, scale) self.zi_logits = zi_logits
[docs] def log_prob(self, value: torch.Tensor) -> torch.Tensor: zi_log_prob = torch.empty_like(value) z_mask = value.abs() < EPS z_zi_logits, nz_zi_logits = self.zi_logits[z_mask], self.zi_logits[~z_mask] zi_log_prob[z_mask] = z_zi_logits - F.softplus(z_zi_logits) zi_log_prob[~z_mask] = D.LogNormal( self.loc[~z_mask], self.scale[~z_mask] ).log_prob(value[~z_mask]) - F.softplus(nz_zi_logits) return zi_log_prob
[docs]class ZINB(D.NegativeBinomial): r""" Zero-inflated negative binomial distribution Parameters ---------- zi_logits Zero-inflation logits total_count Total count of the negative binomial distribution logits Logits of the negative binomial distribution """ def __init__( self, zi_logits: torch.Tensor, total_count: torch.Tensor, logits: torch.Tensor = None ) -> None: super().__init__(total_count, logits=logits) self.zi_logits = zi_logits
[docs] def log_prob(self, value: torch.Tensor) -> torch.Tensor: raw_log_prob = super().log_prob(value) zi_log_prob = torch.empty_like(raw_log_prob) z_mask = value.abs() < EPS z_zi_logits, nz_zi_logits = self.zi_logits[z_mask], self.zi_logits[~z_mask] zi_log_prob[z_mask] = ( raw_log_prob[z_mask].exp() + z_zi_logits.exp() + EPS ).log() - F.softplus(z_zi_logits) zi_log_prob[~z_mask] = raw_log_prob[~z_mask] - F.softplus(nz_zi_logits) return zi_log_prob