Source code for scimilarity.cell_annotation

from typing import Optional, Union, List, Set, Tuple

from .cell_search_knn import CellSearchKNN


[docs] class CellAnnotation(CellSearchKNN): """A class that annotates cells using a cell embedding and then knn search. Parameters ---------- model_path: str Path to the directory containing model files. use_gpu: bool, default: False Use GPU instead of CPU. filenames: dict, optional, default: None Use a dictionary of custom filenames for files instead default. Examples -------- >>> ca = CellAnnotation(model_path="/opt/data/model") """ def __init__( self, model_path: str, use_gpu: bool = False, filenames: Optional[dict] = None, ): import os super().__init__( model_path=model_path, use_gpu=use_gpu, knn_type="hnswlib", ) self.annotation_path = os.path.join(model_path, "annotation") os.makedirs(self.annotation_path, exist_ok=True) if filenames is None: filenames = {} self.filenames["knn"] = os.path.join( self.annotation_path, filenames.get("knn", "labelled_kNN.bin") ) self.filenames["celltype_labels"] = os.path.join( self.annotation_path, filenames.get("celltype_labels", "reference_labels.tsv"), ) # get knn self.load_knn_index(self.filenames["knn"]) # get int2label and int2study self.idx2label = {} self.idx2study = {} if self.knn is not None: with open(self.filenames["celltype_labels"], "r") as fh: for i, line in enumerate(fh): token = line.strip().split("\t") self.idx2label[i] = token[0] if len(token) > 1: self.idx2study[i] = token[1] self.safelist = None self.blocklist = None @property def classes() -> set: """Get the set of all viable prediction classes.""" return set(self.label2int.keys())
[docs] def reset_knn(self): """Reset the knn such that nothing is marked deleted. Examples -------- >>> ca.reset_knn() """ self.blocklist = None self.safelist = None # hnswlib does not have a marked status, so we need to unmark all for i in self.idx2label: try: # throws an expection if not already marked self.knn.unmark_deleted(i) except: pass
[docs] def blocklist_celltypes(self, labels: Union[List[str], Set[str]]): """Blocklist celltypes. Parameters ---------- labels: List[str], Set[str] A list or set containing blocklist labels. Notes ----- Blocking a celltype will persist for this instance of the class and subsequent predictions will have this blocklist. Blocklists and safelists are mutually exclusive, setting one will clear the other. Examples -------- >>> ca.blocklist_celltypes(["T cell"]) """ self.reset_knn() self.blocklist = set(labels) self.safelist = None for i, celltype_name in self.idx2label.items(): if celltype_name in self.blocklist: self.knn.mark_deleted(i) # mark blocklist
[docs] def safelist_celltypes(self, labels: Union[List[str], Set[str]]): """Safelist celltypes. Parameters ---------- labels: List[str], Set[str] A list or set containing safelist labels. Notes ----- Safelisting a celltype will persist for this instance of the class and subsequent predictions will have this safelist. Blocklists and safelists are mutually exclusive, setting one will clear the other. Examples -------- >>> ca.safelist_celltypes(["CD4-positive, alpha-beta T cell"]) """ self.blocklist = None self.safelist = set(labels) for i in self.idx2label: # mark all try: # throws an exception if already marked self.knn.mark_deleted(i) except: pass for i, celltype_name in self.idx2label.items(): if celltype_name in self.safelist: self.knn.unmark_deleted(i) # unmark safelist
[docs] def get_predictions_knn( self, embeddings: "numpy.ndarray", k: int = 50, ef: int = 100, weighting: bool = False, disable_progress: bool = False, ) -> Tuple["numpy.ndarray", "numpy.ndarray", "numpy.ndarray", "pandas.DataFrame"]: """Get predictions from knn search results. Parameters ---------- embeddings: numpy.ndarray Embeddings as a numpy array. k: int, default: 50 The number of nearest neighbors. ef: int, default: 100 The size of the dynamic list for the nearest neighbors. See https://github.com/nmslib/hnswlib/blob/master/ALGO_PARAMS.md weighting: bool, default: False Use distance weighting when getting the consensus prediction. disable_progress: bool, default: False Disable tqdm progress bar Returns ------- predictions: pandas.Series A pandas series containing celltype label predictions. nn_idxs: numpy.ndarray A 2D numpy array of nearest neighbor indices [num_cells x k]. nn_dists: numpy.ndarray A 2D numpy array of nearest neighbor distances [num_cells x k]. stats: pandas.DataFrame Prediction statistics dataframe with columns: "hits" is a json string with the count for every class in k cells. "min_dist" is the minimum distance. "max_dist" is the maximum distance "vs2nd" is sum(best) / sum(best + 2nd best). "vsAll" is sum(best) / sum(all hits). "hits_weighted" is a json string with the weighted count for every class in k cells. "vs2nd_weighted" is weighted sum(best) / sum(best + 2nd best). "vsAll_weighted" is weighted sum(best) / sum(all hits). Examples -------- >>> ca = CellAnnotation(model_path="/opt/data/model") >>> embeddings = ca.get_embeddings(align_dataset(data, ca.gene_order).X) >>> predictions, nn_idxs, nn_dists, stats = ca.get_predictions_knn(embeddings) """ from collections import defaultdict import json import operator import numpy as np import pandas as pd import time from tqdm import tqdm start_time = time.time() nn_idxs, nn_dists = self.get_nearest_neighbors( embeddings=embeddings, k=k, ef=ef ) end_time = time.time() if not disable_progress: print( f"Get nearest neighbors finished in: {float(end_time - start_time) / 60} min" ) stats = { "hits": [], "hits_weighted": [], "min_dist": [], "max_dist": [], "vs2nd": [], "vsAll": [], "vs2nd_weighted": [], "vsAll_weighted": [], } if k == 1: predictions = pd.Series(nn_idxs.flatten()).map(self.idx2label) else: predictions = [] for nns, d_nns in tqdm( zip(nn_idxs, nn_dists), total=nn_idxs.shape[0], disable=disable_progress ): # count celltype in nearest neighbors (optionally with distance weights) celltype = defaultdict(float) celltype_weighted = defaultdict(float) for neighbor, dist in zip(nns, d_nns): celltype[self.idx2label[neighbor]] += 1.0 celltype_weighted[self.idx2label[neighbor]] += 1.0 / float( max(dist, 1e-6) ) # predict based on consensus max occurrence if weighting: predictions.append( max(celltype_weighted.items(), key=operator.itemgetter(1))[0] ) else: predictions.append( max(celltype.items(), key=operator.itemgetter(1))[0] ) # compute prediction stats stats["hits"].append(json.dumps(celltype)) stats["hits_weighted"].append(json.dumps(celltype_weighted)) stats["min_dist"].append(np.min(d_nns)) stats["max_dist"].append(np.max(d_nns)) hits = sorted(celltype.values(), reverse=True) hits_weighted = [ max(x, 1e-6) for x in sorted(celltype_weighted.values(), reverse=True) ] if len(hits) > 1: stats["vs2nd"].append(hits[0] / (hits[0] + hits[1])) stats["vsAll"].append(hits[0] / sum(hits)) stats["vs2nd_weighted"].append( hits_weighted[0] / (hits_weighted[0] + hits_weighted[1]) ) stats["vsAll_weighted"].append( hits_weighted[0] / sum(hits_weighted) ) else: stats["vs2nd"].append(1.0) stats["vsAll"].append(1.0) stats["vs2nd_weighted"].append(1.0) stats["vsAll_weighted"].append(1.0) return ( pd.Series(predictions), nn_idxs, nn_dists, pd.DataFrame(stats), )
[docs] def annotate_dataset( self, data: "anndata.AnnData", ) -> "anndata.AnnData": """Annotate dataset with celltype predictions. Parameters ---------- data: anndata.AnnData The annotated data matrix with rows for cells and columns for genes. This function assumes the data has been log normalized (i.e. via lognorm_counts) accordingly. Returns ------- anndata.AnnData A data object where: - celltype predictions are in obs["celltype_hint"] - embeddings are in obs["X_scimilarity"]. Examples -------- >>> ca = CellAnnotation(model_path="/opt/data/model") >>> data = annotate_dataset(data) """ from .utils import align_dataset embeddings = self.get_embeddings(align_dataset(data, self.gene_order).X) data.obsm["X_scimilarity"] = embeddings predictions, _, _, nn_stats = self.get_predictions_knn(embeddings) data.obs["celltype_hint"] = predictions.values data.obs["min_dist"] = nn_stats["min_dist"].values data.obs["celltype_hits"] = nn_stats["hits"].values data.obs["celltype_hits_weighted"] = nn_stats["hits_weighted"].values data.obs["celltype_hint_stat"] = nn_stats["vsAll"].values data.obs["celltype_hint_weighted_stat"] = nn_stats["vsAll_weighted"].values return data