DeltaTopic.nn.base_components.DeltaTopicDecoder

class DeltaTopic.nn.base_components.DeltaTopicDecoder(*args: Any, **kwargs: Any)[source]

Decoder network for DeltaTopic, a generative network with spike and slab prior for rho and delta.

Parameters:
  • n_input – The dimensionality of the input

  • n_output – The dimensionality of the output

  • pip0_rho – posterior inclusion probability prior for rho

  • pip0_delta – posterior inclusion probability prior for delta

  • v0_rho – variance for rho slab

  • vo_delta – variance for delta slab

__init__(n_input: int, n_output: int, pip0_rho=0.1, pip0_delta=0.1, v0_rho=1, v0_delta=1)[source]

Methods

__init__(n_input, n_output[, pip0_rho, ...])

forward(z)

forward pass for DeltaTopicDecoder

get_beta(spike_logit, slab_mean, slab_lnvar, ...)

Get a spike and slab sample using reparameterization trick

get_rho_delta()

Helper function to get rho and delta

soft_max(z)

softmax function

sparse_kl_loss(logit_0, lnvar_0, ...)

Compute KL divergence between spike and slab piors and posteriors