Base components
Building blocks for the modules, defines encoder and decoder.
- class DeltaTopic.nn.base_components.BALSAMDecoder(*args: Any, **kwargs: Any)[source]
Bases:
ModuleDecoder network for BALSAM model, a generative network with spike and slab prior for beta parameter.
- Parameters:
n_input (int) – The input dimension of the decoder, e.g., number of topics.
n_output (int) – The output dimension of decoder, e.g., tumber of genes.
pip0 (float) – The prior probability of spike in the spike and slab prior.
v0 (float) – The prior variance of slab in the spike and slab prior.
- forward(z: torch.Tensor)[source]
forward pass of the decoder network.
- Parameters:
z (torch.Tensor) – The input tensor of the decoder, e.g., the latent representation from the encoder.
- get_beta(spike_logit: torch.Tensor, slab_mean: torch.Tensor, slab_lnvar: torch.Tensor, bias_d: torch.Tensor)[source]
Sample beta using the repameterization trick
- Parameters:
spike_logit (torch.Tensor) – The logit of spike probability.
slab_mean (torch.Tensor) – The mean of slab.
slab_lnvar (torch.Tensor) – The log variance of slab.
bias_d (torch.Tensor) – The bias term in the GLM model.
- soft_max(z: torch.Tensor)[source]
softmax function
- Parameters:
z (torch.Tensor) – The input tensor.
- sparse_kl_loss(logit_0, lnvar_0, spike_logit, slab_mean, slab_lnvar)[source]
Compute the KL divergence between spike and slab prior and the posterior.
- Parameters:
logit_0 (torch.Tensor) – The logit prior of spike probability.
lnvar_0 (torch.Tensor) – The log variance prior of slab.
spike_logit (torch.Tensor) – The logit of spike probability.
slab_mean (torch.Tensor) – The mean of slab.
slab_lnvar (torch.Tensor) – The log variance of slab.
- class DeltaTopic.nn.base_components.BALSAMEncoder(*args: Any, **kwargs: Any)[source]
Bases:
ModuleEncoder for BALSAM model, encodes the input data into a latent topic representation.
- Parameters:
n_input (int) – The number of input features.
n_output (int) – The number of output features.
n_hidden (int) – The number of hidden units.
n_layers_individual (int) – The number of layers in the network.
use_batch_norm (bool) – Whether to use batch normalization.
log_variational (bool) – Whether to apply log(1+x) to the input.
- class DeltaTopic.nn.base_components.DeltaTopicDecoder(*args: Any, **kwargs: Any)[source]
Bases:
ModuleDecoder network for DeltaTopic, a generative network with spike and slab prior for rho and delta.
- Parameters:
n_input – The dimensionality of the input
n_output – The dimensionality of the output
pip0_rho – posterior inclusion probability prior for rho
pip0_delta – posterior inclusion probability prior for delta
v0_rho – variance for rho slab
vo_delta – variance for delta slab
- forward(z: torch.Tensor)[source]
forward pass for DeltaTopicDecoder
- Parameters:
z – the input the of the decoder, e.g., latent variable from DeltaTopicEncoder
- get_beta(spike_logit: torch.Tensor, slab_mean: torch.Tensor, slab_lnvar: torch.Tensor, bias_gene: torch.Tensor)[source]
Get a spike and slab sample using reparameterization trick
- Parameters:
spike_logit – logit of spike probability
slab_mean – mean of slab
slab_lnvar – log variance of slab
bias_gene – gene-level bias
- sparse_kl_loss(logit_0, lnvar_0, spike_logit, slab_mean, slab_lnvar)[source]
Compute KL divergence between spike and slab piors and posteriors
- Parameters:
logit_0 – logit of spike probability (prior)
lnvar_0 – log variance of slab (prior)
spike_logit – logit of spike probability (posterior)
slab_mean – mean of slab (posterior)
slab_lnvar – log variance of slab (posterior)
- class DeltaTopic.nn.base_components.DeltaTopicEncoder(*args: Any, **kwargs: Any)[source]
Bases:
ModuleA two-headed encoder that maps the two inputs into a shared latent space through a stack of individual and shared fully-connected layers.
- Parameters:
n_input_list – List of the dimension of two input tensors
n_output – The dimensionality of the output
mask – The mask to apply to the first layer (experimental)
mask_first – Transpose the mask if set to false (experimental)
n_hidden – The number of nodes per hidden layer
n_layers_individual – The number of fully-connected hidden layers for the individual encoder
n_layers_shared – The number of fully-connected hidden layers for the shared encoder
dropout_rate – Dropout rate to apply to each of the hidden layers
use_batch_norm – Whether to have BatchNorm layers or not
log_variational – Whether to apply log(1+x) transformation to the input
combine_method – the method to combine the two latent space, either “add” or “concatenate”