Predicting Gene Expression with Decima

Decima allows prediction of gene expression at the cell type level, and this tutorial demonstrates how to leverage the prediction API for both genes in the training data and custom genes.

Precomputed Predictions

Predicted expression levels (log (CPM + 1)) for all genes in the training data have been precomputed and saved to the metadata h5ad object for each model replicate and are available under the DecimaResult class. The predicted_expression_matrix class returns the predicted average gene expression across the replicates.

from decima import DecimaResult

result = DecimaResult.load()
result.predicted_expression_matrix()
STRADA ETV4 USP25 ZSWIM5 C21orf58 MIR497HG CFAP74 GSE1 LPP CLK1 ... STRIP2 TNFRSF1A RBM14-RBM4 C1orf21 LINC00639 NPDC1 ZNF425 COL5A1 BRD3 EVI5L
agg_0 2.973438 1.845565 4.592531 5.099802 1.774879 0.356812 2.590836 4.629774 4.897171 3.326940 ... 2.836060 0.297015 1.883849 4.293593 1.463565 3.183534 2.340202 2.374942 2.911916 3.230072
agg_1 2.954213 1.896726 4.688557 5.510440 1.666929 0.352725 2.292625 4.459535 4.915286 3.192858 ... 3.125704 0.242543 1.908177 4.439424 1.236739 3.494824 2.425672 2.054568 2.713408 3.491463
agg_2 2.938851 2.197247 4.861410 5.617520 1.773381 0.380867 2.394917 4.415038 4.836399 3.390717 ... 3.082098 0.263285 2.006456 4.383455 1.208590 4.013819 2.408381 2.297343 2.892222 3.695785
agg_3 3.045972 2.138573 4.863791 5.273604 1.760097 0.463555 2.391702 3.940975 4.857763 3.410926 ... 2.882890 0.290327 1.922963 4.550189 1.430520 3.693118 2.297103 2.121887 2.626117 3.223912
agg_4 3.025518 2.019096 4.602948 5.257001 1.755338 0.382190 2.432810 4.392480 4.959488 3.250500 ... 3.082296 0.258540 2.038277 4.464807 1.249043 3.665800 2.400820 2.255862 2.925619 3.471005
... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
agg_9533 2.333562 0.633322 4.675825 2.793023 0.752030 0.692083 0.503531 4.327948 6.903193 3.695593 ... 0.549795 2.270181 1.563218 4.395422 0.550088 1.330252 1.044471 3.759369 2.491346 1.872717
agg_9535 0.835037 0.358773 1.964896 0.307449 0.337240 0.834196 0.093885 1.853794 3.700790 4.467302 ... 0.176885 1.370898 1.022708 3.400267 0.052162 1.908870 0.253417 1.448111 1.622033 1.064292
agg_9536 3.008039 1.209324 4.798392 3.931870 1.401328 1.638555 0.969720 4.779201 6.631931 4.127797 ... 1.174298 1.870530 2.506874 5.151776 0.967644 1.809947 2.205356 4.244005 2.974467 2.659873
agg_9537 1.241936 0.455059 2.919995 0.571672 0.486448 1.175586 0.145397 2.412148 4.759118 4.913945 ... 0.371035 1.361073 1.668085 4.005738 0.078611 1.571750 0.508187 2.067150 2.323764 1.429850
agg_9538 1.715507 0.700955 3.044732 0.858696 0.903406 1.763168 0.215304 2.604478 4.549708 4.839124 ... 0.594310 1.801298 2.075996 3.933860 0.165590 1.970268 0.993521 2.232347 2.473388 1.902884

8856 rows × 18457 columns

To access the predicted expression matrix for a specific model, you can use the model_name parameter. In this example, we obtain the predicted gene expression for first model replicate.

result.predicted_expression_matrix(model_name="v1_rep0")
STRADA ETV4 USP25 ZSWIM5 C21orf58 MIR497HG CFAP74 GSE1 LPP CLK1 ... STRIP2 TNFRSF1A RBM14-RBM4 C1orf21 LINC00639 NPDC1 ZNF425 COL5A1 BRD3 EVI5L
agg_0 2.973438 1.845565 4.592531 5.099802 1.774879 0.356812 2.590836 4.629774 4.897171 3.326940 ... 2.836060 0.297015 1.883849 4.293593 1.463565 3.183534 2.340202 2.374942 2.911916 3.230072
agg_1 2.954213 1.896726 4.688557 5.510440 1.666929 0.352725 2.292625 4.459535 4.915286 3.192858 ... 3.125704 0.242543 1.908177 4.439424 1.236739 3.494824 2.425672 2.054568 2.713408 3.491463
agg_2 2.938851 2.197247 4.861410 5.617520 1.773381 0.380867 2.394917 4.415038 4.836399 3.390717 ... 3.082098 0.263285 2.006456 4.383455 1.208590 4.013819 2.408381 2.297343 2.892222 3.695785
agg_3 3.045972 2.138573 4.863791 5.273604 1.760097 0.463555 2.391702 3.940975 4.857763 3.410926 ... 2.882890 0.290327 1.922963 4.550189 1.430520 3.693118 2.297103 2.121887 2.626117 3.223912
agg_4 3.025518 2.019096 4.602948 5.257001 1.755338 0.382190 2.432810 4.392480 4.959488 3.250500 ... 3.082296 0.258540 2.038277 4.464807 1.249043 3.665800 2.400820 2.255862 2.925619 3.471005
... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
agg_9533 2.333562 0.633322 4.675825 2.793023 0.752030 0.692083 0.503531 4.327948 6.903193 3.695593 ... 0.549795 2.270181 1.563218 4.395422 0.550088 1.330252 1.044471 3.759369 2.491346 1.872717
agg_9535 0.835037 0.358773 1.964896 0.307449 0.337240 0.834196 0.093885 1.853794 3.700790 4.467302 ... 0.176885 1.370898 1.022708 3.400267 0.052162 1.908870 0.253417 1.448111 1.622033 1.064292
agg_9536 3.008039 1.209324 4.798392 3.931870 1.401328 1.638555 0.969720 4.779201 6.631931 4.127797 ... 1.174298 1.870530 2.506874 5.151776 0.967644 1.809947 2.205356 4.244005 2.974467 2.659873
agg_9537 1.241936 0.455059 2.919995 0.571672 0.486448 1.175586 0.145397 2.412148 4.759118 4.913945 ... 0.371035 1.361073 1.668085 4.005738 0.078611 1.571750 0.508187 2.067150 2.323764 1.429850
agg_9538 1.715507 0.700955 3.044732 0.858696 0.903406 1.763168 0.215304 2.604478 4.549708 4.839124 ... 0.594310 1.801298 2.075996 3.933860 0.165590 1.970268 0.993521 2.232347 2.473388 1.902884

8856 rows × 18457 columns

and for the second model replicate.

result.predicted_expression_matrix(model_name="v1_rep1")
STRADA ETV4 USP25 ZSWIM5 C21orf58 MIR497HG CFAP74 GSE1 LPP CLK1 ... STRIP2 TNFRSF1A RBM14-RBM4 C1orf21 LINC00639 NPDC1 ZNF425 COL5A1 BRD3 EVI5L
agg_0 2.973438 1.845565 4.592531 5.099802 1.774879 0.356812 2.590836 4.629774 4.897171 3.326940 ... 2.836060 0.297015 1.883849 4.293593 1.463565 3.183534 2.340202 2.374942 2.911916 3.230072
agg_1 2.954213 1.896726 4.688557 5.510440 1.666929 0.352725 2.292625 4.459535 4.915286 3.192858 ... 3.125704 0.242543 1.908177 4.439424 1.236739 3.494824 2.425672 2.054568 2.713408 3.491463
agg_2 2.938851 2.197247 4.861410 5.617520 1.773381 0.380867 2.394917 4.415038 4.836399 3.390717 ... 3.082098 0.263285 2.006456 4.383455 1.208590 4.013819 2.408381 2.297343 2.892222 3.695785
agg_3 3.045972 2.138573 4.863791 5.273604 1.760097 0.463555 2.391702 3.940975 4.857763 3.410926 ... 2.882890 0.290327 1.922963 4.550189 1.430520 3.693118 2.297103 2.121887 2.626117 3.223912
agg_4 3.025518 2.019096 4.602948 5.257001 1.755338 0.382190 2.432810 4.392480 4.959488 3.250500 ... 3.082296 0.258540 2.038277 4.464807 1.249043 3.665800 2.400820 2.255862 2.925619 3.471005
... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
agg_9533 2.333562 0.633322 4.675825 2.793023 0.752030 0.692083 0.503531 4.327948 6.903193 3.695593 ... 0.549795 2.270181 1.563218 4.395422 0.550088 1.330252 1.044471 3.759369 2.491346 1.872717
agg_9535 0.835037 0.358773 1.964896 0.307449 0.337240 0.834196 0.093885 1.853794 3.700790 4.467302 ... 0.176885 1.370898 1.022708 3.400267 0.052162 1.908870 0.253417 1.448111 1.622033 1.064292
agg_9536 3.008039 1.209324 4.798392 3.931870 1.401328 1.638555 0.969720 4.779201 6.631931 4.127797 ... 1.174298 1.870530 2.506874 5.151776 0.967644 1.809947 2.205356 4.244005 2.974467 2.659873
agg_9537 1.241936 0.455059 2.919995 0.571672 0.486448 1.175586 0.145397 2.412148 4.759118 4.913945 ... 0.371035 1.361073 1.668085 4.005738 0.078611 1.571750 0.508187 2.067150 2.323764 1.429850
agg_9538 1.715507 0.700955 3.044732 0.858696 0.903406 1.763168 0.215304 2.604478 4.549708 4.839124 ... 0.594310 1.801298 2.075996 3.933860 0.165590 1.970268 0.993521 2.232347 2.473388 1.902884

8856 rows × 18457 columns

result.anndata.layers
Layers with keys: preds, v1_rep0, v1_rep1, v1_rep2, v1_rep3

Remember, to see the metadata for the various output tracks:

result.cell_metadata
cell_type tissue organ disease study dataset region subregion celltype_coarse n_cells total_counts n_genes size_factor train_pearson val_pearson test_pearson
agg_0 Amygdala excitatory Amygdala_Amygdala CNS healthy jhpce#tran2021 brain_atlas Amygdala Amygdala NaN 331 1.592883e+07 17000 41431.465186 0.942459 0.841377 0.865640
agg_1 Amygdala excitatory Amygdala_Basolateral nuclear group (BLN) - lat... CNS healthy SCR_016152 brain_atlas Amygdala Basolateral nuclear group (BLN) - lateral nucl... NaN 11369 2.952133e+08 18080 40765.341481 0.943098 0.838936 0.861092
agg_2 Amygdala excitatory Amygdala_Bed nucleus of stria terminalis and n... CNS healthy SCR_016152 brain_atlas Amygdala Bed nucleus of stria terminalis and nearby - BNST NaN 139 2.593231e+06 15418 42556.387020 0.952170 0.854544 0.866654
agg_3 Amygdala excitatory Amygdala_Central nuclear group - CEN CNS healthy SCR_016152 brain_atlas Amygdala Central nuclear group - CEN NaN 3892 9.946371e+07 17959 42884.641430 0.959744 0.863585 0.881554
agg_4 Amygdala excitatory Amygdala_Corticomedial nuclear group (CMN) - a... CNS healthy SCR_016152 brain_atlas Amygdala Corticomedial nuclear group (CMN) - anterior c... NaN 2945 1.281619e+08 17885 41816.741933 0.951365 0.854304 0.868902
... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
agg_9533 vascular associated smooth muscle cell upper lobe of right lung lung NA ENCODE scimilarity nan nan NaN 21 3.483375e+04 8515 35404.911768 0.735213 0.665647 0.654491
agg_9535 vascular associated smooth muscle cell urinary bladder urinary healthy GSE129845 scimilarity nan nan NaN 24 8.498500e+04 7337 26189.415789 0.809852 0.690022 0.656160
agg_9536 vascular associated smooth muscle cell uterus uterus NA ENCODE scimilarity nan nan NaN 272 5.700762e+05 14769 44938.403867 0.915329 0.808941 0.839993
agg_9537 vascular associated smooth muscle cell uterus uterus healthy e5f58829-1a66-40b5-a624-9046778e74f5 scimilarity nan nan NaN 472 1.089170e+07 14514 30145.422152 0.852339 0.717682 0.727469
agg_9538 vascular associated smooth muscle cell vasculature vasculature healthy e5f58829-1a66-40b5-a624-9046778e74f5 scimilarity nan nan NaN 1853 5.992697e+07 16764 36464.273371 0.909855 0.780413 0.796351

8856 rows × 16 columns

CLI API

If you want to perform gene expression prediction again, rather than using the precomputed scores, you can use the Decima command-line interface (CLI) to generate new predictions for any set of genes you specify. For example, you can run the decima predict-genes command with the --genes argument to provide a comma-separated list of gene names (such as “STRADA,ETV4,USP25”) if no gene provided it will perform expression predictions for all genes, select the prediction model with the --model option (for instance, “ensemble” or a specific replicate like “0”), and use --save-replicates to save predictions for each replicate. The -o flag lets you specify the output file path for the predictions in .h5ad format.

! decima predict-genes --genes "STRADA,ETV4,USP25" --model ensemble --save-replicates -o example/predict_genes/predictions.h5ad 
decima - INFO - Using device: 0 and genome: hg38 for prediction.
decima - INFO - Loading model ensemble...
Warning: You are sending unauthenticated requests to the HF Hub. Please set a HF_TOKEN to enable higher rate limits and faster downloads.
rep2.safetensors: 100%|██████████████████████| 755M/755M [00:08<00:00, 85.4MB/s]
rep3.safetensors: 100%|██████████████████████| 755M/755M [00:07<00:00, 99.0MB/s]
decima - INFO - Making predictions
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
💡 Tip: For seamless cloud logging and experiment tracking, try installing [litlogger](https://pypi.org/project/litlogger/) to enable LitLogger, which logs metrics and artifacts automatically to the Lightning Experiments platform.
💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]
Predicting ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 3/3 0:00:01 • 0:00:00 2.24it/s [2;4m2.24it/s 
?25h/home/lala8/miniforge3/envs/decima/lib/python3.11/site-packages/torchmetrics/utilities/prints.py:43: UserWarning: The ``compute`` method of metric WarningCounter was called before the ``update`` method which may lead to errors, as metric states have not yet been updated.
decima - INFO - Creating anndata
decima - INFO - Evaluating performance
decima - WARNING - No ground truth expression matrix found in the metadata. Skipping evaluation.

After running this command, you can load the resulting predictions in Python using the DecimaResult class and access the predicted expression matrix as:

result = DecimaResult.load("example/predict_genes/predictions.h5ad")
result.predicted_expression_matrix()
STRADA ETV4 USP25
agg_0 3.060397 2.881919 3.469221
agg_1 3.117476 2.825525 3.596472
agg_2 3.156399 3.122201 3.718817
agg_3 3.213858 3.204461 3.629969
agg_4 3.103418 3.031820 3.512249
... ... ... ...
agg_9533 2.313128 2.173468 3.156261
agg_9535 0.952004 0.956674 1.250045
agg_9536 2.779305 2.705866 3.530668
agg_9537 1.342528 1.407362 1.867149
agg_9538 1.744911 1.633233 2.083091

8856 rows × 3 columns

or for a specific replicate:

result.predicted_expression_matrix(model_name="preds_v1_rep0")
STRADA ETV4 USP25
agg_0 2.932824 2.892707 2.858933
agg_1 2.816676 2.812716 3.058435
agg_2 2.742964 2.971174 2.950304
agg_3 2.804557 3.346401 2.837452
agg_4 2.815957 3.088338 2.973126
... ... ... ...
agg_9533 2.510271 2.760181 1.965718
agg_9535 1.245795 1.192324 0.407210
agg_9536 2.808904 3.369778 2.927307
agg_9537 1.580065 1.533221 0.879286
agg_9538 2.013909 2.077899 1.374621

8856 rows × 3 columns

Python API

The same functionality is available through the Python API, allowing you to perform gene expression prediction programmatically. You can specify the genes, model, and other options directly in your Python code using the provided classes and functions.

from decima.tools.inference import predict_gene_expression

ad = predict_gene_expression(
    genes=["STRADA", "ETV4", "USP25"],
    model="ensemble",
)
Predicting ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 3/3 0:00:01 • 0:00:00 2.24it/s  

/home/lala8/miniforge3/envs/decima/lib/python3.11/site-packages/torchmetrics/utilities/prints.py:43: UserWarning: The ``compute`` method of metric WarningCounter was called before the ``update`` method which may lead to errors, as metric states have not yet been updated.
No ground truth expression matrix found in the metadata. Skipping evaluation.

Developer API

Under the hood, the Decima prediction API uses the GeneDataset and SeqDataset pytorch Dataset classes to prepare the data for prediction. These classes provide a flexible way to handle different types of input data, including custom genes and DNA sequences. Internally, these datasets represent sequences using one-hot encoding and apply a gene mask to indicate which positions correspond to gene regions.

For example, you can create a GeneDataset object to predict expression for the genes in your metadata. The predict_on_dataset method returns a dictionary containing the predicted expression values and if there is any warnings.

from pprint import pprint
from decima.data.dataset import GeneDataset
from decima.hub import load_decima_model

model = load_decima_model("rep0", device=0)
ds = GeneDataset(genes=["STRADA", "ETV4", "USP25"])

preds = model.predict_on_dataset(ds, device=0)
pprint(preds)
Predicting ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1/1 0:00:00 • 0:00:00 0.00it/s  

{'expression': array([[2.932824  , 2.8166761 , 2.7429636 , ..., 2.8089042 , 1.5800648 ,
        2.0139086 ],
       [2.0204957 , 1.9390392 , 2.3947737 , ..., 1.5031691 , 0.5186505 ,
        0.80847657],
       [4.795655  , 4.9177384 , 4.8059716 , ..., 4.6602583 , 2.7780144 ,
        3.1483455 ]], dtype=float32),
 'warnings': Counter({'unknown': tensor(0),
                      'allele_mismatch_with_reference_genome': tensor(0)})}
/home/lala8/miniforge3/envs/decima/lib/python3.11/site-packages/torchmetrics/utilities/prints.py:43: UserWarning: The ``compute`` method of metric WarningCounter was called before the ``update`` method which may lead to errors, as metric states have not yet been updated.

Custom Expression for custom genes

If you have custom genes, you can create a SeqDataset object to predict expression for those genes.

To do this, prepare a FASTA file where:

  • Each sequence is exactly the Decima context size (524,288 bases).

  • The FASTA header for each sequence must include the gene name and the gene mask coordinates, using the format: >gene_name|gene_mask_start=X|gene_mask_end=Y where X and Y specify the start and end positions (0-based, inclusive) of the gene region within the sequence. The gene mask indicates which region of the sequence corresponds to the gene for which expression will be predicted.

For example, seqs.fasta contains these information:

! cat ../../tests/data/seqs.fasta | cut -c 1-200
>CD68|gene_mask_start=163840|gene_mask_end=166460
CTCTGCAGAGAGCGAGGACGGTGTGTCTGCCAGCGCCTTTGACTTCACTGTCTCCAACTTTGTGGACAACCTGTATGGCTACCCGGAAGGCAAGGATGTGCTTCGGGAGACCATCAAGTTTATGTACACAGACTGGGCCGACCGGGACAATGGCGAAATGCGCCGCAAAACCCTGCTGGCGCTCTTTACTGACCACCAAT
>SPI1|gene_mask_start=163840|gene_mask_end=187556
TGCCACTTTTAGATATGTTCATGGGTGCAGATACGGCTTTATTTATTTGAGACAGAGTTTCACTCTTGTTGCCCAGGGTGGAGTGCAGTGGTGCGATCTCAGCTCACTGCAGCCTTCGCCTCCCGGGTTGAAGCGATTCTTCTGCCTCAACCTCGAGTAGCTGGGATTATAGGCACCTGCCAGCATGCCTGGCTAATTTT
from decima.data.dataset import SeqDataset


ds = SeqDataset.from_fasta("../../tests/data/seqs.fasta")

preds = model.predict_on_dataset(ds, device=0)
pprint(preds["expression"])
Predicting ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1/1 0:00:00 • 0:00:00 0.00it/s  

array([[0.22018507, 0.2193921 , 0.23078808, ..., 1.0356065 , 1.2123175 ,
        1.4411153 ],
       [0.2651469 , 0.1334098 , 0.14885235, ..., 0.44431016, 0.26546577,
        0.304252  ]], dtype=float32)

See the documentation of SeqDataset for more details. SeqDataset can be created from a pandas DataFrame with the following columns: seq, gene_mask_start, gene_mask_end, and gene_name with SeqDataset.from_dataframe or from a one-hot encoded tensor with SeqDataset.from_one_hot.