DeltaTopic.nn.base_components.BALSAMDecoder
- class DeltaTopic.nn.base_components.BALSAMDecoder(*args: Any, **kwargs: Any)[source]
Decoder network for BALSAM model, a generative network with spike and slab prior for beta parameter.
- Parameters:
n_input (int) – The input dimension of the decoder, e.g., number of topics.
n_output (int) – The output dimension of decoder, e.g., tumber of genes.
pip0 (float) – The prior probability of spike in the spike and slab prior.
v0 (float) – The prior variance of slab in the spike and slab prior.
Methods
__init__(n_input, n_output[, pip0, v0])forward(z)forward pass of the decoder network.
get_beta(spike_logit, slab_mean, slab_lnvar, ...)Sample beta using the repameterization trick
get_rho()A helper function to get rho.
soft_max(z)softmax function
sparse_kl_loss(logit_0, lnvar_0, ...)Compute the KL divergence between spike and slab prior and the posterior.