DeltaTopic.nn.TrainingPlan.TrainingPlan

class DeltaTopic.nn.TrainingPlan.TrainingPlan(*args: Any, **kwargs: Any)[source]

Lightning module task to train deltaTopic modules.

Parameters:
  • module – A module instance from class BaseModuleClass.

  • lr – Learning rate used for optimization.

  • weight_decay – Weight decay used in optimizatoin.

  • eps – eps used for optimization.

  • optimizer – One of “Adam” (Adam), “AdamW” (AdamW).

  • n_steps_kl_warmup – Number of training steps (minibatches) to scale weight on KL divergences from 0 to 1. Only activated when n_epochs_kl_warmup is set to None.

  • n_epochs_kl_warmup – Number of epochs to scale weight on KL divergences from 0 to 1. Overrides n_steps_kl_warmup when both are not None.

  • reduce_lr_on_plateau – Whether to monitor validation loss and reduce learning rate when validation set lr_scheduler_metric plateaus.

  • lr_factor – Factor to reduce learning rate.

  • lr_patience – Number of epochs with no improvement after which learning rate will be reduced.

  • lr_threshold – Threshold for measuring the new optimum.

  • lr_scheduler_metric – Which metric to track for learning rate reduction.

  • lr_min – Minimum learning rate allowed

  • **loss_kwargs – Keyword args to pass to the loss method of the module. kl_weight should not be passed here and is handled automatically.

__init__(module: BaseModuleClass, lr: float = 0.001, weight_decay: float = 1e-06, eps: float = 0.01, optimizer: Literal['Adam', 'AdamW'] = 'Adam', n_steps_kl_warmup: int | None = None, n_epochs_kl_warmup: int | None = 400, reduce_lr_on_plateau: bool = False, lr_factor: float = 0.6, lr_patience: int = 30, lr_threshold: float = 0.0, lr_scheduler_metric: Literal['elbo_validation', 'reconstruction_loss_validation', 'kl_local_validation', 'elbo_train'] = 'elbo_train', lr_min: float = 0, **loss_kwargs)[source]

Methods

__init__(module[, lr, weight_decay, eps, ...])

configure_optimizers()

forward(*args, **kwargs)

Passthrough to model.forward().

training_epoch_end(outputs)

training_step(batch, batch_idx[, optimizer_idx])

validation_epoch_end(outputs)

Aggregate validation step information.

validation_step(batch, batch_idx)

Attributes

kl_weight

Scaling factor on KL divergence during training.

n_obs_training

Number of observations in the training set.