DeltaTopic.nn.module.BALSAM_module

class DeltaTopic.nn.module.BALSAM_module(*args: Any, **kwargs: Any)[source]

BALASM module

Parameters:
  • n_genes – number of genes

  • n_latent – dimension of latent space

  • n_layers_encoder_individual – number of individual layers in the encoder

  • dim_hidden_encoder – dimension of the hidden layers in the encoder

  • pip0_rho – scaling factor for rho loss, default 0.1

  • kl_weight_beta – scaling factor for KL, default 1.0

  • log_variational – Log(data+1) prior to encoding for numerical stability. Not normalization.

__init__(n_genes: int, n_latent: int = 32, n_layers_encoder_individual: int = 2, dim_hidden_encoder: int = 128, log_variational: bool = True, pip0_rho: float = 0.1, kl_weight_beta: float = 1.0)[source]

Methods

__init__(n_genes[, n_latent, ...])

dir_llik(xx, aa)

Compute the Dirichlet log-likelihood.

forward(tensors[, ...])

Forward pass through the network.

generative(z)

Run the generative model.

get_reconstruction_loss(x)

Returns the reconstruction loss for a batch of data.

inference(x)

Run the inference (recognition) model.

loss(tensors, inference_outputs, ...[, ...])

Agrregates the likelihood and KL divergences to form the loss function.

sample(*args, **kwargs)

Generate samples from the learned model.

sample_from_posterior_z(x[, deterministic, ...])

Sample from the posterior z