DeltaTopic.nn.base_components.DeltaTopicDecoder
- class DeltaTopic.nn.base_components.DeltaTopicDecoder(*args: Any, **kwargs: Any)[source]
Decoder network for DeltaTopic, a generative network with spike and slab prior for rho and delta.
- Parameters:
n_input – The dimensionality of the input
n_output – The dimensionality of the output
pip0_rho – posterior inclusion probability prior for rho
pip0_delta – posterior inclusion probability prior for delta
v0_rho – variance for rho slab
vo_delta – variance for delta slab
Methods
__init__(n_input, n_output[, pip0_rho, ...])forward(z)forward pass for DeltaTopicDecoder
get_beta(spike_logit, slab_mean, slab_lnvar, ...)Get a spike and slab sample using reparameterization trick
Helper function to get rho and delta
soft_max(z)softmax function
sparse_kl_loss(logit_0, lnvar_0, ...)Compute KL divergence between spike and slab piors and posteriors