import collections
from abc import abstractmethod
from typing import Optional, Tuple
import torch
import torch.distributions as D
import torch.nn.functional as F
from ..utils import EPS
from .utils import ZILN, ZIN, ZINB
[docs]class DataEncoder(torch.nn.Module):
r"""
Abstract data encoder
Parameters
----------
in_features
Input dimensionality
out_features
Output dimensionality
h_depth
Hidden layer depth
h_dim
Hidden layer dimensionality
dropout
Dropout rate
"""
def __init__(
self, in_features: int, out_features: int,
h_depth: int = 2, h_dim: int = 256,
dropout: float = 0.2
) -> None:
super().__init__()
self.h_depth = h_depth
ptr_dim = in_features
for layer in range(self.h_depth):
setattr(self, f"linear_{layer}", torch.nn.Linear(ptr_dim, h_dim))
setattr(self, f"act_{layer}", torch.nn.LeakyReLU(negative_slope=0.2))
setattr(self, f"bn_{layer}", torch.nn.BatchNorm1d(h_dim))
setattr(self, f"dropout_{layer}", torch.nn.Dropout(p=dropout))
ptr_dim = h_dim
self.loc = torch.nn.Linear(ptr_dim, out_features)
self.std_lin = torch.nn.Linear(ptr_dim, out_features)
[docs] @abstractmethod
def compute_l(self, x: torch.Tensor) -> Optional[torch.Tensor]:
r"""
Compute normalizer
Parameters
----------
x
Input data
Returns
-------
l
Normalizer
"""
raise NotImplementedError # pragma: no cover
[docs] @abstractmethod
def normalize(
self, x: torch.Tensor, l: Optional[torch.Tensor]
) -> torch.Tensor:
r"""
Normalize data
Parameters
----------
x
Input data
l
Normalizer
Returns
-------
xnorm
Normalized data
"""
raise NotImplementedError # pragma: no cover
[docs] def forward( # pylint: disable=arguments-differ
self, x: torch.Tensor, xalt: torch.Tensor,
lazy_normalizer: bool = True
) -> Tuple[D.Normal, Optional[torch.Tensor]]:
r"""
Encode data to sample latent distribution
Parameters
----------
x
Input data
xalt
Alternative input data
lazy_normalizer
Whether to skip computing `x` normalizer (just return None)
if `xalt` is non-empty
Returns
-------
u
Sample latent distribution
normalizer
Data normalizer
Note
----
Normalization is always computed on `x`.
If xalt is empty, the normalized `x` will be used as input
to the encoder neural network, otherwise xalt is used instead.
"""
if xalt.numel():
l = self.compute_l(x)
ptr = xalt
else:
l = self.compute_l(x)
ptr = self.normalize(x, l)
for layer in range(self.h_depth):
ptr = getattr(self, f"linear_{layer}")(ptr)
ptr = getattr(self, f"act_{layer}")(ptr)
ptr = getattr(self, f"bn_{layer}")(ptr)
ptr = getattr(self, f"dropout_{layer}")(ptr)
loc = self.loc(ptr)
std = F.softplus(self.std_lin(ptr)) + EPS
return D.Normal(loc, std), l
[docs]class VanillaDataEncoder(DataEncoder):
r"""
Vanilla data encoder
Parameters
----------
in_features
Input dimensionality
out_features
Output dimensionality
h_depth
Hidden layer depth
h_dim
Hidden layer dimensionality
dropout
Dropout rate
"""
[docs] def compute_l(self, x: torch.Tensor) -> torch.Tensor:
return x.sum(dim=1, keepdim=True)
# def compute_l(self, x: torch.Tensor) -> Optional[torch.Tensor]:
# return None
[docs] def normalize(
self, x: torch.Tensor, l: Optional[torch.Tensor]
) -> torch.Tensor:
return x
[docs]class NBDataEncoder(torch.nn.Module):
r"""
Data encoder for negative binomial data
Parameters
----------
in_features
Input dimensionality
out_features
Output dimensionality
h_depth
Hidden layer depth
h_dim
Hidden layer dimensionality
dropout
Dropout rate
"""
TOTAL_COUNT = 1e4
def __init__(
self, in_features: int, out_features: int,
h_depth: int = 2, h_dim: int = 256,
dropout: float = 0.2
) -> None:
super().__init__()
self.h_depth = h_depth
ptr_dim = in_features
ptr_dim1 =ptr_dim
#ptr_dim2 = 50
for layer in range(self.h_depth):
setattr(self, f"linear1_{layer}", torch.nn.Linear(ptr_dim1, h_dim))
setattr(self, f"act1_{layer}", torch.nn.LeakyReLU(negative_slope=0.2))
setattr(self, f"bn1_{layer}", torch.nn.BatchNorm1d(h_dim))
setattr(self, f"dropout1_{layer}", torch.nn.Dropout(p=dropout))
ptr_dim1 = h_dim
self.loc1 = torch.nn.Linear(ptr_dim1, out_features)
self.std_lin1 = torch.nn.Linear(ptr_dim1, out_features)
[docs] def forward( # pylint: disable=arguments-differ
self, x: torch.Tensor, xalt: torch.Tensor,
lazy_normalizer: bool = True
) -> Tuple[D.Normal, Optional[torch.Tensor]]:
r"""
Encode data to sample latent distribution
Parameters
----------
x
Input data
xalt
Alternative input data
lazy_normalizer
Whether to skip computing `x` normalizer (just return None)
if `xalt` is non-empty
Returns
-------
u
Sample latent distribution
normalizer
Data normalizer
Note
----
Normalization is always computed on `x`.
If xalt is empty, the normalized `x` will be used as input
to the encoder neural network, otherwise xalt is used instead.
"""
ptr1 = xalt
l = self.compute_l(x)
for layer in range(self.h_depth):
ptr1 = getattr(self, f"linear1_{layer}")(ptr1)
ptr1 = getattr(self, f"act1_{layer}")(ptr1)
ptr1 = getattr(self, f"bn1_{layer}")(ptr1)
ptr1 = getattr(self, f"dropout1_{layer}")(ptr1)
loc = self.loc1(ptr1)
std = F.softplus(self.std_lin1(ptr1)) + EPS
return D.Normal(loc, std), l
[docs] def compute_l(self, x: torch.Tensor) -> torch.Tensor:
return x.sum(dim=1, keepdim=True)
[docs] def normalize(
self, x: torch.Tensor, l: torch.Tensor
) -> torch.Tensor:
return (x * (self.TOTAL_COUNT / l)).log1p()
[docs]class ZEncoder(torch.nn.Module):
def __init__(
self, in_features: int, out_features: int,
h_depth: int = 1, h_dim: int = 16,
dropout: float = 0.2
) -> None:
super().__init__()
self.h_depth = h_depth
ptr_dim = in_features
self.loc = torch.nn.Linear(ptr_dim, out_features)
self.std_lin = torch.nn.Linear(ptr_dim, out_features)
[docs] def forward( # pylint: disable=arguments-differ
self, x: torch.Tensor
) -> Tuple[D.Normal, Optional[torch.Tensor]]:
ptr = x
loc = self.loc(ptr)
std = F.softplus(self.std_lin(ptr)) + EPS
return D.Normal(loc, std)
[docs]class ZDecoder(torch.nn.Module):
def __init__(
self, in_features: int, out_features: int,
h_depth: int = 1, h_dim: int = 50,
dropout: float = 0.2
) -> None:
super().__init__()
self.h_depth = h_depth
ptr_dim = in_features
self.loc = torch.nn.Linear(ptr_dim, out_features)
self.std_lin = torch.nn.Linear(ptr_dim, out_features)
[docs] def forward( # pylint: disable=arguments-differ
self, x: torch.Tensor
) -> Tuple[D.Normal, Optional[torch.Tensor]]:
ptr = x
loc = self.loc(ptr)
std = F.softplus(self.std_lin(ptr)) + EPS
return D.Normal(loc, std)
[docs]class DataDecoder(torch.nn.Module):
r"""
Abstract data decoder
Parameters
----------
out_features
Output dimensionality
n_batches
Number of batches
"""
def __init__(self, out_features: int, n_batches: int = 1) -> None: # pylint: disable=unused-argument
super().__init__()
[docs] @abstractmethod
def forward( # pylint: disable=arguments-differ
self, u: torch.Tensor, v: torch.Tensor,
b: torch.Tensor, l: Optional[torch.Tensor]
) -> D.Normal:
r"""
Decode data from sample and feature latent
Parameters
----------
u
Sample latent
v
Feature latent
b
Batch index
l
Optional normalizer
Returns
-------
recon
Data reconstruction distribution
"""
raise NotImplementedError # pragma: no cover
[docs]class ZILNDataDecoder(DataDecoder):
r"""
Zero-inflated log-normal data decoder
Parameters
----------
out_features
Output dimensionality
n_batches
Number of batches
"""
def __init__(self, out_features: int, n_batches: int = 1) -> None:
super().__init__(out_features, n_batches=n_batches)
self.scale_lin = torch.nn.Parameter(torch.zeros(n_batches, out_features))
self.bias = torch.nn.Parameter(torch.zeros(n_batches, out_features))
self.zi_logits = torch.nn.Parameter(torch.zeros(n_batches, out_features))
self.std_lin = torch.nn.Parameter(torch.zeros(n_batches, out_features))
self.h_depth = 1
ptr_dim = 50
h_dim = out_features
setattr(self, f"linear_", torch.nn.Linear(ptr_dim, h_dim))
[docs] def forward(
self, u: torch.Tensor,
b: torch.Tensor, l: Optional[torch.Tensor]
) -> ZILN:
ptr = u
ptr = getattr(self, f"linear_")(ptr)
scale = F.softplus(self.scale_lin[b])
loc = scale * ptr + self.bias[b]
std = F.softplus(self.std_lin[b]) + EPS
return ZILN(self.zi_logits[b].expand_as(loc), loc, std)
[docs]class NBDataDecoder(DataDecoder):
r"""
Negative binomial data decoder
Parameters
----------
out_features
Output dimensionality
n_batches
Number of batches
"""
def __init__(self, out_features: int, n_batches: int = 1) -> None:
super().__init__(out_features, n_batches=n_batches)
self.scale_lin = torch.nn.Parameter(torch.zeros(n_batches, out_features))
self.bias = torch.nn.Parameter(torch.zeros(n_batches, out_features))
self.log_theta = torch.nn.Parameter(torch.zeros(n_batches, out_features))
self.h_depth = 1
ptr_dim = 50
h_dim = out_features
#print(out_features)
setattr(self, f"linear_", torch.nn.Linear(ptr_dim, h_dim))
[docs] def forward(
self, u: torch.Tensor,
b: torch.Tensor, l: torch.Tensor
) -> D.NegativeBinomial:
scale = F.softplus(self.scale_lin[b])
ptr = u
#print(u.shape)
#print(v.shape)
ptr = getattr(self, f"linear_")(ptr)
logit_mu = scale * ptr+ self.bias[b]
mu = F.softmax(logit_mu, dim=1)*l
log_theta = self.log_theta[b]
return D.NegativeBinomial(
log_theta.exp(),
logits=(mu + EPS).log()
)
[docs]class Discriminator(torch.nn.Sequential, torch.nn.Module):
r"""
Domain discriminator
Parameters
----------
in_features
Input dimensionality
out_features
Output dimensionality
h_depth
Hidden layer depth
h_dim
Hidden layer dimensionality
dropout
Dropout rate
"""
def __init__(
self, in_features: int, out_features: int, n_batches: int = 0,
h_depth: int = 2, h_dim: Optional[int] = 256,
dropout: float = 0.2
) -> None:
self.n_batches = n_batches
od = collections.OrderedDict()
ptr_dim = in_features + self.n_batches
for layer in range(h_depth):
od[f"linear_{layer}"] = torch.nn.Linear(ptr_dim, h_dim)
od[f"act_{layer}"] = torch.nn.LeakyReLU(negative_slope=0.2)
od[f"dropout_{layer}"] = torch.nn.Dropout(p=dropout)
ptr_dim = h_dim
od["pred"] = torch.nn.Linear(ptr_dim, out_features)
super().__init__(od)
[docs] def forward(self, x: torch.Tensor, b: torch.Tensor) -> torch.Tensor: # pylint: disable=arguments-differ
if self.n_batches:
b_one_hot = F.one_hot(b, num_classes=self.n_batches)
x = torch.cat([x, b_one_hot], dim=1)
return super().forward(x)
[docs]class Discriminator_gen(torch.nn.Sequential, torch.nn.Module):
r"""
Domain discriminator
Parameters
----------
in_features
Input dimensionality
out_features
Output dimensionality
h_depth
Hidden layer depth
h_dim
Hidden layer dimensionality
dropout
Dropout rate
"""
def __init__(
self, in_features: int, out_features: int, n_batches: int = 0,
h_depth: int = 2, h_dim: Optional[int] = 256,
dropout: float = 0.2
) -> None:
self.n_batches = n_batches
od = collections.OrderedDict()
ptr_dim = in_features + self.n_batches
for layer in range(h_depth):
od[f"linear_{layer}"] = torch.nn.Linear(ptr_dim, h_dim)
od[f"act_{layer}"] = torch.nn.LeakyReLU(negative_slope=0.2)
od[f"dropout_{layer}"] = torch.nn.Dropout(p=dropout)
ptr_dim = h_dim
od["pred"] = torch.nn.Linear(ptr_dim, out_features)
super().__init__(od)
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: # pylint: disable=arguments-differ
return super().forward(x)
[docs]class Classifier(torch.nn.Linear):
r"""
Linear label classifier
Parameters
----------
in_features
Input dimensionality
out_features
Output dimensionality
"""
[docs]class Prior(torch.nn.Module):
r"""
Prior distribution
Parameters
----------
loc
Mean of the normal distribution
std
Standard deviation of the normal distribution
"""
def __init__(
self, loc: float = 0.0, std: float = 1.0
) -> None:
super().__init__()
loc = torch.as_tensor(loc, dtype=torch.get_default_dtype())
std = torch.as_tensor(std, dtype=torch.get_default_dtype())
self.register_buffer("loc", loc)
self.register_buffer("std", std)
[docs] def forward(self) -> D.Normal:
return D.Normal(self.loc, self.std)