DeltaTopic.nn.base_components.BALSAMDecoder

class DeltaTopic.nn.base_components.BALSAMDecoder(*args: Any, **kwargs: Any)[source]

Decoder network for BALSAM model, a generative network with spike and slab prior for beta parameter.

Parameters:
  • n_input (int) – The input dimension of the decoder, e.g., number of topics.

  • n_output (int) – The output dimension of decoder, e.g., tumber of genes.

  • pip0 (float) – The prior probability of spike in the spike and slab prior.

  • v0 (float) – The prior variance of slab in the spike and slab prior.

__init__(n_input: int, n_output: int, pip0=0.1, v0=1)[source]

Methods

__init__(n_input, n_output[, pip0, v0])

forward(z)

forward pass of the decoder network.

get_beta(spike_logit, slab_mean, slab_lnvar, ...)

Sample beta using the repameterization trick

get_rho()

A helper function to get rho.

soft_max(z)

softmax function

sparse_kl_loss(logit_0, lnvar_0, ...)

Compute the KL divergence between spike and slab prior and the posterior.