DeltaTopic.nn.module.BALSAM_module
- class DeltaTopic.nn.module.BALSAM_module(*args: Any, **kwargs: Any)[source]
BALASM module
- Parameters:
n_genes – number of genes
n_latent – dimension of latent space
n_layers_encoder_individual – number of individual layers in the encoder
dim_hidden_encoder – dimension of the hidden layers in the encoder
pip0_rho – scaling factor for rho loss, default 0.1
kl_weight_beta – scaling factor for KL, default 1.0
log_variational – Log(data+1) prior to encoding for numerical stability. Not normalization.
- __init__(n_genes: int, n_latent: int = 32, n_layers_encoder_individual: int = 2, dim_hidden_encoder: int = 128, log_variational: bool = True, pip0_rho: float = 0.1, kl_weight_beta: float = 1.0)[source]
Methods
__init__(n_genes[, n_latent, ...])dir_llik(xx, aa)Compute the Dirichlet log-likelihood.
forward(tensors[, ...])Forward pass through the network.
generative(z)Run the generative model.
get_reconstruction_loss(x)Returns the reconstruction loss for a batch of data.
inference(x)Run the inference (recognition) model.
loss(tensors, inference_outputs, ...[, ...])Agrregates the likelihood and KL divergences to form the loss function.
sample(*args, **kwargs)Generate samples from the learned model.
sample_from_posterior_z(x[, deterministic, ...])Sample from the posterior z