Fine-tuning Borzoi to create a Decima model¶
import glob
import anndata
import scanpy as sc
import pandas as pd
import bioframe as bf
import os
inputdir = "./data"
outdir = "./example"
ad_file_path = os.path.join(inputdir, "data.h5ad")
h5_file_path = os.path.join(outdir, "data.h5")
1. Load input anndata file¶
The input anndata file needs to be in the format (pseudobulks x genes).
ad = sc.read(ad_file_path)
ad
AnnData object with n_obs × n_vars = 50 × 920
obs: 'cell_type', 'tissue', 'disease', 'study'
var: 'chrom', 'start', 'end', 'strand', 'gene_start', 'gene_end', 'gene_length', 'gene_mask_start', 'gene_mask_end', 'dataset'
uns: 'log1p'
.obs should be a dataframe with a unique index per pseudobulk. You can also include other columns with metadata about the pseudobulks, e.g. cell type, tissue, disease, study, number of cells, total counts.
Note that the original Decima model does NOT separate pseudobulks by sample, i.e. different samples from the same cell type, tissue, disease and study were merged. We also recommend filtering out pseudobulks with few cells or low read count.
ad.obs.head()
| cell_type | tissue | disease | study | |
|---|---|---|---|---|
| pseudobulk_0 | ct_0 | t_0 | d_0 | st_0 |
| pseudobulk_1 | ct_0 | t_0 | d_1 | st_0 |
| pseudobulk_2 | ct_0 | t_0 | d_2 | st_1 |
| pseudobulk_3 | ct_0 | t_0 | d_0 | st_1 |
| pseudobulk_4 | ct_0 | t_0 | d_1 | st_2 |
.var should be a dataframe with a unique index per gene. The index can be the gene name or Ensembl ID, as long as it is unique. Other essential columns are: chrom, start, end and strand (the gene coordinates).
You can also include other columns with metadata about the genes, e.g. Ensembl ID, type of gene.
ad.var.head()
| chrom | start | end | strand | gene_start | gene_end | gene_length | gene_mask_start | gene_mask_end | dataset | |
|---|---|---|---|---|---|---|---|---|---|---|
| gene_0 | chr1 | 26191000 | 26715288 | + | 26354840 | 26879128 | 524288 | 163840 | 524288 | train |
| gene_1 | chr19 | 41275257 | 41799545 | - | 41111417 | 41635705 | 524288 | 163840 | 524288 | train |
| gene_2 | chr1 | 79937866 | 80462154 | - | 79774026 | 80298314 | 524288 | 163840 | 524288 | train |
| gene_4 | chr16 | 3905208 | 4429496 | - | 3741368 | 4265656 | 524288 | 163840 | 524288 | train |
| gene_5 | chr10 | 22495641 | 23019929 | + | 22659481 | 23183769 | 524288 | 163840 | 524288 | train |
.X should contain the total counts per gene and pseudobulk. These should be non-negative integers.
ad.X[:5, :5]
array([[0. , 7.295568 , 7.295568 , 7.295568 , 7.295568 ],
[7.316388 , 7.316388 , 0. , 7.316388 , 7.316388 ],
[7.3014727, 7.3014727, 7.3014727, 7.3014727, 0. ],
[7.3014727, 0. , 7.3014727, 7.3014727, 0. ],
[7.3407264, 7.3407264, 0. , 7.3407264, 7.3407264]],
dtype=float32)
2. Normalize and log transform data¶
We first transform the counts to log(CPM+1) values. CPM = Counts Per Million.
sc.pp.normalize_total(ad, target_sum=1e6)
sc.pp.log1p(ad)
WARNING: adata.X seems to be already log-transformed.
ad.X[:5, :5]
array([[0. , 7.297041 , 7.297041 , 7.297041 , 7.297041 ],
[7.317892 , 7.317892 , 0. , 7.317892 , 7.317892 ],
[7.302954 , 7.302954 , 7.302954 , 7.302954 , 0. ],
[7.3014727, 0. , 7.3014727, 7.3014727, 0. ],
[7.3422675, 7.3422675, 0. , 7.3422675, 7.3422675]],
dtype=float32)
3. Create intervals surrounding genes¶
Decima is trained on 524,288 bp sequence surrounding the genes. Therefore, we have to take the given gene coordinates and extend them to create intervals of this length.
from decima.data.preprocess import var_to_intervals
ad.var.head()
| chrom | start | end | strand | gene_start | gene_end | gene_length | gene_mask_start | gene_mask_end | dataset | |
|---|---|---|---|---|---|---|---|---|---|---|
| gene_0 | chr1 | 26191000 | 26715288 | + | 26354840 | 26879128 | 524288 | 163840 | 524288 | train |
| gene_1 | chr19 | 41275257 | 41799545 | - | 41111417 | 41635705 | 524288 | 163840 | 524288 | train |
| gene_2 | chr1 | 79937866 | 80462154 | - | 79774026 | 80298314 | 524288 | 163840 | 524288 | train |
| gene_4 | chr16 | 3905208 | 4429496 | - | 3741368 | 4265656 | 524288 | 163840 | 524288 | train |
| gene_5 | chr10 | 22495641 | 23019929 | + | 22659481 | 23183769 | 524288 | 163840 | 524288 | train |
First, we copy the start and end columns to gene_start and gene_end. We also create a new column gene_length.
ad.var["gene_start"] = ad.var.start.tolist()
ad.var["gene_end"] = ad.var.end.tolist()
ad.var["gene_length"] = ad.var["gene_end"] - ad.var["gene_start"]
ad.var.head()
| chrom | start | end | strand | gene_start | gene_end | gene_length | gene_mask_start | gene_mask_end | dataset | |
|---|---|---|---|---|---|---|---|---|---|---|
| gene_0 | chr1 | 26191000 | 26715288 | + | 26191000 | 26715288 | 524288 | 163840 | 524288 | train |
| gene_1 | chr19 | 41275257 | 41799545 | - | 41275257 | 41799545 | 524288 | 163840 | 524288 | train |
| gene_2 | chr1 | 79937866 | 80462154 | - | 79937866 | 80462154 | 524288 | 163840 | 524288 | train |
| gene_4 | chr16 | 3905208 | 4429496 | - | 3905208 | 4429496 | 524288 | 163840 | 524288 | train |
| gene_5 | chr10 | 22495641 | 23019929 | + | 22495641 | 23019929 | 524288 | 163840 | 524288 | train |
Now, we extend the gene coordinates to create enclosing intervals:
ad = var_to_intervals(ad, chr_end_pad=10000, genome="hg38")
# Replace genome name if necessary
The interval size is 524288 bases. Of these, 163840 will be upstream of the gene start and 360448 will be downstream of the gene start.
0 intervals extended beyond the chromosome start and have been shifted
1 intervals extended beyond the chromosome end and have been shifted
1 intervals did not extend far enough upstream of the TSS and have been dropped
ad.var.head()
| chrom | start | end | strand | gene_start | gene_end | gene_length | gene_mask_start | gene_mask_end | dataset | |
|---|---|---|---|---|---|---|---|---|---|---|
| gene_0 | chr1 | 26027160 | 26551448 | + | 26191000 | 26715288 | 524288 | 163840 | 524288 | train |
| gene_1 | chr19 | 41439097 | 41963385 | - | 41275257 | 41799545 | 524288 | 163840 | 524288 | train |
| gene_2 | chr1 | 80101706 | 80625994 | - | 79937866 | 80462154 | 524288 | 163840 | 524288 | train |
| gene_4 | chr16 | 4069048 | 4593336 | - | 3905208 | 4429496 | 524288 | 163840 | 524288 | train |
| gene_5 | chr10 | 22331801 | 22856089 | + | 22495641 | 23019929 | 524288 | 163840 | 524288 | train |
You see that the columns start and end now contain the start and end coordinates for the 524,288 bp intervals.
3. Split genes into training, validation and test sets¶
We load the coordinates of the genomic regions used to train Borzoi:
splits_file = "https://raw.githubusercontent.com/calico/borzoi/main/data/sequences_human.bed.gz"
# replace human with mouse for mm10 splits
splits = pd.read_table(splits_file, header=None, names=["chrom", "start", "end", "fold"])
splits.head()
| chrom | start | end | fold | |
|---|---|---|---|---|
| 0 | chr4 | 82524421 | 82721029 | fold0 |
| 1 | chr13 | 18604798 | 18801406 | fold0 |
| 2 | chr2 | 189923408 | 190120016 | fold0 |
| 3 | chr10 | 59875743 | 60072351 | fold0 |
| 4 | chr1 | 117109467 | 117306075 | fold0 |
Now, we overlap our gene intervals with these regions:
overlaps = bf.overlap(ad.var.reset_index(names="gene"), splits, how="left")
overlaps = overlaps[["gene", "fold_"]].drop_duplicates().astype(str)
overlaps.head()
| gene | fold_ | |
|---|---|---|
| 0 | gene_0 | fold5 |
| 15 | gene_1 | fold0 |
| 30 | gene_2 | fold0 |
| 45 | gene_4 | fold2 |
| 59 | gene_5 | fold2 |
Based on the overlap, we divide our gene intervals into training, validation and test sets.
test_genes = overlaps.gene[overlaps.fold_ == "fold3"].tolist()
val_genes = overlaps.gene[overlaps.fold_ == "fold4"].tolist()
train_genes = set(overlaps.gene).difference(set(test_genes).union(val_genes))
And add this information back to ad.var.
ad.var["dataset"] = "test"
ad.var.loc[ad.var.index.isin(val_genes), "dataset"] = "val"
ad.var.loc[ad.var.index.isin(train_genes), "dataset"] = "train"
/tmp/slurmjob.38313776/ipykernel_505470/3109841685.py:1: ImplicitModificationWarning: Trying to modify attribute `.var` of view, initializing view as actual.
ad.var.head()
| chrom | start | end | strand | gene_start | gene_end | gene_length | gene_mask_start | gene_mask_end | dataset | |
|---|---|---|---|---|---|---|---|---|---|---|
| gene_0 | chr1 | 26027160 | 26551448 | + | 26191000 | 26715288 | 524288 | 163840 | 524288 | train |
| gene_1 | chr19 | 41439097 | 41963385 | - | 41275257 | 41799545 | 524288 | 163840 | 524288 | train |
| gene_2 | chr1 | 80101706 | 80625994 | - | 79937866 | 80462154 | 524288 | 163840 | 524288 | train |
| gene_4 | chr16 | 4069048 | 4593336 | - | 3905208 | 4429496 | 524288 | 163840 | 524288 | train |
| gene_5 | chr10 | 22331801 | 22856089 | + | 22495641 | 23019929 | 524288 | 163840 | 524288 | train |
ad.var.dataset.value_counts()
dataset
train 765
test 80
val 74
Name: count, dtype: int64
We have now divided the 1000 genes in our dataset into separate sets to be used for training, validation and testing.
4. Save processed anndata¶
We will save the processed anndata file containing these intervals and data splits.
ad.write_h5ad(ad_file_path)
5. Create an hdf5 file¶
To train Decima, we need to extract the genomic sequences for all the intervals and convert them to one-hot encoded format. We save these one-hot encoded inputs to an hdf5 file.
from decima.data.write_hdf5 import write_hdf5
! mkdir -p example
write_hdf5(file=h5_file_path, ad=ad, pad=5000, genome="hg38")
# Change genome name if necessary
Writing metadata
Writing task indices
Writing genes array of shape: (919, 2)
Writing labels array of shape: (919, 50, 1)
Making gene masks
Writing mask array of shape: (919, 534288)
Encoding sequences
Writing sequence array of shape: (919, 534288)
Done!
Writing mask array of shape: (920, 534288)
Encoding sequences
Writing sequence array of shape: (920, 534288)
Done!
6. Set training parameters¶
# Learning rate default=0.001
lr = 5e-5
# Total weight parameter for the loss function
total_weight = 1e-4
# Gradient accumulation steps
grad = 5
# batch-size. default=4
bs = 4
# max-seq-shift. default=5000
shift = 5000
# Number of epochs. Default 1
epochs = 15
# logger
logger = "wandb" # Change to csv to save logs locally
# Number of workers default=16
workers = 16
7. Generate training commands¶
cmds = []
for model in range(4):
name = f"finetune_test_{model}"
device = model
cmd = (
f"decima finetune --name {name} "
+ f"--model {model} --device {device} "
+ f"--matrix-file {ad_file_path} --h5-file {h5_file_path} "
+ f"--outdir {outdir} --learning-rate {lr} "
+ f"--loss-total-weight {total_weight} --gradient-accumulation {grad} "
+ f"--batch-size {bs} --max-seq-shift {shift} "
+ f"--epochs {epochs} --logger {logger} --num-workers {workers}"
)
cmds.append(cmd)
for cmd in cmds:
print(cmd)
decima finetune --name finetune_test_0 --model 0 --device 0 --matrix-file ./data/data.h5ad --h5-file ./example/data.h5 --outdir ./example --learning-rate 5e-05 --loss-total-weight 0.0001 --gradient-accumulation 5 --batch-size 4 --max-seq-shift 5000 --epochs 15 --logger wandb --num-workers 16
decima finetune --name finetune_test_1 --model 1 --device 1 --matrix-file ./data/data.h5ad --h5-file ./example/data.h5 --outdir ./example --learning-rate 5e-05 --loss-total-weight 0.0001 --gradient-accumulation 5 --batch-size 4 --max-seq-shift 5000 --epochs 15 --logger wandb --num-workers 16
decima finetune --name finetune_test_2 --model 2 --device 2 --matrix-file ./data/data.h5ad --h5-file ./example/data.h5 --outdir ./example --learning-rate 5e-05 --loss-total-weight 0.0001 --gradient-accumulation 5 --batch-size 4 --max-seq-shift 5000 --epochs 15 --logger wandb --num-workers 16
decima finetune --name finetune_test_3 --model 3 --device 3 --matrix-file ./data/data.h5ad --h5-file ./example/data.h5 --outdir ./example --learning-rate 5e-05 --loss-total-weight 0.0001 --gradient-accumulation 5 --batch-size 4 --max-seq-shift 5000 --epochs 15 --logger wandb --num-workers 16
Here, we train the model for 1 epoch for quick progressing in tutorial. Run the training for more epochs in your training.
! CUDA_VISIBLE_DEVICES=0 decima finetune \
--name finetune_test_0 \
--model 0 \
--device 0 \
--matrix-file {ad_file_path} \
--h5-file {h5_file_path} \
--outdir {outdir} \
--learning-rate {lr} \
--loss-total-weight {total_weight} \
--gradient-accumulation {grad} \
--batch-size 1 \
--max-seq-shift {shift} \
--epochs 1 \
--logger {logger} \
--num-workers {workers}
# Uncomment if necessary
# import wandb
# wandb.login(host="https://genentech.wandb.io", anonymous="never", relogin=True)
8. Make and evaluate predictions using trained models¶
Using the training commands above, we trained two model replicates. Now, we can use these models to predict gene expression:
checkpoint = glob.glob(os.path.join(outdir, "lightning_logs/*/checkpoints/*.ckpt"))[0]
print(checkpoint)
./example/lightning_logs/g0m7s659/checkpoints/epoch=0-step=153.ckpt
# comma-separated list of model checkpoints
checkpoint_list = ",".join([checkpoint, checkpoint])
checkpoint_list
'./example/lightning_logs/g0m7s659/checkpoints/epoch=0-step=153.ckpt,./example/lightning_logs/g0m7s659/checkpoints/epoch=0-step=153.ckpt'
! CUDA_VISIBLE_DEVICES=0 decima predict-genes \
--output example/test_preds.h5ad \
--model {checkpoint_list} \
--metadata {ad_file_path} \
--device 0 \
--batch-size 8 \
--num-workers 16 \
--max_seq_shift 0 \
--genome hg38 \
--save-replicates
decima - INFO - Using device: 0 and genome: hg38 for prediction.
decima - INFO - Loading model ['./example/lightning_logs/g0m7s659/checkpoints/epoch=0-step=153.ckpt', './example/lightning_logs/g0m7s659/checkpoints/epoch=0-step=153.ckpt']...
/gpfs/scratchfs01/site/u/lala8/conda/envs/decima/lib/python3.11/site-packages/lightning_fabric/utilities/cloud_io.py:73: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
decima - INFO - Making predictions
/gpfs/scratchfs01/site/u/lala8/conda/envs/decima/lib/python3.11/site-packages/lightning_fabric/plugins/environments/slurm.py:204: PossibleUserWarning: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python3.11 /home/lala8/.local/bin/decima predict-genes --ou ...
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
💡 Tip: For seamless cloud logging and experiment tracking, try installing [litlogger](https://pypi.org/project/litlogger/) to enable LitLogger, which logs metrics and artifacts automatically to the Lightning Experiments platform.
💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Predicting ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 115/115 0:05:16 • 0:00:00 0.36it/s it/s it/s
?25h/gpfs/scratchfs01/site/u/lala8/conda/envs/decima/lib/python3.11/site-packages/torchmetrics/utilities/prints.py:43: UserWarning: The ``compute`` method of metric WarningCounter was called before the ``update`` method which may lead to errors, as metric states have not yet been updated.
decima - INFO - Creating anndata
decima - INFO - Evaluating performance
Performance on genes in the train dataset.
Mean Pearson Correlation per gene: Mean: 0.01.
Mean Pearson Correlation per gene using size factor (baseline): 0.03.
Mean Pearson Correlation per pseudobulk: -0.00
Performance on genes in the val dataset.
Mean Pearson Correlation per gene: Mean: -0.02.
Mean Pearson Correlation per gene using size factor (baseline): 0.05.
Mean Pearson Correlation per pseudobulk: 0.01
Performance on genes in the test dataset.
Mean Pearson Correlation per gene: Mean: -0.02.
Mean Pearson Correlation per gene using size factor (baseline): -0.00.
Mean Pearson Correlation per pseudobulk: -0.02
decima - INFO - Using device: 0 and genome: hg38 for prediction.
decima - INFO - Loading model ['./example/lightning_logs/g20ya0al/checkpoints/epoch=0-step=154.ckpt', './example/lightning_logs/g20ya0al/checkpoints/epoch=0-step=154.ckpt']...
decima - INFO - Making predictions
/home/celikm5/miniforge3/envs/decima2/lib/python3.11/site-packages/torch/__init__.py:1617: UserWarning: Please use the new API settings to control TF32 behavior, such as torch.backends.cudnn.conv.fp32_precision = 'tf32' or torch.backends.cuda.matmul.fp32_precision = 'ieee'. Old settings, e.g, torch.backends.cuda.matmul.allow_tf32 = True, torch.backends.cudnn.allow_tf32 = True, allowTF32CuDNN() and allowTF32CuBLAS() will be deprecated after Pytorch 2.9. Please see https://pytorch.org/docs/main/notes/cuda.html#tensorfloat-32-tf32-on-ampere-and-later-devices (Triggered internally at /pytorch/aten/src/ATen/Context.cpp:80.)
💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/home/celikm5/miniforge3/envs/decima2/lib/python3.11/site-packages/torch/utils/data/dataloader.py:627: UserWarning: This DataLoader will create 32 worker processes in total. Our suggested max number of worker in current system is 4, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
SLURM auto-requeueing enabled. Setting signal handlers.
Predicting: | | 0/? [00:00<?, ?it/s]
Predicting: | | 0/? [00:00<?, ?it/s]
Predicting DataLoader 0: 0%| | 0/115 [00:00<?, ?it/s]
Predicting DataLoader 0: 1%|▏ | 1/115 [00:03<06:53, 0.28it/s]
Predicting DataLoader 0: 2%|▎ | 2/115 [00:05<04:55, 0.38it/s]
Predicting DataLoader 0: 3%|▍ | 3/115 [00:06<04:18, 0.43it/s]
Predicting DataLoader 0: 3%|▋ | 4/115 [00:08<03:59, 0.46it/s]
Predicting DataLoader 0: 4%|▊ | 5/115 [00:10<03:47, 0.48it/s]
Predicting DataLoader 0: 5%|▉ | 6/115 [00:12<03:38, 0.50it/s]
Predicting DataLoader 0: 6%|█ | 7/115 [00:13<03:31, 0.51it/s]
Predicting DataLoader 0: 7%|█▎ | 8/115 [00:15<03:26, 0.52it/s]
Predicting DataLoader 0: 8%|█▍ | 9/115 [00:17<03:21, 0.53it/s]
Predicting DataLoader 0: 9%|█▍ | 10/115 [00:18<03:17, 0.53it/s]
Predicting DataLoader 0: 10%|█▋ | 11/115 [00:20<03:14, 0.54it/s]
Predicting DataLoader 0: 10%|█▊ | 12/115 [00:22<03:10, 0.54it/s]
Predicting DataLoader 0: 11%|█▉ | 13/115 [00:23<03:07, 0.54it/s]
Predicting DataLoader 0: 12%|██ | 14/115 [00:25<03:05, 0.55it/s]
Predicting DataLoader 0: 13%|██▏ | 15/115 [00:27<03:02, 0.55it/s]
Predicting DataLoader 0: 14%|██▎ | 16/115 [00:29<02:59, 0.55it/s]
Predicting DataLoader 0: 15%|██▌ | 17/115 [00:30<02:57, 0.55it/s]
Predicting DataLoader 0: 16%|██▋ | 18/115 [00:32<02:54, 0.55it/s]
Predicting DataLoader 0: 17%|██▊ | 19/115 [00:34<02:52, 0.56it/s]
Predicting DataLoader 0: 17%|██▉ | 20/115 [00:35<02:50, 0.56it/s]
Predicting DataLoader 0: 18%|███ | 21/115 [00:37<02:48, 0.56it/s]
Predicting DataLoader 0: 19%|███▎ | 22/115 [00:39<02:46, 0.56it/s]
Predicting DataLoader 0: 20%|███▍ | 23/115 [00:40<02:43, 0.56it/s]
Predicting DataLoader 0: 21%|███▌ | 24/115 [00:42<02:41, 0.56it/s]
Predicting DataLoader 0: 22%|███▋ | 25/115 [00:44<02:39, 0.56it/s]
Predicting DataLoader 0: 23%|███▊ | 26/115 [00:46<02:37, 0.56it/s]
Predicting DataLoader 0: 23%|███▉ | 27/115 [00:47<02:35, 0.56it/s]
Predicting DataLoader 0: 24%|████▏ | 28/115 [00:49<02:33, 0.57it/s]
Predicting DataLoader 0: 25%|████▎ | 29/115 [00:51<02:31, 0.57it/s]
Predicting DataLoader 0: 26%|████▍ | 30/115 [00:52<02:29, 0.57it/s]
Predicting DataLoader 0: 27%|████▌ | 31/115 [00:54<02:28, 0.57it/s]
Predicting DataLoader 0: 28%|████▋ | 32/115 [00:56<02:26, 0.57it/s]
Predicting DataLoader 0: 29%|████▉ | 33/115 [00:58<02:24, 0.57it/s]
Predicting DataLoader 0: 30%|█████ | 34/115 [00:59<02:22, 0.57it/s]
Predicting DataLoader 0: 30%|█████▏ | 35/115 [01:01<02:20, 0.57it/s]
Predicting DataLoader 0: 31%|█████▎ | 36/115 [01:03<02:18, 0.57it/s]
Predicting DataLoader 0: 32%|█████▍ | 37/115 [01:04<02:16, 0.57it/s]
Predicting DataLoader 0: 33%|█████▌ | 38/115 [01:06<02:14, 0.57it/s]
Predicting DataLoader 0: 34%|█████▊ | 39/115 [01:08<02:13, 0.57it/s]
Predicting DataLoader 0: 35%|█████▉ | 40/115 [01:10<02:11, 0.57it/s]
Predicting DataLoader 0: 36%|██████ | 41/115 [01:11<02:09, 0.57it/s]
Predicting DataLoader 0: 37%|██████▏ | 42/115 [01:13<02:07, 0.57it/s]
Predicting DataLoader 0: 37%|██████▎ | 43/115 [01:15<02:05, 0.57it/s]
Predicting DataLoader 0: 38%|██████▌ | 44/115 [01:16<02:03, 0.57it/s]
Predicting DataLoader 0: 39%|██████▋ | 45/115 [01:18<02:02, 0.57it/s]
Predicting DataLoader 0: 40%|██████▊ | 46/115 [01:20<02:00, 0.57it/s]
Predicting DataLoader 0: 41%|██████▉ | 47/115 [01:21<01:58, 0.57it/s]
Predicting DataLoader 0: 42%|███████ | 48/115 [01:23<01:56, 0.57it/s]
Predicting DataLoader 0: 43%|███████▏ | 49/115 [01:25<01:54, 0.57it/s]
Predicting DataLoader 0: 43%|███████▍ | 50/115 [01:27<01:53, 0.57it/s]
Predicting DataLoader 0: 44%|███████▌ | 51/115 [01:28<01:51, 0.57it/s]
Predicting DataLoader 0: 45%|███████▋ | 52/115 [01:30<01:49, 0.57it/s]
Predicting DataLoader 0: 46%|███████▊ | 53/115 [01:32<01:47, 0.57it/s]
Predicting DataLoader 0: 47%|███████▉ | 54/115 [01:33<01:46, 0.58it/s]
Predicting DataLoader 0: 48%|████████▏ | 55/115 [01:35<01:44, 0.58it/s]
Predicting DataLoader 0: 49%|████████▎ | 56/115 [01:37<01:42, 0.58it/s]
Predicting DataLoader 0: 50%|████████▍ | 57/115 [01:39<01:40, 0.58it/s]
Predicting DataLoader 0: 50%|████████▌ | 58/115 [01:40<01:38, 0.58it/s]
Predicting DataLoader 0: 51%|████████▋ | 59/115 [01:42<01:37, 0.58it/s]
Predicting DataLoader 0: 52%|████████▊ | 60/115 [01:44<01:35, 0.58it/s]
Predicting DataLoader 0: 53%|█████████ | 61/115 [01:45<01:33, 0.58it/s]
Predicting DataLoader 0: 54%|█████████▏ | 62/115 [01:47<01:31, 0.58it/s]
Predicting DataLoader 0: 55%|█████████▎ | 63/115 [01:49<01:30, 0.58it/s]
Predicting DataLoader 0: 56%|█████████▍ | 64/115 [01:50<01:28, 0.58it/s]
Predicting DataLoader 0: 57%|█████████▌ | 65/115 [01:52<01:26, 0.58it/s]
Predicting DataLoader 0: 57%|█████████▊ | 66/115 [01:54<01:24, 0.58it/s]
Predicting DataLoader 0: 58%|█████████▉ | 67/115 [01:56<01:23, 0.58it/s]
Predicting DataLoader 0: 59%|██████████ | 68/115 [01:57<01:21, 0.58it/s]
Predicting DataLoader 0: 60%|██████████▏ | 69/115 [01:59<01:19, 0.58it/s]
Predicting DataLoader 0: 61%|██████████▎ | 70/115 [02:01<01:17, 0.58it/s]
Predicting DataLoader 0: 62%|██████████▍ | 71/115 [02:02<01:16, 0.58it/s]
Predicting DataLoader 0: 63%|██████████▋ | 72/115 [02:04<01:14, 0.58it/s]
Predicting DataLoader 0: 63%|██████████▊ | 73/115 [02:06<01:12, 0.58it/s]
Predicting DataLoader 0: 64%|██████████▉ | 74/115 [02:08<01:10, 0.58it/s]
Predicting DataLoader 0: 65%|███████████ | 75/115 [02:09<01:09, 0.58it/s]
Predicting DataLoader 0: 66%|███████████▏ | 76/115 [02:11<01:07, 0.58it/s]
Predicting DataLoader 0: 67%|███████████▍ | 77/115 [02:13<01:05, 0.58it/s]
Predicting DataLoader 0: 68%|███████████▌ | 78/115 [02:14<01:03, 0.58it/s]
Predicting DataLoader 0: 69%|███████████▋ | 79/115 [02:16<01:02, 0.58it/s]
Predicting DataLoader 0: 70%|███████████▊ | 80/115 [02:18<01:00, 0.58it/s]
Predicting DataLoader 0: 70%|███████████▉ | 81/115 [02:19<00:58, 0.58it/s]
Predicting DataLoader 0: 71%|████████████ | 82/115 [02:21<00:57, 0.58it/s]
Predicting DataLoader 0: 72%|████████████▎ | 83/115 [02:23<00:55, 0.58it/s]
Predicting DataLoader 0: 73%|████████████▍ | 84/115 [02:25<00:53, 0.58it/s]
Predicting DataLoader 0: 74%|████████████▌ | 85/115 [02:26<00:51, 0.58it/s]
Predicting DataLoader 0: 75%|████████████▋ | 86/115 [02:28<00:50, 0.58it/s]
Predicting DataLoader 0: 76%|████████████▊ | 87/115 [02:30<00:48, 0.58it/s]
Predicting DataLoader 0: 77%|█████████████ | 88/115 [02:31<00:46, 0.58it/s]
Predicting DataLoader 0: 77%|█████████████▏ | 89/115 [02:33<00:44, 0.58it/s]
Predicting DataLoader 0: 78%|█████████████▎ | 90/115 [02:35<00:43, 0.58it/s]
Predicting DataLoader 0: 79%|█████████████▍ | 91/115 [02:37<00:41, 0.58it/s]
Predicting DataLoader 0: 80%|█████████████▌ | 92/115 [02:38<00:39, 0.58it/s]
Predicting DataLoader 0: 81%|█████████████▋ | 93/115 [02:40<00:37, 0.58it/s]
Predicting DataLoader 0: 82%|█████████████▉ | 94/115 [02:42<00:36, 0.58it/s]
Predicting DataLoader 0: 83%|██████████████ | 95/115 [02:43<00:34, 0.58it/s]
Predicting DataLoader 0: 83%|██████████████▏ | 96/115 [02:45<00:32, 0.58it/s]
Predicting DataLoader 0: 84%|██████████████▎ | 97/115 [02:47<00:31, 0.58it/s]
Predicting DataLoader 0: 85%|██████████████▍ | 98/115 [02:48<00:29, 0.58it/s]
Predicting DataLoader 0: 86%|██████████████▋ | 99/115 [02:50<00:27, 0.58it/s]
Predicting DataLoader 0: 87%|█████████████▉ | 100/115 [02:52<00:25, 0.58it/s]
Predicting DataLoader 0: 88%|██████████████ | 101/115 [02:54<00:24, 0.58it/s]
Predicting DataLoader 0: 89%|██████████████▏ | 102/115 [02:55<00:22, 0.58it/s]
Predicting DataLoader 0: 90%|██████████████▎ | 103/115 [02:57<00:20, 0.58it/s]
Predicting DataLoader 0: 90%|██████████████▍ | 104/115 [02:59<00:18, 0.58it/s]
Predicting DataLoader 0: 91%|██████████████▌ | 105/115 [03:00<00:17, 0.58it/s]
Predicting DataLoader 0: 92%|██████████████▋ | 106/115 [03:02<00:15, 0.58it/s]
Predicting DataLoader 0: 93%|██████████████▉ | 107/115 [03:04<00:13, 0.58it/s]
Predicting DataLoader 0: 94%|███████████████ | 108/115 [03:06<00:12, 0.58it/s]
Predicting DataLoader 0: 95%|███████████████▏| 109/115 [03:07<00:10, 0.58it/s]
Predicting DataLoader 0: 96%|███████████████▎| 110/115 [03:09<00:08, 0.58it/s]
Predicting DataLoader 0: 97%|███████████████▍| 111/115 [03:11<00:06, 0.58it/s]
Predicting DataLoader 0: 97%|███████████████▌| 112/115 [03:12<00:05, 0.58it/s]
Predicting DataLoader 0: 98%|███████████████▋| 113/115 [03:14<00:03, 0.58it/s]
Predicting DataLoader 0: 99%|███████████████▊| 114/115 [03:16<00:01, 0.58it/s]
Predicting DataLoader 0: 100%|████████████████| 115/115 [03:18<00:00, 0.58it/s]
Predicting DataLoader 0: 100%|████████████████| 115/115 [03:18<00:00, 0.58it/s]
/home/celikm5/miniforge3/envs/decima2/lib/python3.11/site-packages/torchmetrics/utilities/prints.py:43: UserWarning: The ``compute`` method of metric WarningCounter was called before the ``update`` method which may lead to errors, as metric states have not yet been updated.
decima - INFO - Creating anndata
decima - INFO - Evaluating performance
Performance on genes in the train dataset.
Mean Pearson Correlation per gene: Mean: 0.01.
Mean Pearson Correlation per gene using size factor (baseline): 0.03.
Mean Pearson Correlation per pseudobulk: 0.00
Performance on genes in the val dataset.
Mean Pearson Correlation per gene: Mean: -0.01.
Mean Pearson Correlation per gene using size factor (baseline): 0.06.
Mean Pearson Correlation per pseudobulk: -0.01
Performance on genes in the test dataset.
Mean Pearson Correlation per gene: Mean: -0.02.
Mean Pearson Correlation per gene using size factor (baseline): -0.00.
Mean Pearson Correlation per pseudobulk: -0.02
We can open the output h5ad file to see the individual predictions and metrics.
ad_out = anndata.read_h5ad("example/test_preds.h5ad")
ad_out
AnnData object with n_obs × n_vars = 50 × 919
obs: 'cell_type', 'tissue', 'disease', 'study', 'size_factor', 'train_pearson', 'val_pearson', 'test_pearson'
var: 'chrom', 'start', 'end', 'strand', 'gene_start', 'gene_end', 'gene_length', 'gene_mask_start', 'gene_mask_end', 'dataset', 'pearson', 'size_factor_pearson'
layers: 'preds', 'preds_finetune_test_0'
.layers['preds_0'] and .layers['preds_1'] contain the predictions made by the individual models whereas .layers['preds_0'] contains the average predictions. You will see that performance metrics have been added to both .obs and .var.
ad_out.obs.head()
| cell_type | tissue | disease | study | size_factor | train_pearson | val_pearson | test_pearson | |
|---|---|---|---|---|---|---|---|---|
| pseudobulk_0 | ct_0 | t_0 | d_0 | st_0 | 4947.391113 | -0.023931 | -0.007586 | -0.125109 |
| pseudobulk_1 | ct_0 | t_0 | d_1 | st_0 | 4851.750488 | 0.000509 | -0.023935 | 0.067142 |
| pseudobulk_2 | ct_0 | t_0 | d_2 | st_1 | 4922.177734 | -0.011712 | 0.132251 | -0.088595 |
| pseudobulk_3 | ct_0 | t_0 | d_0 | st_1 | 4921.185547 | 0.028251 | -0.114663 | 0.034809 |
| pseudobulk_4 | ct_0 | t_0 | d_1 | st_2 | 4750.456543 | -0.009447 | -0.069073 | 0.072288 |
ad_out.var.head()
| chrom | start | end | strand | gene_start | gene_end | gene_length | gene_mask_start | gene_mask_end | dataset | pearson | size_factor_pearson | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|
| gene_0 | chr1 | 26027160 | 26551448 | + | 26191000 | 26715288 | 524288 | 163840 | 524288 | train | -0.206382 | -0.067028 |
| gene_1 | chr19 | 41439097 | 41963385 | - | 41275257 | 41799545 | 524288 | 163840 | 524288 | train | -0.001485 | -0.033103 |
| gene_2 | chr1 | 80101706 | 80625994 | - | 79937866 | 80462154 | 524288 | 163840 | 524288 | train | 0.052491 | 0.232561 |
| gene_4 | chr16 | 4069048 | 4593336 | - | 3905208 | 4429496 | 524288 | 163840 | 524288 | train | 0.035831 | -0.040373 |
| gene_5 | chr10 | 22331801 | 22856089 | + | 22495641 | 23019929 | 524288 | 163840 | 524288 | train | 0.368011 | -0.066814 |