import logging
import os
import pickle
import warnings
import torch
import numpy as np
from torch import nn
from anndata import AnnData, read
from typing import List, Optional, Union, Dict
from DeltaTopic.nn.util import _get_var_names_from_setup_anndata, parse_use_gpu_arg, _CONSTANTS, DataSplitter, TrainRunner, BaseModelClass
from DeltaTopic.nn.TrainingPlan import TrainingPlan
from DeltaTopic.nn.module import BALSAM_module, DeltaTopic_module
logger = logging.getLogger(__name__)
def _unpack_tensors(tensors):
x = tensors[_CONSTANTS.X_KEY].squeeze_(0)
unspliced = tensors[_CONSTANTS.PROTEIN_EXP_KEY].squeeze_(0)
return x, unspliced
def _unpack_tensors_BETM(tensors):
x = tensors[_CONSTANTS.X_KEY].squeeze_(0)
return x
[docs]class BALSAM(BaseModelClass):
"""
Bayesian Latent topic analysis with Sparse Association Matrix (BALSAM).
Parameters
----------
adata
AnnData object that has been registered via :meth:`~DeltaTopic.nn.util.setup_anndata`.
n_latent
Dimensionality of the latent space
**model_kwargs
Keyword args for :class:`~DeltaTopic.nn.module.BALSAM_module`
Examples
--------
>>> adata = anndata.read_h5ad(path_to_anndata)
>>> DeltaTopic.nn.util.setup_anndata(adata)
>>> model = DeltaTopic.nn.modelhub.BALSAM(adata)
>>> model.train(100)
"""
[docs] def __init__(
self,
adata_seq: AnnData,
n_latent: int = 32,
**model_kwargs,
):
super(BALSAM, self).__init__()
self.n_latent = n_latent
self.adata = adata_seq
self.module = BALSAM_module(
n_genes = self.adata.n_vars,
n_latent=n_latent,
**model_kwargs,
)
self._model_summary_string = (
"BALSAM with the following params: \nn_latent: {}, n_genes: {}"
).format(n_latent, self.adata.n_vars)
[docs] def train(
self,
max_epochs: Optional[int] = 1000,
lr: float = 1e-3,
use_gpu: Optional[Union[str, int, bool]] = None,
train_size: float = 0.9,
validation_size: Optional[float] = None,
batch_size: int = 128,
n_steps_kl_warmup: Union[int, None] = None,
n_epochs_kl_warmup: Union[int, None] = None,
plan_kwargs: Optional[dict] = None,
**kwargs,
):
"""
Trains the model using amortized variational inference.
Parameters
----------
max_epochs
Number of passes through the dataset.
lr
Learning rate for optimization.
use_gpu
Use default GPU if available (if None or True), or index of GPU to use (if int),
or name of GPU (if str, e.g., `'cuda:0'`), or use CPU (if False).
train_size
Size of training set in the range [0.0, 1.0].
validation_size
Size of the test set. If `None`, defaults to 1 - `train_size`. If
`train_size + validation_size < 1`, the remaining cells belong to a test set.
batch_size
Minibatch size to use during training.
n_steps_kl_warmup
Number of training steps (minibatches) to scale weight on KL divergences from 0 to 1.
Only activated when `n_epochs_kl_warmup` is set to None. If `None`, defaults
to `floor(0.75 * adata.n_obs)`.
n_epochs_kl_warmup
Number of epochs to scale weight on KL divergences from 0 to 1.
Overrides `n_steps_kl_warmup` when both are not `None`.
"""
n_steps_kl_warmup = (
n_steps_kl_warmup
if n_steps_kl_warmup is not None
else int(0.75 * self.adata.n_obs)
)
update_dict = {
"lr": lr,
"n_epochs_kl_warmup": n_epochs_kl_warmup,
"n_steps_kl_warmup": n_steps_kl_warmup,
}
if plan_kwargs is not None:
plan_kwargs.update(update_dict)
else:
plan_kwargs = update_dict
if max_epochs is None:
n_cells = self.adata.n_obs
max_epochs = np.min([round((20000 / n_cells) * 400), 400])
plan_kwargs = plan_kwargs if isinstance(plan_kwargs, dict) else dict()
data_splitter = DataSplitter(
self.adata,
train_size=train_size,
validation_size=validation_size,
batch_size=batch_size,
use_gpu=use_gpu,
)
training_plan = TrainingPlan(self.module, **plan_kwargs)
runner = TrainRunner(
self,
training_plan=training_plan,
data_splitter=data_splitter,
max_epochs=max_epochs,
use_gpu=use_gpu,
**kwargs,
)
return runner()
@torch.no_grad()
def get_latent_representation(
self,
adata: AnnData = None,
deterministic: bool = True,
output_softmax_z: bool = True,
batch_size: int = 128,
):
"""
Return the latent space (topic proportions).
Parameters
----------
adatas
adata registered with setup_anndata.
deterministic
If true, use the mean of the encoder instead of a stochastic sample
output_softmax_z
If true, output probability, otherwise output z (unnormalized probability).
batch_size
Minibatch size for data loading into model.
"""
if adata is None:
adata = self.adata
scdl = self._make_data_loader(adata, batch_size=batch_size)
self.module.eval()
latent_z = []
for tensors in scdl:
(
sample_batch
) = _unpack_tensors_BETM(tensors)
z_dict = self.module.sample_from_posterior_z(sample_batch, deterministic=deterministic, output_softmax_z=output_softmax_z)
latent_z.append(z_dict["z"])
latent_z = torch.cat(latent_z).cpu().detach().numpy()
print(f'Deterministic: {deterministic}, output_softmax_z: {output_softmax_z}' )
return latent_z
@torch.no_grad()
def get_parameters(
self,
save_dir = None,
overwrite = False,
):
"""
Save the spike and slab parameters to the specificed directory.
Parameters
----------
save_dir
Save directory.
overwrite
If true, overwrite the existing files.
"""
self.module.eval()
decoder = self.module.decoder
if not os.path.exists(os.path.join(save_dir,"model_parameters")) or overwrite:
os.makedirs(os.path.join(save_dir,"model_parameters"), exist_ok=overwrite)
np.savetxt(os.path.join(
save_dir,"model_parameters", "spike_logit_rho.txt"
), decoder.spike_logit.cpu().numpy())
np.savetxt(os.path.join(
save_dir,"model_parameters", "slab_mean_rho.txt"
), decoder.slab_mean.cpu().numpy())
np.savetxt(os.path.join(
save_dir,"model_parameters", "slab_lnvar_rho.txt"
), decoder.slab_lnvar.cpu().numpy())
np.savetxt(os.path.join(
save_dir,"model_parameters", "bias_gene.txt"
), decoder.bias_d.cpu().numpy())
[docs] def save(
self,
dir_path: str,
overwrite: bool = False,
save_anndata: bool = False,
**anndata_write_kwargs,
):
"""
Save model parameters to the specified directory.
Parameters
----------
dir_path
Path to a directory.
overwrite
Overwrite existing data or not. If `False` and directory
already exists at `dir_path`, error will be raised.
save_anndata
If True, also saves the anndata
anndata_write_kwargs
Kwargs for anndata write function
"""
if not os.path.exists(dir_path) or overwrite:
os.makedirs(dir_path, exist_ok=overwrite)
else:
raise ValueError(
"{} already exists. Please provide an unexisting directory for saving.".format(
dir_path
)
)
if save_anndata:
save_path = os.path.join(
dir_path, "adata.h5ad"
)
self.adata.write(save_path)
varnames_save_path = os.path.join(
dir_path, "var_names.csv"
)
var_names = self.adata.var_names.astype(str)
var_names = var_names.to_numpy()
np.savetxt(varnames_save_path, var_names, fmt="%s")
model_save_path = os.path.join(dir_path, "model_params.pt")
torch.save(self.module.state_dict(), model_save_path)
[docs]class DeltaTopic(BaseModelClass):
"""
Dynamically-Encoded Latent Transcriptomic pattern Analysis by Topic modelling (DeltaTopic).
Parameters
----------
adata
AnnData object that has been registered via :meth:`~DeltaTopic.nn.util.setup_anndata`.
n_latent
Dimensionality of the latent space
**model_kwargs
Keyword args for :class:`~DeltaTopic.nn.module.DeltaTopic_module`
Examples
--------
>>> adata= anndata.read_h5ad(path_to_anndata_spliced)
>>> X_unspliced = sc.read(path_to_anndata_spliced)
>>> adata.obsm["unspliced_expression"] = (X_unspliced.X.copy()
>>> DeltaTopic.nn.util.setup_anndata(adata, layer="counts", unspliced_obsm_key = "unspliced_expression")
>>> model = DeltaTopic.nn.modelhub.DeltaTopic(adata)
>>> model.train(100)
"""
[docs] def __init__(
self,
adata_seq: AnnData,
n_latent: int = 32,
**model_kwargs,
):
super(DeltaTopic, self).__init__()
self.n_latent = n_latent
self.adata = adata_seq
self.module = DeltaTopic_module(
n_genes = self.adata.n_vars,
n_latent=n_latent,
**model_kwargs,
)
self._model_summary_string = (
"DeltaTopic with the following params: \nn_latent: {}, n_genes: {} "
).format(n_latent, self.adata.n_vars)
[docs] def train(
self,
max_epochs: Optional[int] = 1000,
lr: float = 1e-3,
use_gpu: Optional[Union[str, int, bool]] = None,
train_size: float = 0.9,
validation_size: Optional[float] = None,
batch_size: int = 128,
n_steps_kl_warmup: Union[int, None] = None,
n_epochs_kl_warmup: Union[int, None] = None,
plan_kwargs: Optional[dict] = None,
**kwargs,
):
"""
Trains the model using amortized variational inference.
Parameters
----------
max_epochs
Number of passes through the dataset.
lr
Learning rate for optimization.
use_gpu
Use default GPU if available (if None or True), or index of GPU to use (if int),
or name of GPU (if str, e.g., `'cuda:0'`), or use CPU (if False).
train_size
Size of training set in the range [0.0, 1.0].
validation_size
Size of the test set. If `None`, defaults to 1 - `train_size`. If
`train_size + validation_size < 1`, the remaining cells belong to a test set.
batch_size
Minibatch size to use during training.
n_steps_kl_warmup
Number of training steps (minibatches) to scale weight on KL divergences from 0 to 1.
Only activated when `n_epochs_kl_warmup` is set to None. If `None`, defaults
to `floor(0.75 * adata.n_obs)`.
n_epochs_kl_warmup
Number of epochs to scale weight on KL divergences from 0 to 1.
Overrides `n_steps_kl_warmup` when both are not `None`.
"""
n_steps_kl_warmup = (
n_steps_kl_warmup
if n_steps_kl_warmup is not None
else int(0.75 * self.adata.n_obs)
)
update_dict = {
"lr": lr,
"n_epochs_kl_warmup": n_epochs_kl_warmup,
"n_steps_kl_warmup": n_steps_kl_warmup,
}
if plan_kwargs is not None:
plan_kwargs.update(update_dict)
else:
plan_kwargs = update_dict
if max_epochs is None:
n_cells = self.adata.n_obs
max_epochs = np.min([round((20000 / n_cells) * 400), 400])
plan_kwargs = plan_kwargs if isinstance(plan_kwargs, dict) else dict()
data_splitter = DataSplitter(
self.adata,
train_size=train_size,
validation_size=validation_size,
batch_size=batch_size,
use_gpu=use_gpu,
)
training_plan = TrainingPlan(self.module, **plan_kwargs)
runner = TrainRunner(
self,
training_plan=training_plan,
data_splitter=data_splitter,
max_epochs=max_epochs,
use_gpu=use_gpu,
**kwargs,
)
return runner()
@torch.no_grad()
def get_latent_representation(
self,
adata: AnnData = None,
deterministic: bool = True,
output_softmax_z: bool = True,
batch_size: int = 128,
):
"""
Return the latent space (topic proportions) for spliced and unspliced.
Parameters
----------
adatas
List of adata_spliced and adata_unspliced.
deterministic
If true, use the mean of the encoder instead of a stochastic sample.
output_softmax_z
if true, output probability, otherwise output z.
batch_size
Minibatch size for data loading into model.
"""
if adata is None:
adata = self.adata
scdl = self._make_data_loader(adata, batch_size=batch_size)
self.module.eval()
latent_z = []
for tensors in scdl:
(
sample_batch,
sample_batch_unspliced,
) = _unpack_tensors(tensors)
z_dict = self.module.sample_from_posterior_z(sample_batch, sample_batch_unspliced, deterministic=deterministic, output_softmax_z=output_softmax_z)
latent_z.append(z_dict["z"])
latent_z = torch.cat(latent_z).cpu().detach().numpy()
print(f'Deterministic: {deterministic}, output_softmax_z: {output_softmax_z}' )
return latent_z
@torch.no_grad()
def get_parameters(
self,
save_dir = None,
overwrite = False,
):
"""
Save the spike and slab parameters to the specified directory.
Parameters
----------
save_dir
Directory to save the parameters.
overwrite
If true, overwrite the existing parameters.
"""
self.module.eval()
decoder = self.module.decoder
if not os.path.exists(os.path.join(save_dir,"model_parameters")) or overwrite:
os.makedirs(os.path.join(save_dir,"model_parameters"), exist_ok=overwrite)
np.savetxt(os.path.join(
save_dir,"model_parameters", "spike_logit_delta.txt"
), decoder.spike_logit_delta.cpu().numpy())
np.savetxt(os.path.join(
save_dir,"model_parameters", "spike_logit_rho.txt"
), decoder.spike_logit_rho.cpu().numpy())
np.savetxt(os.path.join(
save_dir,"model_parameters", "slab_mean_delta.txt"
), decoder.slab_mean_delta.cpu().numpy())
np.savetxt(os.path.join(
save_dir,"model_parameters", "slab_mean_rho.txt"
), decoder.slab_mean_rho.cpu().numpy())
np.savetxt(os.path.join(
save_dir,"model_parameters", "slab_lnvar_delta.txt"
), decoder.slab_lnvar_delta.cpu().numpy())
np.savetxt(os.path.join(
save_dir,"model_parameters", "slab_lnvar_rho.txt"
), decoder.slab_lnvar_rho.cpu().numpy())
np.savetxt(os.path.join(
save_dir,"model_parameters", "bias_gene.txt"
), decoder.bias_gene.cpu().numpy())
@torch.no_grad()
def get_reconstruction_error(
self,
adata: Optional[AnnData] = None,
batch_size: Optional[int] = 128,
):
"""
Return the reconstruction error for the data.
Parameters
----------
adata
AnnData object with equivalent structure to initial AnnData. If `None`, defaults to the
AnnData object used to initialize the model.
batch_size
Minibatch size for data loading into model.
"""
if adata is None:
adata = self.adata
scdl = self._make_data_loader(adata, batch_size=batch_size)
self.module.eval()
reconstruction_loss_spliced_sum = 0
reconstruction_loss_unspliced_sum = 0
n_spliced = 0
n_unspliced = 0
for tensors in scdl:
(
sample_batch_spliced,
sample_batch_unspliced,
*_,
) = _unpack_tensors(tensors)
reconstruction_loss_spliced, reconstruction_loss_unspliced = self.module.get_reconstruction_loss(sample_batch_spliced, sample_batch_unspliced)
reconstruction_loss_spliced_sum += torch.sum(reconstruction_loss_spliced)
reconstruction_loss_unspliced_sum += torch.sum(reconstruction_loss_unspliced)
n_spliced += sample_batch_spliced.shape[0]
n_unspliced += sample_batch_unspliced.shape[0]
recon_spliced = reconstruction_loss_spliced_sum/n_spliced
recon_unspliced = reconstruction_loss_unspliced_sum/n_unspliced
return recon_spliced.cpu().numpy(), recon_unspliced.cpu().numpy()
[docs] def save(
self,
dir_path: str,
overwrite: bool = False,
save_anndata: bool = False,
**anndata_write_kwargs,
):
"""
Save the state of the model.
Neither the trainer optimizer state nor the trainer history are saved.
Parameters
----------
dir_path
Path to a directory.
overwrite
Overwrite existing data or not. If `False` and directory
already exists at `dir_path`, error will be raised.
save_anndata
If True, also saves the anndata
anndata_write_kwargs
Kwargs for anndata write function
"""
# save the model state dict and the trainer state dict only
if not os.path.exists(dir_path) or overwrite:
os.makedirs(dir_path, exist_ok=overwrite)
else:
raise ValueError(
"{} already exists. Please provide an unexisting directory for saving.".format(
dir_path
)
)
if save_anndata:
save_path = os.path.join(
dir_path, "adata.h5ad"
)
self.adata.write(save_path)
varnames_save_path = os.path.join(
dir_path, "var_names.csv"
)
var_names = self.adata.var_names.astype(str)
var_names = var_names.to_numpy()
np.savetxt(varnames_save_path, var_names, fmt="%s")
model_save_path = os.path.join(dir_path, "model_params.pt")
attr_save_path = os.path.join(dir_path, "attr.pkl")
torch.save(self.module.state_dict(), model_save_path)
[docs] @classmethod
def load(
cls,
dir_path: str,
adata_seq: Optional[AnnData] = None,
use_gpu: Optional[Union[str, int, bool]] = None,
):
"""
Instantiate a model from the saved output.
Parameters
----------
adata_seq
AnnData organized in the same way as data used to train model.
dir_path
Path to saved outputs.
use_gpu
Load model on default GPU if available (if None or True),
or index of GPU to use (if int), or name of GPU (if str), or use CPU (if False).
Returns
-------
Model with loaded state dictionaries.
"""
model_path = os.path.join(dir_path, "model_params.pt")
setup_dict_path = os.path.join(dir_path, "attr.pkl")
seq_data_path = os.path.join(dir_path, "adata.h5ad")
path_data_path = os.path.join(dir_path, "adata_pathways.h5ad")
seq_var_names_path = os.path.join(dir_path, "var_names.csv")
if adata_seq is None and os.path.exists(seq_data_path):
adata_seq = read(seq_data_path)
elif adata_seq is None and not os.path.exists(seq_data_path):
raise ValueError(
"Save path contains no saved anndata and no adata was passed."
)
if os.path.exists(path_data_path):
adata_path = read(path_data_path)
elif not os.path.exists(path_data_path):
adata_path = None
print("no pathways saved")
adata = adata_seq
seq_var_names = np.genfromtxt(seq_var_names_path, delimiter=",", dtype=str)
var_names = seq_var_names
saved_var_names = var_names
user_var_names = adata.var_names.astype(str)
if not np.array_equal(saved_var_names, user_var_names):
warnings.warn(
"var_names for adata passed in does not match var_names of "
"adata used to train the model. For valid results, the vars "
"need to be the same and in the same order as the adata used to train the model."
)
with open(setup_dict_path, "rb") as handle:
attr_dict = pickle.load(handle)
scvi_setup_dicts = attr_dict.pop("scvi_setup_dicts_")
transfer_anndata_setup(scvi_setup_dicts, adata_seq)
# get the parameters for the class init signiture
init_params = attr_dict.pop("init_params_")
# new saving and loading, enable backwards compatibility
if "non_kwargs" in init_params.keys():
# grab all the parameters execept for kwargs (is a dict)
non_kwargs = init_params["non_kwargs"]
kwargs = init_params["kwargs"]
# expand out kwargs
kwargs = {k: v for (i, j) in kwargs.items() for (k, v) in j.items()}
else:
# grab all the parameters execept for kwargs (is a dict)
non_kwargs = {
k: v for k, v in init_params.items() if not isinstance(v, dict)
}
kwargs = {k: v for k, v in init_params.items() if isinstance(v, dict)}
kwargs = {k: v for (i, j) in kwargs.items() for (k, v) in j.items()}
# the default init require this way of loading models
if adata_path is not None:
model = cls(adata_seq, **non_kwargs, adata_pathway=adata_path, **kwargs)
elif adata_path is None:
model = cls(adata_seq, **non_kwargs, **kwargs)
for attr, val in attr_dict.items():
setattr(model, attr, val)
_, device = parse_use_gpu_arg(use_gpu)
model.module.load_state_dict(torch.load(model_path, map_location=device))
model.module.eval()
model.to_device(device)
return model