Source code for sccross.models

"""
    scCross is a dDeep Learning-Based Model for integration, cross-dataset cross-modality generation and matched muti-omics simulation of single-cell multi-omics data. Our model can also maintain in-silico perturbations in cross-modality generation and can use in-silico perturbations to find key genes.
    Part of the sccross' code is adapted from MIT licensed projects GLUE and SCDIFF2.
    Thanks for these projects:

    Author: Zhi-Jie Cao
    Project: GLUE
    Ref: Cao Z J, Gao G. Multi-omics single-cell data integration and regulatory inference with graph-linked embedding[J].
    Nature Biotechnology, 2022, 40(10): 1458-1466.

    Author: Jun Ding
    Project: SCDIFF2
    Ref: Ding, J., Aronow, B. J., Kaminski, N., Kitzmiller, J., Whitsett, J. A., & Bar-Joseph, Z.
    (2018). Reconstructing differentiation networks and their regulation from time series
    single-cell expression data. Genome research, 28(3), 383-395.

"""


import os
from typing import Mapping


import numpy as np
from anndata import AnnData

from ..utils import logged, Kws
from .utils import Model


from .sccross import (AUTO, SCCROSSModel,
                     configure_dataset)


[docs]def load_model(fname: os.PathLike) -> Model: r""" Load model from file Parameters ---------- fname Specifies path to the file """ return Model.load(fname)
[docs]@logged def fit_SCCROSS( adatas: Mapping[str, AnnData], model: type = SCCROSSModel, init_kws: Kws = None, compile_kws: Kws = None, fit_kws: Kws = None, balance_kws: Kws = None ) -> SCCROSSModel: init_kws = init_kws or {} compile_kws = compile_kws or {} fit_kws = fit_kws or {} fit_SCCROSS.logger.info("Pretraining SCCROSS model...") pretrain_init_kws = init_kws.copy() pretrain_init_kws.update({"shared_batches": False}) pretrain_fit_kws = fit_kws.copy() pretrain_fit_kws.update({"align_burnin": np.inf, "safe_burnin": False}) if "directory" in pretrain_fit_kws: pretrain_fit_kws["directory"] = \ os.path.join(pretrain_fit_kws["directory"], "pretrain") pretrain = model(adatas, **pretrain_init_kws) pretrain.compile(**compile_kws) pretrain.fit(adatas, **pretrain_fit_kws) if "directory" in pretrain_fit_kws: pretrain.save(os.path.join(pretrain_fit_kws["directory"], "pretrain.dill")) fit_SCCROSS.logger.info("Fine-tuning SCCROSS model...") finetune_fit_kws = fit_kws.copy() if "directory" in finetune_fit_kws: finetune_fit_kws["directory"] = \ os.path.join(finetune_fit_kws["directory"], "fine-tune") finetune = model(adatas, **init_kws) finetune.adopt_pretrained_model(pretrain) finetune.compile(**compile_kws) finetune.fit(adatas, **finetune_fit_kws) if "directory" in finetune_fit_kws: finetune.save(os.path.join(finetune_fit_kws["directory"], "fine-tune.dill")) return finetune