Source code for DeltaTopic.nn.module

# -*- coding: utf-8 -*-
"""Main module."""
from typing import List, Tuple
import torch
from torch.distributions import Normal
from torch.distributions import kl_divergence as kl
from DeltaTopic.nn.util import _CONSTANTS
from DeltaTopic.nn.base_model import BaseModuleClass, LossRecorder, auto_move_data
from DeltaTopic.nn.base_components import DeltaTopicEncoder, BALSAMDecoder, BALSAMEncoder, DeltaTopicDecoder

torch.backends.cudnn.benchmark = True

[docs]class BALSAM_module(BaseModuleClass): """ BALASM module Parameters ---------- n_genes number of genes n_latent dimension of latent space n_layers_encoder_individual number of individual layers in the encoder dim_hidden_encoder dimension of the hidden layers in the encoder pip0_rho scaling factor for rho loss, default 0.1 kl_weight_beta: scaling factor for KL, default 1.0 log_variational Log(data+1) prior to encoding for numerical stability. Not normalization. """
[docs] def __init__( self, n_genes: int, n_latent: int = 32, n_layers_encoder_individual: int = 2, dim_hidden_encoder: int = 128, log_variational: bool = True, pip0_rho: float = 0.1, kl_weight_beta: float = 1.0, ): super().__init__() self.n_input = n_genes self.n_latent = n_latent self.log_variational = log_variational self.pip0_rho = pip0_rho self.kl_weight_beta = kl_weight_beta self.z_encoder = BALSAMEncoder( n_input=self.n_input, n_output=self.n_latent, n_hidden=dim_hidden_encoder, n_layers_individual=n_layers_encoder_individual, log_variational = self.log_variational, ) self.decoder = BALSAMDecoder(self.n_latent , self.n_input, pip0 = self.pip0_rho, )
def dir_llik(self, xx: torch.Tensor, aa: torch.Tensor, ) -> torch.Tensor: ''' Compute the Dirichlet log-likelihood. ''' reconstruction_loss = None term1 = (torch.lgamma(torch.sum(aa, dim=-1)) - torch.lgamma(torch.sum(aa + xx, dim=-1))) #[n_batch] term2 = torch.sum(torch.where(xx > 0, torch.lgamma(aa + xx) - torch.lgamma(aa), torch.zeros_like(xx)), dim=-1) #[n_batch reconstruction_loss = term1 + term2 #[n_batch return reconstruction_loss def _get_inference_input(self, tensors): return dict(x=tensors[_CONSTANTS.X_KEY]) def _get_generative_input(self, tensors, inference_outputs): z = inference_outputs["z"] return dict(z=z) @auto_move_data def inference(self, x: torch.Tensor) -> dict: x_ = x qz_m, qz_v, z = self.z_encoder(x_) return dict(qz_m=qz_m, qz_v=qz_v, z=z) @auto_move_data def generative(self, z) -> dict: rho, rho_kl, theta = self.decoder(z) return dict(rho = rho, rho_kl = rho_kl, theta = theta) def sample_from_posterior_z( self, x: torch.Tensor, deterministic: bool = True, output_softmax_z: bool = True, ): """Sample from the posterior z """ inference_out = self.inference(x) if deterministic: z = inference_out["qz_m"] else: z = inference_out["z"] if output_softmax_z: generative_outputs = self.generative(z) z = generative_outputs["theta"] return dict(z=z) @auto_move_data def get_reconstruction_loss( self, x: torch.Tensor, ) -> torch.Tensor: """ Returns the reconstruction loss for a batch of data. Parameters ---------- x tensor of values with shape ``(batch_size, n_input)`` Returns ------- type tensor of means of the scaled frequencies """ inference_out = self.inference(x) z = inference_out["z"] gen_out = self.generative(z) theta = gen_out["theta"] rho = gen_out["rho"] log_aa = torch.clamp(torch.mm(theta, rho), -10, 10) aa = torch.exp(log_aa) reconstruction_loss = -self.dir_llik(x, aa) return reconstruction_loss def loss( self, tensors, inference_outputs, generative_outputs, # this is important to include kl_weight=1.0, #kl_weight_beta = 1.0, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Agrregates the likelihood and KL divergences to form the loss function. """ kl_weight_beta = self.kl_weight_beta x = tensors[_CONSTANTS.X_KEY] qz_m = inference_outputs["qz_m"] qz_v = inference_outputs["qz_v"] rho_kl = generative_outputs["rho_kl"] # [batch_size] reconstruction_loss = self.get_reconstruction_loss(x) # KL Divergence for z [batch_size] mean = torch.zeros_like(qz_m) scale = torch.ones_like(qz_v) kl_divergence_z = kl(Normal(qz_m, torch.sqrt(qz_v)), Normal(mean, scale)).sum( dim=1 ) # suming over all the latent dimensinos # kl_divergence for beta, rho_kl, tensor of torch.size([]) <- torch.sum([N_topics, N_genes]) kl_divergence_beta = rho_kl kl_local = kl_divergence_z loss = torch.mean(reconstruction_loss + kl_weight * kl_local) + kl_weight_beta * kl_divergence_beta/x.shape[1] return LossRecorder(loss, reconstruction_loss, kl_local, reconstruction_loss_spliced=reconstruction_loss, reconstruction_loss_unspliced=torch.Tensor(0), kl_beta = kl_divergence_beta, kl_rho = rho_kl, kl_delta = torch.Tensor(0))
[docs]class DeltaTopic_module(BaseModuleClass): """ DeltaTopic module. Parameters ---------- n_genes number of genes n_latent dimension of latent space n_layers_encoder_individual number of individual layers in the encoder dim_hidden_encoder dimension of the hidden layers in the encoder pip0_rho scaling factor for rho loss, default 0.1 pip0_delta scaling factor for delta loss, default 0.1 kl_weight_beta: scaling factor for KL, default 1.0 log_variational Log(data+1) prior to encoding for numerical stability. Not normalization. """
[docs] def __init__( self, n_genes: int, n_latent: int = 10, n_layers_encoder_individual: int = 2, dim_hidden_encoder: int = 128, pip0_rho: float = 0.1, pip0_delta: float = 0.1, kl_weight_beta: float = 1.0, log_variational: bool = True, ): super().__init__() dim_input_list = [n_genes, n_genes] self.n_input_list = dim_input_list self.total_genes = n_genes self.n_latent = n_latent self.pip0_rho = pip0_rho self.pip0_delta = pip0_delta self.log_variational = log_variational self.kl_weight_beta = kl_weight_beta self.z_encoder = DeltaTopicEncoder( n_input_list=dim_input_list, n_output=self.n_latent, n_hidden=dim_hidden_encoder, n_layers_individual=n_layers_encoder_individual, log_variational = self.log_variational, ) # TODO: use self.total_genes is dangerous, if we have dfferent sets of genes in spliced and unspliced self.decoder = DeltaTopicDecoder(self.n_latent , self.total_genes, pip0_rho = self.pip0_rho, pip0_delta = self.pip0_delta, )
def dir_llik(self, xx: torch.Tensor, aa: torch.Tensor, ) -> torch.Tensor: ''' Return the Dirichlet log-likelihood for a batch. ''' reconstruction_loss = None term1 = (torch.lgamma(torch.sum(aa, dim=-1)) - torch.lgamma(torch.sum(aa + xx, dim=-1))) #[n_batch] term2 = torch.sum(torch.where(xx > 0, torch.lgamma(aa + xx) - torch.lgamma(aa), torch.zeros_like(xx)), dim=-1) #[n_batch reconstruction_loss = term1 + term2 #[n_batch return reconstruction_loss def _get_inference_input(self, tensors): return dict(x=tensors[_CONSTANTS.X_KEY],y = tensors[_CONSTANTS.PROTEIN_EXP_KEY]) def _get_generative_input(self, tensors, inference_outputs): z = inference_outputs["z"] return dict(z=z) @auto_move_data def inference(self, x: torch.Tensor, y: torch.Tensor) -> dict: x_ = x y_ = y q_m, q_v, z = self.z_encoder(x_, y_) return dict(z = z, q_m = q_m, q_v = q_v) @auto_move_data def generative(self, z) -> dict: rho, delta, rho_kl, delta_kl, theta = self.decoder(z) return dict(rho = rho, delta = delta, rho_kl = rho_kl, delta_kl = delta_kl, theta = theta) def sample_from_posterior_z( self, x: torch.Tensor, y: torch.Tensor, deterministic: bool = True, output_softmax_z: bool = True, ): """ sample from the posterior of latent space z """ inference_out = self.inference(x,y) if deterministic: # average of the two means WITHOUT sampling z = inference_out["q_m"] else: # sampling z = inference_out["z"] if output_softmax_z: generative_outputs = self.generative(z) z = generative_outputs["theta"] return dict(z=z) @auto_move_data def get_reconstruction_loss( self, x: torch.Tensor, y: torch.Tensor, ) -> torch.Tensor: """ Returns the reconstruction loss for the given batch. Parameters ---------- x tensor of values with shape ``(batch_size, n_input)`` y tensor of values with shape ``(batch_size, n_input)`` """ inference_out = self.inference(x, y) z = inference_out["z"] gen_out = self.generative(z) theta = gen_out["theta"] rho = gen_out["rho"] log_aa_spliced = torch.clamp(torch.mm(theta, rho), -10, 10) aa_spliced = torch.exp(log_aa_spliced) delta = gen_out["delta"] log_aa_unspliced = torch.clamp(torch.mm(theta, rho + delta), -10, 10) aa_unspliced = torch.exp(log_aa_unspliced) reconstruction_loss_spliced = -self.dir_llik(x, aa_spliced) reconstruction_loss_unspliced = -self.dir_llik(y, aa_unspliced) return reconstruction_loss_spliced, reconstruction_loss_unspliced def loss( self, tensors, inference_outputs, generative_outputs, # this is important to include kl_weight=1.0, kl_weight_beta = 1.0, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Aggregate the kl and likelihood to form the loss. """ kl_weight_beta = self.kl_weight_beta x = tensors[_CONSTANTS.X_KEY] y = tensors[_CONSTANTS.PROTEIN_EXP_KEY] q_m = inference_outputs["q_m"] q_v = inference_outputs["q_v"] rho_kl = generative_outputs["rho_kl"] delta_kl = generative_outputs["delta_kl"] # [batch_size] reconstruction_loss_spliced, reconstruction_loss_unspliced = self.get_reconstruction_loss(x, y) # KL Divergence for z [batch_size] mean = torch.zeros_like(q_m) scale = torch.ones_like(q_v) kl_divergence = kl(Normal(q_m, torch.sqrt(q_v)), Normal(mean, scale)).sum( dim=1 ) # suming over all the topics # kl_divergence for beta, rho_kl, tensor of torch.size([]) <- torch.sum([N_topics, N_genes]) kl_divergence_beta = rho_kl + delta_kl kl_local = kl_divergence reconstruction_loss = reconstruction_loss_spliced + reconstruction_loss_unspliced loss = torch.mean(reconstruction_loss + kl_weight * kl_local) + kl_weight_beta * kl_divergence_beta/x.shape[1] return LossRecorder(loss, reconstruction_loss, kl_local, reconstruction_loss_spliced=reconstruction_loss_spliced, reconstruction_loss_unspliced=reconstruction_loss_unspliced, kl_beta = kl_divergence_beta, kl_rho = rho_kl, kl_delta = delta_kl)