Fine-tuning Borzoi to create a Decima model

import glob
import anndata
import scanpy as sc
import pandas as pd
import bioframe as bf
import os
inputdir = "./data"
outdir = "./example"
ad_file_path = os.path.join(inputdir, "data.h5ad")
h5_file_path = os.path.join(outdir, "data.h5")

1. Load input anndata file

The input anndata file needs to be in the format (pseudobulks x genes).

ad = sc.read(ad_file_path)
ad
AnnData object with n_obs × n_vars = 50 × 920
    obs: 'cell_type', 'tissue', 'disease', 'study'
    var: 'chrom', 'start', 'end', 'strand', 'gene_start', 'gene_end', 'gene_length', 'gene_mask_start', 'gene_mask_end', 'dataset'
    uns: 'log1p'

.obs should be a dataframe with a unique index per pseudobulk. You can also include other columns with metadata about the pseudobulks, e.g. cell type, tissue, disease, study, number of cells, total counts.

Note that the original Decima model does NOT separate pseudobulks by sample, i.e. different samples from the same cell type, tissue, disease and study were merged. We also recommend filtering out pseudobulks with few cells or low read count.

ad.obs.head()
cell_type tissue disease study
pseudobulk_0 ct_0 t_0 d_0 st_0
pseudobulk_1 ct_0 t_0 d_1 st_0
pseudobulk_2 ct_0 t_0 d_2 st_1
pseudobulk_3 ct_0 t_0 d_0 st_1
pseudobulk_4 ct_0 t_0 d_1 st_2

.var should be a dataframe with a unique index per gene. The index can be the gene name or Ensembl ID, as long as it is unique. Other essential columns are: chrom, start, end and strand (the gene coordinates).

You can also include other columns with metadata about the genes, e.g. Ensembl ID, type of gene.

ad.var.head()
chrom start end strand gene_start gene_end gene_length gene_mask_start gene_mask_end dataset
gene_0 chr1 26191000 26715288 + 26354840 26879128 524288 163840 524288 train
gene_1 chr19 41275257 41799545 - 41111417 41635705 524288 163840 524288 train
gene_2 chr1 79937866 80462154 - 79774026 80298314 524288 163840 524288 train
gene_4 chr16 3905208 4429496 - 3741368 4265656 524288 163840 524288 train
gene_5 chr10 22495641 23019929 + 22659481 23183769 524288 163840 524288 train

.X should contain the total counts per gene and pseudobulk. These should be non-negative integers.

ad.X[:5, :5]
array([[0.       , 7.295568 , 7.295568 , 7.295568 , 7.295568 ],
       [7.316388 , 7.316388 , 0.       , 7.316388 , 7.316388 ],
       [7.3014727, 7.3014727, 7.3014727, 7.3014727, 0.       ],
       [7.3014727, 0.       , 7.3014727, 7.3014727, 0.       ],
       [7.3407264, 7.3407264, 0.       , 7.3407264, 7.3407264]],
      dtype=float32)

2. Normalize and log transform data

We first transform the counts to log(CPM+1) values. CPM = Counts Per Million.

sc.pp.normalize_total(ad, target_sum=1e6)
sc.pp.log1p(ad)
WARNING: adata.X seems to be already log-transformed.
ad.X[:5, :5]
array([[0.       , 7.297041 , 7.297041 , 7.297041 , 7.297041 ],
       [7.317892 , 7.317892 , 0.       , 7.317892 , 7.317892 ],
       [7.302954 , 7.302954 , 7.302954 , 7.302954 , 0.       ],
       [7.3014727, 0.       , 7.3014727, 7.3014727, 0.       ],
       [7.3422675, 7.3422675, 0.       , 7.3422675, 7.3422675]],
      dtype=float32)

3. Create intervals surrounding genes

Decima is trained on 524,288 bp sequence surrounding the genes. Therefore, we have to take the given gene coordinates and extend them to create intervals of this length.

from decima.data.preprocess import var_to_intervals
ad.var.head()
chrom start end strand gene_start gene_end gene_length gene_mask_start gene_mask_end dataset
gene_0 chr1 26191000 26715288 + 26354840 26879128 524288 163840 524288 train
gene_1 chr19 41275257 41799545 - 41111417 41635705 524288 163840 524288 train
gene_2 chr1 79937866 80462154 - 79774026 80298314 524288 163840 524288 train
gene_4 chr16 3905208 4429496 - 3741368 4265656 524288 163840 524288 train
gene_5 chr10 22495641 23019929 + 22659481 23183769 524288 163840 524288 train

First, we copy the start and end columns to gene_start and gene_end. We also create a new column gene_length.

ad.var["gene_start"] = ad.var.start.tolist()
ad.var["gene_end"] = ad.var.end.tolist()
ad.var["gene_length"] = ad.var["gene_end"] - ad.var["gene_start"]
ad.var.head()
chrom start end strand gene_start gene_end gene_length gene_mask_start gene_mask_end dataset
gene_0 chr1 26191000 26715288 + 26191000 26715288 524288 163840 524288 train
gene_1 chr19 41275257 41799545 - 41275257 41799545 524288 163840 524288 train
gene_2 chr1 79937866 80462154 - 79937866 80462154 524288 163840 524288 train
gene_4 chr16 3905208 4429496 - 3905208 4429496 524288 163840 524288 train
gene_5 chr10 22495641 23019929 + 22495641 23019929 524288 163840 524288 train

Now, we extend the gene coordinates to create enclosing intervals:

ad = var_to_intervals(ad, chr_end_pad=10000, genome="hg38")
# Replace genome name if necessary
The interval size is 524288 bases. Of these, 163840 will be upstream of the gene start and 360448 will be downstream of the gene start.
0 intervals extended beyond the chromosome start and have been shifted
1 intervals extended beyond the chromosome end and have been shifted
1 intervals did not extend far enough upstream of the TSS and have been dropped
ad.var.head()
chrom start end strand gene_start gene_end gene_length gene_mask_start gene_mask_end dataset
gene_0 chr1 26027160 26551448 + 26191000 26715288 524288 163840 524288 train
gene_1 chr19 41439097 41963385 - 41275257 41799545 524288 163840 524288 train
gene_2 chr1 80101706 80625994 - 79937866 80462154 524288 163840 524288 train
gene_4 chr16 4069048 4593336 - 3905208 4429496 524288 163840 524288 train
gene_5 chr10 22331801 22856089 + 22495641 23019929 524288 163840 524288 train

You see that the columns start and end now contain the start and end coordinates for the 524,288 bp intervals.

3. Split genes into training, validation and test sets

We load the coordinates of the genomic regions used to train Borzoi:

splits_file = "https://raw.githubusercontent.com/calico/borzoi/main/data/sequences_human.bed.gz"
# replace human with mouse for mm10 splits
splits = pd.read_table(splits_file, header=None, names=["chrom", "start", "end", "fold"])
splits.head()
chrom start end fold
0 chr4 82524421 82721029 fold0
1 chr13 18604798 18801406 fold0
2 chr2 189923408 190120016 fold0
3 chr10 59875743 60072351 fold0
4 chr1 117109467 117306075 fold0

Now, we overlap our gene intervals with these regions:

overlaps = bf.overlap(ad.var.reset_index(names="gene"), splits, how="left")
overlaps = overlaps[["gene", "fold_"]].drop_duplicates().astype(str)
overlaps.head()
gene fold_
0 gene_0 fold5
15 gene_1 fold0
30 gene_2 fold0
45 gene_4 fold2
59 gene_5 fold2

Based on the overlap, we divide our gene intervals into training, validation and test sets.

test_genes = overlaps.gene[overlaps.fold_ == "fold3"].tolist()
val_genes = overlaps.gene[overlaps.fold_ == "fold4"].tolist()
train_genes = set(overlaps.gene).difference(set(test_genes).union(val_genes))

And add this information back to ad.var.

ad.var["dataset"] = "test"
ad.var.loc[ad.var.index.isin(val_genes), "dataset"] = "val"
ad.var.loc[ad.var.index.isin(train_genes), "dataset"] = "train"
/tmp/slurmjob.38313776/ipykernel_505470/3109841685.py:1: ImplicitModificationWarning: Trying to modify attribute `.var` of view, initializing view as actual.
ad.var.head()
chrom start end strand gene_start gene_end gene_length gene_mask_start gene_mask_end dataset
gene_0 chr1 26027160 26551448 + 26191000 26715288 524288 163840 524288 train
gene_1 chr19 41439097 41963385 - 41275257 41799545 524288 163840 524288 train
gene_2 chr1 80101706 80625994 - 79937866 80462154 524288 163840 524288 train
gene_4 chr16 4069048 4593336 - 3905208 4429496 524288 163840 524288 train
gene_5 chr10 22331801 22856089 + 22495641 23019929 524288 163840 524288 train
ad.var.dataset.value_counts()
dataset
train    765
test      80
val       74
Name: count, dtype: int64

We have now divided the 1000 genes in our dataset into separate sets to be used for training, validation and testing.

4. Save processed anndata

We will save the processed anndata file containing these intervals and data splits.

ad.write_h5ad(ad_file_path)

5. Create an hdf5 file

To train Decima, we need to extract the genomic sequences for all the intervals and convert them to one-hot encoded format. We save these one-hot encoded inputs to an hdf5 file.

from decima.data.write_hdf5 import write_hdf5
! mkdir -p example
write_hdf5(file=h5_file_path, ad=ad, pad=5000, genome="hg38")
# Change genome name if necessary
Writing metadata
Writing task indices
Writing genes array of shape: (919, 2)
Writing labels array of shape: (919, 50, 1)
Making gene masks
Writing mask array of shape: (919, 534288)
Encoding sequences
Writing sequence array of shape: (919, 534288)
Done!
Writing mask array of shape: (920, 534288)
Encoding sequences
Writing sequence array of shape: (920, 534288)
Done!

6. Set training parameters

# Learning rate default=0.001
lr = 5e-5
# Total weight parameter for the loss function
total_weight = 1e-4
# Gradient accumulation steps
grad = 5
# batch-size. default=4
bs = 4
# max-seq-shift. default=5000
shift = 5000
# Number of epochs. Default 1
epochs = 15

# logger
logger = "wandb"  # Change to csv to save logs locally

# Number of workers default=16
workers = 16

7. Generate training commands

cmds = []

for model in range(4):
    name = f"finetune_test_{model}"
    device = model

    cmd = (
        f"decima finetune --name {name} "
        + f"--model {model} --device {device} "
        + f"--matrix-file {ad_file_path} --h5-file {h5_file_path} "
        + f"--outdir {outdir} --learning-rate {lr} "
        + f"--loss-total-weight {total_weight} --gradient-accumulation {grad} "
        + f"--batch-size {bs} --max-seq-shift {shift} "
        + f"--epochs {epochs} --logger {logger} --num-workers {workers}"
    )
    cmds.append(cmd)
for cmd in cmds:
    print(cmd)
decima finetune --name finetune_test_0 --model 0 --device 0 --matrix-file ./data/data.h5ad --h5-file ./example/data.h5 --outdir ./example --learning-rate 5e-05 --loss-total-weight 0.0001 --gradient-accumulation 5 --batch-size 4 --max-seq-shift 5000 --epochs 15 --logger wandb --num-workers 16
decima finetune --name finetune_test_1 --model 1 --device 1 --matrix-file ./data/data.h5ad --h5-file ./example/data.h5 --outdir ./example --learning-rate 5e-05 --loss-total-weight 0.0001 --gradient-accumulation 5 --batch-size 4 --max-seq-shift 5000 --epochs 15 --logger wandb --num-workers 16
decima finetune --name finetune_test_2 --model 2 --device 2 --matrix-file ./data/data.h5ad --h5-file ./example/data.h5 --outdir ./example --learning-rate 5e-05 --loss-total-weight 0.0001 --gradient-accumulation 5 --batch-size 4 --max-seq-shift 5000 --epochs 15 --logger wandb --num-workers 16
decima finetune --name finetune_test_3 --model 3 --device 3 --matrix-file ./data/data.h5ad --h5-file ./example/data.h5 --outdir ./example --learning-rate 5e-05 --loss-total-weight 0.0001 --gradient-accumulation 5 --batch-size 4 --max-seq-shift 5000 --epochs 15 --logger wandb --num-workers 16

Here, we train the model for 1 epoch for quick progressing in tutorial. Run the training for more epochs in your training.

! CUDA_VISIBLE_DEVICES=0 decima finetune \
--name finetune_test_0 \
--model 0 \
--device 0 \
--matrix-file {ad_file_path} \
--h5-file {h5_file_path} \
--outdir {outdir} \
--learning-rate {lr} \
--loss-total-weight {total_weight} \
--gradient-accumulation {grad} \
--batch-size 1 \
--max-seq-shift {shift} \
--epochs 1 \
--logger {logger} \
--num-workers {workers}
# Uncomment if necessary
# import wandb
# wandb.login(host="https://genentech.wandb.io", anonymous="never", relogin=True)

8. Make and evaluate predictions using trained models

Using the training commands above, we trained two model replicates. Now, we can use these models to predict gene expression:

checkpoint = glob.glob(os.path.join(outdir, "lightning_logs/*/checkpoints/*.ckpt"))[0]
print(checkpoint)
./example/lightning_logs/g0m7s659/checkpoints/epoch=0-step=153.ckpt
# comma-separated list of model checkpoints
checkpoint_list = ",".join([checkpoint, checkpoint])
checkpoint_list
'./example/lightning_logs/g0m7s659/checkpoints/epoch=0-step=153.ckpt,./example/lightning_logs/g0m7s659/checkpoints/epoch=0-step=153.ckpt'
! CUDA_VISIBLE_DEVICES=0 decima predict-genes \
--output example/test_preds.h5ad \
--model {checkpoint_list} \
--metadata {ad_file_path} \
--device 0 \
--batch-size 8 \
--num-workers 16 \
--max_seq_shift 0 \
--genome hg38 \
--save-replicates
decima - INFO - Using device: 0 and genome: hg38 for prediction.
decima - INFO - Loading model ['./example/lightning_logs/g0m7s659/checkpoints/epoch=0-step=153.ckpt', './example/lightning_logs/g0m7s659/checkpoints/epoch=0-step=153.ckpt']...
/gpfs/scratchfs01/site/u/lala8/conda/envs/decima/lib/python3.11/site-packages/lightning_fabric/utilities/cloud_io.py:73: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
decima - INFO - Making predictions
/gpfs/scratchfs01/site/u/lala8/conda/envs/decima/lib/python3.11/site-packages/lightning_fabric/plugins/environments/slurm.py:204: PossibleUserWarning: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python3.11 /home/lala8/.local/bin/decima predict-genes --ou ...
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
💡 Tip: For seamless cloud logging and experiment tracking, try installing [litlogger](https://pypi.org/project/litlogger/) to enable LitLogger, which logs metrics and artifacts automatically to the Lightning Experiments platform.
💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Predicting ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 115/115 0:05:16 • 0:00:00 0.36it/s it/s it/s 
?25h/gpfs/scratchfs01/site/u/lala8/conda/envs/decima/lib/python3.11/site-packages/torchmetrics/utilities/prints.py:43: UserWarning: The ``compute`` method of metric WarningCounter was called before the ``update`` method which may lead to errors, as metric states have not yet been updated.
decima - INFO - Creating anndata
decima - INFO - Evaluating performance
Performance on genes in the train dataset.
Mean Pearson Correlation per gene: Mean: 0.01.
Mean Pearson Correlation per gene using size factor (baseline): 0.03.
Mean Pearson Correlation per pseudobulk: -0.00

Performance on genes in the val dataset.
Mean Pearson Correlation per gene: Mean: -0.02.
Mean Pearson Correlation per gene using size factor (baseline): 0.05.
Mean Pearson Correlation per pseudobulk:  0.01

Performance on genes in the test dataset.
Mean Pearson Correlation per gene: Mean: -0.02.
Mean Pearson Correlation per gene using size factor (baseline): -0.00.
Mean Pearson Correlation per pseudobulk: -0.02


decima - INFO - Using device: 0 and genome: hg38 for prediction.
decima - INFO - Loading model ['./example/lightning_logs/g20ya0al/checkpoints/epoch=0-step=154.ckpt', './example/lightning_logs/g20ya0al/checkpoints/epoch=0-step=154.ckpt']...
decima - INFO - Making predictions
/home/celikm5/miniforge3/envs/decima2/lib/python3.11/site-packages/torch/__init__.py:1617: UserWarning: Please use the new API settings to control TF32 behavior, such as torch.backends.cudnn.conv.fp32_precision = 'tf32' or torch.backends.cuda.matmul.fp32_precision = 'ieee'. Old settings, e.g, torch.backends.cuda.matmul.allow_tf32 = True, torch.backends.cudnn.allow_tf32 = True, allowTF32CuDNN() and allowTF32CuBLAS() will be deprecated after Pytorch 2.9. Please see https://pytorch.org/docs/main/notes/cuda.html#tensorfloat-32-tf32-on-ampere-and-later-devices (Triggered internally at /pytorch/aten/src/ATen/Context.cpp:80.)
💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/home/celikm5/miniforge3/envs/decima2/lib/python3.11/site-packages/torch/utils/data/dataloader.py:627: UserWarning: This DataLoader will create 32 worker processes in total. Our suggested max number of worker in current system is 4, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
SLURM auto-requeueing enabled. Setting signal handlers.
Predicting: |                                             | 0/? [00:00<?, ?it/s]
Predicting: |                                             | 0/? [00:00<?, ?it/s]
Predicting DataLoader 0:   0%|                          | 0/115 [00:00<?, ?it/s]
Predicting DataLoader 0:   1%|▏                 | 1/115 [00:03<06:53,  0.28it/s]
Predicting DataLoader 0:   2%|▎                 | 2/115 [00:05<04:55,  0.38it/s]
Predicting DataLoader 0:   3%|▍                 | 3/115 [00:06<04:18,  0.43it/s]
Predicting DataLoader 0:   3%|▋                 | 4/115 [00:08<03:59,  0.46it/s]
Predicting DataLoader 0:   4%|▊                 | 5/115 [00:10<03:47,  0.48it/s]
Predicting DataLoader 0:   5%|▉                 | 6/115 [00:12<03:38,  0.50it/s]
Predicting DataLoader 0:   6%|█                 | 7/115 [00:13<03:31,  0.51it/s]
Predicting DataLoader 0:   7%|█▎                | 8/115 [00:15<03:26,  0.52it/s]
Predicting DataLoader 0:   8%|█▍                | 9/115 [00:17<03:21,  0.53it/s]
Predicting DataLoader 0:   9%|█▍               | 10/115 [00:18<03:17,  0.53it/s]
Predicting DataLoader 0:  10%|█▋               | 11/115 [00:20<03:14,  0.54it/s]
Predicting DataLoader 0:  10%|█▊               | 12/115 [00:22<03:10,  0.54it/s]
Predicting DataLoader 0:  11%|█▉               | 13/115 [00:23<03:07,  0.54it/s]
Predicting DataLoader 0:  12%|██               | 14/115 [00:25<03:05,  0.55it/s]
Predicting DataLoader 0:  13%|██▏              | 15/115 [00:27<03:02,  0.55it/s]
Predicting DataLoader 0:  14%|██▎              | 16/115 [00:29<02:59,  0.55it/s]
Predicting DataLoader 0:  15%|██▌              | 17/115 [00:30<02:57,  0.55it/s]
Predicting DataLoader 0:  16%|██▋              | 18/115 [00:32<02:54,  0.55it/s]
Predicting DataLoader 0:  17%|██▊              | 19/115 [00:34<02:52,  0.56it/s]
Predicting DataLoader 0:  17%|██▉              | 20/115 [00:35<02:50,  0.56it/s]
Predicting DataLoader 0:  18%|███              | 21/115 [00:37<02:48,  0.56it/s]
Predicting DataLoader 0:  19%|███▎             | 22/115 [00:39<02:46,  0.56it/s]
Predicting DataLoader 0:  20%|███▍             | 23/115 [00:40<02:43,  0.56it/s]
Predicting DataLoader 0:  21%|███▌             | 24/115 [00:42<02:41,  0.56it/s]
Predicting DataLoader 0:  22%|███▋             | 25/115 [00:44<02:39,  0.56it/s]
Predicting DataLoader 0:  23%|███▊             | 26/115 [00:46<02:37,  0.56it/s]
Predicting DataLoader 0:  23%|███▉             | 27/115 [00:47<02:35,  0.56it/s]
Predicting DataLoader 0:  24%|████▏            | 28/115 [00:49<02:33,  0.57it/s]
Predicting DataLoader 0:  25%|████▎            | 29/115 [00:51<02:31,  0.57it/s]
Predicting DataLoader 0:  26%|████▍            | 30/115 [00:52<02:29,  0.57it/s]
Predicting DataLoader 0:  27%|████▌            | 31/115 [00:54<02:28,  0.57it/s]
Predicting DataLoader 0:  28%|████▋            | 32/115 [00:56<02:26,  0.57it/s]
Predicting DataLoader 0:  29%|████▉            | 33/115 [00:58<02:24,  0.57it/s]
Predicting DataLoader 0:  30%|█████            | 34/115 [00:59<02:22,  0.57it/s]
Predicting DataLoader 0:  30%|█████▏           | 35/115 [01:01<02:20,  0.57it/s]
Predicting DataLoader 0:  31%|█████▎           | 36/115 [01:03<02:18,  0.57it/s]
Predicting DataLoader 0:  32%|█████▍           | 37/115 [01:04<02:16,  0.57it/s]
Predicting DataLoader 0:  33%|█████▌           | 38/115 [01:06<02:14,  0.57it/s]
Predicting DataLoader 0:  34%|█████▊           | 39/115 [01:08<02:13,  0.57it/s]
Predicting DataLoader 0:  35%|█████▉           | 40/115 [01:10<02:11,  0.57it/s]
Predicting DataLoader 0:  36%|██████           | 41/115 [01:11<02:09,  0.57it/s]
Predicting DataLoader 0:  37%|██████▏          | 42/115 [01:13<02:07,  0.57it/s]
Predicting DataLoader 0:  37%|██████▎          | 43/115 [01:15<02:05,  0.57it/s]
Predicting DataLoader 0:  38%|██████▌          | 44/115 [01:16<02:03,  0.57it/s]
Predicting DataLoader 0:  39%|██████▋          | 45/115 [01:18<02:02,  0.57it/s]
Predicting DataLoader 0:  40%|██████▊          | 46/115 [01:20<02:00,  0.57it/s]
Predicting DataLoader 0:  41%|██████▉          | 47/115 [01:21<01:58,  0.57it/s]
Predicting DataLoader 0:  42%|███████          | 48/115 [01:23<01:56,  0.57it/s]
Predicting DataLoader 0:  43%|███████▏         | 49/115 [01:25<01:54,  0.57it/s]
Predicting DataLoader 0:  43%|███████▍         | 50/115 [01:27<01:53,  0.57it/s]
Predicting DataLoader 0:  44%|███████▌         | 51/115 [01:28<01:51,  0.57it/s]
Predicting DataLoader 0:  45%|███████▋         | 52/115 [01:30<01:49,  0.57it/s]
Predicting DataLoader 0:  46%|███████▊         | 53/115 [01:32<01:47,  0.57it/s]
Predicting DataLoader 0:  47%|███████▉         | 54/115 [01:33<01:46,  0.58it/s]
Predicting DataLoader 0:  48%|████████▏        | 55/115 [01:35<01:44,  0.58it/s]
Predicting DataLoader 0:  49%|████████▎        | 56/115 [01:37<01:42,  0.58it/s]
Predicting DataLoader 0:  50%|████████▍        | 57/115 [01:39<01:40,  0.58it/s]
Predicting DataLoader 0:  50%|████████▌        | 58/115 [01:40<01:38,  0.58it/s]
Predicting DataLoader 0:  51%|████████▋        | 59/115 [01:42<01:37,  0.58it/s]
Predicting DataLoader 0:  52%|████████▊        | 60/115 [01:44<01:35,  0.58it/s]
Predicting DataLoader 0:  53%|█████████        | 61/115 [01:45<01:33,  0.58it/s]
Predicting DataLoader 0:  54%|█████████▏       | 62/115 [01:47<01:31,  0.58it/s]
Predicting DataLoader 0:  55%|█████████▎       | 63/115 [01:49<01:30,  0.58it/s]
Predicting DataLoader 0:  56%|█████████▍       | 64/115 [01:50<01:28,  0.58it/s]
Predicting DataLoader 0:  57%|█████████▌       | 65/115 [01:52<01:26,  0.58it/s]
Predicting DataLoader 0:  57%|█████████▊       | 66/115 [01:54<01:24,  0.58it/s]
Predicting DataLoader 0:  58%|█████████▉       | 67/115 [01:56<01:23,  0.58it/s]
Predicting DataLoader 0:  59%|██████████       | 68/115 [01:57<01:21,  0.58it/s]
Predicting DataLoader 0:  60%|██████████▏      | 69/115 [01:59<01:19,  0.58it/s]
Predicting DataLoader 0:  61%|██████████▎      | 70/115 [02:01<01:17,  0.58it/s]
Predicting DataLoader 0:  62%|██████████▍      | 71/115 [02:02<01:16,  0.58it/s]
Predicting DataLoader 0:  63%|██████████▋      | 72/115 [02:04<01:14,  0.58it/s]
Predicting DataLoader 0:  63%|██████████▊      | 73/115 [02:06<01:12,  0.58it/s]
Predicting DataLoader 0:  64%|██████████▉      | 74/115 [02:08<01:10,  0.58it/s]
Predicting DataLoader 0:  65%|███████████      | 75/115 [02:09<01:09,  0.58it/s]
Predicting DataLoader 0:  66%|███████████▏     | 76/115 [02:11<01:07,  0.58it/s]
Predicting DataLoader 0:  67%|███████████▍     | 77/115 [02:13<01:05,  0.58it/s]
Predicting DataLoader 0:  68%|███████████▌     | 78/115 [02:14<01:03,  0.58it/s]
Predicting DataLoader 0:  69%|███████████▋     | 79/115 [02:16<01:02,  0.58it/s]
Predicting DataLoader 0:  70%|███████████▊     | 80/115 [02:18<01:00,  0.58it/s]
Predicting DataLoader 0:  70%|███████████▉     | 81/115 [02:19<00:58,  0.58it/s]
Predicting DataLoader 0:  71%|████████████     | 82/115 [02:21<00:57,  0.58it/s]
Predicting DataLoader 0:  72%|████████████▎    | 83/115 [02:23<00:55,  0.58it/s]
Predicting DataLoader 0:  73%|████████████▍    | 84/115 [02:25<00:53,  0.58it/s]
Predicting DataLoader 0:  74%|████████████▌    | 85/115 [02:26<00:51,  0.58it/s]
Predicting DataLoader 0:  75%|████████████▋    | 86/115 [02:28<00:50,  0.58it/s]
Predicting DataLoader 0:  76%|████████████▊    | 87/115 [02:30<00:48,  0.58it/s]
Predicting DataLoader 0:  77%|█████████████    | 88/115 [02:31<00:46,  0.58it/s]
Predicting DataLoader 0:  77%|█████████████▏   | 89/115 [02:33<00:44,  0.58it/s]
Predicting DataLoader 0:  78%|█████████████▎   | 90/115 [02:35<00:43,  0.58it/s]
Predicting DataLoader 0:  79%|█████████████▍   | 91/115 [02:37<00:41,  0.58it/s]
Predicting DataLoader 0:  80%|█████████████▌   | 92/115 [02:38<00:39,  0.58it/s]
Predicting DataLoader 0:  81%|█████████████▋   | 93/115 [02:40<00:37,  0.58it/s]
Predicting DataLoader 0:  82%|█████████████▉   | 94/115 [02:42<00:36,  0.58it/s]
Predicting DataLoader 0:  83%|██████████████   | 95/115 [02:43<00:34,  0.58it/s]
Predicting DataLoader 0:  83%|██████████████▏  | 96/115 [02:45<00:32,  0.58it/s]
Predicting DataLoader 0:  84%|██████████████▎  | 97/115 [02:47<00:31,  0.58it/s]
Predicting DataLoader 0:  85%|██████████████▍  | 98/115 [02:48<00:29,  0.58it/s]
Predicting DataLoader 0:  86%|██████████████▋  | 99/115 [02:50<00:27,  0.58it/s]
Predicting DataLoader 0:  87%|█████████████▉  | 100/115 [02:52<00:25,  0.58it/s]
Predicting DataLoader 0:  88%|██████████████  | 101/115 [02:54<00:24,  0.58it/s]
Predicting DataLoader 0:  89%|██████████████▏ | 102/115 [02:55<00:22,  0.58it/s]
Predicting DataLoader 0:  90%|██████████████▎ | 103/115 [02:57<00:20,  0.58it/s]
Predicting DataLoader 0:  90%|██████████████▍ | 104/115 [02:59<00:18,  0.58it/s]
Predicting DataLoader 0:  91%|██████████████▌ | 105/115 [03:00<00:17,  0.58it/s]
Predicting DataLoader 0:  92%|██████████████▋ | 106/115 [03:02<00:15,  0.58it/s]
Predicting DataLoader 0:  93%|██████████████▉ | 107/115 [03:04<00:13,  0.58it/s]
Predicting DataLoader 0:  94%|███████████████ | 108/115 [03:06<00:12,  0.58it/s]
Predicting DataLoader 0:  95%|███████████████▏| 109/115 [03:07<00:10,  0.58it/s]
Predicting DataLoader 0:  96%|███████████████▎| 110/115 [03:09<00:08,  0.58it/s]
Predicting DataLoader 0:  97%|███████████████▍| 111/115 [03:11<00:06,  0.58it/s]
Predicting DataLoader 0:  97%|███████████████▌| 112/115 [03:12<00:05,  0.58it/s]
Predicting DataLoader 0:  98%|███████████████▋| 113/115 [03:14<00:03,  0.58it/s]
Predicting DataLoader 0:  99%|███████████████▊| 114/115 [03:16<00:01,  0.58it/s]
Predicting DataLoader 0: 100%|████████████████| 115/115 [03:18<00:00,  0.58it/s]
Predicting DataLoader 0: 100%|████████████████| 115/115 [03:18<00:00,  0.58it/s]
/home/celikm5/miniforge3/envs/decima2/lib/python3.11/site-packages/torchmetrics/utilities/prints.py:43: UserWarning: The ``compute`` method of metric WarningCounter was called before the ``update`` method which may lead to errors, as metric states have not yet been updated.
decima - INFO - Creating anndata
decima - INFO - Evaluating performance
Performance on genes in the train dataset.
Mean Pearson Correlation per gene: Mean: 0.01.
Mean Pearson Correlation per gene using size factor (baseline): 0.03.
Mean Pearson Correlation per pseudobulk:  0.00

Performance on genes in the val dataset.
Mean Pearson Correlation per gene: Mean: -0.01.
Mean Pearson Correlation per gene using size factor (baseline): 0.06.
Mean Pearson Correlation per pseudobulk: -0.01

Performance on genes in the test dataset.
Mean Pearson Correlation per gene: Mean: -0.02.
Mean Pearson Correlation per gene using size factor (baseline): -0.00.
Mean Pearson Correlation per pseudobulk: -0.02

We can open the output h5ad file to see the individual predictions and metrics.

ad_out = anndata.read_h5ad("example/test_preds.h5ad")
ad_out
AnnData object with n_obs × n_vars = 50 × 919
    obs: 'cell_type', 'tissue', 'disease', 'study', 'size_factor', 'train_pearson', 'val_pearson', 'test_pearson'
    var: 'chrom', 'start', 'end', 'strand', 'gene_start', 'gene_end', 'gene_length', 'gene_mask_start', 'gene_mask_end', 'dataset', 'pearson', 'size_factor_pearson'
    layers: 'preds', 'preds_finetune_test_0'

.layers['preds_0'] and .layers['preds_1'] contain the predictions made by the individual models whereas .layers['preds_0'] contains the average predictions. You will see that performance metrics have been added to both .obs and .var.

ad_out.obs.head()
cell_type tissue disease study size_factor train_pearson val_pearson test_pearson
pseudobulk_0 ct_0 t_0 d_0 st_0 4947.391113 -0.023931 -0.007586 -0.125109
pseudobulk_1 ct_0 t_0 d_1 st_0 4851.750488 0.000509 -0.023935 0.067142
pseudobulk_2 ct_0 t_0 d_2 st_1 4922.177734 -0.011712 0.132251 -0.088595
pseudobulk_3 ct_0 t_0 d_0 st_1 4921.185547 0.028251 -0.114663 0.034809
pseudobulk_4 ct_0 t_0 d_1 st_2 4750.456543 -0.009447 -0.069073 0.072288
ad_out.var.head()
chrom start end strand gene_start gene_end gene_length gene_mask_start gene_mask_end dataset pearson size_factor_pearson
gene_0 chr1 26027160 26551448 + 26191000 26715288 524288 163840 524288 train -0.206382 -0.067028
gene_1 chr19 41439097 41963385 - 41275257 41799545 524288 163840 524288 train -0.001485 -0.033103
gene_2 chr1 80101706 80625994 - 79937866 80462154 524288 163840 524288 train 0.052491 0.232561
gene_4 chr16 4069048 4593336 - 3905208 4429496 524288 163840 524288 train 0.035831 -0.040373
gene_5 chr10 22331801 22856089 + 22495641 23019929 524288 163840 524288 train 0.368011 -0.066814