import collections
from typing import Iterable, List
import torch
from torch import nn as nn
from torch.distributions import Normal
from torch.nn import ModuleList
import torch.nn.functional as F
torch.backends.cudnn.benchmark = True
def identity(x):
return x
def reparameterize_gaussian(mu, var):
return Normal(mu, var.sqrt()).rsample()
def one_hot(index: torch.Tensor, n_cat: int) -> torch.Tensor:
"""One hot a tensor of categories."""
onehot = torch.zeros(index.size(0), n_cat, device=index.device)
onehot.scatter_(1, index.type(torch.long), 1)
return onehot.type(torch.float32)
class FCLayers(nn.Module):
"""
A helper class to build fully-connected layers for a neural network.
Parameters
----------
n_in
The dimensionality of the input
n_out
The dimensionality of the output
n_cat_list
A list containing, for each category of interest,
the number of categories. Each category will be
included using a one-hot encoding.
n_layers
The number of fully-connected hidden layers
n_hidden
The number of nodes per hidden layer
dropout_rate
Dropout rate to apply to each of the hidden layers
use_batch_norm
Whether to have `BatchNorm` layers or not
use_layer_norm
Whether to have `LayerNorm` layers or not
use_activation
Whether to have layer activation or not
bias
Whether to learn bias in linear layers or not
inject_covariates
Whether to inject covariates in each layer, or just the first (default).
activation_fn
Which activation function to use
"""
def __init__(
self,
n_in: int,
n_out: int,
n_cat_list: Iterable[int] = None,
n_layers: int = 1,
n_hidden: int = 128,
dropout_rate: float = 0.1,
use_batch_norm: bool = True,
use_layer_norm: bool = False,
use_activation: bool = True,
bias: bool = True,
inject_covariates: bool = True,
activation_fn: nn.Module = nn.ReLU,
):
super().__init__()
self.inject_covariates = inject_covariates
layers_dim = [n_in] + (n_layers - 1) * [n_hidden] + [n_out]
if n_cat_list is not None:
# n_cat = 1 will be ignored
self.n_cat_list = [n_cat if n_cat > 1 else 0 for n_cat in n_cat_list]
else:
self.n_cat_list = []
cat_dim = sum(self.n_cat_list)
self.fc_layers = nn.Sequential(
collections.OrderedDict(
[
(
"Layer {}".format(i),
nn.Sequential(
nn.Linear(
n_in + cat_dim * self.inject_into_layer(i),
n_out,
bias=bias,
),
# non-default params come from defaults in original Tensorflow implementation
nn.BatchNorm1d(n_out, momentum=0.01, eps=0.001)
if use_batch_norm
else None,
nn.LayerNorm(n_out, elementwise_affine=False)
if use_layer_norm
else None,
activation_fn() if use_activation else None,
nn.Dropout(p=dropout_rate) if dropout_rate > 0 else None,
),
)
for i, (n_in, n_out) in enumerate(
zip(layers_dim[:-1], layers_dim[1:])
)
]
)
)
def inject_into_layer(self, layer_num) -> bool:
"""Helper to determine if covariates should be injected."""
user_cond = layer_num == 0 or (layer_num > 0 and self.inject_covariates)
return user_cond
def set_online_update_hooks(self, hook_first_layer=True):
self.hooks = []
def _hook_fn_weight(grad):
categorical_dims = sum(self.n_cat_list)
new_grad = torch.zeros_like(grad)
if categorical_dims > 0:
new_grad[:, -categorical_dims:] = grad[:, -categorical_dims:]
return new_grad
def _hook_fn_zero_out(grad):
return grad * 0
for i, layers in enumerate(self.fc_layers):
for layer in layers:
if i == 0 and not hook_first_layer:
continue
if isinstance(layer, nn.Linear):
if self.inject_into_layer(i):
w = layer.weight.register_hook(_hook_fn_weight)
else:
w = layer.weight.register_hook(_hook_fn_zero_out)
self.hooks.append(w)
b = layer.bias.register_hook(_hook_fn_zero_out)
self.hooks.append(b)
def forward(self, x: torch.Tensor, *cat_list: int):
"""
Forward computation on ``x``.
Parameters
----------
x
tensor of values with shape ``(n_in,)``
cat_list
list of category membership(s) for this sample
x: torch.Tensor
Returns
-------
py:class:`torch.Tensor`
tensor of shape ``(n_out,)``
"""
one_hot_cat_list = [] # for generality in this list many indices useless.
if len(self.n_cat_list) > len(cat_list):
raise ValueError(
"nb. categorical args provided doesn't match init. params."
)
for n_cat, cat in zip(self.n_cat_list, cat_list):
if n_cat and cat is None:
raise ValueError("cat not provided while n_cat != 0 in init. params.")
if n_cat > 1: # n_cat = 1 will be ignored - no additional information
if cat.size(1) != n_cat:
one_hot_cat = one_hot(cat, n_cat)
else:
one_hot_cat = cat # cat has already been one_hot encoded
one_hot_cat_list += [one_hot_cat]
for i, layers in enumerate(self.fc_layers):
for layer in layers:
if layer is not None:
if isinstance(layer, nn.BatchNorm1d):
if x.dim() == 3:
x = torch.cat(
[(layer(slice_x)).unsqueeze(0) for slice_x in x], dim=0
)
else:
x = layer(x)
else:
if isinstance(layer, nn.Linear) and self.inject_into_layer(i):
if x.dim() == 3:
one_hot_cat_list_layer = [
o.unsqueeze(0).expand(
(x.size(0), o.size(0), o.size(1))
)
for o in one_hot_cat_list
]
else:
one_hot_cat_list_layer = one_hot_cat_list
x = torch.cat((x, *one_hot_cat_list_layer), dim=-1)
x = layer(x)
return x
class MaskedLinear(nn.Linear):
"""
same as Linear except has a configurable mask on the weights
"""
def __init__(self, in_features, out_features, mask, bias=True):
super().__init__(in_features, out_features, bias)
self.register_buffer('mask', mask)
def forward(self, input):
#mask = Variable(self.mask, requires_grad=False)
if self.bias is None:
return F.linear(input, self.weight*self.mask)
else:
return F.linear(input, self.weight*self.mask, self.bias)
class MaskedLinearLayers(FCLayers):
"""
This incorporates the one-hot encoding for for category input.
A helper class to build Masked Linear layers compatible with FClayer
Parameters
----------
n_in
The dimensionality of the input
n_out
The dimensionality of the output
mask
The mask, should be dimension n_out * n_in
mask_first
wheather mask linear layer should be before or after fully-connected layers, default is true;
False is useful to construct an decoder with the oposite strucutre (mask linear after fully connected)
n_cat_list
A list containing, for each category of interest,
the number of categories. Each category will be
included using a one-hot encoding.
n_layers
The number of fully-connected hidden layers
n_hidden
The number of nodes per hidden layer
dropout_rate
Dropout rate to apply to each of the hidden layers
use_batch_norm
Whether to have `BatchNorm` layers or not
use_layer_norm
Whether to have `LayerNorm` layers or not
use_activation
Whether to have layer activation or not
bias
Whether to learn bias in linear layers or not
inject_covariates
Whether to inject covariates in each layer, or just the first (default).
activation_fn
Which activation function to use
"""
def __init__(
self,
n_in: int,
n_out: int,
mask: torch.Tensor = None,
mask_first: bool = True,
n_cat_list: Iterable[int] = None,
n_layers: int = 1,
n_hidden: int = 128,
dropout_rate: float = 0.1,
use_batch_norm: bool = True,
use_layer_norm: bool = False,
use_activation: bool = True,
bias: bool = True,
inject_covariates: bool = True,
activation_fn: nn.Module = nn.ReLU
):
super().__init__(
n_in=n_in,
n_out=n_out,
n_cat_list=n_cat_list,
n_layers=n_layers,
n_hidden=n_hidden,
dropout_rate=dropout_rate,
use_batch_norm=use_batch_norm,
use_layer_norm=use_layer_norm,
use_activation=use_activation,
bias=bias,
inject_covariates=inject_covariates,
activation_fn=activation_fn
)
self.mask = mask ## out_features, in_features
#if mask is None:
#print("No mask input, use all fully connected layers")
if mask is not None:
if mask_first:
layers_dim = [n_in] + [mask.shape[0]] + (n_layers - 1) * [n_hidden] + [n_out]
else:
layers_dim = [n_in] + (n_layers - 1) * [n_hidden] + [mask.shape[0]] + [n_out]
else:
layers_dim = [n_in] + (n_layers - 1) * [n_hidden] + [n_out]
if n_cat_list is not None:
# n_cat = 1 will be ignored
self.n_cat_list = [n_cat if n_cat > 1 else 0 for n_cat in n_cat_list]
else:
self.n_cat_list = []
cat_dim = sum(self.n_cat_list)
# concatnat one hot encoding to mask if available
if cat_dim>0:
mask_input = torch.cat((self.mask, torch.ones(cat_dim, self.mask.shape[1])), dim=0)
else:
mask_input = self.mask
self.fc_layers = nn.Sequential(
collections.OrderedDict(
[
(
"Layer {}".format(i),
nn.Sequential(
nn.Linear(
n_in + cat_dim * self.inject_into_layer(i),
n_out,
bias=bias,
),
# non-default params come from defaults in original Tensorflow implementation
nn.BatchNorm1d(n_out, momentum=0.01, eps=0.001)
if use_batch_norm
else None,
nn.LayerNorm(n_out, elementwise_affine=False)
if use_layer_norm
else None,
activation_fn() if use_activation else None,
nn.Dropout(p=dropout_rate) if dropout_rate > 0 else None,
),
)
for i, (n_in, n_out) in enumerate(
zip(layers_dim[:-1], layers_dim[1:])
)
]
)
)
if mask is not None:
if mask_first:
# change the first layer to be MaskedLinear
self.fc_layers[0] = nn.Sequential(
MaskedLinear(
layers_dim[0] + cat_dim * self.inject_into_layer(0),
layers_dim[1],
mask_input,
bias=bias,
),
# non-default params come from defaults in original Tensorflow implementation
nn.BatchNorm1d(layers_dim[1], momentum=0.01, eps=0.001)
if use_batch_norm
else None,
nn.LayerNorm(layers_dim[1], elementwise_affine=False)
if use_layer_norm
else None,
activation_fn() if use_activation else None,
nn.Dropout(p=dropout_rate) if dropout_rate > 0 else None,
)
else:
# change the last layer to be MaskedLinear
self.fc_layers[-1] = nn.Sequential(
MaskedLinear(
layers_dim[-2] + cat_dim * self.inject_into_layer(0),
layers_dim[-1],
torch.transpose(mask_input,0,1),
bias=bias,
),
# non-default params come from defaults in original Tensorflow implementation
nn.BatchNorm1d(layers_dim[-1], momentum=0.01, eps=0.001)
if use_batch_norm
else None,
nn.LayerNorm(layers_dim[-1], elementwise_affine=False)
if use_layer_norm
else None,
activation_fn() if use_activation else None,
nn.Dropout(p=dropout_rate) if dropout_rate > 0 else None,
)
def forward(self, x: torch.Tensor, *cat_list: int):
"""
Forward computation on ``x``.
Parameters
----------
x
tensor of values with shape ``(n_in,)``
cat_list
list of category membership(s) for this sample
x: torch.Tensor
Returns
-------
py:class:`torch.Tensor`
tensor of shape ``(n_out,)``
"""
one_hot_cat_list = [] # for generality in this list many indices useless.
if len(self.n_cat_list) > len(cat_list):
raise ValueError(
"nb. categorical args provided doesn't match init. params."
)
for n_cat, cat in zip(self.n_cat_list, cat_list):
if n_cat and cat is None:
raise ValueError("cat not provided while n_cat != 0 in init. params.")
if n_cat > 1: # n_cat = 1 will be ignored - no additional information
if cat.size(1) != n_cat:
one_hot_cat = one_hot(cat, n_cat)
else:
one_hot_cat = cat # cat has already been one_hot encoded
one_hot_cat_list += [one_hot_cat]
for i, layers in enumerate(self.fc_layers):
for layer in layers:
if layer is not None:
if isinstance(layer, nn.BatchNorm1d):
if x.dim() == 3:
x = torch.cat(
[(layer(slice_x)).unsqueeze(0) for slice_x in x], dim=0
)
else:
x = layer(x)
else:
if (isinstance(layer, nn.Linear) or isinstance(layer, MaskedLinear)) and self.inject_into_layer(i):
if x.dim() == 3:
one_hot_cat_list_layer = [
o.unsqueeze(0).expand(
(x.size(0), o.size(0), o.size(1))
)
for o in one_hot_cat_list
]
else:
one_hot_cat_list_layer = one_hot_cat_list
x = torch.cat((x, *one_hot_cat_list_layer), dim=-1)
x = layer(x)
return x
[docs]class DeltaTopicEncoder(nn.Module):
"""
A two-headed encoder that maps the two inputs into a shared latent space through a stack of individual and shared fully-connected layers.
Parameters
----------
n_input_list
List of the dimension of two input tensors
n_output
The dimensionality of the output
mask
The mask to apply to the first layer (experimental)
mask_first
Transpose the mask if set to false (experimental)
n_hidden
The number of nodes per hidden layer
n_layers_individual
The number of fully-connected hidden layers for the individual encoder
n_layers_shared
The number of fully-connected hidden layers for the shared encoder
dropout_rate
Dropout rate to apply to each of the hidden layers
use_batch_norm
Whether to have `BatchNorm` layers or not
log_variational
Whether to apply log(1+x) transformation to the input
combine_method
the method to combine the two latent space, either "add" or "concatenate"
"""
[docs] def __init__(
self,
n_input_list: List[int],
n_output: int,
mask: torch.Tensor = None,
mask_first: bool = True,
n_hidden: int = 128,
n_layers_individual: int = 1,
n_layers_shared: int = 2,
n_cat_list: Iterable[int] = None,
dropout_rate: float = 0.1,
use_batch_norm: bool = True,
log_variational: bool = True,
combine_method: str = "add",
):
super().__init__()
self.log_variational = log_variational
self.combine_method = combine_method
self.encoders = ModuleList(
[
MaskedLinearLayers(
n_in=n_input_list[i],
n_out=n_hidden,
n_cat_list=n_cat_list,
mask=mask,
mask_first=mask_first,
n_layers=n_layers_individual,
n_hidden=n_hidden,
dropout_rate=dropout_rate,
use_batch_norm=use_batch_norm,
)
for i in range(len(n_input_list))
]
)
if self.combine_method == 'concat':
dim_encoder_shared = n_hidden + n_hidden
elif self.combine_method == 'add':
dim_encoder_shared = n_hidden
else:
raise ValueError("combine method must choose from concat or add")
self.encoder_shared = FCLayers(
n_in=dim_encoder_shared,
n_out=n_hidden,
n_cat_list=n_cat_list,
n_layers=n_layers_shared,
n_hidden=n_hidden,
dropout_rate=dropout_rate,
)
self.mean_encoder = nn.Linear(n_hidden, n_output)
self.var_encoder = nn.Linear(n_hidden, n_output)
[docs] def forward(self, x: torch.Tensor, y: torch.Tensor, *cat_list: int):
'''
Forward pass for DeltaTopicEncoder
Parameters
----------
x
First input tensor, e.g., spliced RNA count
y
Second input tensorm, e.g., unsplice RNA count
'''
if self.log_variational:
x_ = torch.log(1 + x)
y_ = torch.log(1 + y)
q_x = self.encoders[0](x_, *cat_list)
q_y = self.encoders[1](y_, *cat_list)
if self.combine_method == 'concat':
q = torch.cat([q_x, q_y], dim=-1)
elif self.combine_method == 'add':
q = (q_x + q_y)/2.
else:
raise ValueError("combine method must choose from concat or add")
q = self.encoder_shared(q, *cat_list)
q_m = self.mean_encoder(q)
q_v = torch.exp(torch.clamp(self.var_encoder(q), -4.0, 4.0)/2.)
latent = reparameterize_gaussian(q_m, q_v)
return q_m, q_v, latent
[docs]class DeltaTopicDecoder(nn.Module):
"""
Decoder network for DeltaTopic, a generative network with spike and slab prior for rho and delta.
Parameters
----------
n_input
The dimensionality of the input
n_output
The dimensionality of the output
pip0_rho
posterior inclusion probability prior for rho
pip0_delta
posterior inclusion probability prior for delta
v0_rho
variance for rho slab
vo_delta
variance for delta slab
"""
[docs] def __init__(
self,
n_input: int,
n_output: int,
pip0_rho = 0.1,
pip0_delta = 0.1,
v0_rho = 1,
v0_delta = 1,
):
super().__init__()
self.n_input = n_input # topics
self.n_output = n_output # genes
# gene-level bias, shared across all topics
self.bias_gene = nn.Parameter(torch.zeros(1, n_output))
# for shared effect(rho)
self.logit_0_rho = nn.Parameter(torch.logit(torch.ones(1)* pip0_rho, eps=1e-6), requires_grad = False)
self.lnvar_0_rho = nn.Parameter(torch.log(torch.ones(1) * v0_rho), requires_grad = False)
self.slab_mean_rho = nn.Parameter(torch.randn(n_input, n_output) * torch.sqrt(torch.ones(1) * v0_rho))
self.slab_lnvar_rho = nn.Parameter(torch.ones(n_input, n_output) * torch.log(torch.ones(1) * v0_rho))
self.spike_logit_rho = nn.Parameter(torch.zeros(n_input, n_output) * self.logit_0_rho)
# delta effect
self.logit_0_delta = nn.Parameter(torch.logit(torch.ones(1)*pip0_delta, eps=1e-6), requires_grad = False)
self.lnvar_0_delta = nn.Parameter(torch.log(torch.ones(1)*v0_delta), requires_grad = False)
self.slab_mean_delta = nn.Parameter(torch.randn(n_input, n_output) * torch.sqrt(torch.ones(1) * v0_delta))
self.slab_lnvar_delta = nn.Parameter(torch.ones(n_input, n_output) * torch.log(torch.ones(1) * v0_delta))
self.spike_logit_delta = nn.Parameter(torch.zeros(n_input, n_output) * self.logit_0_delta)
# Log softmax operations
self.log_softmax = nn.LogSoftmax(dim=-1)
[docs] def forward(
self,
z: torch.Tensor,
):
'''
forward pass for DeltaTopicDecoder
Parameters
----------
z
the input the of the decoder, e.g., latent variable from DeltaTopicEncoder
'''
theta = self.soft_max(z)
rho = self.get_beta(self.spike_logit_rho, self.slab_mean_rho, self.slab_lnvar_rho, self.bias_gene)
rho_kl = self.sparse_kl_loss(self.logit_0_rho, self.lnvar_0_rho, self.spike_logit_rho, self.slab_mean_rho, self.slab_lnvar_rho)
delta = self.get_beta(self.spike_logit_delta, self.slab_mean_delta, self.slab_lnvar_delta, self.bias_gene)
delta_kl = self.sparse_kl_loss(self.logit_0_delta, self.lnvar_0_delta, self.spike_logit_delta, self.slab_mean_delta, self.slab_lnvar_delta)
return rho, delta, rho_kl, delta_kl, theta
[docs] def get_rho_delta(
self,
):
'''
Helper function to get rho and delta
'''
rho = self.get_beta(self.spike_logit_rho, self.slab_mean_rho, self.slab_lnvar_rho, self.bias_gene)
delta = self.get_beta(self.spike_logit_delta, self.slab_mean_delta, self.slab_lnvar_delta, self.bias_gene)
return rho, delta
[docs] def get_beta(self,
spike_logit: torch.Tensor,
slab_mean: torch.Tensor,
slab_lnvar: torch.Tensor,
bias_gene: torch.Tensor,
):
'''
Get a spike and slab sample using reparameterization trick
Parameters
----------
spike_logit
logit of spike probability
slab_mean
mean of slab
slab_lnvar
log variance of slab
bias_gene
gene-level bias
'''
pip = torch.sigmoid(spike_logit)
mean = slab_mean * pip
var = pip * (1 - pip) * torch.square(slab_mean)
var = var + pip * torch.exp(slab_lnvar)
eps = torch.randn_like(var)
return mean + eps * torch.sqrt(var) - bias_gene
[docs] def soft_max(self,
z: torch.Tensor,
):
'''
softmax function
Parameters
----------
z
input tensor
'''
return torch.exp(self.log_softmax(z))
[docs] def sparse_kl_loss(
self,
logit_0,
lnvar_0,
spike_logit,
slab_mean,
slab_lnvar,
):
'''
Compute KL divergence between spike and slab piors and posteriors
Parameters
----------
logit_0
logit of spike probability (prior)
lnvar_0
log variance of slab (prior)
spike_logit
logit of spike probability (posterior)
slab_mean
mean of slab (posterior)
slab_lnvar
log variance of slab (posterior)
'''
## PIP KL between p and p0
## p * ln(p / p0) + (1-p) * ln(1-p/1-p0)
## = p * ln(p / 1-p) + ln(1-p) +
## p * ln(1-p0 / p0) - ln(1-p0)
## = sigmoid(logit) * logit - softplus(logit)
## - sigmoid(logit) * logit0 + softplus(logit0)
pip_hat = torch.sigmoid(spike_logit)
kl_pip_1 = pip_hat * (spike_logit - logit_0)
kl_pip = kl_pip_1 - nn.functional.softplus(spike_logit) + nn.functional.softplus(logit_0)
## Gaussian KL between N(μ,ν) and N(0, v0)
sq_term = torch.exp(-lnvar_0) * (torch.square(slab_mean) + torch.exp(slab_lnvar))
kl_g = -0.5 * (1. + slab_lnvar - lnvar_0 - sq_term)
## Combine both logit and Gaussian KL
return torch.sum(kl_pip + pip_hat * kl_g) # return a number sum over [N_topics, N_genes]
[docs]class BALSAMDecoder(nn.Module):
"""
Decoder network for BALSAM model, a generative network with spike and slab prior for beta parameter.
Parameters
----------
n_input: int
The input dimension of the decoder, e.g., number of topics.
n_output: int
The output dimension of decoder, e.g., tumber of genes.
pip0: float
The prior probability of spike in the spike and slab prior.
v0: float
The prior variance of slab in the spike and slab prior.
"""
[docs] def __init__(
self,
n_input: int,
n_output: int,
pip0 = 0.1,
v0 = 1,
):
super().__init__()
self.n_input = n_input # topics
self.n_output = n_output # genes
# for shared effect(rho)
self.logit_0 = nn.Parameter(torch.logit(torch.ones(1)* pip0, eps=1e-6), requires_grad = False)
self.lnvar_0 = nn.Parameter(torch.log(torch.ones(1) * v0), requires_grad = False)
self.bias_d = nn.Parameter(torch.zeros(1, n_output))
self.slab_mean = nn.Parameter(torch.randn(n_input, n_output) * torch.sqrt(torch.ones(1) * v0))
self.slab_lnvar = nn.Parameter(torch.ones(n_input, n_output) * torch.log(torch.ones(1) * v0))
self.spike_logit = nn.Parameter(torch.zeros(n_input, n_output) * self.logit_0)
# Log softmax operations
self.log_softmax = nn.LogSoftmax(dim=-1)
[docs] def forward(
self,
z: torch.Tensor,
):
'''
forward pass of the decoder network.
Parameters
----------
z: torch.Tensor
The input tensor of the decoder, e.g., the latent representation from the encoder.
'''
theta = self.soft_max(z)
rho = self.get_beta(self.spike_logit, self.slab_mean, self.slab_lnvar, self.bias_d)
rho_kl = self.sparse_kl_loss(self.logit_0, self.lnvar_0, self.spike_logit, self.slab_mean, self.slab_lnvar)
return rho, rho_kl, theta
[docs] def get_rho(
self,
):
'''
A helper function to get rho.
'''
rho = self.get_beta(self.spike_logit, self.slab_mean, self.slab_lnvar, self.bias_d)
return rho
[docs] def get_beta(self,
spike_logit: torch.Tensor,
slab_mean: torch.Tensor,
slab_lnvar: torch.Tensor,
bias_d: torch.Tensor,
):
'''
Sample beta using the repameterization trick
Parameters
----------
spike_logit: torch.Tensor
The logit of spike probability.
slab_mean: torch.Tensor
The mean of slab.
slab_lnvar: torch.Tensor
The log variance of slab.
bias_d: torch.Tensor
The bias term in the GLM model.
'''
pip = torch.sigmoid(spike_logit)
mean = slab_mean * pip
var = pip * (1 - pip) * torch.square(slab_mean)
var = var + pip * torch.exp(slab_lnvar)
eps = torch.randn_like(var)
return mean + eps * torch.sqrt(var) - bias_d
[docs] def soft_max(self,
z: torch.Tensor,
):
'''
softmax function
Parameters
----------
z: torch.Tensor
The input tensor.
'''
return torch.exp(self.log_softmax(z))
[docs] def sparse_kl_loss(
self,
logit_0,
lnvar_0,
spike_logit,
slab_mean,
slab_lnvar,
):
'''
Compute the KL divergence between spike and slab prior and the posterior.
Parameters
----------
logit_0: torch.Tensor
The logit prior of spike probability.
lnvar_0: torch.Tensor
The log variance prior of slab.
spike_logit: torch.Tensor
The logit of spike probability.
slab_mean: torch.Tensor
The mean of slab.
slab_lnvar: torch.Tensor
The log variance of slab.
'''
## PIP KL between p and p0
## p * ln(p / p0) + (1-p) * ln(1-p/1-p0)
## = p * ln(p / 1-p) + ln(1-p) +
## p * ln(1-p0 / p0) - ln(1-p0)
## = sigmoid(logit) * logit - softplus(logit)
## - sigmoid(logit) * logit0 + softplus(logit0)
pip_hat = torch.sigmoid(spike_logit)
kl_pip_1 = pip_hat * (spike_logit - logit_0)
kl_pip = kl_pip_1 - nn.functional.softplus(spike_logit) + nn.functional.softplus(logit_0)
## Gaussian KL between N(μ,ν) and N(0, v0)
sq_term = torch.exp(-lnvar_0) * (torch.square(slab_mean) + torch.exp(slab_lnvar))
kl_g = -0.5 * (1. + slab_lnvar - lnvar_0 - sq_term)
## Combine both logit and Gaussian KL
return torch.sum(kl_pip + pip_hat * kl_g) # return a number sum over [N_topics, N_genes]
[docs]class BALSAMEncoder(nn.Module):
"""
Encoder for BALSAM model, encodes the input data into a latent topic representation.
Parameters
----------
n_input: int
The number of input features.
n_output: int
The number of output features.
n_hidden: int
The number of hidden units.
n_layers_individual: int
The number of layers in the network.
use_batch_norm: bool
Whether to use batch normalization.
log_variational: bool
Whether to apply log(1+x) to the input.
"""
[docs] def __init__(
self,
n_input: int,
n_output: int,
n_hidden: int = 128,
n_layers_individual: int = 3,
use_batch_norm: bool = True,
log_variational: bool = True,
):
super().__init__()
self.log_variational = log_variational
self.encoder = FCLayers(
n_in=n_input,
n_out=n_hidden,
n_cat_list=None,
n_layers=n_layers_individual,
n_hidden=n_hidden,
dropout_rate=0,
use_batch_norm = use_batch_norm
)
self.mean_encoder = nn.Linear(n_hidden, n_output)
self.var_encoder = nn.Linear(n_hidden, n_output)
[docs] def forward(self, x: torch.Tensor, *cat_list: int):
'''
forward pass of the encoder
'''
if self.log_variational:
x_ = torch.log(1 + x)
q = self.encoder(x_, *cat_list)
q_m = self.mean_encoder(q)
q_v = torch.exp(torch.clamp(self.var_encoder(q), -4.0, 4.0)/2.)
latent = reparameterize_gaussian(q_m, q_v)
return q_m, q_v, latent