sccross.models.utils

Probability distributions

Functions

autodevice

Get torch computation device automatically based on GPU availability and memory usage

freeze_running_stats

Selectively stops normalization layers from updating running stats

get_default_numpy_dtype

Get numpy dtype matching that of the pytorch default dtype

Classes

EarlyStopping

Early stop model training when loss no longer decreases

LRScheduler

Reduce learning rate on loss plateau

MSE

A “sham” distribution that outputs negative MSE on log_prob

Model

Abstract model class

RMSE

A “sham” distribution that outputs negative RMSE on log_prob

Tensorboard

Training logging via tensorboard

Trainer

Abstract trainer class

TrainingPlugin

Plugin used to extend the training process with certain functions

ZILN

Zero-inflated log-normal distribution with subsetting support

ZIN

Zero-inflated normal distribution with subsetting support

ZINB

Zero-inflated negative binomial distribution