Source code for stereoAlign.metrics.silhouette
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Time : 8/11/23 3:00 PM
# @Author : zhangchao
# @File : silhouette.py
# @Email : zhangchao5@genomics.cn
import numpy as np
import pandas as pd
from scanpy._utils import deprecated_arg_names
from sklearn.metrics.cluster import silhouette_samples, silhouette_score
[docs]@deprecated_arg_names({"group_key": "label_key"})
def silhouette(adata, label_key, embed, metric="euclidean", scale=True):
"""Average silhouette width (ASW)
Wrapper for sklearn silhouette function values range from [-1, 1] with
* 1 indicates distinct, compact clusters
* 0 indicates overlapping clusters
* -1 indicates core-periphery (non-cluster) structure
By default, the score is scaled between 0 and 1 (``scale=True``).
Parameters
----------
label_key:
key in adata.obs of cell labels
embed:
embedding key in adata.obsm, default: 'X_pca'
metric:
type of distance metric to use for the silhouette scores
scale:
default True, scale between 0 (worst) and 1 (best)
The function requires an embedding to be stored in ``adata.obsm`` and can only be applied to feature and embedding
integration outputs.
Please note, that the metric cannot be used to evaluate kNN graph outputs.
"""
if embed not in adata.obsm.keys():
print(adata.obsm.keys())
raise KeyError(f"{embed} not in obsm")
asw = silhouette_score(
X=adata.obsm[embed], labels=adata.obs[label_key], metric=metric
)
if scale:
asw = (asw + 1) / 2
return asw
[docs]@deprecated_arg_names({"group_key": "label_key"})
def silhouette_batch(
adata,
batch_key,
label_key,
embed,
metric="euclidean",
return_all=False,
scale=True,
verbose=True,
):
r"""Batch ASW
Modified average silhouette width (ASW) of batch
This metric measures the silhouette of a given batch.
It assumes that a silhouette width close to 0 represents perfect overlap of the batches, thus the absolute value of
the silhouette width is used to measure how well batches are mixed.
For all cells :math:`i` of a cell type :math:`C_j`, the batch ASW of that cell type is:
.. math::
batch\\, ASW_j=\\frac{1}{|C_j|}\\sum_{i\\in C_j}|silhouette(i)|
The final score is the average of the absolute silhouette widths computed per cell type :math:`M`.
.. math::
batch\\, ASW =\\frac{1}{|M|}\\sum_{i\\in M} batch\\, ASW_j
For a scaled metric (which is the default), the absolute ASW per group is subtracted from 1 before averaging, so that
0 indicates suboptimal label representation and 1 indicates optimal label representation.
.. math::
batch\\, ASW_j =\\frac{1}{|C_j|}\\sum_{i\\in C_j} 1 - |silhouette(i)|
Parameters
----------
batch_key:
batch labels to be compared against
label_key:
group labels to be subset by e.g. cell type
embed:
name of column in adata.obsm
metric:
see sklearn silhouette score
scale:
if True, scale between 0 and 1
return_all:
if True, return all silhouette scores and label means. default False: return average width silhouette (ASW)
verbose:
print silhouette score per group
Returns
-------
Batch ASW (always)
Mean silhouette per group in pd.DataFrame (additionally, if return_all=True)
Absolute silhouette scores per group label (additionally, if return_all=True)
The function requires an embedding to be stored in ``adata.obsm`` and can only be applied to feature and embedding
integration outputs.
Please note, that the metric cannot be used to evaluate kNN graph outputs.
"""
if embed not in adata.obsm.keys():
print(adata.obsm.keys())
raise KeyError(f"{embed} not in obsm")
sil_per_label = []
for group in adata.obs[label_key].unique():
adata_group = adata[adata.obs[label_key] == group]
n_batches = adata_group.obs[batch_key].nunique()
if (n_batches == 1) or (n_batches == adata_group.shape[0]):
continue
sil = silhouette_samples(
adata_group.obsm[embed], adata_group.obs[batch_key], metric=metric
)
# take only absolute value
sil = [abs(i) for i in sil]
if scale:
# scale s.t. highest number is optimal
sil = [1 - i for i in sil]
sil_per_label.extend([(group, score) for score in sil])
sil_df = pd.DataFrame.from_records(
sil_per_label, columns=["group", "silhouette_score"]
)
if len(sil_per_label) == 0:
sil_means = np.nan
asw = np.nan
else:
sil_means = sil_df.groupby("group").mean()
asw = sil_means["silhouette_score"].mean()
if verbose:
print(f"mean silhouette per group: {sil_means}")
if return_all:
return asw, sil_means, sil_df
return asw