Base components

Building blocks for the modules, defines encoder and decoder.

class DeltaTopic.nn.base_components.BALSAMDecoder(*args: Any, **kwargs: Any)[source]

Bases: Module

Decoder network for BALSAM model, a generative network with spike and slab prior for beta parameter.

Parameters:
  • n_input (int) – The input dimension of the decoder, e.g., number of topics.

  • n_output (int) – The output dimension of decoder, e.g., tumber of genes.

  • pip0 (float) – The prior probability of spike in the spike and slab prior.

  • v0 (float) – The prior variance of slab in the spike and slab prior.

forward(z: torch.Tensor)[source]

forward pass of the decoder network.

Parameters:

z (torch.Tensor) – The input tensor of the decoder, e.g., the latent representation from the encoder.

get_beta(spike_logit: torch.Tensor, slab_mean: torch.Tensor, slab_lnvar: torch.Tensor, bias_d: torch.Tensor)[source]

Sample beta using the repameterization trick

Parameters:
  • spike_logit (torch.Tensor) – The logit of spike probability.

  • slab_mean (torch.Tensor) – The mean of slab.

  • slab_lnvar (torch.Tensor) – The log variance of slab.

  • bias_d (torch.Tensor) – The bias term in the GLM model.

get_rho()[source]

A helper function to get rho.

soft_max(z: torch.Tensor)[source]

softmax function

Parameters:

z (torch.Tensor) – The input tensor.

sparse_kl_loss(logit_0, lnvar_0, spike_logit, slab_mean, slab_lnvar)[source]

Compute the KL divergence between spike and slab prior and the posterior.

Parameters:
  • logit_0 (torch.Tensor) – The logit prior of spike probability.

  • lnvar_0 (torch.Tensor) – The log variance prior of slab.

  • spike_logit (torch.Tensor) – The logit of spike probability.

  • slab_mean (torch.Tensor) – The mean of slab.

  • slab_lnvar (torch.Tensor) – The log variance of slab.

class DeltaTopic.nn.base_components.BALSAMEncoder(*args: Any, **kwargs: Any)[source]

Bases: Module

Encoder for BALSAM model, encodes the input data into a latent topic representation.

Parameters:
  • n_input (int) – The number of input features.

  • n_output (int) – The number of output features.

  • n_hidden (int) – The number of hidden units.

  • n_layers_individual (int) – The number of layers in the network.

  • use_batch_norm (bool) – Whether to use batch normalization.

  • log_variational (bool) – Whether to apply log(1+x) to the input.

forward(x: torch.Tensor, *cat_list: int)[source]

forward pass of the encoder

class DeltaTopic.nn.base_components.DeltaTopicDecoder(*args: Any, **kwargs: Any)[source]

Bases: Module

Decoder network for DeltaTopic, a generative network with spike and slab prior for rho and delta.

Parameters:
  • n_input – The dimensionality of the input

  • n_output – The dimensionality of the output

  • pip0_rho – posterior inclusion probability prior for rho

  • pip0_delta – posterior inclusion probability prior for delta

  • v0_rho – variance for rho slab

  • vo_delta – variance for delta slab

forward(z: torch.Tensor)[source]

forward pass for DeltaTopicDecoder

Parameters:

z – the input the of the decoder, e.g., latent variable from DeltaTopicEncoder

get_beta(spike_logit: torch.Tensor, slab_mean: torch.Tensor, slab_lnvar: torch.Tensor, bias_gene: torch.Tensor)[source]

Get a spike and slab sample using reparameterization trick

Parameters:
  • spike_logit – logit of spike probability

  • slab_mean – mean of slab

  • slab_lnvar – log variance of slab

  • bias_gene – gene-level bias

get_rho_delta()[source]

Helper function to get rho and delta

soft_max(z: torch.Tensor)[source]

softmax function

Parameters:

z – input tensor

sparse_kl_loss(logit_0, lnvar_0, spike_logit, slab_mean, slab_lnvar)[source]

Compute KL divergence between spike and slab piors and posteriors

Parameters:
  • logit_0 – logit of spike probability (prior)

  • lnvar_0 – log variance of slab (prior)

  • spike_logit – logit of spike probability (posterior)

  • slab_mean – mean of slab (posterior)

  • slab_lnvar – log variance of slab (posterior)

class DeltaTopic.nn.base_components.DeltaTopicEncoder(*args: Any, **kwargs: Any)[source]

Bases: Module

A two-headed encoder that maps the two inputs into a shared latent space through a stack of individual and shared fully-connected layers.

Parameters:
  • n_input_list – List of the dimension of two input tensors

  • n_output – The dimensionality of the output

  • mask – The mask to apply to the first layer (experimental)

  • mask_first – Transpose the mask if set to false (experimental)

  • n_hidden – The number of nodes per hidden layer

  • n_layers_individual – The number of fully-connected hidden layers for the individual encoder

  • n_layers_shared – The number of fully-connected hidden layers for the shared encoder

  • dropout_rate – Dropout rate to apply to each of the hidden layers

  • use_batch_norm – Whether to have BatchNorm layers or not

  • log_variational – Whether to apply log(1+x) transformation to the input

  • combine_method – the method to combine the two latent space, either “add” or “concatenate”

forward(x: torch.Tensor, y: torch.Tensor, *cat_list: int)[source]

Forward pass for DeltaTopicEncoder

Parameters:
  • x – First input tensor, e.g., spliced RNA count

  • y – Second input tensorm, e.g., unsplice RNA count