Source code for stereoAlign.alignment.scvi_alignment

#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Time    : 8/24/23 9:47 AM
# @Author  : zhangchao
# @File    : _scvi_alignment.py
# @Email   : zhangchao5@genomics.cn
from anndata import AnnData

from stereoAlign.utils import check_sanity


[docs]def scvi_alignment(adata: AnnData, batch_key, hvg=None, return_model=False, max_epochs=None): """scVI wrapper function Based on scvi-tools version >=0.16.0 (available through `conda <https://docs.scvi-tools.org/en/stable/installation.html>`_) .. note:: scVI expects only non-normalized (count) data on highly variable genes! :param adata: preprocessed ``anndata`` object :param batch_key: batch key in ``adata.obs`` :param hvg: list of highly variables to subset to. If ``None``, the full dataset will be used :return: ``anndata`` object containing the corrected feature matrix as well as an embedding representation of the corrected data """ try: from scvi.model import SCVI except ImportError: raise ImportError("\nplease install scvi:\n\n\tpip install scvi-tools") check_sanity(adata, batch_key, hvg) # Check for counts data layer if "counts" not in adata.layers: raise TypeError( "Adata does not contain a `counts` layer in `adata.layers[`counts`]`" ) # Defaults from SCVI github tutorials scanpy_pbmc3k and harmonization n_latent = 30 n_hidden = 128 n_layers = 2 # copying to not return values added to adata during setup_anndata net_adata = adata.copy() if hvg is not None: net_adata = adata[:, hvg].copy() SCVI.setup_anndata(net_adata, layer="counts", batch_key=batch_key) vae = SCVI( net_adata, gene_likelihood="nb", n_layers=n_layers, n_latent=n_latent, n_hidden=n_hidden, ) train_kwargs = {"train_size": 1.0} if max_epochs is not None: train_kwargs["max_epochs"] = max_epochs vae.train(**train_kwargs) adata.obsm["aligned_scvi"] = vae.get_latent_representation() if not return_model: return adata else: return vae