Predicting Gene Expression with Decima

Decima allows prediction of gene expression at the cell type level, and this tutorial demonstrates how to leverage the prediction API for both genes in the training data and custom genes.

Precomputed Predictions

Scores for all genes in the training data are precomputed and saved to metadata h5ad object for each model replicate and are available under the DecimaResult class. predicted_expression_matrix class returns predicted average gene expression across the replicates.

from decima import DecimaResult

result = DecimaResult.load()
result.predicted_expression_matrix()
/home/celikm5/miniforge3/envs/decima2/lib/python3.11/site-packages/pydantic/_internal/_generate_schema.py:2249: UnsupportedFieldAttributeWarning: The 'repr' attribute with value False was provided to the `Field()` function, which has no effect in the context it was used. 'repr' is field-specific metadata, and can only be attached to a model field using `Annotated` metadata or by assignment. This may have happened because an `Annotated` type alias using the `type` statement was used, or if the `Field()` function was attached to a single member of a union type.
  warnings.warn(
/home/celikm5/miniforge3/envs/decima2/lib/python3.11/site-packages/pydantic/_internal/_generate_schema.py:2249: UnsupportedFieldAttributeWarning: The 'frozen' attribute with value True was provided to the `Field()` function, which has no effect in the context it was used. 'frozen' is field-specific metadata, and can only be attached to a model field using `Annotated` metadata or by assignment. This may have happened because an `Annotated` type alias using the `type` statement was used, or if the `Field()` function was attached to a single member of a union type.
  warnings.warn(
wandb: Currently logged in as: mhcelik (mhcw) to https://api.wandb.ai. Use `wandb login --relogin` to force relogin
wandb: Downloading large artifact 'metadata:latest', 3122.32MB. 1 files...
wandb:   1 of 1 files downloaded.  
Done. 00:00:08.3 (375.3MB/s)
STRADA ETV4 USP25 ZSWIM5 C21orf58 MIR497HG CFAP74 GSE1 LPP CLK1 ... STRIP2 TNFRSF1A RBM14-RBM4 C1orf21 LINC00639 NPDC1 ZNF425 COL5A1 BRD3 EVI5L
agg_0 2.973438 1.845565 4.592531 5.099802 1.774879 0.356812 2.590836 4.629774 4.897171 3.326940 ... 2.836060 0.297015 1.883849 4.293593 1.463565 3.183534 2.340202 2.374942 2.911916 3.230072
agg_1 2.954213 1.896726 4.688557 5.510440 1.666929 0.352725 2.292625 4.459535 4.915286 3.192858 ... 3.125704 0.242543 1.908177 4.439424 1.236739 3.494824 2.425672 2.054568 2.713408 3.491463
agg_2 2.938851 2.197247 4.861410 5.617520 1.773381 0.380867 2.394917 4.415038 4.836399 3.390717 ... 3.082098 0.263285 2.006456 4.383455 1.208590 4.013819 2.408381 2.297343 2.892222 3.695785
agg_3 3.045972 2.138573 4.863791 5.273604 1.760097 0.463555 2.391702 3.940975 4.857763 3.410926 ... 2.882890 0.290327 1.922963 4.550189 1.430520 3.693118 2.297103 2.121887 2.626117 3.223912
agg_4 3.025518 2.019096 4.602948 5.257001 1.755338 0.382190 2.432810 4.392480 4.959488 3.250500 ... 3.082296 0.258540 2.038277 4.464807 1.249043 3.665800 2.400820 2.255862 2.925619 3.471005
... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
agg_9533 2.333562 0.633322 4.675825 2.793023 0.752030 0.692083 0.503531 4.327948 6.903193 3.695593 ... 0.549795 2.270181 1.563218 4.395422 0.550088 1.330252 1.044471 3.759369 2.491346 1.872717
agg_9535 0.835037 0.358773 1.964896 0.307449 0.337240 0.834196 0.093885 1.853794 3.700790 4.467302 ... 0.176885 1.370898 1.022708 3.400267 0.052162 1.908870 0.253417 1.448111 1.622033 1.064292
agg_9536 3.008039 1.209324 4.798392 3.931870 1.401328 1.638555 0.969720 4.779201 6.631931 4.127797 ... 1.174298 1.870530 2.506874 5.151776 0.967644 1.809947 2.205356 4.244005 2.974467 2.659873
agg_9537 1.241936 0.455059 2.919995 0.571672 0.486448 1.175586 0.145397 2.412148 4.759118 4.913945 ... 0.371035 1.361073 1.668085 4.005738 0.078611 1.571750 0.508187 2.067150 2.323764 1.429850
agg_9538 1.715507 0.700955 3.044732 0.858696 0.903406 1.763168 0.215304 2.604478 4.549708 4.839124 ... 0.594310 1.801298 2.075996 3.933860 0.165590 1.970268 0.993521 2.232347 2.473388 1.902884

8856 rows × 18457 columns

To access the predicted expression matrix for a specific model, you can use the model_name parameter. In this example, we obtain the predicted gene expression for first model replicate.

result.predicted_expression_matrix(model_name="v1_rep0")
STRADA ETV4 USP25 ZSWIM5 C21orf58 MIR497HG CFAP74 GSE1 LPP CLK1 ... STRIP2 TNFRSF1A RBM14-RBM4 C1orf21 LINC00639 NPDC1 ZNF425 COL5A1 BRD3 EVI5L
agg_0 2.973438 1.845565 4.592531 5.099802 1.774879 0.356812 2.590836 4.629774 4.897171 3.326940 ... 2.836060 0.297015 1.883849 4.293593 1.463565 3.183534 2.340202 2.374942 2.911916 3.230072
agg_1 2.954213 1.896726 4.688557 5.510440 1.666929 0.352725 2.292625 4.459535 4.915286 3.192858 ... 3.125704 0.242543 1.908177 4.439424 1.236739 3.494824 2.425672 2.054568 2.713408 3.491463
agg_2 2.938851 2.197247 4.861410 5.617520 1.773381 0.380867 2.394917 4.415038 4.836399 3.390717 ... 3.082098 0.263285 2.006456 4.383455 1.208590 4.013819 2.408381 2.297343 2.892222 3.695785
agg_3 3.045972 2.138573 4.863791 5.273604 1.760097 0.463555 2.391702 3.940975 4.857763 3.410926 ... 2.882890 0.290327 1.922963 4.550189 1.430520 3.693118 2.297103 2.121887 2.626117 3.223912
agg_4 3.025518 2.019096 4.602948 5.257001 1.755338 0.382190 2.432810 4.392480 4.959488 3.250500 ... 3.082296 0.258540 2.038277 4.464807 1.249043 3.665800 2.400820 2.255862 2.925619 3.471005
... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
agg_9533 2.333562 0.633322 4.675825 2.793023 0.752030 0.692083 0.503531 4.327948 6.903193 3.695593 ... 0.549795 2.270181 1.563218 4.395422 0.550088 1.330252 1.044471 3.759369 2.491346 1.872717
agg_9535 0.835037 0.358773 1.964896 0.307449 0.337240 0.834196 0.093885 1.853794 3.700790 4.467302 ... 0.176885 1.370898 1.022708 3.400267 0.052162 1.908870 0.253417 1.448111 1.622033 1.064292
agg_9536 3.008039 1.209324 4.798392 3.931870 1.401328 1.638555 0.969720 4.779201 6.631931 4.127797 ... 1.174298 1.870530 2.506874 5.151776 0.967644 1.809947 2.205356 4.244005 2.974467 2.659873
agg_9537 1.241936 0.455059 2.919995 0.571672 0.486448 1.175586 0.145397 2.412148 4.759118 4.913945 ... 0.371035 1.361073 1.668085 4.005738 0.078611 1.571750 0.508187 2.067150 2.323764 1.429850
agg_9538 1.715507 0.700955 3.044732 0.858696 0.903406 1.763168 0.215304 2.604478 4.549708 4.839124 ... 0.594310 1.801298 2.075996 3.933860 0.165590 1.970268 0.993521 2.232347 2.473388 1.902884

8856 rows × 18457 columns

and for the second model replicate.

result.predicted_expression_matrix(model_name="v1_rep1")
STRADA ETV4 USP25 ZSWIM5 C21orf58 MIR497HG CFAP74 GSE1 LPP CLK1 ... STRIP2 TNFRSF1A RBM14-RBM4 C1orf21 LINC00639 NPDC1 ZNF425 COL5A1 BRD3 EVI5L
agg_0 2.973438 1.845565 4.592531 5.099802 1.774879 0.356812 2.590836 4.629774 4.897171 3.326940 ... 2.836060 0.297015 1.883849 4.293593 1.463565 3.183534 2.340202 2.374942 2.911916 3.230072
agg_1 2.954213 1.896726 4.688557 5.510440 1.666929 0.352725 2.292625 4.459535 4.915286 3.192858 ... 3.125704 0.242543 1.908177 4.439424 1.236739 3.494824 2.425672 2.054568 2.713408 3.491463
agg_2 2.938851 2.197247 4.861410 5.617520 1.773381 0.380867 2.394917 4.415038 4.836399 3.390717 ... 3.082098 0.263285 2.006456 4.383455 1.208590 4.013819 2.408381 2.297343 2.892222 3.695785
agg_3 3.045972 2.138573 4.863791 5.273604 1.760097 0.463555 2.391702 3.940975 4.857763 3.410926 ... 2.882890 0.290327 1.922963 4.550189 1.430520 3.693118 2.297103 2.121887 2.626117 3.223912
agg_4 3.025518 2.019096 4.602948 5.257001 1.755338 0.382190 2.432810 4.392480 4.959488 3.250500 ... 3.082296 0.258540 2.038277 4.464807 1.249043 3.665800 2.400820 2.255862 2.925619 3.471005
... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
agg_9533 2.333562 0.633322 4.675825 2.793023 0.752030 0.692083 0.503531 4.327948 6.903193 3.695593 ... 0.549795 2.270181 1.563218 4.395422 0.550088 1.330252 1.044471 3.759369 2.491346 1.872717
agg_9535 0.835037 0.358773 1.964896 0.307449 0.337240 0.834196 0.093885 1.853794 3.700790 4.467302 ... 0.176885 1.370898 1.022708 3.400267 0.052162 1.908870 0.253417 1.448111 1.622033 1.064292
agg_9536 3.008039 1.209324 4.798392 3.931870 1.401328 1.638555 0.969720 4.779201 6.631931 4.127797 ... 1.174298 1.870530 2.506874 5.151776 0.967644 1.809947 2.205356 4.244005 2.974467 2.659873
agg_9537 1.241936 0.455059 2.919995 0.571672 0.486448 1.175586 0.145397 2.412148 4.759118 4.913945 ... 0.371035 1.361073 1.668085 4.005738 0.078611 1.571750 0.508187 2.067150 2.323764 1.429850
agg_9538 1.715507 0.700955 3.044732 0.858696 0.903406 1.763168 0.215304 2.604478 4.549708 4.839124 ... 0.594310 1.801298 2.075996 3.933860 0.165590 1.970268 0.993521 2.232347 2.473388 1.902884

8856 rows × 18457 columns

result.anndata.layers
Layers with keys: preds, v1_rep0, v1_rep1, v1_rep2, v1_rep3

CLI API

If you want to perform gene expression prediction again, rather than using the precomputed scores, you can use the Decima command-line interface (CLI) to generate new predictions for any set of genes you specify. For example, you can run the decima predict-genes command with the --genes argument to provide a comma-separated list of gene names (such as “STRADA,ETV4,USP25”) if no gene provided it will perform expression predictions for all genes, select the prediction model with the --model option (for instance, “ensemble” or a specific replicate like “0”), and use --save-replicates to save predictions for each replicate. The -o flag lets you specify the output file path for the predictions in .h5ad format.

! decima predict-genes --genes "STRADA,ETV4,USP25" --model ensemble --save-replicates -o example/predict_genes/predictions.h5ad 
/home/celikm5/miniforge3/envs/decima2/lib/python3.11/site-packages/pydantic/_internal/_generate_schema.py:2249: UnsupportedFieldAttributeWarning: The 'repr' attribute with value False was provided to the `Field()` function, which has no effect in the context it was used. 'repr' is field-specific metadata, and can only be attached to a model field using `Annotated` metadata or by assignment. This may have happened because an `Annotated` type alias using the `type` statement was used, or if the `Field()` function was attached to a single member of a union type.
  warnings.warn(
/home/celikm5/miniforge3/envs/decima2/lib/python3.11/site-packages/pydantic/_internal/_generate_schema.py:2249: UnsupportedFieldAttributeWarning: The 'frozen' attribute with value True was provided to the `Field()` function, which has no effect in the context it was used. 'frozen' is field-specific metadata, and can only be attached to a model field using `Annotated` metadata or by assignment. This may have happened because an `Annotated` type alias using the `type` statement was used, or if the `Field()` function was attached to a single member of a union type.
  warnings.warn(
decima - INFO - Using device: 0 and genome: hg38 for prediction.
decima - INFO - Loading model ensemble...
wandb: Currently logged in as: mhcelik (mhcw) to https://api.wandb.ai. Use `wandb login --relogin` to force relogin
wandb: Downloading large artifact 'rep0:latest', 720.03MB. 1 files...
wandb:   1 of 1 files downloaded.  
Done. 00:00:02.6 (272.9MB/s)
wandb: Downloading large artifact 'rep1:latest', 720.03MB. 1 files...
wandb:   1 of 1 files downloaded.  
Done. 00:00:01.8 (405.6MB/s)
wandb: Downloading large artifact 'rep2:latest', 720.03MB. 1 files...
wandb:   1 of 1 files downloaded.  
Done. 00:00:02.0 (359.5MB/s)
wandb: Downloading large artifact 'rep3:latest', 720.03MB. 1 files...
wandb:   1 of 1 files downloaded.  
Done. 00:00:01.8 (400.1MB/s)
decima - INFO - Making predictions
wandb: Downloading large artifact 'metadata:latest', 3122.32MB. 1 files...
wandb:   1 of 1 files downloaded.  
Done. 00:00:01.9 (1645.5MB/s)
/home/celikm5/miniforge3/envs/decima2/lib/python3.11/site-packages/torch/__init__.py:1617: UserWarning: Please use the new API settings to control TF32 behavior, such as torch.backends.cudnn.conv.fp32_precision = 'tf32' or torch.backends.cuda.matmul.fp32_precision = 'ieee'. Old settings, e.g, torch.backends.cuda.matmul.allow_tf32 = True, torch.backends.cudnn.allow_tf32 = True, allowTF32CuDNN() and allowTF32CuBLAS() will be deprecated after Pytorch 2.9. Please see https://pytorch.org/docs/main/notes/cuda.html#tensorfloat-32-tf32-on-ampere-and-later-devices (Triggered internally at /pytorch/aten/src/ATen/Context.cpp:80.)
💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
SLURM auto-requeueing enabled. Setting signal handlers.
Predicting: |                                             | 0/? [00:00<?, ?it/s]
Predicting: |                                             | 0/? [00:00<?, ?it/s]
Predicting DataLoader 0:   0%|                            | 0/3 [00:00<?, ?it/s]
Predicting DataLoader 0:  33%|██████▋             | 1/3 [00:00<00:01,  1.32it/s]
Predicting DataLoader 0:  67%|█████████████▎      | 2/3 [00:01<00:00,  1.67it/s]
Predicting DataLoader 0: 100%|████████████████████| 3/3 [00:01<00:00,  1.82it/s]
Predicting DataLoader 0: 100%|████████████████████| 3/3 [00:01<00:00,  1.81it/s]
/home/celikm5/miniforge3/envs/decima2/lib/python3.11/site-packages/torchmetrics/utilities/prints.py:43: UserWarning: The ``compute`` method of metric WarningCounter was called before the ``update`` method which may lead to errors, as metric states have not yet been updated.
decima - INFO - Creating anndata
decima - INFO - Evaluating performance
decima - WARNING - No ground truth expression matrix found in the metadata. Skipping evaluation.

After running this command, you can load the resulting predictions in Python using the DecimaResult class and access the predicted expression matrix as:

result = DecimaResult.load("example/predict_genes/predictions.h5ad")
result.predicted_expression_matrix()
STRADA ETV4 USP25
agg_0 3.060561 2.882100 3.469085
agg_1 3.117702 2.825738 3.596372
agg_2 3.156642 3.122381 3.718749
agg_3 3.214047 3.204670 3.629874
agg_4 3.103570 3.032028 3.512117
... ... ... ...
agg_9533 2.313220 2.173597 3.156157
agg_9535 0.952112 0.956720 1.250005
agg_9536 2.779494 2.705945 3.530610
agg_9537 1.342694 1.407407 1.867070
agg_9538 1.745080 1.633314 2.082985

8856 rows × 3 columns

or for a specific replicate:

result.predicted_expression_matrix(model_name="preds_v1_rep0")
STRADA ETV4 USP25
agg_0 2.933017 2.892934 2.858879
agg_1 2.816778 2.812964 3.058485
agg_2 2.743120 2.971366 2.950323
agg_3 2.804692 3.346689 2.837382
agg_4 2.816030 3.088620 2.973081
... ... ... ...
agg_9533 2.510605 2.760446 1.965776
agg_9535 1.246022 1.192524 0.407104
agg_9536 2.809229 3.369957 2.927355
agg_9537 1.580334 1.533467 0.879179
agg_9538 2.014224 2.078125 1.374325

8856 rows × 3 columns

Python API

The same functionality is available through the Python API, allowing you to perform gene expression prediction programmatically. You can specify the genes, model, and other options directly in your Python code using the provided classes and functions.

from decima.tools.inference import predict_gene_expression

ad = predict_gene_expression(
    genes=["STRADA", "ETV4", "USP25"],
    model="ensemble",
)
wandb: Downloading large artifact 'rep0:latest', 720.03MB. 1 files...
wandb:   1 of 1 files downloaded.  
Done. 00:00:01.3 (539.8MB/s)
wandb: Downloading large artifact 'rep1:latest', 720.03MB. 1 files...
wandb:   1 of 1 files downloaded.  
Done. 00:00:01.4 (529.6MB/s)
wandb: Downloading large artifact 'rep2:latest', 720.03MB. 1 files...
wandb:   1 of 1 files downloaded.  
Done. 00:00:00.6 (1166.2MB/s)
wandb: Downloading large artifact 'rep3:latest', 720.03MB. 1 files...
wandb:   1 of 1 files downloaded.  
Done. 00:00:00.6 (1242.4MB/s)
wandb: Downloading large artifact 'metadata:latest', 3122.32MB. 1 files...
wandb:   1 of 1 files downloaded.  
Done. 00:00:05.1 (613.2MB/s)
/home/celikm5/miniforge3/envs/decima2/lib/python3.11/site-packages/torch/__init__.py:1617: UserWarning: Please use the new API settings to control TF32 behavior, such as torch.backends.cudnn.conv.fp32_precision = 'tf32' or torch.backends.cuda.matmul.fp32_precision = 'ieee'. Old settings, e.g, torch.backends.cuda.matmul.allow_tf32 = True, torch.backends.cudnn.allow_tf32 = True, allowTF32CuDNN() and allowTF32CuBLAS() will be deprecated after Pytorch 2.9. Please see https://pytorch.org/docs/main/notes/cuda.html#tensorfloat-32-tf32-on-ampere-and-later-devices (Triggered internally at /pytorch/aten/src/ATen/Context.cpp:80.)
💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
SLURM auto-requeueing enabled. Setting signal handlers.
/home/celikm5/miniforge3/envs/decima2/lib/python3.11/site-packages/torchmetrics/utilities/prints.py:43: UserWarning: The ``compute`` method of metric WarningCounter was called before the ``update`` method which may lead to errors, as metric states have not yet been updated.
No ground truth expression matrix found in the metadata. Skipping evaluation.

Developer API

Under the hood, the Decima prediction API uses the GeneDataset and SeqDataset pytorch Dataset classes to prepare the data for prediction. These classes provide a flexible way to handle different types of input data, including custom genes and DNA sequences. Internally, these datasets represent sequences using one-hot encoding and apply a gene mask to indicate which positions correspond to gene regions.

For example, you can create a GeneDataset object to predict expression for the genes in your metadata. The predict_on_dataset method returns a dictionary containing the predicted expression values and if there is any warnings.

from pprint import pprint
from decima.data.dataset import GeneDataset
from decima.hub import load_decima_model

model = load_decima_model("rep0", device=0)
ds = GeneDataset(genes=["STRADA", "ETV4", "USP25"])

preds = model.predict_on_dataset(ds, device=0)
pprint(preds)
wandb: Downloading large artifact 'rep0:latest', 720.03MB. 1 files...
wandb:   1 of 1 files downloaded.  
Done. 00:00:01.2 (580.7MB/s)
wandb: Downloading large artifact 'metadata:latest', 3122.32MB. 1 files...
wandb:   1 of 1 files downloaded.  
Done. 00:00:02.7 (1164.2MB/s)
💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
SLURM auto-requeueing enabled. Setting signal handlers.
/home/celikm5/miniforge3/envs/decima2/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:433: PossibleUserWarning: The 'predict_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=3` in the `DataLoader` to improve performance.
{'expression': array([[2.9330175, 2.8167777, 2.7431197, ..., 2.8092291, 1.580334 ,
        2.0142236],
       [2.0205972, 1.9390993, 2.3948462, ..., 1.5032245, 0.51865  ,
        0.8084921],
       [4.795624 , 4.9177237, 4.8060527, ..., 4.6602893, 2.7781048,
        3.1485069]], dtype=float32),
 'warnings': Counter({'unknown': tensor(0),
                      'allele_mismatch_with_reference_genome': tensor(0)})}
/home/celikm5/miniforge3/envs/decima2/lib/python3.11/site-packages/torchmetrics/utilities/prints.py:43: UserWarning: The ``compute`` method of metric WarningCounter was called before the ``update`` method which may lead to errors, as metric states have not yet been updated.

Custom Expression for custom genes

If you have custom genes, you can create a SeqDataset object to predict expression for those genes.

To do this, prepare a FASTA file where:

  • Each sequence is exactly the Decima context size (524,288 bases).

  • The FASTA header for each sequence must include the gene name and the gene mask coordinates, using the format: >gene_name|gene_mask_start=X|gene_mask_end=Y where X and Y specify the start and end positions (0-based, inclusive) of the gene region within the sequence. The gene mask indicates which region of the sequence corresponds to the gene for which expression will be predicted.

For example, seqs.fasta contains these information:

! cat ../tests/data/seqs.fasta | cut -c 1-200
cat: ../tests/data/seqs.fasta: No such file or directory
from decima.data.dataset import SeqDataset


ds = SeqDataset.from_fasta("../../tests/data/seqs.fasta")

preds = model.predict_on_dataset(ds, device=0)
pprint(preds["expression"])
💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
SLURM auto-requeueing enabled. Setting signal handlers.
array([[0.22026505, 0.21946757, 0.23086782, ..., 1.0358492 , 1.2124686 ,
        1.441437  ],
       [0.26515013, 0.13341773, 0.14887223, ..., 0.44430685, 0.26546437,
        0.3042384 ]], dtype=float32)

See the documentation of SeqDataset for more details. SeqDataset can be created from a pandas DataFrame with the following columns: seq, gene_mask_start, gene_mask_end, and gene_name with SeqDataset.from_dataframe or from a one-hot encoded tensor with SeqDataset.from_one_hot.