Source code for scimilarity.cell_embedding

from typing import Optional, Tuple, Union


[docs]class CellEmbedding: """A class that embeds cell gene expression data using a ML model.""" def __init__( self, model_path: str, use_gpu: bool = False, ): """Constructor. Parameters ---------- model_path: str Path to the directory containing model files. use_gpu: bool, default: False Use GPU instead of CPU. Examples -------- >>> ce = CellEmbedding(model_path="/opt/data/model") """ import json import os import pandas as pd from .nn_models import Encoder self.model_path = model_path self.use_gpu = use_gpu self.filenames = { "model": os.path.join(self.model_path, "encoder.ckpt"), "gene_order": os.path.join(self.model_path, "gene_order.tsv"), } # get gene order with open(self.filenames["gene_order"], "r") as fh: self.gene_order = [line.strip() for line in fh] # get neural network model and infer network size with open(os.path.join(self.model_path, "layer_sizes.json"), "r") as fh: layer_sizes = json.load(fh) # keys: network.1.weight, network.2.weight, ..., network.n.weight layers = [ (key, layer_sizes[key]) for key in sorted(list(layer_sizes.keys())) if "weight" in key and len(layer_sizes[key]) > 1 ] parameters = { "latent_dim": layers[-1][1][0], # last "hidden_dim": [layer[1][0] for layer in layers][0:-1], # all but last } self.n_genes = len(self.gene_order) self.latent_dim = parameters["latent_dim"] self.model = Encoder( n_genes=self.n_genes, latent_dim=parameters["latent_dim"], hidden_dim=parameters["hidden_dim"], residual=False, ) if self.use_gpu is True: self.model.cuda() self.model.load_state(self.filenames["model"]) self.model.eval() self.int2label = pd.read_csv( os.path.join(self.model_path, "label_ints.csv"), index_col=0 )["0"].to_dict() self.label2int = {value: key for key, value in self.int2label.items()}
[docs] def get_embeddings( self, X: Union["scipy.sparse.csr_matrix", "scipy.sparse.csc_matrix", "numpy.ndarray"], num_cells: int = -1, buffer_size: int = 10000, ) -> "numpy.ndarray": """Calculate embeddings for lognormed gene expression matrix. Parameters ---------- X: scipy.sparse.csr_matrix, scipy.sparse.csc_matrix, numpy.ndarray Gene space aligned and log normalized (tp10k) gene expression matrix. num_cells: int, default: -1 The number of cells to embed, starting from index 0. A value of -1 will embed all cells. buffer_size: int, default: 10000 The number of cells to embed in one batch. Returns ------- numpy.ndarray A 2D numpy array of embeddings [num_cells x latent_space_dimensions]. Examples -------- >>> from scimilarity.utils import align_dataset, lognorm_counts >>> ce = CellEmbedding(model_path="/opt/data/model") >>> data = align_dataset(data, ce.gene_order) >>> data = lognorm_counts(data) >>> embeddings = ce.get_embeddings(data.X) """ import numpy as np from scipy.sparse import csr_matrix, csc_matrix import torch import zarr if num_cells == -1: num_cells = X.shape[0] if ( (isinstance(X, csr_matrix) or isinstance(X, csc_matrix)) and ( isinstance(X.data, zarr.core.Array) or isinstance(X.indices, zarr.core.Array) or isinstance(X.indptr, zarr.core.Array) ) and num_cells <= buffer_size ): X.data = X.data[...] X.indices = X.indices[...] X.indptr = X.indptr[...] embedding_parts = [] with torch.inference_mode(): # disable gradients, not needed for inference for i in range(0, num_cells, buffer_size): profiles = None if isinstance(X, np.ndarray): profiles = torch.Tensor(X[i : i + buffer_size]) elif isinstance(X, torch.Tensor): profiles = X[i : i + buffer_size] elif isinstance(X, csr_matrix) or isinstance(X, csc_matrix): profiles = torch.Tensor(X[i : i + buffer_size].toarray()) if profiles is None: raise RuntimeError(f"Unknown data type {type(X)}.") if self.use_gpu is True: profiles = profiles.cuda() embedding_parts.append(self.model(profiles)) if not embedding_parts: raise RuntimeError(f"No valid cells detected.") if self.use_gpu: # detach, move from gpu into cpu, return as numpy array embedding = torch.vstack(embedding_parts).detach().cpu().numpy() else: # detach, return as numpy array embedding = torch.vstack(embedding_parts).detach().numpy() if np.isnan(embedding).any(): raise RuntimeError(f"NaN detected in embeddings.") return embedding