Source code for starfysh.dataloader

import numpy as np
import pandas as pd
import logging
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset

# from torchvision import transforms


[docs]class VisiumDataset(Dataset): """ Loading preprocessed Visium AnnData, gene signature & Anchor spots for Starfysh training """ def __init__( self, adata, args, ): spots = adata.obs_names genes = adata.var_names x = adata.X if isinstance(adata.X, np.ndarray) else adata.X.A self.expr_mat = pd.DataFrame(x, index=spots, columns=genes) self.gexp = args.sig_mean_znorm self.anchor_idx = args.pure_idx self.library_n = args.win_loglib def __len__(self): return len(self.expr_mat) def __getitem__(self, idx): if torch.is_tensor(idx): idx = idx.tolist() sample = torch.Tensor( np.array(self.expr_mat.iloc[idx, :], dtype='float') ) return (sample, torch.Tensor(self.gexp.iloc[idx, :]), torch.Tensor(self.anchor_idx[idx, :]), torch.Tensor(self.library_n[idx,None]), )
[docs]class VisiumPoEDataSet(VisiumDataset): """ return the data stack with expression and image """ def __init__( self, adata, args, ): super(VisiumPoEDataSet, self).__init__(adata, args) self.image = args.img self.map_info = args.map_info self.patch_r = args.params['patch_r'] self.spot_img_stack = [] assert self.image is not None,\ "Empty paired H&E image," \ "please use regular `Starfysh` without PoE integration" \ "if your dataset doesn't contain histology image" for i in range(len(self.expr_mat)): img_xmin = int(self.map_info.iloc[i]['imagecol'])-self.patch_r img_xmax = int(self.map_info.iloc[i]['imagecol'])+self.patch_r img_ymin = int(self.map_info.iloc[i]['imagerow'])-self.patch_r img_ymax = int(self.map_info.iloc[i]['imagerow'])+self.patch_r self.spot_img_stack.append(self.image[img_ymin:img_ymax,img_xmin:img_xmax]) def __len__(self): return len(self.expr_mat) def __getitem__(self, idx): if torch.is_tensor(idx): idx = idx.tolist() sample = torch.Tensor( np.array(self.expr_mat.iloc[idx, :], dtype='float') ) return (sample, torch.Tensor(self.anchor_idx[idx, :]), torch.Tensor(self.library_n[idx, None]), self.spot_img_stack[idx], self.map_info.index[idx], torch.Tensor(self.gexp.iloc[idx, :]), )