Predicting Gene Expression with Decima¶
Decima allows prediction of gene expression at the cell type level, and this tutorial demonstrates how to leverage the prediction API for both genes in the training data and custom genes.
Precomputed Predictions¶
Predicted expression levels (log (CPM + 1)) for all genes in the training data have been precomputed and saved to the metadata h5ad object for each model replicate and are available under the DecimaResult class. The predicted_expression_matrix class returns the predicted average gene expression across the replicates.
from decima import DecimaResult
result = DecimaResult.load()
result.predicted_expression_matrix()
| STRADA | ETV4 | USP25 | ZSWIM5 | C21orf58 | MIR497HG | CFAP74 | GSE1 | LPP | CLK1 | ... | STRIP2 | TNFRSF1A | RBM14-RBM4 | C1orf21 | LINC00639 | NPDC1 | ZNF425 | COL5A1 | BRD3 | EVI5L | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| agg_0 | 2.973438 | 1.845565 | 4.592531 | 5.099802 | 1.774879 | 0.356812 | 2.590836 | 4.629774 | 4.897171 | 3.326940 | ... | 2.836060 | 0.297015 | 1.883849 | 4.293593 | 1.463565 | 3.183534 | 2.340202 | 2.374942 | 2.911916 | 3.230072 |
| agg_1 | 2.954213 | 1.896726 | 4.688557 | 5.510440 | 1.666929 | 0.352725 | 2.292625 | 4.459535 | 4.915286 | 3.192858 | ... | 3.125704 | 0.242543 | 1.908177 | 4.439424 | 1.236739 | 3.494824 | 2.425672 | 2.054568 | 2.713408 | 3.491463 |
| agg_2 | 2.938851 | 2.197247 | 4.861410 | 5.617520 | 1.773381 | 0.380867 | 2.394917 | 4.415038 | 4.836399 | 3.390717 | ... | 3.082098 | 0.263285 | 2.006456 | 4.383455 | 1.208590 | 4.013819 | 2.408381 | 2.297343 | 2.892222 | 3.695785 |
| agg_3 | 3.045972 | 2.138573 | 4.863791 | 5.273604 | 1.760097 | 0.463555 | 2.391702 | 3.940975 | 4.857763 | 3.410926 | ... | 2.882890 | 0.290327 | 1.922963 | 4.550189 | 1.430520 | 3.693118 | 2.297103 | 2.121887 | 2.626117 | 3.223912 |
| agg_4 | 3.025518 | 2.019096 | 4.602948 | 5.257001 | 1.755338 | 0.382190 | 2.432810 | 4.392480 | 4.959488 | 3.250500 | ... | 3.082296 | 0.258540 | 2.038277 | 4.464807 | 1.249043 | 3.665800 | 2.400820 | 2.255862 | 2.925619 | 3.471005 |
| ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
| agg_9533 | 2.333562 | 0.633322 | 4.675825 | 2.793023 | 0.752030 | 0.692083 | 0.503531 | 4.327948 | 6.903193 | 3.695593 | ... | 0.549795 | 2.270181 | 1.563218 | 4.395422 | 0.550088 | 1.330252 | 1.044471 | 3.759369 | 2.491346 | 1.872717 |
| agg_9535 | 0.835037 | 0.358773 | 1.964896 | 0.307449 | 0.337240 | 0.834196 | 0.093885 | 1.853794 | 3.700790 | 4.467302 | ... | 0.176885 | 1.370898 | 1.022708 | 3.400267 | 0.052162 | 1.908870 | 0.253417 | 1.448111 | 1.622033 | 1.064292 |
| agg_9536 | 3.008039 | 1.209324 | 4.798392 | 3.931870 | 1.401328 | 1.638555 | 0.969720 | 4.779201 | 6.631931 | 4.127797 | ... | 1.174298 | 1.870530 | 2.506874 | 5.151776 | 0.967644 | 1.809947 | 2.205356 | 4.244005 | 2.974467 | 2.659873 |
| agg_9537 | 1.241936 | 0.455059 | 2.919995 | 0.571672 | 0.486448 | 1.175586 | 0.145397 | 2.412148 | 4.759118 | 4.913945 | ... | 0.371035 | 1.361073 | 1.668085 | 4.005738 | 0.078611 | 1.571750 | 0.508187 | 2.067150 | 2.323764 | 1.429850 |
| agg_9538 | 1.715507 | 0.700955 | 3.044732 | 0.858696 | 0.903406 | 1.763168 | 0.215304 | 2.604478 | 4.549708 | 4.839124 | ... | 0.594310 | 1.801298 | 2.075996 | 3.933860 | 0.165590 | 1.970268 | 0.993521 | 2.232347 | 2.473388 | 1.902884 |
8856 rows × 18457 columns
To access the predicted expression matrix for a specific model, you can use the model_name parameter. In this example, we obtain the predicted gene expression for first model replicate.
result.predicted_expression_matrix(model_name="v1_rep0")
| STRADA | ETV4 | USP25 | ZSWIM5 | C21orf58 | MIR497HG | CFAP74 | GSE1 | LPP | CLK1 | ... | STRIP2 | TNFRSF1A | RBM14-RBM4 | C1orf21 | LINC00639 | NPDC1 | ZNF425 | COL5A1 | BRD3 | EVI5L | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| agg_0 | 2.973438 | 1.845565 | 4.592531 | 5.099802 | 1.774879 | 0.356812 | 2.590836 | 4.629774 | 4.897171 | 3.326940 | ... | 2.836060 | 0.297015 | 1.883849 | 4.293593 | 1.463565 | 3.183534 | 2.340202 | 2.374942 | 2.911916 | 3.230072 |
| agg_1 | 2.954213 | 1.896726 | 4.688557 | 5.510440 | 1.666929 | 0.352725 | 2.292625 | 4.459535 | 4.915286 | 3.192858 | ... | 3.125704 | 0.242543 | 1.908177 | 4.439424 | 1.236739 | 3.494824 | 2.425672 | 2.054568 | 2.713408 | 3.491463 |
| agg_2 | 2.938851 | 2.197247 | 4.861410 | 5.617520 | 1.773381 | 0.380867 | 2.394917 | 4.415038 | 4.836399 | 3.390717 | ... | 3.082098 | 0.263285 | 2.006456 | 4.383455 | 1.208590 | 4.013819 | 2.408381 | 2.297343 | 2.892222 | 3.695785 |
| agg_3 | 3.045972 | 2.138573 | 4.863791 | 5.273604 | 1.760097 | 0.463555 | 2.391702 | 3.940975 | 4.857763 | 3.410926 | ... | 2.882890 | 0.290327 | 1.922963 | 4.550189 | 1.430520 | 3.693118 | 2.297103 | 2.121887 | 2.626117 | 3.223912 |
| agg_4 | 3.025518 | 2.019096 | 4.602948 | 5.257001 | 1.755338 | 0.382190 | 2.432810 | 4.392480 | 4.959488 | 3.250500 | ... | 3.082296 | 0.258540 | 2.038277 | 4.464807 | 1.249043 | 3.665800 | 2.400820 | 2.255862 | 2.925619 | 3.471005 |
| ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
| agg_9533 | 2.333562 | 0.633322 | 4.675825 | 2.793023 | 0.752030 | 0.692083 | 0.503531 | 4.327948 | 6.903193 | 3.695593 | ... | 0.549795 | 2.270181 | 1.563218 | 4.395422 | 0.550088 | 1.330252 | 1.044471 | 3.759369 | 2.491346 | 1.872717 |
| agg_9535 | 0.835037 | 0.358773 | 1.964896 | 0.307449 | 0.337240 | 0.834196 | 0.093885 | 1.853794 | 3.700790 | 4.467302 | ... | 0.176885 | 1.370898 | 1.022708 | 3.400267 | 0.052162 | 1.908870 | 0.253417 | 1.448111 | 1.622033 | 1.064292 |
| agg_9536 | 3.008039 | 1.209324 | 4.798392 | 3.931870 | 1.401328 | 1.638555 | 0.969720 | 4.779201 | 6.631931 | 4.127797 | ... | 1.174298 | 1.870530 | 2.506874 | 5.151776 | 0.967644 | 1.809947 | 2.205356 | 4.244005 | 2.974467 | 2.659873 |
| agg_9537 | 1.241936 | 0.455059 | 2.919995 | 0.571672 | 0.486448 | 1.175586 | 0.145397 | 2.412148 | 4.759118 | 4.913945 | ... | 0.371035 | 1.361073 | 1.668085 | 4.005738 | 0.078611 | 1.571750 | 0.508187 | 2.067150 | 2.323764 | 1.429850 |
| agg_9538 | 1.715507 | 0.700955 | 3.044732 | 0.858696 | 0.903406 | 1.763168 | 0.215304 | 2.604478 | 4.549708 | 4.839124 | ... | 0.594310 | 1.801298 | 2.075996 | 3.933860 | 0.165590 | 1.970268 | 0.993521 | 2.232347 | 2.473388 | 1.902884 |
8856 rows × 18457 columns
and for the second model replicate.
result.predicted_expression_matrix(model_name="v1_rep1")
| STRADA | ETV4 | USP25 | ZSWIM5 | C21orf58 | MIR497HG | CFAP74 | GSE1 | LPP | CLK1 | ... | STRIP2 | TNFRSF1A | RBM14-RBM4 | C1orf21 | LINC00639 | NPDC1 | ZNF425 | COL5A1 | BRD3 | EVI5L | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| agg_0 | 2.973438 | 1.845565 | 4.592531 | 5.099802 | 1.774879 | 0.356812 | 2.590836 | 4.629774 | 4.897171 | 3.326940 | ... | 2.836060 | 0.297015 | 1.883849 | 4.293593 | 1.463565 | 3.183534 | 2.340202 | 2.374942 | 2.911916 | 3.230072 |
| agg_1 | 2.954213 | 1.896726 | 4.688557 | 5.510440 | 1.666929 | 0.352725 | 2.292625 | 4.459535 | 4.915286 | 3.192858 | ... | 3.125704 | 0.242543 | 1.908177 | 4.439424 | 1.236739 | 3.494824 | 2.425672 | 2.054568 | 2.713408 | 3.491463 |
| agg_2 | 2.938851 | 2.197247 | 4.861410 | 5.617520 | 1.773381 | 0.380867 | 2.394917 | 4.415038 | 4.836399 | 3.390717 | ... | 3.082098 | 0.263285 | 2.006456 | 4.383455 | 1.208590 | 4.013819 | 2.408381 | 2.297343 | 2.892222 | 3.695785 |
| agg_3 | 3.045972 | 2.138573 | 4.863791 | 5.273604 | 1.760097 | 0.463555 | 2.391702 | 3.940975 | 4.857763 | 3.410926 | ... | 2.882890 | 0.290327 | 1.922963 | 4.550189 | 1.430520 | 3.693118 | 2.297103 | 2.121887 | 2.626117 | 3.223912 |
| agg_4 | 3.025518 | 2.019096 | 4.602948 | 5.257001 | 1.755338 | 0.382190 | 2.432810 | 4.392480 | 4.959488 | 3.250500 | ... | 3.082296 | 0.258540 | 2.038277 | 4.464807 | 1.249043 | 3.665800 | 2.400820 | 2.255862 | 2.925619 | 3.471005 |
| ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
| agg_9533 | 2.333562 | 0.633322 | 4.675825 | 2.793023 | 0.752030 | 0.692083 | 0.503531 | 4.327948 | 6.903193 | 3.695593 | ... | 0.549795 | 2.270181 | 1.563218 | 4.395422 | 0.550088 | 1.330252 | 1.044471 | 3.759369 | 2.491346 | 1.872717 |
| agg_9535 | 0.835037 | 0.358773 | 1.964896 | 0.307449 | 0.337240 | 0.834196 | 0.093885 | 1.853794 | 3.700790 | 4.467302 | ... | 0.176885 | 1.370898 | 1.022708 | 3.400267 | 0.052162 | 1.908870 | 0.253417 | 1.448111 | 1.622033 | 1.064292 |
| agg_9536 | 3.008039 | 1.209324 | 4.798392 | 3.931870 | 1.401328 | 1.638555 | 0.969720 | 4.779201 | 6.631931 | 4.127797 | ... | 1.174298 | 1.870530 | 2.506874 | 5.151776 | 0.967644 | 1.809947 | 2.205356 | 4.244005 | 2.974467 | 2.659873 |
| agg_9537 | 1.241936 | 0.455059 | 2.919995 | 0.571672 | 0.486448 | 1.175586 | 0.145397 | 2.412148 | 4.759118 | 4.913945 | ... | 0.371035 | 1.361073 | 1.668085 | 4.005738 | 0.078611 | 1.571750 | 0.508187 | 2.067150 | 2.323764 | 1.429850 |
| agg_9538 | 1.715507 | 0.700955 | 3.044732 | 0.858696 | 0.903406 | 1.763168 | 0.215304 | 2.604478 | 4.549708 | 4.839124 | ... | 0.594310 | 1.801298 | 2.075996 | 3.933860 | 0.165590 | 1.970268 | 0.993521 | 2.232347 | 2.473388 | 1.902884 |
8856 rows × 18457 columns
result.anndata.layers
Layers with keys: preds, v1_rep0, v1_rep1, v1_rep2, v1_rep3
Remember, to see the metadata for the various output tracks:
result.cell_metadata
| cell_type | tissue | organ | disease | study | dataset | region | subregion | celltype_coarse | n_cells | total_counts | n_genes | size_factor | train_pearson | val_pearson | test_pearson | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| agg_0 | Amygdala excitatory | Amygdala_Amygdala | CNS | healthy | jhpce#tran2021 | brain_atlas | Amygdala | Amygdala | NaN | 331 | 1.592883e+07 | 17000 | 41431.465186 | 0.942459 | 0.841377 | 0.865640 |
| agg_1 | Amygdala excitatory | Amygdala_Basolateral nuclear group (BLN) - lat... | CNS | healthy | SCR_016152 | brain_atlas | Amygdala | Basolateral nuclear group (BLN) - lateral nucl... | NaN | 11369 | 2.952133e+08 | 18080 | 40765.341481 | 0.943098 | 0.838936 | 0.861092 |
| agg_2 | Amygdala excitatory | Amygdala_Bed nucleus of stria terminalis and n... | CNS | healthy | SCR_016152 | brain_atlas | Amygdala | Bed nucleus of stria terminalis and nearby - BNST | NaN | 139 | 2.593231e+06 | 15418 | 42556.387020 | 0.952170 | 0.854544 | 0.866654 |
| agg_3 | Amygdala excitatory | Amygdala_Central nuclear group - CEN | CNS | healthy | SCR_016152 | brain_atlas | Amygdala | Central nuclear group - CEN | NaN | 3892 | 9.946371e+07 | 17959 | 42884.641430 | 0.959744 | 0.863585 | 0.881554 |
| agg_4 | Amygdala excitatory | Amygdala_Corticomedial nuclear group (CMN) - a... | CNS | healthy | SCR_016152 | brain_atlas | Amygdala | Corticomedial nuclear group (CMN) - anterior c... | NaN | 2945 | 1.281619e+08 | 17885 | 41816.741933 | 0.951365 | 0.854304 | 0.868902 |
| ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
| agg_9533 | vascular associated smooth muscle cell | upper lobe of right lung | lung | NA | ENCODE | scimilarity | nan | nan | NaN | 21 | 3.483375e+04 | 8515 | 35404.911768 | 0.735213 | 0.665647 | 0.654491 |
| agg_9535 | vascular associated smooth muscle cell | urinary bladder | urinary | healthy | GSE129845 | scimilarity | nan | nan | NaN | 24 | 8.498500e+04 | 7337 | 26189.415789 | 0.809852 | 0.690022 | 0.656160 |
| agg_9536 | vascular associated smooth muscle cell | uterus | uterus | NA | ENCODE | scimilarity | nan | nan | NaN | 272 | 5.700762e+05 | 14769 | 44938.403867 | 0.915329 | 0.808941 | 0.839993 |
| agg_9537 | vascular associated smooth muscle cell | uterus | uterus | healthy | e5f58829-1a66-40b5-a624-9046778e74f5 | scimilarity | nan | nan | NaN | 472 | 1.089170e+07 | 14514 | 30145.422152 | 0.852339 | 0.717682 | 0.727469 |
| agg_9538 | vascular associated smooth muscle cell | vasculature | vasculature | healthy | e5f58829-1a66-40b5-a624-9046778e74f5 | scimilarity | nan | nan | NaN | 1853 | 5.992697e+07 | 16764 | 36464.273371 | 0.909855 | 0.780413 | 0.796351 |
8856 rows × 16 columns
CLI API¶
If you want to perform gene expression prediction again, rather than using the precomputed scores, you can use the Decima command-line interface (CLI) to generate new predictions for any set of genes you specify. For example, you can run the decima predict-genes command with the --genes argument to provide a comma-separated list of gene names (such as “STRADA,ETV4,USP25”) if no gene provided it will perform expression predictions for all genes, select the prediction model with the --model option (for instance, “ensemble” or a specific replicate like “0”), and use --save-replicates to save predictions for each replicate. The -o flag lets you specify the output file path for the predictions in .h5ad format.
! decima predict-genes --genes "STRADA,ETV4,USP25" --model ensemble --save-replicates -o example/predict_genes/predictions.h5ad
decima - INFO - Using device: 0 and genome: hg38 for prediction.
decima - INFO - Loading model ensemble...
Warning: You are sending unauthenticated requests to the HF Hub. Please set a HF_TOKEN to enable higher rate limits and faster downloads.
rep2.safetensors: 100%|██████████████████████| 755M/755M [00:08<00:00, 85.4MB/s]
rep3.safetensors: 100%|██████████████████████| 755M/755M [00:07<00:00, 99.0MB/s]
decima - INFO - Making predictions
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
💡 Tip: For seamless cloud logging and experiment tracking, try installing [litlogger](https://pypi.org/project/litlogger/) to enable LitLogger, which logs metrics and artifacts automatically to the Lightning Experiments platform.
💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]
Predicting ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 3/3 0:00:01 • 0:00:00 2.24it/s [2;4m2.24it/s
?25h/home/lala8/miniforge3/envs/decima/lib/python3.11/site-packages/torchmetrics/utilities/prints.py:43: UserWarning: The ``compute`` method of metric WarningCounter was called before the ``update`` method which may lead to errors, as metric states have not yet been updated.
decima - INFO - Creating anndata
decima - INFO - Evaluating performance
decima - WARNING - No ground truth expression matrix found in the metadata. Skipping evaluation.
After running this command, you can load the resulting predictions in Python using the DecimaResult class and access the predicted expression matrix as:
result = DecimaResult.load("example/predict_genes/predictions.h5ad")
result.predicted_expression_matrix()
| STRADA | ETV4 | USP25 | |
|---|---|---|---|
| agg_0 | 3.060397 | 2.881919 | 3.469221 |
| agg_1 | 3.117476 | 2.825525 | 3.596472 |
| agg_2 | 3.156399 | 3.122201 | 3.718817 |
| agg_3 | 3.213858 | 3.204461 | 3.629969 |
| agg_4 | 3.103418 | 3.031820 | 3.512249 |
| ... | ... | ... | ... |
| agg_9533 | 2.313128 | 2.173468 | 3.156261 |
| agg_9535 | 0.952004 | 0.956674 | 1.250045 |
| agg_9536 | 2.779305 | 2.705866 | 3.530668 |
| agg_9537 | 1.342528 | 1.407362 | 1.867149 |
| agg_9538 | 1.744911 | 1.633233 | 2.083091 |
8856 rows × 3 columns
or for a specific replicate:
result.predicted_expression_matrix(model_name="preds_v1_rep0")
| STRADA | ETV4 | USP25 | |
|---|---|---|---|
| agg_0 | 2.932824 | 2.892707 | 2.858933 |
| agg_1 | 2.816676 | 2.812716 | 3.058435 |
| agg_2 | 2.742964 | 2.971174 | 2.950304 |
| agg_3 | 2.804557 | 3.346401 | 2.837452 |
| agg_4 | 2.815957 | 3.088338 | 2.973126 |
| ... | ... | ... | ... |
| agg_9533 | 2.510271 | 2.760181 | 1.965718 |
| agg_9535 | 1.245795 | 1.192324 | 0.407210 |
| agg_9536 | 2.808904 | 3.369778 | 2.927307 |
| agg_9537 | 1.580065 | 1.533221 | 0.879286 |
| agg_9538 | 2.013909 | 2.077899 | 1.374621 |
8856 rows × 3 columns
Python API¶
The same functionality is available through the Python API, allowing you to perform gene expression prediction programmatically. You can specify the genes, model, and other options directly in your Python code using the provided classes and functions.
from decima.tools.inference import predict_gene_expression
ad = predict_gene_expression(
genes=["STRADA", "ETV4", "USP25"],
model="ensemble",
)
Predicting ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 3/3 0:00:01 • 0:00:00 2.24it/s
/home/lala8/miniforge3/envs/decima/lib/python3.11/site-packages/torchmetrics/utilities/prints.py:43: UserWarning: The ``compute`` method of metric WarningCounter was called before the ``update`` method which may lead to errors, as metric states have not yet been updated.
No ground truth expression matrix found in the metadata. Skipping evaluation.
Developer API¶
Under the hood, the Decima prediction API uses the GeneDataset and SeqDataset pytorch Dataset classes to prepare the data for prediction. These classes provide a flexible way to handle different types of input data, including custom genes and DNA sequences. Internally, these datasets represent sequences using one-hot encoding and apply a gene mask to indicate which positions correspond to gene regions.
For example, you can create a GeneDataset object to predict expression for the genes in your metadata. The predict_on_dataset method returns a dictionary containing the predicted expression values and if there is any warnings.
from pprint import pprint
from decima.data.dataset import GeneDataset
from decima.hub import load_decima_model
model = load_decima_model("rep0", device=0)
ds = GeneDataset(genes=["STRADA", "ETV4", "USP25"])
preds = model.predict_on_dataset(ds, device=0)
pprint(preds)
Predicting ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1/1 0:00:00 • 0:00:00 0.00it/s
{'expression': array([[2.932824 , 2.8166761 , 2.7429636 , ..., 2.8089042 , 1.5800648 ,
2.0139086 ],
[2.0204957 , 1.9390392 , 2.3947737 , ..., 1.5031691 , 0.5186505 ,
0.80847657],
[4.795655 , 4.9177384 , 4.8059716 , ..., 4.6602583 , 2.7780144 ,
3.1483455 ]], dtype=float32),
'warnings': Counter({'unknown': tensor(0),
'allele_mismatch_with_reference_genome': tensor(0)})}
/home/lala8/miniforge3/envs/decima/lib/python3.11/site-packages/torchmetrics/utilities/prints.py:43: UserWarning: The ``compute`` method of metric WarningCounter was called before the ``update`` method which may lead to errors, as metric states have not yet been updated.
Custom Expression for custom genes¶
If you have custom genes, you can create a SeqDataset object to predict expression for those genes.
To do this, prepare a FASTA file where:
Each sequence is exactly the Decima context size (524,288 bases).
The FASTA header for each sequence must include the gene name and the gene mask coordinates, using the format:
>gene_name|gene_mask_start=X|gene_mask_end=YwhereXandYspecify the start and end positions (0-based, inclusive) of the gene region within the sequence. The gene mask indicates which region of the sequence corresponds to the gene for which expression will be predicted.
For example, seqs.fasta contains these information:
! cat ../../tests/data/seqs.fasta | cut -c 1-200
>CD68|gene_mask_start=163840|gene_mask_end=166460
CTCTGCAGAGAGCGAGGACGGTGTGTCTGCCAGCGCCTTTGACTTCACTGTCTCCAACTTTGTGGACAACCTGTATGGCTACCCGGAAGGCAAGGATGTGCTTCGGGAGACCATCAAGTTTATGTACACAGACTGGGCCGACCGGGACAATGGCGAAATGCGCCGCAAAACCCTGCTGGCGCTCTTTACTGACCACCAAT
>SPI1|gene_mask_start=163840|gene_mask_end=187556
TGCCACTTTTAGATATGTTCATGGGTGCAGATACGGCTTTATTTATTTGAGACAGAGTTTCACTCTTGTTGCCCAGGGTGGAGTGCAGTGGTGCGATCTCAGCTCACTGCAGCCTTCGCCTCCCGGGTTGAAGCGATTCTTCTGCCTCAACCTCGAGTAGCTGGGATTATAGGCACCTGCCAGCATGCCTGGCTAATTTT
from decima.data.dataset import SeqDataset
ds = SeqDataset.from_fasta("../../tests/data/seqs.fasta")
preds = model.predict_on_dataset(ds, device=0)
pprint(preds["expression"])
Predicting ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1/1 0:00:00 • 0:00:00 0.00it/s
array([[0.22018507, 0.2193921 , 0.23078808, ..., 1.0356065 , 1.2123175 ,
1.4411153 ],
[0.2651469 , 0.1334098 , 0.14885235, ..., 0.44431016, 0.26546577,
0.304252 ]], dtype=float32)
See the documentation of SeqDataset for more details. SeqDataset can be created from a pandas DataFrame with the following columns: seq, gene_mask_start, gene_mask_end, and gene_name with SeqDataset.from_dataframe or from a one-hot encoded tensor with SeqDataset.from_one_hot.