Source code for sccross.models.layers



import collections
from abc import abstractmethod
from typing import Optional, Tuple

import torch
import torch.distributions as D
import torch.nn.functional as F

from ..utils import EPS


from .utils import ZILN, ZIN, ZINB









[docs]class DataEncoder(torch.nn.Module): r""" Abstract data encoder Parameters ---------- in_features Input dimensionality out_features Output dimensionality h_depth Hidden layer depth h_dim Hidden layer dimensionality dropout Dropout rate """ def __init__( self, in_features: int, out_features: int, h_depth: int = 2, h_dim: int = 256, dropout: float = 0.2 ) -> None: super().__init__() self.h_depth = h_depth ptr_dim = in_features for layer in range(self.h_depth): setattr(self, f"linear_{layer}", torch.nn.Linear(ptr_dim, h_dim)) setattr(self, f"act_{layer}", torch.nn.LeakyReLU(negative_slope=0.2)) setattr(self, f"bn_{layer}", torch.nn.BatchNorm1d(h_dim)) setattr(self, f"dropout_{layer}", torch.nn.Dropout(p=dropout)) ptr_dim = h_dim self.loc = torch.nn.Linear(ptr_dim, out_features) self.std_lin = torch.nn.Linear(ptr_dim, out_features)
[docs] @abstractmethod def compute_l(self, x: torch.Tensor) -> Optional[torch.Tensor]: r""" Compute normalizer Parameters ---------- x Input data Returns ------- l Normalizer """ raise NotImplementedError # pragma: no cover
[docs] @abstractmethod def normalize( self, x: torch.Tensor, l: Optional[torch.Tensor] ) -> torch.Tensor: r""" Normalize data Parameters ---------- x Input data l Normalizer Returns ------- xnorm Normalized data """ raise NotImplementedError # pragma: no cover
[docs] def forward( # pylint: disable=arguments-differ self, x: torch.Tensor, xalt: torch.Tensor, lazy_normalizer: bool = True ) -> Tuple[D.Normal, Optional[torch.Tensor]]: r""" Encode data to sample latent distribution Parameters ---------- x Input data xalt Alternative input data lazy_normalizer Whether to skip computing `x` normalizer (just return None) if `xalt` is non-empty Returns ------- u Sample latent distribution normalizer Data normalizer Note ---- Normalization is always computed on `x`. If xalt is empty, the normalized `x` will be used as input to the encoder neural network, otherwise xalt is used instead. """ if xalt.numel(): l = self.compute_l(x) ptr = xalt else: l = self.compute_l(x) ptr = self.normalize(x, l) for layer in range(self.h_depth): ptr = getattr(self, f"linear_{layer}")(ptr) ptr = getattr(self, f"act_{layer}")(ptr) ptr = getattr(self, f"bn_{layer}")(ptr) ptr = getattr(self, f"dropout_{layer}")(ptr) loc = self.loc(ptr) std = F.softplus(self.std_lin(ptr)) + EPS return D.Normal(loc, std), l
[docs]class VanillaDataEncoder(DataEncoder): r""" Vanilla data encoder Parameters ---------- in_features Input dimensionality out_features Output dimensionality h_depth Hidden layer depth h_dim Hidden layer dimensionality dropout Dropout rate """
[docs] def compute_l(self, x: torch.Tensor) -> torch.Tensor: return x.sum(dim=1, keepdim=True)
# def compute_l(self, x: torch.Tensor) -> Optional[torch.Tensor]: # return None
[docs] def normalize( self, x: torch.Tensor, l: Optional[torch.Tensor] ) -> torch.Tensor: return x
[docs]class NBDataEncoder(torch.nn.Module): r""" Data encoder for negative binomial data Parameters ---------- in_features Input dimensionality out_features Output dimensionality h_depth Hidden layer depth h_dim Hidden layer dimensionality dropout Dropout rate """ TOTAL_COUNT = 1e4 def __init__( self, in_features: int, out_features: int, h_depth: int = 2, h_dim: int = 256, dropout: float = 0.2 ) -> None: super().__init__() self.h_depth = h_depth ptr_dim = in_features ptr_dim1 =ptr_dim #ptr_dim2 = 50 for layer in range(self.h_depth): setattr(self, f"linear1_{layer}", torch.nn.Linear(ptr_dim1, h_dim)) setattr(self, f"act1_{layer}", torch.nn.LeakyReLU(negative_slope=0.2)) setattr(self, f"bn1_{layer}", torch.nn.BatchNorm1d(h_dim)) setattr(self, f"dropout1_{layer}", torch.nn.Dropout(p=dropout)) ptr_dim1 = h_dim self.loc1 = torch.nn.Linear(ptr_dim1, out_features) self.std_lin1 = torch.nn.Linear(ptr_dim1, out_features)
[docs] def forward( # pylint: disable=arguments-differ self, x: torch.Tensor, xalt: torch.Tensor, lazy_normalizer: bool = True ) -> Tuple[D.Normal, Optional[torch.Tensor]]: r""" Encode data to sample latent distribution Parameters ---------- x Input data xalt Alternative input data lazy_normalizer Whether to skip computing `x` normalizer (just return None) if `xalt` is non-empty Returns ------- u Sample latent distribution normalizer Data normalizer Note ---- Normalization is always computed on `x`. If xalt is empty, the normalized `x` will be used as input to the encoder neural network, otherwise xalt is used instead. """ ptr1 = xalt l = self.compute_l(x) for layer in range(self.h_depth): ptr1 = getattr(self, f"linear1_{layer}")(ptr1) ptr1 = getattr(self, f"act1_{layer}")(ptr1) ptr1 = getattr(self, f"bn1_{layer}")(ptr1) ptr1 = getattr(self, f"dropout1_{layer}")(ptr1) loc = self.loc1(ptr1) std = F.softplus(self.std_lin1(ptr1)) + EPS return D.Normal(loc, std), l
[docs] def compute_l(self, x: torch.Tensor) -> torch.Tensor: return x.sum(dim=1, keepdim=True)
[docs] def normalize( self, x: torch.Tensor, l: torch.Tensor ) -> torch.Tensor: return (x * (self.TOTAL_COUNT / l)).log1p()
[docs]class ZEncoder(torch.nn.Module): def __init__( self, in_features: int, out_features: int, h_depth: int = 1, h_dim: int = 16, dropout: float = 0.2 ) -> None: super().__init__() self.h_depth = h_depth ptr_dim = in_features self.loc = torch.nn.Linear(ptr_dim, out_features) self.std_lin = torch.nn.Linear(ptr_dim, out_features)
[docs] def forward( # pylint: disable=arguments-differ self, x: torch.Tensor ) -> Tuple[D.Normal, Optional[torch.Tensor]]: ptr = x loc = self.loc(ptr) std = F.softplus(self.std_lin(ptr)) + EPS return D.Normal(loc, std)
[docs]class ZDecoder(torch.nn.Module): def __init__( self, in_features: int, out_features: int, h_depth: int = 1, h_dim: int = 50, dropout: float = 0.2 ) -> None: super().__init__() self.h_depth = h_depth ptr_dim = in_features self.loc = torch.nn.Linear(ptr_dim, out_features) self.std_lin = torch.nn.Linear(ptr_dim, out_features)
[docs] def forward( # pylint: disable=arguments-differ self, x: torch.Tensor ) -> Tuple[D.Normal, Optional[torch.Tensor]]: ptr = x loc = self.loc(ptr) std = F.softplus(self.std_lin(ptr)) + EPS return D.Normal(loc, std)
[docs]class DataDecoder(torch.nn.Module): r""" Abstract data decoder Parameters ---------- out_features Output dimensionality n_batches Number of batches """ def __init__(self, out_features: int, n_batches: int = 1) -> None: # pylint: disable=unused-argument super().__init__()
[docs] @abstractmethod def forward( # pylint: disable=arguments-differ self, u: torch.Tensor, v: torch.Tensor, b: torch.Tensor, l: Optional[torch.Tensor] ) -> D.Normal: r""" Decode data from sample and feature latent Parameters ---------- u Sample latent v Feature latent b Batch index l Optional normalizer Returns ------- recon Data reconstruction distribution """ raise NotImplementedError # pragma: no cover
[docs]class ZILNDataDecoder(DataDecoder): r""" Zero-inflated log-normal data decoder Parameters ---------- out_features Output dimensionality n_batches Number of batches """ def __init__(self, out_features: int, n_batches: int = 1) -> None: super().__init__(out_features, n_batches=n_batches) self.scale_lin = torch.nn.Parameter(torch.zeros(n_batches, out_features)) self.bias = torch.nn.Parameter(torch.zeros(n_batches, out_features)) self.zi_logits = torch.nn.Parameter(torch.zeros(n_batches, out_features)) self.std_lin = torch.nn.Parameter(torch.zeros(n_batches, out_features)) self.h_depth = 1 ptr_dim = 50 h_dim = out_features setattr(self, f"linear_", torch.nn.Linear(ptr_dim, h_dim))
[docs] def forward( self, u: torch.Tensor, b: torch.Tensor, l: Optional[torch.Tensor] ) -> ZILN: ptr = u ptr = getattr(self, f"linear_")(ptr) scale = F.softplus(self.scale_lin[b]) loc = scale * ptr + self.bias[b] std = F.softplus(self.std_lin[b]) + EPS return ZILN(self.zi_logits[b].expand_as(loc), loc, std)
[docs]class NBDataDecoder(DataDecoder): r""" Negative binomial data decoder Parameters ---------- out_features Output dimensionality n_batches Number of batches """ def __init__(self, out_features: int, n_batches: int = 1) -> None: super().__init__(out_features, n_batches=n_batches) self.scale_lin = torch.nn.Parameter(torch.zeros(n_batches, out_features)) self.bias = torch.nn.Parameter(torch.zeros(n_batches, out_features)) self.log_theta = torch.nn.Parameter(torch.zeros(n_batches, out_features)) self.h_depth = 1 ptr_dim = 50 h_dim = out_features #print(out_features) setattr(self, f"linear_", torch.nn.Linear(ptr_dim, h_dim))
[docs] def forward( self, u: torch.Tensor, b: torch.Tensor, l: torch.Tensor ) -> D.NegativeBinomial: scale = F.softplus(self.scale_lin[b]) ptr = u #print(u.shape) #print(v.shape) ptr = getattr(self, f"linear_")(ptr) logit_mu = scale * ptr+ self.bias[b] mu = F.softmax(logit_mu, dim=1)*l log_theta = self.log_theta[b] return D.NegativeBinomial( log_theta.exp(), logits=(mu + EPS).log() )
[docs]class Discriminator(torch.nn.Sequential, torch.nn.Module): r""" Domain discriminator Parameters ---------- in_features Input dimensionality out_features Output dimensionality h_depth Hidden layer depth h_dim Hidden layer dimensionality dropout Dropout rate """ def __init__( self, in_features: int, out_features: int, n_batches: int = 0, h_depth: int = 2, h_dim: Optional[int] = 256, dropout: float = 0.2 ) -> None: self.n_batches = n_batches od = collections.OrderedDict() ptr_dim = in_features + self.n_batches for layer in range(h_depth): od[f"linear_{layer}"] = torch.nn.Linear(ptr_dim, h_dim) od[f"act_{layer}"] = torch.nn.LeakyReLU(negative_slope=0.2) od[f"dropout_{layer}"] = torch.nn.Dropout(p=dropout) ptr_dim = h_dim od["pred"] = torch.nn.Linear(ptr_dim, out_features) super().__init__(od)
[docs] def forward(self, x: torch.Tensor, b: torch.Tensor) -> torch.Tensor: # pylint: disable=arguments-differ if self.n_batches: b_one_hot = F.one_hot(b, num_classes=self.n_batches) x = torch.cat([x, b_one_hot], dim=1) return super().forward(x)
[docs]class Discriminator_gen(torch.nn.Sequential, torch.nn.Module): r""" Domain discriminator Parameters ---------- in_features Input dimensionality out_features Output dimensionality h_depth Hidden layer depth h_dim Hidden layer dimensionality dropout Dropout rate """ def __init__( self, in_features: int, out_features: int, n_batches: int = 0, h_depth: int = 2, h_dim: Optional[int] = 256, dropout: float = 0.2 ) -> None: self.n_batches = n_batches od = collections.OrderedDict() ptr_dim = in_features + self.n_batches for layer in range(h_depth): od[f"linear_{layer}"] = torch.nn.Linear(ptr_dim, h_dim) od[f"act_{layer}"] = torch.nn.LeakyReLU(negative_slope=0.2) od[f"dropout_{layer}"] = torch.nn.Dropout(p=dropout) ptr_dim = h_dim od["pred"] = torch.nn.Linear(ptr_dim, out_features) super().__init__(od)
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: # pylint: disable=arguments-differ return super().forward(x)
[docs]class Classifier(torch.nn.Linear): r""" Linear label classifier Parameters ---------- in_features Input dimensionality out_features Output dimensionality """
[docs]class Prior(torch.nn.Module): r""" Prior distribution Parameters ---------- loc Mean of the normal distribution std Standard deviation of the normal distribution """ def __init__( self, loc: float = 0.0, std: float = 1.0 ) -> None: super().__init__() loc = torch.as_tensor(loc, dtype=torch.get_default_dtype()) std = torch.as_tensor(std, dtype=torch.get_default_dtype()) self.register_buffer("loc", loc) self.register_buffer("std", std)
[docs] def forward(self) -> D.Normal: return D.Normal(self.loc, self.std)