Source code for scimilarity.cell_annotation

from typing import Optional, Union, List, Set, Tuple

from .cell_search_knn import CellSearchKNN


[docs]class CellAnnotation(CellSearchKNN): """A class that annotates cells using a cell embedding and then knn search.""" def __init__( self, model_path: str, use_gpu: bool = False, filenames: Optional[dict] = None, ): """Constructor. Parameters ---------- model_path: str Path to the directory containing model files. use_gpu: bool, default: False Use GPU instead of CPU. filenames: dict, optional, default: None Use a dictionary of custom filenames for files instead default. Examples -------- >>> ca = CellAnnotation(model_path="/opt/data/model") """ import os super().__init__( model_path=model_path, use_gpu=use_gpu, ) self.annotation_path = os.path.join(model_path, "annotation") os.makedirs(self.annotation_path, exist_ok=True) if filenames is None: filenames = {} self.filenames["knn"] = os.path.join( self.annotation_path, filenames.get("knn", "labelled_kNN.bin") ) self.filenames["celltype_labels"] = os.path.join( self.annotation_path, filenames.get("celltype_labels", "reference_labels.tsv"), ) # get knn self.load_knn_index(self.filenames["knn"]) # get int2label self.idx2label = None self.classes = None if self.knn is not None: with open(self.filenames["celltype_labels"], "r") as fh: self.idx2label = {i: line.strip() for i, line in enumerate(fh)} self.classes = set(self.label2int.keys()) self.safelist = None self.blocklist = None
[docs] def build_knn( self, input_data: Union["anndata.AnnData", List[str]], knn_filename: str = "labelled_kNN.bin", celltype_labels_filename: str = "reference_labels.tsv", obs_field: str = "celltype_name", ef_construction: int = 1000, M: int = 80, target_labels: Optional[List[str]] = None, ): """Build and save a knn index from a h5ad data file or directory of aligned.zarr stores. Parameters ---------- input_data: Union[anndata.AnnData, List[str]], If a list, it should contain a list of zarr store locations (zarr format saved by anndata). The zarr data should contain cells that are already log normalized and gene space aligned. Otherwise, the annotated data matrix with rows for cells and columns for genes. NOTE: The data should be curated to only contain valid cell ontology labels. knn_filename: str, default: "labelled_kNN.bin" Filename of the knn index. celltype_labels_filename: str, default: "reference_labels.tsv" Filename of the cell type reference labels. obs_field: str, default: "celltype_name" The obs column name of celltype labels. ef_construction: int, default: 1000 The size of the dynamic list for the nearest neighbors. See https://github.com/nmslib/hnswlib/blob/master/ALGO_PARAMS.md M: int, default: 80 The number of bi-directional links created for every new element during construction. See https://github.com/nmslib/hnswlib/blob/master/ALGO_PARAMS.md target_labels: Optional[List[str]], default: None Optional list of cell type names to filter the data. Examples -------- >>> ca.build_knn(filename="/opt/data/train/train.h5ad") """ import anndata import hnswlib import numpy as np import os import pandas as pd from .utils import align_dataset from .zarr_dataset import ZarrDataset from tqdm import tqdm if isinstance(input_data, list): data_list = input_data embeddings_list = [] labels = [] for filename in tqdm(data_list): dataset = ZarrDataset(filename) obs = pd.DataFrame({obs_field: dataset.get_obs(obs_field)}) obs.index = obs.index.astype(str) data = anndata.AnnData( X=dataset.get_X(in_mem=True), obs=obs, var=pd.DataFrame(index=dataset.var_index), dtype=np.float32, ) if target_labels is not None: data = data[data.obs["celltype_name"].isin(target_labels)].copy() if len(data.obs) == 0: continue embeddings_list.append( self.get_embeddings(align_dataset(data, self.gene_order).X) ) labels.extend(data.obs[obs_field].tolist()) embeddings = np.concatenate(embeddings_list) else: data = input_data if target_labels is not None: data = data[data.obs["celltype_name"].isin(target_labels)].copy() if len(data.obs) == 0: raise RuntimeError("No cells remain after filtering.") embeddings = self.get_embeddings(align_dataset(data, self.gene_order).X) labels = data.obs[obs_field].tolist() # save knn n_cells, n_dims = embeddings.shape self.knn = hnswlib.Index(space="cosine", dim=n_dims) self.knn.init_index(max_elements=n_cells, ef_construction=ef_construction, M=M) self.knn.set_ef(ef_construction) self.knn.add_items(embeddings, range(len(embeddings))) knn_fullpath = os.path.join(self.annotation_path, knn_filename) if os.path.isfile(knn_fullpath): # backup existing os.rename(knn_fullpath, knn_fullpath + ".bak") self.knn.save_index(knn_fullpath) # save labels celltype_labels_fullpath = os.path.join( self.annotation_path, celltype_labels_filename ) if os.path.isfile(celltype_labels_fullpath): # backup existing os.rename( celltype_labels_fullpath, celltype_labels_fullpath + ".bak", ) with open(celltype_labels_fullpath, "w") as f: f.write("\n".join(labels)) # load new int2label with open(celltype_labels_fullpath, "r") as fh: self.idx2label = {i: line.strip() for i, line in enumerate(fh)}
[docs] def reset_knn(self): """Reset the knn such that nothing is marked deleted. Examples -------- >>> ca.reset_knn() """ self.blocklist = None self.safelist = None # hnswlib does not have a marked status, so we need to unmark all for i in self.idx2label: try: # throws an expection if not already marked self.knn.unmark_deleted(i) except: pass
[docs] def blocklist_celltypes(self, labels: Union[List[str], Set[str]]): """Blocklist celltypes. Parameters ---------- labels: List[str], Set[str] A list or set containing blocklist labels. Notes ----- Blocking a celltype will persist for this instance of the class and subsequent predictions will have this blocklist. Blocklists and safelists are mutually exclusive, setting one will clear the other. Examples -------- >>> ca.blocklist_celltypes(["T cell"]) """ self.reset_knn() self.blocklist = set(labels) self.safelist = None for i, celltype_name in self.idx2label.items(): if celltype_name in self.blocklist: self.knn.mark_deleted(i) # mark blocklist
[docs] def safelist_celltypes(self, labels: Union[List[str], Set[str]]): """Safelist celltypes. Parameters ---------- labels: List[str], Set[str] A list or set containing safelist labels. Notes ----- Safelisting a celltype will persist for this instance of the class and subsequent predictions will have this safelist. Blocklists and safelists are mutually exclusive, setting one will clear the other. Examples -------- >>> ca.safelist_celltypes(["CD4-positive, alpha-beta T cell"]) """ self.blocklist = None self.safelist = set(labels) for i in self.idx2label: # mark all try: # throws an exception if already marked self.knn.mark_deleted(i) except: pass for i, celltype_name in self.idx2label.items(): if celltype_name in self.safelist: self.knn.unmark_deleted(i) # unmark safelist
[docs] def get_predictions_knn( self, embeddings: "numpy.ndarray", k: int = 50, ef: int = 100, weighting: bool = False, disable_progress: bool = False, ) -> Tuple["numpy.ndarray", "numpy.ndarray", "numpy.ndarray", "pandas.DataFrame"]: """Get predictions from knn search results. Parameters ---------- embeddings: numpy.ndarray Embeddings as a numpy array. k: int, default: 50 The number of nearest neighbors. ef: int, default: 100 The size of the dynamic list for the nearest neighbors. See https://github.com/nmslib/hnswlib/blob/master/ALGO_PARAMS.md weighting: bool, default: False Use distance weighting when getting the consensus prediction. disable_progress: bool, default: False Disable tqdm progress bar Returns ------- predictions: pandas.Series A pandas series containing celltype label predictions. nn_idxs: numpy.ndarray A 2D numpy array of nearest neighbor indices [num_cells x k]. nn_dists: numpy.ndarray A 2D numpy array of nearest neighbor distances [num_cells x k]. stats: pandas.DataFrame Prediction statistics dataframe with columns: "hits" is a json string with the count for every class in k cells. "min_dist" is the minimum distance. "max_dist" is the maximum distance "vs2nd" is sum(best) / sum(best + 2nd best). "vsAll" is sum(best) / sum(all hits). "hits_weighted" is a json string with the weighted count for every class in k cells. "vs2nd_weighted" is weighted sum(best) / sum(best + 2nd best). "vsAll_weighted" is weighted sum(best) / sum(all hits). Examples -------- >>> ca = CellAnnotation(model_path="/opt/data/model") >>> embeddings = ca.get_embeddings(align_dataset(data, ca.gene_order).X) >>> predictions, nn_idxs, nn_dists, stats = ca.get_predictions_knn(embeddings) """ from collections import defaultdict import json import operator import numpy as np import pandas as pd import time from tqdm import tqdm start_time = time.time() nn_idxs, nn_dists = self.get_nearest_neighbors( embeddings=embeddings, k=k, ef=ef ) end_time = time.time() if not disable_progress: print( f"Get nearest neighbors finished in: {float(end_time - start_time) / 60} min" ) stats = { "hits": [], "hits_weighted": [], "min_dist": [], "max_dist": [], "vs2nd": [], "vsAll": [], "vs2nd_weighted": [], "vsAll_weighted": [], } if k == 1: predictions = pd.Series(nn_idxs.flatten()).map(self.idx2label) else: predictions = [] for nns, d_nns in tqdm( zip(nn_idxs, nn_dists), total=nn_idxs.shape[0], disable=disable_progress ): # count celltype in nearest neighbors (optionally with distance weights) celltype = defaultdict(float) celltype_weighted = defaultdict(float) for neighbor, dist in zip(nns, d_nns): celltype[self.idx2label[neighbor]] += 1.0 celltype_weighted[self.idx2label[neighbor]] += 1.0 / float( max(dist, 1e-6) ) # predict based on consensus max occurrence if weighting: predictions.append( max(celltype_weighted.items(), key=operator.itemgetter(1))[0] ) else: predictions.append( max(celltype.items(), key=operator.itemgetter(1))[0] ) # compute prediction stats stats["hits"].append(json.dumps(celltype)) stats["hits_weighted"].append(json.dumps(celltype_weighted)) stats["min_dist"].append(np.min(d_nns)) stats["max_dist"].append(np.max(d_nns)) hits = sorted(celltype.values(), reverse=True) hits_weighted = [ max(x, 1e-6) for x in sorted(celltype_weighted.values(), reverse=True) ] if len(hits) > 1: stats["vs2nd"].append(hits[0] / (hits[0] + hits[1])) stats["vsAll"].append(hits[0] / sum(hits)) stats["vs2nd_weighted"].append( hits_weighted[0] / (hits_weighted[0] + hits_weighted[1]) ) stats["vsAll_weighted"].append( hits_weighted[0] / sum(hits_weighted) ) else: stats["vs2nd"].append(1.0) stats["vsAll"].append(1.0) stats["vs2nd_weighted"].append(1.0) stats["vsAll_weighted"].append(1.0) return ( pd.Series(predictions), nn_idxs, nn_dists, pd.DataFrame(stats), )
[docs] def annotate_dataset( self, data: "anndata.AnnData", ) -> "anndata.AnnData": """Annotate dataset with celltype predictions. Parameters ---------- data: anndata.AnnData The annotated data matrix with rows for cells and columns for genes. This function assumes the data has been log normalized (i.e. via lognorm_counts) accordingly. Returns ------- anndata.AnnData A data object where: - celltype predictions are in obs["celltype_hint"] - embeddings are in obs["X_scimilarity"]. Examples -------- >>> ca = CellAnnotation(model_path="/opt/data/model") >>> data = annotate_dataset(data) """ from .utils import align_dataset embeddings = self.get_embeddings(align_dataset(data, self.gene_order).X) data.obsm["X_scimilarity"] = embeddings predictions, _, _, nn_stats = self.get_predictions_knn(embeddings) data.obs["celltype_hint"] = predictions.values data.obs["min_dist"] = nn_stats["min_dist"].values data.obs["celltype_hits"] = nn_stats["hits"].values data.obs["celltype_hits_weighted"] = nn_stats["hits_weighted"].values data.obs["celltype_hint_stat"] = nn_stats["vsAll"].values data.obs["celltype_hint_weighted_stat"] = nn_stats["vsAll_weighted"].values return data