Source code for starfysh.starfysh

from __future__ import print_function

import numpy as np
import pandas as pd
import os
import random

os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
import torch
import torch.nn as nn

from torchvision import transforms
from torchvision.utils import make_grid
from torch.distributions import constraints, Distribution, Normal, Gamma, Poisson, Dirichlet
from torch.distributions import kl_divergence as kl

# Module import
from starfysh import LOGGER


os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
random.seed(0)
torch.manual_seed(0)
np.random.seed(0)


# TODO: inherit `AVAE` (expr model) w/ `AVAE_PoE` (expr + histology model), update latest PoE model

[docs]class AVAE(nn.Module): """ Model design p(x|z)=f(z) p(z|x)~N(0,1) q(z|x)~g(x) """ def __init__( self, adata, gene_sig, win_loglib, ) -> None: """ Auxiliary Variational AutoEncoder (AVAE) - Core model for spatial deconvolution without H&E image integration Paramters --------- adata : sc.AnnData ST raw expression count (dim: [S, G]) gene_sig : pd.DataFrame Signature gene sets for each annotated cell type win_loglib : float Log-library size smoothed with neighboring spots """ super().__init__() self.win_loglib=torch.Tensor(win_loglib) self.c_in = adata.shape[1] # c_in : Num. input features (# input genes) self.c_bn = 10 # c_bn : latent number, numbers of bottle-necks self.c_hidden = 256 self.c_kn = gene_sig.shape[1] self.eps = 1e-5 # for r.v. w/ numerical constraints self.alpha = torch.nn.Parameter(torch.rand(self.c_kn)*1e3,requires_grad=True) self.qs_logm = torch.nn.Parameter(torch.zeros(self.c_kn, self.c_bn), requires_grad=True) self.qu_m = torch.nn.Parameter(torch.randn(self.c_kn, self.c_bn), requires_grad=True) self.qu_logv = torch.nn.Parameter(torch.zeros(self.c_kn, self.c_bn), requires_grad=True) self.c_enc = nn.Sequential( nn.Linear(self.c_in, self.c_hidden, bias=True), nn.BatchNorm1d(self.c_hidden, momentum=0.01,eps=0.001), nn.ReLU() ) self.c_enc_m = nn.Sequential( nn.Linear(self.c_hidden, self.c_kn, bias=True), nn.BatchNorm1d(self.c_kn, momentum=0.01,eps=0.001), nn.Softmax(dim=-1) ) #self.c_enc_logv = nn.Linear(self.c_hidden, self.c_hidden) self.l_enc = nn.Sequential( nn.Linear(self.c_in, self.c_hidden, bias=True), nn.BatchNorm1d(self.c_hidden, momentum=0.01,eps=0.001), nn.ReLU(), #nn.Linear(self.c_hidden, 1, bias=True), #nn.ReLU(), ) self.l_enc_m = nn.Linear(self.c_hidden, 1) self.l_enc_logv = nn.Linear(self.c_hidden, 1) # neural network f1 to get the z, p(z|x), f1(x,\phi_1)=[z_m,torch.exp(z_logv)] self.z_enc = nn.Sequential( #nn.Linear(self.c_in+self.c_kn, self.c_hidden, bias=True), nn.Linear(self.c_in, self.c_hidden, bias=True), nn.BatchNorm1d(self.c_hidden, momentum=0.01,eps=0.001), nn.ReLU(), ) self.z_enc_m = nn.Linear(self.c_hidden, self.c_bn * self.c_kn) self.z_enc_logv = nn.Linear(self.c_hidden, self.c_bn * self.c_kn) # gene dispersion self.px_r = torch.nn.Parameter(torch.randn(self.c_in),requires_grad=True) # neural network g to get the x_m and x_v, p(x|z), g(z,\phi_3)=[x_m,x_v] self.px_hidden_decoder = nn.Sequential( nn.Linear(self.c_bn, self.c_hidden, bias=True), nn.ReLU(), ) self.px_scale_decoder = nn.Sequential( nn.Linear(self.c_hidden,self.c_in), # nn.Softplus(), nn.ReLU() )
[docs] def reparameterize(self, mu, log_var): """ :param mu: mean from the encoder's latent space :param log_var: log variance from the encoder's latent space """ std = torch.exp(0.5*log_var) # standard deviation eps = torch.randn_like(std) # `randn_like` as we need the same size sample = mu + (eps * std) # sampling return sample
[docs] def inference(self, x): # l is inferred from logrithmized x x_n = torch.log(1+x) hidden = self.l_enc(x_n) ql_m = self.l_enc_m(hidden) ql_logv = self.l_enc_logv(hidden) ql = self.reparameterize(ql_m, ql_logv) # ql = torch.clamp(ql, min = 0.01) ql = torch.clamp(ql, min=self.eps) # non-negative constraints # x is processed by dividing the inferred library x_n = torch.log(1+x) hidden = self.c_enc(x_n) qc_m = self.c_enc_m(hidden) qc = Dirichlet(self.alpha * qc_m + self.eps).rsample()[:,:,None] hidden = self.z_enc(x_n) qz_m_ct = self.z_enc_m(hidden).reshape([x_n.shape[0],self.c_kn,self.c_bn]) qz_m_ct = (qc * qz_m_ct) qz_m = qz_m_ct.sum(axis=1) qz_logv_ct = self.z_enc_logv(hidden).reshape([x_n.shape[0],self.c_kn,self.c_bn]) qz_logv_ct = (qc * qz_logv_ct) qz_logv = qz_logv_ct.sum(axis=1) qz = self.reparameterize(qz_m, qz_logv) qu_m = self.qu_m qu_logv = self.qu_logv qu = self.reparameterize(qu_m, qu_logv) return dict( qc_m = qc_m, qc=qc, qz_m=qz_m, qz_m_ct=qz_m_ct, qz_logv = qz_logv, qz_logv_ct = qz_logv_ct, qz=qz, ql_m=ql_m, ql_logv = ql_logv, ql=ql, qu_m=qu_m, qu_logv=qu_logv, qu=qu, qs_logm=self.qs_logm, )
[docs] def generative( self, inference_outputs, xs_k, ): qz = inference_outputs['qz'] ql = inference_outputs['ql'] hidden = self.px_hidden_decoder(qz) px_scale = self.px_scale_decoder(hidden) # TODO: verify whether taking exponential for `ql` term is valid? px_rate = torch.exp(ql) * px_scale xs_k = xs_k / torch.exp(ql) * torch.exp(ql.mean(axis=1,keepdims=True)) pc_p = self.alpha * xs_k + self.eps with torch.no_grad(): self.px_r.clamp_(min=self.px_r) return dict( px_rate=px_rate, px_r=self.px_r, pc_p=pc_p, xs_k=xs_k, )
[docs] def get_loss( self, generative_outputs, inference_outputs, x, x_peri, library, device ): qc = inference_outputs["qc"] qc_m = inference_outputs["qc_m"] qs_logm = self.qs_logm qu = inference_outputs["qu"] qu_m = inference_outputs["qu_m"] qu_logv = inference_outputs["qu_logv"] qz_m = inference_outputs["qz_m"] qz_logv = inference_outputs["qz_logv"] ql_m = inference_outputs["ql_m"] ql_logv = inference_outputs['ql_logv'] ql = inference_outputs['ql'] px_rate = generative_outputs["px_rate"] px_r = generative_outputs["px_r"] pc_p = generative_outputs["pc_p"] kl_divergence_u = kl( Normal(qu_m, torch.exp(qu_logv / 2)), Normal(torch.zeros_like(qu_m), torch.ones_like(qu_m)) ).sum( dim=1 ).mean() mean_pz = (qu.unsqueeze(0) * qc).sum(axis=1) std_pz = (torch.exp(qs_logm / 2).unsqueeze(0) * qc).sum(axis=1) kl_divergence_z = kl(Normal(qz_m, torch.exp(qz_logv / 2)), Normal(mean_pz, std_pz)).sum( dim=1 ).mean() kl_divergence_n = kl(Normal(ql_m, torch.sqrt(torch.exp(ql_logv))), Normal(library,torch.ones_like(ql))).sum( dim=1 ).mean() if (x_peri[:,0] == 1).sum() > 0: kl_divergence_c = kl(Dirichlet(qc_m[x_peri[:,0]==1]*self.alpha), Dirichlet(pc_p[x_peri[:,0]==1])).mean() if (x_peri[:,0] == 0).sum() > 0: # para-dependent if ((x_peri[:,0]==0)&(library[:,0]<torch.quantile(self.win_loglib, 0.2))).sum()>0: kl_divergence_c = kl_divergence_c + 1e-1*kl(Dirichlet(qc_m[(x_peri[:,0]==0)&(library[:,0]<torch.quantile(self.win_loglib, 0.2))]*self.alpha), Dirichlet(pc_p[(x_peri[:,0]==0)&(library[:,0]<torch.quantile(self.win_loglib, 0.2))])).mean() if ((x_peri[:,0]==0)&(library[:,0]>=torch.quantile(self.win_loglib, 0.2))).sum()>0: kl_divergence_c = kl_divergence_c + 1e-2*kl(Dirichlet(qc_m[(x_peri[:,0]==0)&(library[:,0]>=torch.quantile(self.win_loglib, 0.2))]*self.alpha), Dirichlet(pc_p[(x_peri[:,0]==0)&(library[:,0]>=torch.quantile(self.win_loglib, 0.2))])).mean() else: kl_divergence_c = torch.Tensor([0.0]) reconst_loss = -NegBinom(px_rate, torch.exp(px_r)).log_prob(x).sum(-1).mean() reconst_loss = reconst_loss.to(device) kl_divergence_u = kl_divergence_u.to(device) kl_divergence_z = kl_divergence_z.to(device) kl_divergence_c = kl_divergence_c.to(device) kl_divergence_n = kl_divergence_n.to(device) loss = reconst_loss + kl_divergence_z+ kl_divergence_c + kl_divergence_n return (loss, reconst_loss, kl_divergence_u, kl_divergence_z, kl_divergence_c, kl_divergence_n )
[docs]class AVAE_PoE(nn.Module): """ Model design: p(x|z)=f(z) p(z|x)~N(0,1) q(z|x)~g(x) """ def __init__( self, adata, gene_sig, patch_r, win_loglib, ) -> None: """ Auxiliary Variational AutoEncoder (AVAE) with Joint H&E inference - Core model for spatial deconvolution w/ H&E image integration Paramters --------- adata : sc.AnnData ST raw expression count (dim: [S, G]) gene_sig : pd.DataFrame Signature gene sets for each annotated cell type patch_r : int Mini-patch size sampled around each spot from raw H&E image win_loglib : float Log-library size smoothed with neighboring spots """ super(AVAE_PoE, self).__init__() self.win_loglib = torch.Tensor(win_loglib) self.c_in = adata.shape[1] # c_in : Num. input features (# input genes) self.c_bn = 10 # c_bn : latent number, numbers of bottle neck self.c_hidden = 256 self.patch_r = patch_r self.c_kn = gene_sig.shape[1] self.eps = 1e-5 # for r.v. w/ numerical constraints self.MAX = 5000 # DEBUG: for controling `img_ql`, TODO: solve `ql`& `img_ql` issues on line 553-559 self.alpha = torch.nn.Parameter(torch.rand(self.c_kn) * 1e3, requires_grad=True) self.c_enc = nn.Sequential( nn.Linear(self.c_in, self.c_hidden, bias=True), nn.BatchNorm1d(self.c_hidden, momentum=0.01, eps=0.001), nn.ReLU() ) self.c_enc_m = nn.Sequential( nn.Linear(self.c_hidden, self.c_kn, bias=True), nn.BatchNorm1d(self.c_kn, momentum=0.01, eps=0.001), # nn.ReLU(), nn.Softmax() ) self.l_enc = nn.Sequential( nn.Linear(self.c_in, self.c_hidden, bias=True), nn.BatchNorm1d(self.c_hidden, momentum=0.01, eps=0.001), nn.ReLU(), # nn.Linear(self.c_hidden, 1, bias=True), # nn.ReLU(), ) self.l_enc_m = nn.Linear(self.c_hidden, 1) self.l_enc_logv = nn.Linear(self.c_hidden, 1) # neural network f1 to get the z, p(z|x), f1(x,\phi_1)=[z_m,torch.exp(z_logv)] self.z_enc = nn.Sequential( # nn.Linear(self.c_in+self.c_kn, self.c_hidden, bias=True), nn.Linear(self.c_in, self.c_hidden, bias=True), nn.BatchNorm1d(self.c_hidden, momentum=0.01, eps=0.001), nn.ReLU(), ) self.z_enc_m = nn.Linear(self.c_hidden, self.c_bn * self.c_kn) self.z_enc_logv = nn.Linear(self.c_hidden, self.c_bn * self.c_kn) # gene dispersion self.px_r = torch.nn.Parameter(torch.randn(self.c_in), requires_grad=True) # neural network g to get the x_m and x_v, p(x|z), g(z,\phi_3)=[x_m,x_v] self.px_hidden_decoder = nn.Sequential( nn.Linear(self.c_bn, self.c_hidden, bias=True), nn.ReLU(), ) self.px_scale_decoder = nn.Sequential( nn.Linear(self.c_hidden, self.c_in), #nn.Softplus(), nn.ReLU() ) self.px_r_poe = torch.nn.Parameter(torch.randn(self.c_in), requires_grad=True) self.img_l_enc = nn.Sequential( nn.Linear(self.patch_r * self.patch_r * 4 * 3, self.c_hidden, bias=True), nn.BatchNorm1d(self.c_hidden, momentum=0.01, eps=0.001), nn.ReLU(), ) self.img_l_enc_m = nn.Linear(self.c_hidden, 1) self.img_l_enc_logv = nn.Linear(self.c_hidden, 1) self.img_c_enc = nn.Sequential( nn.Linear(self.patch_r * self.patch_r * 4 * 3, self.c_hidden, bias=True), # flatten the images into 1D nn.BatchNorm1d(self.c_hidden, momentum=0.01, eps=0.001), nn.ReLU() ) self.img_c_enc_m = nn.Sequential( nn.Linear(self.c_hidden, self.c_kn, bias=True), nn.BatchNorm1d(self.c_kn, momentum=0.01, eps=0.001), nn.Softmax() ) self.imgVAE_z_enc = nn.Sequential( nn.Linear(self.patch_r * self.patch_r * 4 * 3, self.c_hidden, bias=True), nn.BatchNorm1d(self.c_hidden, momentum=0.01, eps=0.001), nn.ReLU(), ) self.imgVAE_mu = nn.Linear(self.c_hidden, self.c_bn * self.c_kn) self.imgVAE_logvar = nn.Linear(self.c_hidden, self.c_bn * self.c_kn) self.imgVAE_z_fc = nn.Linear(self.c_bn, self.c_hidden) self.imgVAE_dec = nn.Sequential(nn.Linear(self.c_hidden, self.patch_r * self.patch_r * 4 * 3, bias=True), nn.BatchNorm1d(self.patch_r * self.patch_r * 4 * 3, momentum=0.01, eps=0.001), ) # PoE self.POE_z_fc = nn.Linear(self.c_bn, self.c_hidden) # neural network g to get the x_m and x_v, p(x|z), g(z,\phi_3)=[x_m,x_v] self.POE_z_fc = nn.Sequential( nn.Linear(self.c_bn, self.c_hidden, bias=True), nn.ReLU(), ) self.POE_px_scale_decoder = nn.Sequential( nn.Linear(self.c_hidden, self.c_in), #nn.Softplus() nn.ReLU() ) self.POE_dec_img = nn.Sequential( nn.Linear(self.c_hidden, self.patch_r * self.patch_r * 4 * 3, bias=True), nn.BatchNorm1d(self.patch_r * self.patch_r * 4 * 3, momentum=0.01, eps=0.001), # nn.ReLU(), )
[docs] def reparameterize(self, mu, log_var): """ :param mu: mean from the encoder's latent space :param log_var: log variance from the encoder's latent space """ std = torch.exp(0.5 * log_var) # standard deviation eps = torch.randn_like(std) # `randn_like` as we need the same size sample = mu + (eps * std) # sampling return sample
[docs] def inference(self, x): # library = torch.log(x.sum(1)).unsqueeze(1) # l is inferred from logrithmized x x1 = torch.log1p(x) hidden = self.l_enc(x1) ql_m = self.l_enc_m(hidden) ql_logv = self.l_enc_logv(hidden) ql = self.reparameterize(ql_m, ql_logv) # ql = torch.clamp(ql, min=0.1) #with torch.no_grad(): ql = torch.clamp(ql, min=self.eps) # non-negative constraints x_n = torch.log1p(x) hidden = self.c_enc(x_n) qc_m = self.c_enc_m(hidden) qc = Dirichlet(self.alpha * qc_m + self.eps).rsample()[:, :, None] hidden = self.z_enc(x_n) qz_m_ct = self.z_enc_m(hidden).reshape([x1.shape[0], self.c_kn, self.c_bn]) # qz_m_ct = (qc * qz_m_ct) qz_m = (qc * qz_m_ct).sum(axis=1) qz_logv_ct = self.z_enc_logv(hidden).reshape([x1.shape[0], self.c_kn, self.c_bn]) qz_logv = (qc * qz_logv_ct).sum(axis=1) qz = self.reparameterize(qz_m, qz_logv) return dict( qc_m=qc_m, qc=qc, qz_m=qz_m, qz_m_ct=qz_m_ct, qz_logv=qz_logv, qz_logv_ct=qz_logv_ct, qz=qz, ql_m=ql_m, ql_logv=ql_logv, ql=ql )
[docs] def predict_imgVAE(self, x): # x = x * 255 hidden = self.img_l_enc(x) img_ql_m = self.img_l_enc_m(hidden) img_ql_logv = self.img_l_enc_logv(hidden) img_ql = self.reparameterize(img_ql_m, img_ql_logv) # img_ql = torch.clamp(img_ql, min=0.1) img_ql = torch.clamp(img_ql, min=self.eps) # non-negative constraints hidden = self.img_c_enc(x) img_qc_m = self.img_c_enc_m(hidden) img_qc = Dirichlet(self.alpha * img_qc_m + self.eps).rsample()[:, :, None] hidden = self.imgVAE_z_enc(x) img_qz_m_ct = self.imgVAE_mu(hidden).reshape([x.shape[0], self.c_kn, self.c_bn]) # img_qz_m_ct = (img_qc * img_qz_m_ct) img_qz_m = (img_qc * img_qz_m_ct).sum(axis=1) img_qz_logv_ct = self.imgVAE_logvar(hidden).reshape([x.shape[0], self.c_kn, self.c_bn]) # img_qz_logv_ct = (img_qc * img_qz_logv_ct) img_qz_logv = (img_qc * img_qz_logv_ct).sum(axis=1) img_qz = self.reparameterize(img_qz_m, img_qz_logv) hidden = self.imgVAE_z_fc(img_qz) reconstruction = self.imgVAE_dec(hidden) return dict(reconstruction=reconstruction, img_z_mu=img_qz_m, img_z_logv=img_qz_logv, img_qz_m_ct=img_qz_m_ct, img_qz_logv_ct=img_qz_logv_ct, img_qc=img_qc, img_qc_m=img_qc_m, img_ql_m=img_ql_m, img_ql_logv=img_ql_logv, img_ql=img_ql )
[docs] def generative( self, inference_outputs, xs_k, img_path_outputs ): """ xs_k : torch.Tensor Z-normed avg. gene exprs """ xs_k = xs_k.to() qz = inference_outputs['qz'] ql = inference_outputs['ql'] ql_m = inference_outputs['ql_m'] img_ql = img_path_outputs['img_ql'] hidden = self.px_hidden_decoder(qz) px_scale = self.px_scale_decoder(hidden) # TODO: verify whether taking exponential for `ql` term is valid? # px_rate = torch.exp((ql + img_ql) / 2) * px_scale px_rate = torch.clamp( (ql+img_ql)/2 * px_scale, min=self.eps, max=self.MAX ) xs_k = xs_k / torch.exp(ql) * torch.exp(ql.mean(axis=1, keepdims=True)) pc_p = self.alpha * xs_k + self.eps with torch.no_grad(): self.px_r.clamp_(min=self.eps) return dict( px_rate=px_rate, px_r=self.px_r, pc_p=pc_p, xs_k=xs_k, # DEBUG: save for numerical stability checks img_ql=img_ql, px_scale=px_scale, )
[docs] def predictor_POE( self, inference_outputs, exp_path_outputs, img_path_outputs, ): mu_img = img_path_outputs['img_z_mu'] logvar_img = img_path_outputs['img_z_logv'] mu_exp = inference_outputs['qz_m'] logvar_exp = inference_outputs['qz_logv'] ql = inference_outputs['ql'] batch, _ = mu_exp.shape var_poe = torch.div(1., 1 + torch.div(1., torch.exp(logvar_exp)) + torch.div(1., torch.exp(logvar_img)) ) mu_poe = var_poe * (0 + mu_exp * torch.div(1., torch.exp(logvar_exp) + self.eps) + mu_img * torch.div(1., torch.exp(logvar_img) + self.eps) ) z = self.reparameterize(mu_poe, torch.log(var_poe + 0.001)) hidden = self.POE_z_fc(z) px_scale_rna = self.POE_px_scale_decoder(hidden) px_rate_poe = torch.exp(ql) * px_scale_rna px_r_poe = self.px_r_poe # reconstruction_rna_peri = self.POE_dec_rna(torch.cat((x_peri,x_peri_sig),1)) reconstruction_img = self.POE_dec_img(hidden) return dict(px_rate_poe=px_rate_poe, px_r_poe=px_r_poe, reconstruction_img=reconstruction_img, mu_poe=mu_poe, var_poe=torch.log(var_poe + 0.001) )
[docs] def get_loss( self, generative_outputs, inference_outputs, img_path_outputs, poe_path_outputs, x, x_peri, library, adata_img, device ): beta = 0.001 alpha_c = 5 qc_m = inference_outputs["qc_m"] qc = inference_outputs["qc"] qz_m = inference_outputs["qz_m"] qz_logv = inference_outputs["qz_logv"] qz = inference_outputs["qz"] ql_m = inference_outputs["ql_m"] ql_logv = inference_outputs['ql_logv'] ql = inference_outputs['ql'] px_rate = generative_outputs["px_rate"] px_r = generative_outputs["px_r"] pc_p = generative_outputs["pc_p"] px_rate_poe = poe_path_outputs['px_rate_poe'] px_r_poe = poe_path_outputs['px_r_poe'] recon_poe_img = poe_path_outputs['reconstruction_img'] logvar_poe = poe_path_outputs['var_poe'] mu_poe = poe_path_outputs['mu_poe'] criterion = nn.MSELoss(reduction='sum') reconst_loss_poe_rna = -NegBinom(px_rate_poe, torch.exp(px_r_poe)).log_prob(x).sum(-1).mean() bce_loss_poe_img = criterion(recon_poe_img, adata_img) Loss_IBJ = (torch.sum(reconst_loss_poe_rna) + bce_loss_poe_img) + \ beta * (-0.5 * torch.sum(1 + logvar_poe - mu_poe.pow(2) - logvar_poe.exp())) mean = torch.zeros_like(qz_m) scale = torch.ones_like(qz_logv) kl_divergence_z = kl(Normal(qz_m, torch.sqrt(torch.exp(qz_logv))), Normal(mean, scale)).sum( dim=1 ).mean() kl_divergence_n = kl(Normal(ql_m, torch.sqrt(torch.exp(ql_logv))), Normal(library, torch.ones_like(ql))).sum( dim=1 ).mean() if (x_peri[:, 0] == 1).sum() > 0: kl_divergence_c = kl(Dirichlet(qc_m[x_peri[:, 0] == 1] * self.alpha), Dirichlet(pc_p[x_peri[:, 0] == 1])).mean() if (x_peri[:, 0] == 0).sum() > 0: if ((x_peri[:, 0] == 0) & (library[:, 0] < torch.quantile(self.win_loglib, 0.2))).sum() > 0: kl_divergence_c = kl_divergence_c + 1e-1 * kl(Dirichlet(qc_m[(x_peri[:, 0] == 0) & ( library[:, 0] < torch.quantile(self.win_loglib, 0.2))] * self.alpha), Dirichlet( pc_p[(x_peri[:, 0] == 0) & (library[:, 0] < torch.quantile(self.win_loglib, 0.2))])).mean() if ((x_peri[:, 0] == 0) & (library[:, 0] >= torch.quantile(self.win_loglib, 0.2))).sum() > 0: kl_divergence_c = kl_divergence_c + 1e-2 * kl(Dirichlet(qc_m[(x_peri[:, 0] == 0) & ( library[:, 0] >= torch.quantile(self.win_loglib, 0.2))] * self.alpha), Dirichlet( pc_p[(x_peri[:, 0] == 0) & (library[:, 0] >= torch.quantile(self.win_loglib, 0.2))])).mean() else: kl_divergence_c = torch.Tensor([0.0]) reconst_loss = -NegBinom(px_rate, torch.exp(px_r)).log_prob(x).sum(-1).mean() loss_exp = reconst_loss.to(device) + kl_divergence_z.to(device) + kl_divergence_c.to(device) + kl_divergence_n.to(device) mu_img = img_path_outputs['img_z_mu'] logvar_img = img_path_outputs['img_z_logv'] recon_img = img_path_outputs['reconstruction'] img_qc = img_path_outputs['img_qc'] img_ql = img_path_outputs['img_ql'] img_ql_m = img_path_outputs['img_ql_m'] img_ql_logv = img_path_outputs['img_ql_logv'] img_qc_m = img_path_outputs['img_qc_m'] kl_divergence_n_img = kl(Normal(img_ql_m, torch.sqrt(torch.exp(img_ql_logv))), Normal(library, torch.ones_like(ql))).sum( dim=1 ).mean() kl_divergence_z_img = kl(Normal(mu_img, torch.sqrt(torch.exp(logvar_img))), Normal(mean, scale)).sum( dim=1 ) if (x_peri[:, 0] == 1).sum() > 0: kl_divergence_c_img = kl(Dirichlet(img_qc_m[x_peri[:, 0] == 1] * self.alpha), Dirichlet(pc_p[x_peri[:, 0] == 1])).mean() else: kl_divergence_c_img = torch.Tensor([0.0]) bce_loss_img = criterion(recon_img, adata_img) bce_loss_img = bce_loss_img.to(device) kl_divergence_z_img = kl_divergence_z_img.to(device) kl_divergence_n_img = kl_divergence_n_img.to(device) kl_divergence_c_img = kl_divergence_c_img.to(device) loss_img = torch.sum(bce_loss_img + kl_divergence_z_img + kl_divergence_n_img + kl_divergence_c_img) Loss_IBM = (loss_exp + loss_img) loss = Loss_IBJ + alpha_c * Loss_IBM return (loss, reconst_loss, kl_divergence_z, kl_divergence_c, kl_divergence_n )
[docs]def valid_model(model): model.eval() x_valid = torch.Tensor(np.array(adata_sample_filter.to_df())) x_valid = x_valid.to(device) gene_sig_exp_valid = torch.Tensor(np.array(gene_sig_exp_m)).to(device) library = torch.log(x_valid.sum(1)).unsqueeze(1) inference_outputs = model.inference(x_valid) generative_outputs = model.generative(inference_outputs, gene_sig_exp_valid) qz_m = inference_outputs["qz_m"].detach().numpy() qc_m = inference_outputs["qc_m"].detach().numpy() qc = inference_outputs["qc"].detach().numpy() qz_logv = inference_outputs["qz_logv"].detach().numpy() qz = inference_outputs["qz"].detach().numpy() px_r = generative_outputs["px_r"].detach().numpy() pc_p = generative_outputs["pc_p"].detach().numpy() px_rate = generative_outputs["px_rate"].detach().numpy() ql = inference_outputs["ql"].detach().numpy() ql_m = inference_outputs["ql_m"].detach().numpy() px = NegBinom(generative_outputs["px_rate"], torch.exp(generative_outputs["px_r"])).sample().detach().numpy() corr_map_qcm = np.zeros([3,3]) #corr_map_genesig = np.zeros([3,3]) #for i in range(3): # for j in range(3): # corr_map_qcm[i,j], _ = pearsonr(qc_m[:,i], proportions.iloc[:,j]) #corr_map_genesig[i,j], _ = pearsonr(gene_sig_exp_m.iloc[:,i], proportions.iloc[:,j]) return 1/3*(corr_map_qcm[0,0]+corr_map_qcm[1,1]+corr_map_qcm[2,2])
[docs]def train( model, dataloader, device, optimizer, ): model.train() running_loss = 0.0 running_u = 0.0 running_z = 0.0 running_c = 0.0 running_n = 0.0 running_reconst = 0.0 counter = 0 corr_list = [] for i, (x, xs_k, x_peri, library_i) in enumerate(dataloader): counter +=1 x = x.float() x = x.to(device) xs_k = xs_k.to(device) x_peri = x_peri.to(device) library_i = library_i.to(device) inference_outputs = model.inference(x) generative_outputs = model.generative(inference_outputs, xs_k) (loss, reconst_loss, kl_divergence_u, kl_divergence_z, kl_divergence_c, kl_divergence_n ) = model.get_loss(generative_outputs, inference_outputs, x, x_peri, library_i, device ) optimizer.zero_grad() loss.backward() running_loss += loss.item() running_reconst +=reconst_loss.item() running_u +=kl_divergence_u.item() running_z +=kl_divergence_z.item() running_c +=kl_divergence_c.item() running_n +=kl_divergence_n.item() optimizer.step() train_loss = running_loss / counter train_reconst = running_reconst / counter train_u = running_u / counter train_z = running_z / counter train_c = running_c / counter train_n = running_n / counter return train_loss, train_reconst, train_u, train_z, train_c, train_n, corr_list
[docs]def train_poe( model, dataloader, device, optimizer, ): # TODO: add `u` in PoE integration model.train() running_loss = 0.0 running_z = 0.0 running_c = 0.0 running_n = 0.0 running_reconst = 0.0 counter = 0 corr_list = [] for i, (x, x_peri, library_i, adata_img, data_loc, xs_k, ) in enumerate(dataloader): counter += 1 x = x.float() x = x.to(device) x_peri = x_peri.to(device) library_i = library_i.to(device) xs_k = xs_k.to(device) mini_batch, _ = x.shape adata_img = adata_img.reshape(mini_batch, -1).float() adata_img = adata_img / 255 adata_img = adata_img.to(device) # gene expression, 1D data inference_outputs = model.inference(x) # image, 2D data img_path_outputs = model.predict_imgVAE(adata_img) generative_outputs = model.generative(inference_outputs, xs_k, img_path_outputs) # POE poe_path_outputs = model.predictor_POE(inference_outputs, generative_outputs, img_path_outputs ) (loss, reconst_loss, kl_divergence_z, kl_divergence_c, kl_divergence_n ) = model.get_loss(generative_outputs, inference_outputs, img_path_outputs, poe_path_outputs, x, x_peri, library_i, adata_img, device ) optimizer.zero_grad() loss.backward() running_loss += loss.item() running_reconst += reconst_loss.item() running_z += kl_divergence_z.item() running_c += kl_divergence_c.item() running_n += kl_divergence_n.item() optimizer.step() # corr_list.append(valid_model(model)) train_loss = running_loss / counter train_reconst = running_reconst / counter train_u = 0 # TMP, to be deleted after updating PoE model with u->c->z->x train_z = running_z / counter train_c = running_c / counter train_n = running_n / counter return train_loss, train_reconst, train_u, train_z, train_c, train_n, corr_list
# Reference: # https://github.com/YosefLab/scvi-tools/blob/master/scvi/distributions/_negative_binomial.py
[docs]class NegBinom(Distribution): """ Gamma-Poisson mixture approximation of Negative Binomial(mean, dispersion) lambda ~ Gamma(mu, theta) x ~ Poisson(lambda) """ arg_constraints = { 'mu': constraints.greater_than_eq(0), 'theta': constraints.greater_than_eq(0), } support = constraints.nonnegative_integer def __init__(self, mu, theta, eps=1e-10): """ Parameters ---------- mu : torch.Tensor mean of NegBinom. distribution shape - [# genes,] theta : torch.Tensor dispersion of NegBinom. distribution shape - [# genes,] """ self.mu = mu self.theta = theta self.eps = eps super(NegBinom, self).__init__(validate_args=True)
[docs] def sample(self): lambdas = Gamma( concentration=self.theta+self.eps, rate=(self.theta+self.eps) / (self.mu+self.eps), ).rsample() x = Poisson(lambdas).sample() return x
[docs] def log_prob(self, x): """log-likelihood""" ll = torch.lgamma(x + self.theta) - \ torch.lgamma(x + 1) - \ torch.lgamma(self.theta) + \ self.theta * (torch.log(self.theta + self.eps) - torch.log(self.theta + self.mu + self.eps)) + \ x * (torch.log(self.mu + self.eps) - torch.log(self.theta + self.mu + self.eps)) return ll
[docs]def model_eval( model, adata, visium_args, poe=False, device=torch.device('cpu') ): # TODO: solve NegBinom learnt params (inf. `px_rate` & negative `px_r`) in PoE --> unable to sample `px` # For now, only store `px` with non-PoE case model.eval() model = model.to(device) x_in = torch.Tensor(adata.to_df().values).to(device) sig_means = torch.Tensor(visium_args.sig_mean_znorm.values).to(device) inference_outputs = model.inference(x_in) if poe: img_in = torch.Tensor(visium_args.get_img_patches()).float().to(device) img_outputs = model.predict_imgVAE(img_in) generative_outputs = model.generative(inference_outputs, sig_means, img_outputs) else: generative_outputs = model.generative(inference_outputs, sig_means) try: px = NegBinom( mu=generative_outputs["px_rate"], theta=torch.exp(generative_outputs["px_r"]) ).sample().detach().cpu().numpy() adata.obsm['px'] = px except ValueError as ve: LOGGER.warning('Invalid Gamma distribution parameters `px_rate` or `px_r`, unable to sample inferred p(x | z)') # Save inference & generative outputs in adata for rv in inference_outputs.keys(): val = inference_outputs[rv].detach().cpu().numpy().squeeze() if "qu" not in rv and "qs" not in rv: adata.obsm[rv] = val else: adata.uns[rv] = val for rv in generative_outputs.keys(): try: if rv == 'px_r' or rv == 'reconstruction': # Posterior avg. znorm signature means val = generative_outputs[rv].data.detach().cpu().numpy().squeeze() adata.varm[rv] = val else: val = generative_outputs[rv].data.detach().cpu().numpy().squeeze() adata.obsm[rv] = val except: print("rv: {} can't be stored".format(rv)) return inference_outputs, generative_outputs
[docs]def model_ct_exp( model, adata, visium_args, poe=False, device=torch.device('cpu') ): """ Obtain predicted cell-type specific expression in each spot """ model.eval() model = model.to(device) pred_exprs = {} for ct_idx, cell_type in enumerate(adata.uns['cell_types']): model.eval() x_in = torch.Tensor(adata.to_df().values).to(device) sig_means = torch.Tensor(visium_args.sig_mean_znorm.values).to(device) # Get inference outputs inference_outputs = model.inference(x_in) inference_outputs['qz'] = inference_outputs['qz_m_ct'][:, ct_idx, :] # Get generative outputs if poe: img_outputs = model.predict_imgVAE( torch.Tensor( visium_args.get_img_patches() ).float().to(device) ) generative_outputs = model.generative(inference_outputs, sig_means, img_outputs) else: generative_outputs = model.generative(inference_outputs, sig_means) px = NegBinom( mu=generative_outputs["px_rate"], theta=torch.exp(generative_outputs["px_r"]) ).sample() px = px.detach().cpu().numpy() # Save results in adata.obsm px_df = pd.DataFrame(px, index=adata.obs_names, columns=adata.var_names) pred_exprs[cell_type] = px_df adata.obsm[cell_type + '_inferred_exprs'] = px return pred_exprs