{ "cells": [ { "cell_type": "markdown", "id": "8943365d-1db3-4b0f-9ba0-18ed16e124fd", "metadata": {}, "source": [ "# Data preparation, model training, and building post-training structures\n", "\n", "We have transitioned the SCimilarity codebase to the use of [TileDB](https://tiledb.com/) instead of the previously used [Zarr](https://zarr.dev/) to streamline dataloading, training, and building of post-training data structures. \n", "\n", "This tutorial is a high level workflow for training SCimilarity, including data structure requirement, dataloaders, and scripts to build post-training data structures.\n", "\n", "The workflow is divided into the following sections:\n", "\n", " 1. Process the data into CellArr TileDB format.\n", " 2. Train using the TileDB dataloader.\n", " 3. Build all post-training data structures, such as kNN indices and TileDB arrays." ] }, { "cell_type": "markdown", "id": "086e39f2-6d87-40bd-bb84-5a8f06011d35", "metadata": {}, "source": [ "## 1. Construct a TileDB store in CellArr format\n", "\n", "To facilate random access for a cell corpus that is too large to fit into memory, we currently recommend a TileDB structure as used in [CellArr](https://github.com/CellArr/cellarr). The construction requires a collection of AnnData H5AD files which will be organized into TileDB arrays and dataframes by CellArr.\n", "\n", "For more details on how to construct CellArr TileDB structures from AnnData H5AD files, please see the [CellArr documentation](https://cellarr.github.io/cellarr/) and the [tutorial](https://cellarr.github.io/cellarr/tutorial_cellxgene.html) to create CellArr objects from CELLxGENE datasets.\n", "\n", "**Before you start creating a CellArr object, read the below sections on requirements and processing the AnnData.**" ] }, { "cell_type": "markdown", "id": "90bbe463-55f5-42c6-8165-b66638213df2", "metadata": {}, "source": [ "### 1.1 Processing the AnnData\n", "\n", "We recommend you process the AnnData before constructing the CellArr. There are required `obs` columns that are described in the next section.\n", "\n", "Some important requirements:\n", "\n", " 1. `var.index` must be in gene symbols and not Ensemble IDs, you will need to map Ensemble ID to gene symbol.\n", " 2. If you map Ensemble ID to gene symbol, you will need to consolidate counts in cases where multiple Ensemble IDs map to one gene symbol. We provide a utility function for this as `utils.consolidate_duplicate_symbols`.\n", " 3. Please ensure that `var.index` contains only gene symbols as sometimes it can contain a mix of symbols and Ensemble IDs. \n", "\n", "An example of this process, assuming you have mapping tables, is shown below. Please note that this may need to be modified depending on the structure of your mapping table." ] }, { "cell_type": "code", "execution_count": null, "id": "57f0a94b-9f2c-4eb2-8f92-763074c9a6ca", "metadata": {}, "outputs": [], "source": [ "def get_id2name(mapping_table_file: str):\n", " import pandas as pd\n", "\n", " id2name = pd.read_csv(mapping_table_file, delimiter=\"\\t\", index_col=0, dtype=\"unicode\")[\"Gene name\"]\n", " id2name = id2name.reset_index().drop_duplicates().set_index(\"Gene stable ID\")[\"Gene name\"]\n", " id2name = pd.concat([id2name, pd.Series(id2name.values, index=id2name.values).drop_duplicates()])\n", " return id2name\n", "\n", "def convert_ids_to_names(adata):\n", " id2name = get_id2name()\n", " if not any(adata.var.index.isin(id2name.keys())) and \"symbol\" in adata.var.columns:\n", " adata.var = adata.var.set_index(\"symbol\", drop=False)\n", " adata.var.index = adata.var.index.str.replace(\"'\", \"\")\n", " adata.var.index = adata.var.index.str.replace('\"', \"\")\n", " adata = adata[:, (adata.var.index.isin(id2name.keys())) & ~(adata.var.index.isnull())].copy()\n", " adata.var.index = id2name[adata.var.index]\n", " return adata\n", "\n", "def clean_var(adata):\n", " from scimilarity.utils import consolidate_duplicate_symbols\n", " \n", " adata = convert_ids_to_names(adata)\n", " adata = consolidate_duplicate_symbols(adata)\n", " adata.var.index.name = \"symbol\"\n", " return adata" ] }, { "cell_type": "markdown", "id": "613ba0ca-6ac6-4cf2-abfc-0cded8fea50e", "metadata": {}, "source": [ "### 1.2 Required columns\n", "\n", "We require specific columns to be present in the metadata by default, many of which are used for filtering cells. If you would like to change the default filter and/or what columns are required, you can do so with the `filter_condition` parameter in the dataloader class `tiledb_data_models.CellMultisetDataModule`.\n", "\n", "You may include any columns you wish in the metadata but the following columns are required by default:\n", "\n", " 1. A string column for the study identifier, e.g., `datasetID`.\n", " 2. A string column for sample identifier, e.g., `sampleID`.\n", " 3. A string column for the cell type label identifier that contains the **cell ontology term identifier**, e.g., `cellTypeOntologyID`.\n", " 4. A string column named `tissue` that contains the tissue annotation.\n", " 5. A string column named `disease` that contains the disease annotation.\n", " 6. An int column named `n_genes_by_counts` that contains the number of genes expressed, as produced by `scanpy.pp.calculate_qc_metrics`.\n", " 7. An int column named `total_counts` that contains the number of UMIs, as produced by `scanpy.pp.calculate_qc_metrics`.\n", " 8. An int column named `total_counts_mt` that contains the number of mitochondrial UMIs, as produced by `scanpy.pp.calculate_qc_metrics`.\n", " 9. A float column named `pct_counts_mt` that contains the percentage of mitochondrial UMIs, as produced by `scanpy.pp.calculate_qc_metrics`.\n", " 10. An int column named `predicted_doublets` that is `1` for doublet and `0` for not doublet, as produced by [scrublet](https://github.com/swolock/scrublet) or your choice of doublet detection tool.\n", "\n", "**NOTE**: Not every column needs to have values. For example, some cells do not have cell type annotations, in which case you may leave it blank.\n", "\n", "An example of this process is shown below. Please note that you may need to customize this based on the columns that exist in your data." ] }, { "cell_type": "code", "execution_count": null, "id": "cda21c09-4c8e-4bc4-86da-afa10bf5304f", "metadata": {}, "outputs": [], "source": [ "# datasetID and sampleID are explicit as it is not trivial to infer it from obs column names\n", "# Feel free to add cell type ontology ID, tissue, or disease if you wish to make those explicit per AnnData\n", "def clean_obs(adata, datasetID, sampleID):\n", " # If you wish to use a doublet prediction tool other than scrublet, feel free\n", " # Though, we will require the \"predicted_doublets\" column as described previously\n", " import scrublet as scr\n", " import scanpy as sc\n", "\n", " # Determine between \"MT-\" and \"mt-\"\n", " mito_prefix = \"MT-\"\n", " if any(adata.var.index.str.startswith(\"mt-\")) is True:\n", " mito_prefix = \"mt-\"\n", "\n", " # QC stats\n", " adata.var[\"mt\"] = adata.var_names.str.startswith(mito_prefix)\n", " sc.pp.calculate_qc_metrics(\n", " adata,\n", " qc_vars=[\"mt\"],\n", " percent_top=None,\n", " log1p=False,\n", " inplace=True,\n", " layer=\"counts\",\n", " )\n", " obs = adata.obs.copy()\n", "\n", " obs.columns = [x[0].lower() + x[1:] for x in obs.columns] # lowercase first letter\n", "\n", " obs[\"datasetID\"] = datasetID\n", " obs[\"sampleID\"] = sampleID\n", "\n", " # Scrublet can fail, so we default to False in that case\n", " try:\n", " scrub = scr.Scrublet(adata.layers['counts'])\n", " doublet_scores, predicted_doublets = scrub.scrub_doublets(verbose=False)\n", " obs[\"doublet_score\"] = doublet_scores\n", " obs[\"predicted_doublets\"] = predicted_doublets.astype(int)\n", " except:\n", " obs[\"doublet_score\"] = np.nan\n", " obs[\"predicted_doublets\"] = 0\n", " pass\n", "\n", " # You will need to go through all columns that might contain cell type ontology\n", " # or convert cell type name to ontology (see: scimilarity.ontologies for helper functions)\n", " if \"cellTypeOntologyID\" not in obs.columns:\n", " obs[\"cellTypeOntologyID\"] = \"\"\n", "\n", " # You will need to go through all columns that might contain tissue\n", " if \"tissue\" not in obs.columns and \"meta_tissue\" not in obs.columns:\n", " obs[\"tissue\"] = \"\"\n", " elif \"tissue\" not in obs.columns and \"meta_tissue\" in obs.columns:\n", " obs = obs.rename(columns={\"meta_tissue\": \"tissue\"})\n", "\n", " # You will need to go through all columns that might contain disease\n", " if \"disease\" not in obs.columns and \"meta_disease\" not in obs.columns:\n", " obs[\"disease\"] = \"\"\n", " elif \"disease\" not in obs.columns and \"meta_disease\" in obs.columns:\n", " obs = obs.rename(columns={\"meta_disease\": \"disease\"})\n", "\n", " columns = [\n", " \"datasetID\", \"sampleID\", \"cellTypeOntologyID\", \"tissue\", \"disease\",\n", " \"n_genes_by_counts\", \"total_counts\", \"total_counts_mt\", \"pct_counts_mt\",\n", " \"doublet_score\", \"predicted_doublets\",\n", " ]\n", " obs = obs[columns].copy()\n", "\n", " convert_dict = {\n", " \"datasetID\": str,\n", " \"sampleID\": str,\n", " \"cellTypeOntologyID\": str,\n", " \"tissue\": str,\n", " \"disease\": str,\n", " \"n_genes_by_counts\": int,\n", " \"total_counts\": int,\n", " \"total_counts_mt\": int,\n", " \"pct_counts_mt\": float,\n", " \"doublet_score\": float,\n", " \"predicted_doublets\": int,\n", " }\n", " adata.obs = obs.astype(convert_dict)\n", " return adata" ] }, { "cell_type": "markdown", "id": "84451603-8e9d-4990-b87d-50c031ae5582", "metadata": {}, "source": [ "### 1.3 Reason for required columns\n", "\n", "First, the cell type ontology identifier is required because we utilize the cell ontology to determine ancestors or descendents during triplet mining. It is important that the cell type annotation is a controlled vocabulary that conforms to cell ontology.\n", "\n", "Second, the CellArr cell corpus is a database of **all** cells, not just those which will be used in training. Cells not used in training are still used in cell search. The default behavior for the dataloader is to select for cells that can be used in training, as follows:\n", "\n", " - If the cell type annotation exists and is an ontology identifier\n", " - `total_counts` > 1000\n", " - `n_genes_by_counts` > 500\n", " - `pct_counts_mt` < 20\n", " - `predicted_doublets` == 0\n", "\n", "These criteria can be customized in the dataloader, but the above are the defaults.\n", "\n", "Lastly, this standardization of datasets makes it easier to create a CellArr object and provides expected columns for the dataloader. Feel free to add additional columns as you wish." ] }, { "cell_type": "markdown", "id": "9d1b3213-4bf4-43c2-ba3f-1fd7f19ddc0a", "metadata": {}, "source": [ "### 1.4 Select a gene space\n", "\n", "CellArr will create tiledb structures that cover the union set of all genes in all your datasets. This is usually not efficient for the model, as only a subset of genes are well represented across studies, so we filter the gene space for model training.\n", "\n", "An example of this process is shown below, utilizing the sample metadata to look at the gene vs study count distribution and select well represented genes. **You should save the gene order file so that it can be used for training.**" ] }, { "cell_type": "code", "execution_count": null, "id": "8c794227-4d6a-4991-bc66-261c0df2310e", "metadata": {}, "outputs": [], "source": [ "def get_genes_vs_studies(cellarr_path):\n", " import os\n", " import tiledb\n", " import seaborn as sns\n", " from matplotlib import pyplot as plt\n", "\n", " plt.rcParams[\"pdf.fonttype\"] = 42\n", "\n", " SAMPLEURI = \"sample_metadata\"\n", "\n", " # Get studies, one per row of a dataframe\n", " sample_tdb = tiledb.open(os.path.join(tiledb_base_path, SAMPLEURI), \"r\")\n", " sample_df = sample_tdb.query(attrs=[\"datasetID\", \"cellarr_cell_counts\"]).df[:]\n", " sample_tdb.close()\n", " \n", " df = sample_df.groupby(\"datasetID\", observed=True)[\"cellarr_cell_counts\"].max().sort_values(ascending=False).reset_index()\n", " df = df[df.index.isin(df[\"datasetID\"].drop_duplicates().index.values)]\n", " \n", " selected = []\n", " for idx, row in df.iterrows():\n", " datasetID = row[\"datasetID\"]\n", " selected.append(sample_df[sample_df[\"datasetID\"]==datasetID].index[0])\n", " print(len(selected))\n", " \n", " sample_tdb = tiledb.open(os.path.join(tiledb_base_path, SAMPLEURI), \"r\")\n", " genes_df = sample_tdb.df[sorted(selected)]\n", " sample_tdb.close()\n", " \n", " # Get all original genes from each study\n", " gene_counts = {}\n", " for idx, row in genes_df.iterrows():\n", " genes = row['cellarr_original_gene_set'].split(',')\n", " for g in genes:\n", " if g not in gene_counts:\n", " gene_counts[g] = 0\n", " gene_counts[g] += 1\n", " gene_counts = {k: v for k, v in sorted(gene_counts.items(), key=lambda item: item[1], reverse=False)}\n", " \n", " # Construct gene vs count dataframe and visualize\n", " gene = []\n", " count = []\n", " for k, v in gene_counts.items():\n", " gene.append(k)\n", " count.append(v)\n", " gene_counts_df = pd.DataFrame({\"gene\": gene, \"count\": count}) \n", " gene_counts_df = gene_counts_df.sort_values(by=\"count\", ascending=False).reset_index(drop=True)\n", " \n", " fig, ax = plt.subplots(1)\n", " sns.barplot(ax=ax, x=\"gene\", y=\"count\", data=gene_counts_df)\n", " plt.tick_params(axis='x', which='both', bottom=False, top=False, labelbottom=False)\n", " plt.tight_layout()\n", "\n", " return gene_counts_df\n", "\n", "def select_genes(gene_counts_df, min_studies, filename=\"scimilarity_gene_order.tsv\"):\n", " gene_order = gene_counts_df[gene_counts_df[\"count\"]>=min_studies][\"gene\"].values.tolist()\n", " gene_order = sorted(gene_order)\n", " with open(filename, \"w\") as f:\n", " f.write(\"\\n\".join(gene_order))" ] }, { "cell_type": "markdown", "id": "ce12b4c0-c63b-4d76-9f04-0f1d9d2b3cba", "metadata": {}, "source": [ "## 2. Training\n", "\n", "We provide an example training script at `scripts/train.py` which you can customize to your specifications.\n", "\n", "Training follows the `pytorch-lightning` paradigm:\n", " - Data module class `CellMultisetDataModule` from `tiledb_data_models`\n", " - Training module class `MetricLearning` from `training_models`\n", "\n", "Training can be resumed if interrupted by using the checkpoints logged by `pytorch-lightning`. The training script already implements this by specifying the log directory. By default, we use `tensorboard` as our logger, but you can change this to your preferred logger.\n", "\n", "Once training is complete, all files will be saved to the model directory given by the user. This includes csv files that contain the indices and metadata for cells that were used in training and validation." ] }, { "cell_type": "markdown", "id": "2a3afff8-8b1a-4c6b-ba26-a66fba324449", "metadata": {}, "source": [ "## 3. Post-training structures" ] }, { "cell_type": "markdown", "id": "c71986f0-af76-487b-b66c-d1e8a6f3c7a6", "metadata": {}, "source": [ "### 3.1 Embeddings tiledb\n", "\n", "The first post-training structure to create is the embeddings TileDB, which will be used to create kNN indices.\n", "\n", "We provide an example script at `scripts/build_embeddings.py` which embeds all cells in your CellArr objects and saves a TileDB array to `{model_path}/cellsearch/cell_embedding`." ] }, { "cell_type": "markdown", "id": "4c76e0b2-cba5-4848-98d6-00410df104f1", "metadata": {}, "source": [ "### 3.2 Annotation kNN\n", "\n", "The annotation kNN uses the embeddings TileDB and the training cells csv (`train_cells.csv.gz`) to create a `hnswlib` kNN index.\n", "\n", "We provide an example script at `scripts/build_annotation_knn.py`. By default, the kNN index file will be saved to `{model_path}/annotation/labelled_kNN.bin` and the reference labels will be saved to `{model_path}/annotation/reference_labels.tsv`." ] }, { "cell_type": "markdown", "id": "ae102661-f637-47b4-8de9-ffbe35587041", "metadata": {}, "source": [ "### 3.3 Cell search kNN\n", "\n", "The cell search kNN uses the embeddings TileDB to create a `tiledb_vector_search` kNN index for all cells in your CellArr object.\n", "\n", "We provide an example script at `scripts/build_cellsearch_knn.py`. By default, the kNN index will be saved to `{model_path}/cellsearch/full_kNN.bin`.\n", "\n", "**NOTE**: Previously, we used `hnswlib` as the cell search kNN index. But as the data size grew, we transitioned to an on-disk kNN. The `CellQuery` class still defaults to an `hnswlib` index, but if you followed the process described in this tutorial, you should initialize `CellQuery` with:\n", "\n", "`cq = CellQuery(model_path, knn_type=\"tiledb_vector_search\")`" ] }, { "cell_type": "markdown", "id": "6a2e1255-7b5b-4d54-a4c2-28133246b12f", "metadata": {}, "source": [ "### 3.4 Cell search metadata\n", "\n", "The `CellQuery` class includes a TileDB dataframe with cell metadata containing precomputed cell type predictions for all cells, as well as other metadata from the CellArr object.\n", "\n", "We provide an example script at `scripts/build_cellsearch_metadata.py` that uses the CellArr object, embeddings TileDB, and annotation kNN to construct the TileDB dataframe containing the cell metadata.\n", "\n", "This process will copy over cell metadata from the CellArr object into the model's cell metadata. Though this is somewhat inefficient in storage, we prefer the separation of the \"raw\" data contained in the CellArr object and the \"processed\" data in the model structures, as you may wish to train completely different models using the CellArr object.\n", "\n", "**NOTE**: We often apply a safelist for cell type predictions for the cell metadata to remove cell types that are too general or too specific. The script as an option to use a safelist, given as a file with one gene per line. It is up to the user to decide if they wish to use this and select the cell types they wish to safelist." ] }, { "cell_type": "markdown", "id": "ad48d52c-f9b9-4aa4-8adc-59fdd2dd7f64", "metadata": {}, "source": [ "## Conclusion\n", "\n", "At this point you will have all the structures you need to run SCimilarity.\n", "\n", "One advantage of using the CellArr structure is you can retrieve the original count data using `utils.adata_from_tiledb` based on indices in the cell metadata of the model, which is aligned with the CellArr index. This includes results from kNN cell search or from filters you may want to apply to the cell metadata (e.g., \"get all regulatory T cells\").\n", "\n", "An example of how to search and then retrieve count data is shown below." ] }, { "cell_type": "code", "execution_count": null, "id": "4a18ada6-b8e3-4edb-957f-72806227681b", "metadata": {}, "outputs": [], "source": [ "def search_then_gather_cells(model_path, cellarr_path, adata, centroid_column_name):\n", " from scimilarity import CellQuery\n", " from scimilarity.utils import adata_from_tiledb \n", "\n", " cq = CellQuery(model_path, knn_type=\"tiledb_vector_search\")\n", " centroid_embedding, nn_idxs, nn_dists, results_metadata, qc_stats = cq.search_centroid_nearest(adata, centroid_column_name)\n", " \n", " results_adata = adata_from_tiledb(results_metadata[\"index\"].values, tiledb_base_path=cellarr_path)\n", "\n", " #The above results are from the CellArr raw data, so augment the obs with `results_metadata` \n", " for c in results_metadata.columns:\n", " if c not in results_adata.obs:\n", " results_adata.obs[c] = results_metadata[c].values\n", " return results_adata" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.14" } }, "nbformat": 4, "nbformat_minor": 5 }