DeltaTopic.nn.base_components.BALSAMEncoder
- class DeltaTopic.nn.base_components.BALSAMEncoder(*args: Any, **kwargs: Any)[source]
Encoder for BALSAM model, encodes the input data into a latent topic representation.
- Parameters:
n_input (int) – The number of input features.
n_output (int) – The number of output features.
n_hidden (int) – The number of hidden units.
n_layers_individual (int) – The number of layers in the network.
use_batch_norm (bool) – Whether to use batch normalization.
log_variational (bool) – Whether to apply log(1+x) to the input.
- __init__(n_input: int, n_output: int, n_hidden: int = 128, n_layers_individual: int = 3, use_batch_norm: bool = True, log_variational: bool = True)[source]
Methods
__init__(n_input, n_output[, n_hidden, ...])forward(x, *cat_list)forward pass of the encoder