Source code for stereoAlign.metrics.silhouette

#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Time    : 8/11/23 3:00 PM
# @Author  : zhangchao
# @File    : silhouette.py
# @Email   : zhangchao5@genomics.cn
import numpy as np
import pandas as pd
from scanpy._utils import deprecated_arg_names
from sklearn.metrics.cluster import silhouette_samples, silhouette_score


[docs]@deprecated_arg_names({"group_key": "label_key"}) def silhouette(adata, label_key, embed, metric="euclidean", scale=True): """Average silhouette width (ASW) Wrapper for sklearn silhouette function values range from [-1, 1] with * 1 indicates distinct, compact clusters * 0 indicates overlapping clusters * -1 indicates core-periphery (non-cluster) structure By default, the score is scaled between 0 and 1 (``scale=True``). Parameters ---------- label_key: key in adata.obs of cell labels embed: embedding key in adata.obsm, default: 'X_pca' metric: type of distance metric to use for the silhouette scores scale: default True, scale between 0 (worst) and 1 (best) The function requires an embedding to be stored in ``adata.obsm`` and can only be applied to feature and embedding integration outputs. Please note, that the metric cannot be used to evaluate kNN graph outputs. """ if embed not in adata.obsm.keys(): print(adata.obsm.keys()) raise KeyError(f"{embed} not in obsm") asw = silhouette_score( X=adata.obsm[embed], labels=adata.obs[label_key], metric=metric ) if scale: asw = (asw + 1) / 2 return asw
[docs]@deprecated_arg_names({"group_key": "label_key"}) def silhouette_batch( adata, batch_key, label_key, embed, metric="euclidean", return_all=False, scale=True, verbose=True, ): r"""Batch ASW Modified average silhouette width (ASW) of batch This metric measures the silhouette of a given batch. It assumes that a silhouette width close to 0 represents perfect overlap of the batches, thus the absolute value of the silhouette width is used to measure how well batches are mixed. For all cells :math:`i` of a cell type :math:`C_j`, the batch ASW of that cell type is: .. math:: batch\\, ASW_j=\\frac{1}{|C_j|}\\sum_{i\\in C_j}|silhouette(i)| The final score is the average of the absolute silhouette widths computed per cell type :math:`M`. .. math:: batch\\, ASW =\\frac{1}{|M|}\\sum_{i\\in M} batch\\, ASW_j For a scaled metric (which is the default), the absolute ASW per group is subtracted from 1 before averaging, so that 0 indicates suboptimal label representation and 1 indicates optimal label representation. .. math:: batch\\, ASW_j =\\frac{1}{|C_j|}\\sum_{i\\in C_j} 1 - |silhouette(i)| Parameters ---------- batch_key: batch labels to be compared against label_key: group labels to be subset by e.g. cell type embed: name of column in adata.obsm metric: see sklearn silhouette score scale: if True, scale between 0 and 1 return_all: if True, return all silhouette scores and label means. default False: return average width silhouette (ASW) verbose: print silhouette score per group Returns ------- Batch ASW (always) Mean silhouette per group in pd.DataFrame (additionally, if return_all=True) Absolute silhouette scores per group label (additionally, if return_all=True) The function requires an embedding to be stored in ``adata.obsm`` and can only be applied to feature and embedding integration outputs. Please note, that the metric cannot be used to evaluate kNN graph outputs. """ if embed not in adata.obsm.keys(): print(adata.obsm.keys()) raise KeyError(f"{embed} not in obsm") sil_per_label = [] for group in adata.obs[label_key].unique(): adata_group = adata[adata.obs[label_key] == group] n_batches = adata_group.obs[batch_key].nunique() if (n_batches == 1) or (n_batches == adata_group.shape[0]): continue sil = silhouette_samples( adata_group.obsm[embed], adata_group.obs[batch_key], metric=metric ) # take only absolute value sil = [abs(i) for i in sil] if scale: # scale s.t. highest number is optimal sil = [1 - i for i in sil] sil_per_label.extend([(group, score) for score in sil]) sil_df = pd.DataFrame.from_records( sil_per_label, columns=["group", "silhouette_score"] ) if len(sil_per_label) == 0: sil_means = np.nan asw = np.nan else: sil_means = sil_df.groupby("group").mean() asw = sil_means["silhouette_score"].mean() if verbose: print(f"mean silhouette per group: {sil_means}") if return_all: return asw, sil_means, sil_df return asw