import torch
from inspect import getfullargspec
from typing import Union, Literal
import pytorch_lightning as pl
import torch
from torch.optim.lr_scheduler import ReduceLROnPlateau
from DeltaTopic.nn.base_model import BaseModuleClass
[docs]class TrainingPlan(pl.LightningModule):
"""
Lightning module task to train deltaTopic modules.
Parameters
----------
module
A module instance from class ``BaseModuleClass``.
lr
Learning rate used for optimization.
weight_decay
Weight decay used in optimizatoin.
eps
eps used for optimization.
optimizer
One of "Adam" (:class:`~torch.optim.Adam`), "AdamW" (:class:`~torch.optim.AdamW`).
n_steps_kl_warmup
Number of training steps (minibatches) to scale weight on KL divergences from 0 to 1.
Only activated when `n_epochs_kl_warmup` is set to None.
n_epochs_kl_warmup
Number of epochs to scale weight on KL divergences from 0 to 1.
Overrides `n_steps_kl_warmup` when both are not `None`.
reduce_lr_on_plateau
Whether to monitor validation loss and reduce learning rate when validation set
`lr_scheduler_metric` plateaus.
lr_factor
Factor to reduce learning rate.
lr_patience
Number of epochs with no improvement after which learning rate will be reduced.
lr_threshold
Threshold for measuring the new optimum.
lr_scheduler_metric
Which metric to track for learning rate reduction.
lr_min
Minimum learning rate allowed
**loss_kwargs
Keyword args to pass to the loss method of the `module`.
`kl_weight` should not be passed here and is handled automatically.
"""
[docs] def __init__(
self,
module: BaseModuleClass,
lr: float = 1e-3,
weight_decay: float = 1e-6,
eps: float = 0.01,
optimizer: Literal["Adam", "AdamW"] = "Adam",
n_steps_kl_warmup: Union[int, None] = None,
n_epochs_kl_warmup: Union[int, None] = 400,
reduce_lr_on_plateau: bool = False,
lr_factor: float = 0.6,
lr_patience: int = 30,
lr_threshold: float = 0.0,
lr_scheduler_metric: Literal[
"elbo_validation", "reconstruction_loss_validation", "kl_local_validation", "elbo_train"
] = "elbo_train",
lr_min: float = 0,
**loss_kwargs,
):
super(TrainingPlan, self).__init__()
self.module = module
self.lr = lr
self.weight_decay = weight_decay
self.eps = eps
self.optimizer_name = optimizer
self.n_steps_kl_warmup = n_steps_kl_warmup
self.n_epochs_kl_warmup = n_epochs_kl_warmup
self.reduce_lr_on_plateau = reduce_lr_on_plateau
self.lr_factor = lr_factor
self.lr_patience = lr_patience
self.lr_scheduler_metric = lr_scheduler_metric
self.lr_threshold = lr_threshold
self.lr_min = lr_min
self.loss_kwargs = loss_kwargs
self._n_obs_training = None
# automatic handling of kl weight
self._loss_args = getfullargspec(self.module.loss)[0]
if "kl_weight" in self._loss_args:
self.loss_kwargs.update({"kl_weight": self.kl_weight})
@property
def n_obs_training(self):
"""
Number of observations in the training set.
This will update the loss kwargs for loss rescaling.
"""
return self._n_obs_training
@n_obs_training.setter
def n_obs_training(self, n_obs: int):
if "n_obs" in self._loss_args:
self.loss_kwargs.update({"n_obs": n_obs})
self._n_obs_training = n_obs
[docs] def forward(self, *args, **kwargs):
"""Passthrough to `model.forward()`."""
return self.module(*args, **kwargs)
def training_step(self, batch, batch_idx, optimizer_idx=0):
if "kl_weight" in self.loss_kwargs:
self.loss_kwargs.update({"kl_weight": self.kl_weight})
_, _, scvi_loss = self.forward(batch, loss_kwargs=self.loss_kwargs)
reconstruction_loss = scvi_loss.reconstruction_loss
reconstruction_loss_spliced = scvi_loss.reconstruction_loss_spliced
reconstruction_loss_unspliced = scvi_loss.reconstruction_loss_unspliced
kl_beta = scvi_loss.kl_beta
kl_rho = scvi_loss.kl_rho
kl_delta = scvi_loss.kl_delta
# pytorch lightning automatically backprops on "loss"
self.log("train_loss", scvi_loss.loss, on_epoch=True)
return {
"loss": scvi_loss.loss,
"reconstruction_loss_sum": reconstruction_loss.sum(),
"kl_local_sum": scvi_loss.kl_local.sum(),
#"kl_global": scvi_loss.kl_global,
"kl_beta_sum": kl_beta.sum(),
"kl_rho_sum": kl_rho.sum(),
"kl_delta_sum": kl_delta.sum(),
"reconstruction_loss_spliced_sum": reconstruction_loss_spliced.sum(),
"reconstruction_loss_unspliced_sum": reconstruction_loss_unspliced.sum(),
"n_obs": reconstruction_loss.shape[0],
}
def training_epoch_end(self, outputs):
n_obs, elbo, rec_loss, kl_local, rec_loss_spliced, rec_loss_unspliced, kl_beta, kl_rho, kl_delta = 0, 0, 0, 0, 0, 0, 0, 0, 0
for tensors in outputs:
elbo += tensors["reconstruction_loss_sum"] + tensors["kl_local_sum"]
rec_loss += tensors["reconstruction_loss_sum"]
rec_loss_spliced += tensors['reconstruction_loss_spliced_sum']
rec_loss_unspliced += tensors['reconstruction_loss_unspliced_sum']
kl_local += tensors["kl_local_sum"]
kl_beta += tensors["kl_beta_sum"]
kl_rho += tensors["kl_rho_sum"]
kl_delta += tensors["kl_delta_sum"]
n_obs += tensors["n_obs"]
# kl global same for each minibatch
#kl_global = outputs[0]["kl_global"]
#elbo += kl_global
self.log("elbo_train", elbo / n_obs)
self.log("reconstruction_loss_train", rec_loss / n_obs)
self.log("kl_local_train", kl_local / n_obs)
self.log("kl_beta_train", kl_beta / n_obs)
self.log("kl_rho_train", kl_rho / n_obs)
self.log("kl_delta_train", kl_delta / n_obs)
self.log("reconstruction_loss_spliced_train", rec_loss_spliced / n_obs)
self.log("reconstruction_loss_unspliced_train", rec_loss_unspliced / n_obs)
#self.log("kl_global_train", kl_global)
def validation_step(self, batch, batch_idx):
_, _, scvi_loss = self.forward(batch, loss_kwargs=self.loss_kwargs)
reconstruction_loss = scvi_loss.reconstruction_loss
reconstruction_loss_spliced = scvi_loss.reconstruction_loss_spliced
reconstruction_loss_unspliced = scvi_loss.reconstruction_loss_unspliced
kl_beta = scvi_loss.kl_beta
kl_rho = scvi_loss.kl_rho
kl_delta = scvi_loss.kl_delta
self.log("validation_loss", scvi_loss.loss, on_epoch=True)
return {
"reconstruction_loss_sum": reconstruction_loss.sum(),
"kl_local_sum": scvi_loss.kl_local.sum(),
#"kl_global": scvi_loss.kl_global,
"kl_beta_sum": kl_beta.sum(),
"kl_rho_sum": kl_rho.sum(),
"kl_delta_sum": kl_delta.sum(),
"reconstruction_loss_spliced_sum": reconstruction_loss_spliced.sum(),
"reconstruction_loss_unspliced_sum": reconstruction_loss_unspliced.sum(),
"n_obs": reconstruction_loss.shape[0],
}
[docs] def validation_epoch_end(self, outputs):
"""Aggregate validation step information."""
#n_obs, elbo, rec_loss, kl_local = 0, 0, 0, 0
n_obs, elbo, rec_loss, kl_local, rec_loss_spliced, rec_loss_unspliced, kl_beta, kl_rho, kl_delta = 0, 0, 0, 0, 0, 0, 0, 0, 0
for tensors in outputs:
elbo += tensors["reconstruction_loss_sum"] + tensors["kl_local_sum"]
rec_loss += tensors["reconstruction_loss_sum"]
rec_loss_spliced = tensors['reconstruction_loss_spliced_sum']
rec_loss_unspliced = tensors['reconstruction_loss_unspliced_sum']
kl_local += tensors["kl_local_sum"]
kl_beta += tensors["kl_beta_sum"]
kl_rho += tensors["kl_rho_sum"]
kl_delta += tensors["kl_delta_sum"]
n_obs += tensors["n_obs"]
# kl global same for each minibatch
#kl_global = outputs[0]["kl_global"]
#elbo += kl_global
self.log("elbo_validation", elbo / n_obs)
self.log("reconstruction_loss_validation", rec_loss / n_obs)
self.log("kl_local_validation", kl_local / n_obs)
self.log("kl_beta_validation", kl_beta / n_obs)
self.log("kl_rho_validation", kl_rho / n_obs)
self.log("kl_delta_validation", kl_delta / n_obs)
self.log("reconstruction_loss_spliced_validation", rec_loss_spliced / n_obs)
self.log("reconstruction_loss_unspliced_validation", rec_loss_unspliced / n_obs)
#self.log("kl_global_validation", kl_global)
def configure_optimizers(self):
params = filter(lambda p: p.requires_grad, self.module.parameters())
if self.optimizer_name == "Adam":
optim_cls = torch.optim.Adam
elif self.optimizer_name == "AdamW":
optim_cls = torch.optim.AdamW
else:
raise ValueError("Optimizer not understood.")
optimizer = optim_cls(
params, lr=self.lr, eps=self.eps, weight_decay=self.weight_decay
)
config = {"optimizer": optimizer}
if self.reduce_lr_on_plateau:
scheduler = ReduceLROnPlateau(
optimizer,
patience=self.lr_patience,
factor=self.lr_factor,
threshold=self.lr_threshold,
min_lr=self.lr_min,
threshold_mode="abs",
verbose=True,
)
config.update(
{
"lr_scheduler": scheduler,
"monitor": self.lr_scheduler_metric,
},
)
return config
@property
def kl_weight(self):
"""Scaling factor on KL divergence during training."""
epoch_criterion = self.n_epochs_kl_warmup is not None
step_criterion = self.n_steps_kl_warmup is not None
if epoch_criterion:
kl_weight = min(1.0, self.current_epoch / self.n_epochs_kl_warmup)
elif step_criterion:
kl_weight = min(1.0, self.global_step / self.n_steps_kl_warmup)
else:
kl_weight = 1.0
return kl_weight