TrainingPlan
A Pytorch lightning wrappper, defines the training/validation step, optimizers, and data loaders.
- class DeltaTopic.nn.TrainingPlan.TrainingPlan(*args: Any, **kwargs: Any)[source]
Bases:
LightningModuleLightning module task to train deltaTopic modules.
- Parameters:
module – A module instance from class
BaseModuleClass.lr – Learning rate used for optimization.
weight_decay – Weight decay used in optimizatoin.
eps – eps used for optimization.
optimizer – One of “Adam” (
Adam), “AdamW” (AdamW).n_steps_kl_warmup – Number of training steps (minibatches) to scale weight on KL divergences from 0 to 1. Only activated when n_epochs_kl_warmup is set to None.
n_epochs_kl_warmup – Number of epochs to scale weight on KL divergences from 0 to 1. Overrides n_steps_kl_warmup when both are not None.
reduce_lr_on_plateau – Whether to monitor validation loss and reduce learning rate when validation set lr_scheduler_metric plateaus.
lr_factor – Factor to reduce learning rate.
lr_patience – Number of epochs with no improvement after which learning rate will be reduced.
lr_threshold – Threshold for measuring the new optimum.
lr_scheduler_metric – Which metric to track for learning rate reduction.
lr_min – Minimum learning rate allowed
**loss_kwargs – Keyword args to pass to the loss method of the module. kl_weight should not be passed here and is handled automatically.
- property kl_weight
Scaling factor on KL divergence during training.
- property n_obs_training
Number of observations in the training set.
This will update the loss kwargs for loss rescaling.