import os
import numpy as np
import pandas as pd
import scanpy as sc
import umap
import skdim
import matplotlib.pyplot as plt
import seaborn as sns
from py_pcha import PCHA
from matplotlib.pyplot import cm
from scipy.spatial.distance import cdist, euclidean
from sklearn.neighbors import NearestNeighbors
from starfysh import LOGGER
[docs]class ArchetypalAnalysis:
# Todo: implement non-linear archetype analysis with VAE, compare explanability with linear implementation
def __init__(
self,
adata_orig,
u=None,
u_3d=None,
verbose=True,
outdir=None,
filename=None,
savefig=False,
):
self.adata = adata_orig.copy()
# Perform dim-reduction with PCA, select the first 30 PCs
sc.pp.pca(self.adata)
self.count = self.adata.obsm['X_pca'][:, :30]
# self.count = self.adata.X.A if isinstance(self.adata.X, sparse.csr_matrix) else self.adata.X
self.n_spots = self.count.shape[0]
self.verbose = verbose
self.outdir = outdir
self.filename = filename
self.savefig = savefig
self.archetype = None
self.major_archetype = None
self.major_idx = None
self.arche_dict = None
self.arche_df = None
self.kmin = 0
self.U = u
self.U_3d = u_3d
[docs] def compute_archetypes(
self,
cn=30,
n_iters=20,
converge=1e-3,
r=20,
display=False
):
"""
Estimate the upper bound of archetype count (k) by calculating intrinsic dimension
Compute hierarchical archetypes (major + raw) with given granularity
Parameters
----------
cn : int
Conditional Number to choose PCs for intrinsic estimator as
lower bound # archetype estimation. Please refer to:
https://scikit-dimension.readthedocs.io/en/latest/skdim.id.FisherS.html#skdim.id.FisherS
n_iters : int
Max. # iterations of AA to find the best k estimation
converge : int
Convergence criteria for AA iteration with diff(explained variance)
` r : int
Resolution parameter to control granularity of major archetypes
If two archetypes reside within r nearest neighbors, the latter
one will be merged.
display : bool
Whether to display Intrinsic Dimension (ID) estimation plots
Returns
-------
archetype : np.ndarray (dim=[K, G])
Raw archetypes as linear combination of subset of spot counts
arche_dict : dict
Hierarchical structure of major_archetype -> its fine-grained neighbor archetypes
major_idx : int
Index of major archetypes among `k` raw candidates after merging\
evs : list
Explained variance with different Ks
"""
# TMP: across-sample comparison: fix # principle components for all samples
if self.verbose:
LOGGER.info('Computing intrinsic dimension to estimate k...')
# Estimate ID
conditional_num = cn
id_model = skdim.id.FisherS(conditional_number=conditional_num,
produce_plots=display,
verbose=self.verbose)
self.kmin = max(1, int(id_model.fit(self.count).dimension_))
# Compute raw archetypes
if self.verbose:
LOGGER.info('Estimating lower bound of # archetype as {0}...'.format(self.kmin))
X = self.count.T
archetypes = []
evs = []
# TODO: speedup with multiprocessing
for i, k in enumerate(range(self.kmin, self.kmin+n_iters)):
archetype, _, _, _, ev = PCHA(X, noc=k, delta=0.1)
evs.append(ev)
archetypes.append(np.array(archetype).T)
if i > 0 and ev - evs[i-1] < converge:
break
self.archetype = archetypes[-1]
if self.U is None:
self.U = self._get_umap(ndim=2)
if self.U_3d is None:
self.U_3d = self._get_umap(ndim=3)
# Merge raw archetypes to get major archetypes
if self.verbose:
LOGGER.info('{0} variance explained by raw archetypes.\nMerging raw archetypes within {1} NNs to get major archetypes'.format(np.round(ev, 4), r))
arche_dict, major_idx = self._merge_archetypes(r)
self.major_archetype = self.archetype[major_idx]
self.major_idx = np.array(major_idx)
self.arche_dict = arche_dict
# temp: return all archetypes for Silhouette score calculation
return archetypes, arche_dict, major_idx, evs
def _merge_archetypes(self, r):
"""
Merge raw archetypes into major ones by removing candidate with `r`-step distance
from its previous identified neighbors
"""
assert self.archetype is not None, "Please compute archetypes first!"
n_archetypes = self.archetype.shape[0]
X_concat = np.vstack([self.count, self.archetype])
nbrs = NearestNeighbors(n_neighbors=r).fit(X_concat)
nn_graph = nbrs.kneighbors(X_concat)[1][self.n_spots:, 1:] # retrieve NN-graph of only archetype spots
idxs_to_remove = set()
arche_dict = {}
for i in range(n_archetypes):
if i not in idxs_to_remove:
query = np.arange(self.n_spots+i, self.n_spots+n_archetypes)
nbrs = np.setdiff1d(
nn_graph[i][np.isin(nn_graph[i], query)] - self.n_spots,
list(idxs_to_remove) # avoid over-assign merged archetypes to multiple major archetypes
)
if len(nbrs) != 0:
arche_dict[i] = np.insert(nbrs, 0, i)
idxs_to_remove.update(nbrs)
major_idx = np.setdiff1d(np.arange(n_archetypes), list(idxs_to_remove))
return arche_dict, major_idx
[docs] def find_archetypal_spots(self, n_neighbors=20, major=True):
"""
Assign N-nearest-neighbor spots to each archetype as `archetypal spots` (archetype community)
Parameters
----------
n_neighbors : int (default=40)
N nearest neighbors of each archetype for archetypal spots
major : bool
Whether to find NNs for only major archetypes
Returns
-------
arche_df : pd.DataFrame
Dataframe of archetypal spots
"""
assert self.archetype is not None, "Please compute archetypes first!"
if self.verbose:
LOGGER.info('Finding {} nearest neighbors for each archetype...'.format(n_neighbors))
nbr_dict = {}
indices = self.major_idx if major else np.arange(self.archetype.shape[0])
for i in indices:
v = self.archetype[i]
X_concat = np.vstack([self.count, v])
nbrs = NearestNeighbors(n_neighbors=n_neighbors+1).fit(X_concat)
nn_graph = nbrs.kneighbors(X_concat)[1][-1, 1:] # find NNs of archetype `v`
nbr_dict['arch_{}'.format(i)] = nn_graph
self.arche_df = pd.DataFrame(nbr_dict)
return self.arche_df
[docs] def find_markers(self, n_markers=30, display=False):
"""
Find marker genes for each archetype community via Wilcoxon rank sum test (in-group vs. out-of-group)
Parameters
----------
n_markers : int
Number of top marker genes to find for each archetype community
Returns
-------
marker_df : pd.DataFrame
Dataframe of marker genes for each archetype community
"""
assert self.arche_df is not None, "Please compute archetypes & assign nearest-neighbors first!"
if self.verbose:
LOGGER.info('Finding {} top marker genes for each archetype...'.format(n_markers))
adata = self.adata.copy()
markers = []
for col in self.arche_df.columns:
# Annotate in-group (current archetype) vs. out-of-group
annots = np.zeros(self.n_spots, dtype=np.int64).astype(str)
annots[self.arche_df[col]] = col
adata.obs[col] = annots
adata.obs[col] = adata.obs[col].astype('category')
# Identify marker genes
sc.tl.rank_genes_groups(adata, col, use_raw=False, method='wilcoxon')
markers.append(adata.uns['rank_genes_groups']['names'][col][:n_markers])
if display:
plt.rcParams['figure.figsize'] = (8, 3)
plt.rcParams['figure.dpi'] = 300
sc.pl.rank_genes_groups_violin(adata, groups=[col], n_genes=n_markers)
return pd.DataFrame(np.stack(markers, axis=1), columns=self.arche_df.columns)
[docs] def assign_archetypes(self, anchor_df, threshold=.20):
"""
Assign best 1-1 mapping of archetype community to its closest anchor community (cell-type specific anchor spots)
With spot overlapping ratio >= threshold
Parameters
----------
anchor_df : pd.DataFrame
Dataframe of anchor spot indices
threshold : float
Threshold to determine anchor-archetype mapping
Returns
-------
map_df : pd.DataFrame
DataFrame of overlapping spot ratio of each anchor `i` to archetype `j`
map_dict : dict
Dictionary of cell type -> mapped archetype
"""
assert self.arche_df is not None, "Please compute archetypes & assign nearest-neighbors first!"
n_nbrs, n_archetypes = self.arche_df.shape
n_cell_types = anchor_df.shape[1]
map_ratio = np.zeros((n_cell_types, n_archetypes))
for i, cell_type in enumerate(anchor_df.columns):
for j, arche_label in enumerate(self.arche_df.columns):
n_overlap = len(set(anchor_df[cell_type]).intersection(set(self.arche_df[arche_label])))
map_ratio[i, j] = n_overlap / n_nbrs
match_idx = map_ratio.argmax(1)
map_df = pd.DataFrame(
map_ratio,
index=anchor_df.columns,
columns=self.arche_df.columns
)
map_dict = {
anchor_df.columns[k]: self.arche_df.columns[v]
for (k, v) in enumerate(match_idx)
if map_df.iloc[k, v] >= threshold
}
return map_df, map_dict
[docs] def find_distant_archetypes(self, anchor_df, map_dict=None, n=3):
"""
Sort and return top n archetypes that are unmapped and farthest from anchor spots of know cell types
They are more likely to represent novel cell types / states
Parameters
----------
anchor_df : pd.DataFrame
Dataframe of anchor spot indices
map_dict : dict
Dictionary of cell type -> mapped archetype
n : int
Number of distant archetypes to return
Returns
-------
distant_archetypes : list
List of archetype labels (farthest --> closest to anchors)
"""
assert self.arche_df is not None, "Please compute archetypes & assign nearest-neighbors first!"
cell_types = anchor_df.columns
arche_lbls = self.arche_df.columns
# Find the unmapped archetypes
if map_dict is None:
_, map_dict = self.assign_archetypes(anchor_df=anchor_df)
unmapped_archetypes = np.setdiff1d(
arche_lbls,
list(set([v for k, v in map_dict.items()]))
)
# Sort unmapped archetypes in descending orders with avg. distance to its 2 closest anchor spot centroid
if n > len(unmapped_archetypes):
LOGGER.warning('Insufficient candidates to find {0} distant archetypes\nSet n={1}'.format(
n, len(unmapped_archetypes)
))
anchor_centroids = self.count[anchor_df[anchor_df.columns]].mean(0)
arche_centroids = self.count[self.arche_df[self.arche_df.columns]].mean(0)
dist_df = pd.DataFrame(
cdist(anchor_centroids, arche_centroids),
index=cell_types,
columns=arche_lbls
)
dist_unmapped = dist_df[unmapped_archetypes].values # subset only distance to `unmapped` archetypes
dist_to_nbrs = np.sort(dist_unmapped, axis=0)[:2].mean(0)
distant_arches = [unmapped_archetypes[idx] for idx in np.argsort(-dist_to_nbrs)][:n] # dist - Discending order
return distant_arches
# -------------------
# Plotting functions
# -------------------
def _get_umap(self, ndim=2, random_state=42):
assert ndim == 2 or ndim == 3, "Invalid dimension for UMAP: {}".format(ndim)
LOGGER.info('Calculating UMAPs for counts + Archetypes...')
reducer = umap.UMAP(n_components=ndim, random_state=random_state)
U = reducer.fit_transform(np.vstack([self.count, self.archetype]))
return U
def _save_fig(self, fig, lgds, default_name):
filename = self.filename if self.filename is not None else default_name
if not os.path.exists(self.outdir):
os.makedirs(self.outdir)
fig.savefig(
os.path.join(self.outdir, filename+'.svg'),
bbox_extra_artists=lgds, bbox_inches='tight',
format='svg'
)
[docs] def plot_archetypes(
self,
major=True,
do_3d=False,
lgd_ncol=1,
figsize=(6, 4),
disp_cluster=True,
disp_arche=True
):
"""
Display archetype & archetypal spot communities
"""
assert self.arche_df is not None, "Please compute archetypes & assign nearest-neighbors first!"
n_archetypes = self.arche_df.shape[1]
arche_indices = self.major_idx if major else np.arange(n_archetypes)
U = self.U_3d if do_3d else self.U
colors = cm.tab20(np.linspace(0, 1, n_archetypes))
if do_3d:
fig, ax = plt.subplots(1, 1, figsize=figsize, dpi=200, subplot_kw=dict(projection='3d'))
# Color background spots & archetypal spots
ax.scatter(
U[:self.n_spots, 0],
U[:self.n_spots, 1],
U[:self.n_spots, 2],
s=1, alpha=0.7, linewidth=.3,
edgecolors='black', c='lightgray'
)
if disp_cluster:
for i, label in enumerate(self.arche_df.columns):
lbl = int(label.split('_')[-1])
if lbl in arche_indices:
idxs = self.arche_df[label]
ax.scatter(
U[idxs, 0],
U[idxs, 1],
U[idxs, 2],
marker='o', s=3,
color=colors[i], label=label
)
# Highlight archetype
if disp_arche:
ax.scatter(
U[self.n_spots+arche_indices, 0],
U[self.n_spots+arche_indices, 1],
U[self.n_spots+arche_indices, 2],
s=10, c='blue', marker='^'
)
for j, z in zip(arche_indices, U[self.n_spots+arche_indices]):
ax.text(z[0], z[1], z[2], str(j), fontsize=10, c='blue')
lgd = ax.legend(loc='right', bbox_to_anchor=(0.5, 0, 1, 0.5), ncol=lgd_ncol)
ax.grid(False)
ax.set_xlabel('UMAP1')
ax.set_ylabel('UMAP2')
ax.set_zlabel('UMAP3')
ax.set_xticklabels([])
ax.set_yticklabels([])
ax.set_zticklabels([])
ax.xaxis.pane.set_edgecolor('black')
ax.yaxis.pane.set_edgecolor('black')
ax.set_xticks([])
ax.set_yticks([])
ax.set_zticks([])
ax.xaxis.pane.fill = False
ax.yaxis.pane.fill = False
ax.zaxis.pane.fill = False
ax.view_init(20, 135)
else: # 2D plot
fig, ax = plt.subplots(1, 1, figsize=figsize, dpi=300)
# Color background & archetypal spots
ax.scatter(
U[:self.n_spots, 0],
U[:self.n_spots, 1],
alpha=1, s=1, color='lightgray')
if disp_cluster:
for i, label in enumerate(self.arche_df.columns):
lbl = int(label.split('_')[-1])
if lbl in arche_indices:
idxs = self.arche_df[label]
ax.scatter(
U[idxs, 0],
U[idxs, 1],
marker='o', s=3,
color=colors[i], label=label
)
if disp_arche:
ax.scatter(
U[self.n_spots+arche_indices, 0],
U[self.n_spots+arche_indices, 1],
s=10, c='blue', marker='^'
)
for j, z in zip(arche_indices, U[self.n_spots+arche_indices]):
ax.text(z[0], z[1], str(j), fontsize=10, c='blue')
lgd = ax.legend(loc='right', bbox_to_anchor=(2, 0.5), ncol=lgd_ncol)
ax.grid(False)
ax.axis('off')
if self.savefig and self.outdir is not None:
self._save_fig(fig, (lgd,), 'archetypes')
return fig, ax
[docs] def plot_anchor_archetype_clusters(
self,
anchor_df,
cell_types=None,
arche_lbls=None,
lgd_ncol=2,
do_3d=False
):
"""
Joint display subset of anchor spots & archetypal spots (to visualize overlapping degree)
"""
assert self.arche_df is not None, "Please compute archetypes & assign nearest-neighbors first!"
U = self.U_3d if do_3d else self.U
cell_types = anchor_df.columns if cell_types is None else np.intersect1d(cell_types, anchor_df.columns)
arche_lbls = self.arche_df.columns if arche_lbls is None else np.intersect1d(arche_lbls, self.arche_df.columns)
u_centroids = U[self.arche_df[arche_lbls]].mean(0)
anchor_colors = cm.RdBu_r(np.linspace(0, 1, len(cell_types)))
arche_colors = cm.RdBu_r(np.linspace(0, 1, len(arche_lbls)))
if do_3d:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 5), dpi=300, subplot_kw=dict(projection='3d'))
# Display anchors
ax1.scatter(
U[:self.n_spots, 0],
U[:self.n_spots, 1],
U[:self.n_spots, 2],
c='gray', marker='.', s=1, alpha=0.2
)
for c, label in zip(anchor_colors, cell_types):
idxs = anchor_df[label]
ax1.scatter(
U[idxs, 0],
U[idxs, 1],
U[idxs, 2],
color=c, marker='^', s=5,
alpha=0.9, label=label
)
ax1.grid(False)
ax1.set_xticklabels([])
ax1.set_yticklabels([])
ax1.set_zticklabels([])
ax1.view_init(30, 45)
lgd1 = ax1.legend(loc='lower center', bbox_to_anchor=(0.5, -1), ncol=lgd_ncol)
# Display archetypal spots
ax2.scatter(U[:self.n_spots, 0], U[:self.n_spots, 1], U[:self.n_spots, 2], c='gray', marker='.', s=1, alpha=0.2)
for c, label in zip(arche_colors, arche_lbls):
idxs = self.arche_df[label]
ax2.scatter(U[idxs, 0], U[idxs, 1], U[idxs, 2], color=c, marker='o', s=3, alpha=0.9, label=label)
# Highlight selected archetypes
for label, z in zip(arche_lbls, u_centroids):
idx = int(label.split('_')[-1])
ax2.text(z[0], z[1], z[2], str(idx))
ax2.grid(False)
ax2.set_xticklabels([])
ax2.set_yticklabels([])
ax2.set_zticklabels([])
ax2.view_init(30, 45)
lgd2 = ax2.legend(loc='lower center', bbox_to_anchor=(0.5, -1), ncol=lgd_ncol)
else:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(9, 3), dpi=300)
# Display anchors
ax1.scatter(
U[:self.n_spots, 0],
U[:self.n_spots, 1],
c='gray', marker='.', s=1, alpha=0.2
)
for c, label in zip(anchor_colors, cell_types):
idxs = anchor_df[label]
ax1.scatter(
U[idxs, 0],
U[idxs, 1],
color=c, marker='^', s=5,
alpha=0.9, label=label
)
lgd1 = ax1.legend(loc='lower center', bbox_to_anchor=(0.5, -1.75), ncol=lgd_ncol)
# Display archetypal spots
ax2.scatter(U[:self.n_spots, 0], U[:self.n_spots, 1], c='gray', marker='.', s=1, alpha=0.2)
for c, label in zip(arche_colors, arche_lbls):
idxs = self.arche_df[label]
ax2.scatter(
U[idxs, 0],
U[idxs, 1],
color=c, marker='o', s=3,
alpha=0.9, label=label
)
# Highlight selected archetypes
for label, z in zip(arche_lbls, u_centroids):
idx = int(label.split('_')[-1])
ax2.text(z[0], z[1], str(idx))
lgd2 = ax2.legend(loc='lower center', bbox_to_anchor=(0.5, -1.85), ncol=lgd_ncol)
if self.savefig and self.outdir is not None:
self._save_fig(fig, (lgd1, lgd2), 'anchor_archetypal_spots')
return fig, (ax1, ax2)
[docs] def plot_mapping(self, map_df, figsize=(6, 5)):
"""
Display anchor - archetype mapping (overlapping # spot ratio)
"""
filename = 'cluster' if self.filename is None else self.filename
g = sns.clustermap(
map_df,
method='ward', vmin=0, vmax=1,
figsize=figsize,
xticklabels=True,
yticklabels=True,
annot_kws={'size': 15}
)
text = g.ax_heatmap.set_title('Proportion of Overlapped Spots (k={})'.format(map_df.shape[1]),
fontsize=20, x=0.6, y=1.3)
# g.ax_row_dendrogram.set_visible(False)
# g.ax_col_dendrogram.set_visible(False)
if self.savefig and self.outdir is not None:
if not os.path.exists(self.outdir):
os.makedirs(self.outdir)
g.figure.savefig(
os.path.join(self.outdir, filename + '.eps'),
bbox_extra_artists=(text,), bbox_inches='tight', format='eps'
)
return g