sccross.models.sccross.SCCROSSTrainer

class sccross.models.sccross.SCCROSSTrainer(net, lam_data=None, lam_kl=None, lam_graph=None, lam_align=None, lam_sup=None, normalize_u=None, domain_weight=None, optim=None, lr=None, **kwargs)[source]

Bases: sccross.models.utils.Trainer

Methods

compute_losses

rtype

Mapping[str, Tensor]

compute_losses_first

rtype

Mapping[str, Tensor]

fit

Fit network

format_data

Format data tensors

get_losses

Get loss values for given data

load_state_dict

Load state from a state dict

state_dict

State dict

train_step

A single training step

val_step

A single validation step

Attributes

BURNIN_NOISE_EXAG

freeze_u

Whether to freeze cell embeddings

logger