Source code for starfysh.dataloader

import numpy as np
import pandas as pd
import logging
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset

# from torchvision import transforms


[docs]class VisiumDataset(Dataset): """ Loading preprocessed Visium AnnData, gene signature & Anchor spots for Starfysh training """ def __init__( self, adata, args, ): spots = adata.obs_names genes = adata.var_names x = adata.X if isinstance(adata.X, np.ndarray) else adata.X.A self.expr_mat = pd.DataFrame(x, index=spots, columns=genes) self.gexp = args.sig_mean_znorm self.anchor_idx = args.pure_idx self.library_n = args.win_loglib def __len__(self): return len(self.expr_mat) def __getitem__(self, idx): if torch.is_tensor(idx): idx = idx.tolist() sample = torch.Tensor( np.array(self.expr_mat.iloc[idx, :], dtype='float') ) return (sample, torch.Tensor(self.gexp.iloc[idx, :]), # z-normed signature exprs torch.Tensor(self.anchor_idx[idx, :]), # anchors torch.Tensor(self.library_n[idx,None]), # library size )
[docs]class VisiumPoEDataSet(VisiumDataset): """ return the data stack with expression and image """ def __init__( self, adata, args, ): super(VisiumPoEDataSet, self).__init__(adata, args) self.image = args.img self.map_info = args.map_info self.r = args.params['patch_r'] self.spot_img_stack = [] assert self.image is not None,\ "Empty paired H&E image," \ "please use regular `Starfysh` without PoE integration" \ "if your dataset doesn't contain histology image" # Retrieve image patch around each spot scalef = args.scalefactor['tissue_hires_scalef'] # High-res scale factor h, w, d = self.image.shape for i in range(len(self.expr_mat)): xc = int(np.round(self.map_info.iloc[i]['imagecol'] * scalef)) yc = int(np.round(self.map_info.iloc[i]['imagerow'] * scalef)) # boundary conditions: edge spots yl, yr = max(0, yc-self.r), min(self.image.shape[0], yc+self.r) xl, xr = max(0, xc-self.r), min(self.image.shape[1], xc+self.r) top = max(0, self.r-yc) bottom = h if h > (yc+self.r) else h-(yc+self.r) left = max(0, self.r-xc) right = w if w > (xc+self.r) else w-(xc+self.r) patch = np.zeros((self.r*2, self.r*2, d)) patch[top:bottom, left:right] = self.image[yl:yr, xl:xr] self.spot_img_stack.append(patch) def __len__(self): return len(self.expr_mat) def __getitem__(self, idx): if torch.is_tensor(idx): idx = idx.tolist() sample = torch.Tensor( np.array(self.expr_mat.iloc[idx, :], dtype='float') ) return (sample, torch.Tensor(self.anchor_idx[idx, :]), torch.Tensor(self.library_n[idx, None]), self.spot_img_stack[idx], self.map_info.index[idx], torch.Tensor(self.gexp.iloc[idx, :]), )