Train a single-task regression model from scratch#
In this tutorial, we train a single-task convolutional regression model to predict total coverage over ATAC-seq peaks, starting from an ATAC-seq fragment file.
import os
import numpy as np
import pandas as pd
import torch
#os.environ["CUDA_VISIBLE_DEVICES"] = "4"
Set experiment parameters#
experiment='tutorial_3'
if not os.path.exists(experiment):
os.makedirs(experiment)
Peak and fragment files#
We downloaded pseudobulk scATAC data for human microglia from Corces et al. (2020): https://www.nature.com/articles/s41588-020-00721-x
. Here, we use the grelu.resources
module to download the fragment file and peak file from the model zoo:
import grelu.resources
# Download these datasets into local directories
fragment_file_dir = grelu.resources.get_artifact(
project='microglia-scatac-tutorial', name='fragment_file').download()
peak_file_dir = grelu.resources.get_artifact(
project='microglia-scatac-tutorial', name='peak_file').download()
# Paths to files
frag_file = os.path.join(fragment_file_dir, "Microglia_full.bed")
peak_file = os.path.join(peak_file_dir, "Microglia_full_peaks.narrowPeak")
Set parameters#
seq_len=2114 # Length of the input sequence
label_len=1000 # Length over which we calculate total coverage
val_chroms=["chr10"]
test_chroms=["chr11"]
genome="hg38"
Read peak file#
We read peak coordinates from the narrowPeak file.
import grelu.io.bed
peaks = grelu.io.bed.read_narrowpeak(peak_file)
peaks.tail(3)
chrom | start | end | name | score | strand | signal | pvalue | qvalue | summit | |
---|---|---|---|---|---|---|---|---|---|---|
83316 | chrY | 56870777 | 56870983 | Microglia_full_peak_83318 | 94 | . | 3.83352 | 11.58170 | 9.41381 | 116 |
83317 | chrY | 56873629 | 56873811 | Microglia_full_peak_83319 | 49 | . | 3.01171 | 7.02098 | 4.99754 | 105 |
83318 | chrY | 56874075 | 56874225 | Microglia_full_peak_83320 | 42 | . | 2.86533 | 6.19767 | 4.20704 | 17 |
Summit-center peaks#
We extract the genomic coordinates for the 2114 bases surrounding the summit of each peak.
import grelu.data.preprocess
peaks = grelu.data.preprocess.extend_from_coord(
peaks,
seq_len=seq_len,
center_col="summit"
)
peaks.tail(3)
chrom | start | end | |
---|---|---|---|
83316 | chrY | 56869836 | 56871950 |
83317 | chrY | 56872677 | 56874791 |
83318 | chrY | 56873035 | 56875149 |
Filter peaks#
We filter the peaks to include only those within autosomes. You can also use “autosomesX” or “autosomesXY” to include sex chromosomes.
peaks = grelu.data.preprocess.filter_chromosomes(peaks, 'autosomes')
Keeping 80823 intervals
We drop peaks that are close to ENCODE hg38 blacklist regions.
peaks = grelu.data.preprocess.filter_blacklist(
peaks,
genome=genome,
window=50 # Remove peaks if they are within 50 bp of a blacklist region
)
Keeping 80028 intervals
Get GC matched negative regions#
To ensure that the model also learns to recognize regions that are not peaks, we will include a set of “negative” (non-peak) regions with similar GC content to the peaks.
negatives = grelu.data.preprocess.get_gc_matched_intervals(
peaks,
binwidth=0.02, # resolution of measuring GC content
genome=genome,
chroms="autosomes", # negative regions will also be chosen from autosomes
#gc_bw_file='gc_hg38_2114.bw',
blacklist=genome, # negative regions overlapping the blacklist will be dropped
seed=0,
)
negatives.head(3)
Calculating GC content genomewide and saving to gc_hg38_2114.bw
Extracting matching intervals
GC paired t-test: 0.027, 9.33e-26
Filtering blacklist
Keeping 77564 intervals
chrom | start | end | |
---|---|---|---|
0 | chr3 | 69786703 | 69788817 |
1 | chr10 | 111915184 | 111917298 |
2 | chr17 | 4597230 | 4599344 |
We can visualize a histogram of GC content in the peaks and negative regions to verify that they are similar.
Combine peaks and negative regions#
The full dataset includes both peaks and negative regions.
regions = pd.concat([peaks, negatives])
len(regions)
157592
Make BigWig file#
We convert the ATAC-seq fragment file into a bigWig file which contains the number of Tn5 insertions at each position.
bw_file = grelu.data.preprocess.make_insertion_bigwig(
frag_file = frag_file,
plus_shift=0,
minus_shift=1, # This corrects the +4/-5 Tn5 shift to a +4/-4 shift
genome=genome,
chroms="autosomes", # The output bigWig file contains coverage over autosomes.
)
Making bedgraph file
cat /code/gReLU/docs/tutorials/artifacts/fragment_file:v0/Microglia_full.bed | awk -v OFS="\t" '{print $1,$2+0,$3,1000,0,"+";
print $1,$2,$3+1,1000,0,"-"}' | sort -k1,1 | grep -e ^chr1 -e ^chr2 -e ^chr3 -e ^chr4 -e ^chr5 -e ^chr6 -e ^chr7 -e ^chr8 -e ^chr9 -e ^chr10 -e ^chr11 -e ^chr12 -e ^chr13 -e ^chr14 -e ^chr15 -e ^chr16 -e ^chr17 -e ^chr18 -e ^chr19 -e ^chr20 -e ^chr21 -e ^chr22 | bedtools genomecov -bg -5 -i stdin -g /root/.local/share/genomes/hg38/hg38.fa.sizes | bedtools sort -i stdin > ./Microglia_full.bedGraph
Making bigWig file
bedGraphToBigWig ./Microglia_full.bedGraph /root/.local/share/genomes/hg38/hg38.fa.sizes ./Microglia_full.bw
Deleting temporary files
Split data by chromosome#
We now split the dataset by chromosome to create separate datasets for training, validation and testing.
train, val, test = grelu.data.preprocess.split(
regions, val_chroms=val_chroms, test_chroms=test_chroms)
Selecting training samples
Keeping 141228 intervals
Selecting validation samples
Keeping 7921 intervals
Selecting test samples
Keeping 8443 intervals
Final sizes: train: (126990, 3), val: (577, 3), test: (673, 3)
Make labeled datasets#
We now make pytorch dataset objects in order to load paired sequences and coverage values from the genome and the bigWig file. We use the BigWigSeqDataset
class.
We first make the training dataset. To increase model robustness we use several forms of data augmentation here: rc=True
(reverse complementing the input sequence), max_seq_shift=2
(shifting the coordinates of the input sequence by upto 2 bp in either direction; also known as jitter), and max_pair_shift=20
(shifting both the input sequence and the region for which to calculate coverage by upto 20 bp in either direction).
Further, we use label_aggfunc="sum"
which means that the label will be the summed coverage over the central region of the interval.
import grelu.data.dataset
train_ds = grelu.data.dataset.BigWigSeqDataset(
intervals = train,
bw_files=[bw_file],
label_len=label_len,
label_aggfunc="sum",
rc=True, # reverse complement
max_seq_shift=2, # Shift the sequence
max_pair_shift=20, # Shift both sequence and label
augment_mode="random",
seed=0,
genome=genome,
)
We do not apply any augmentations to the validation and test datasets (although it is possible to do so).
val_ds = grelu.data.dataset.BigWigSeqDataset(
intervals = val,
bw_files=[bw_file],
label_len=label_len,
label_aggfunc="sum",
genome=genome,
)
test_ds = grelu.data.dataset.BigWigSeqDataset(
intervals = test,
bw_files=[bw_file],
label_len=label_len,
label_aggfunc="sum",
genome=genome,
)
len(train_ds), len(val_ds), len(test_ds)
(126990, 577, 673)
Build model#
model_params = {
'model_type':'DilatedConvModel',
'crop_len':(seq_len-label_len)//2,
'n_tasks':1,
'channels':512,
'n_conv':8,
}
train_params = {
'task':'regression',
'loss': 'poisson', # Poisson loss. Other regression loss functions are "mse" and "poisson_multinomial"
'logger':'csv',
'lr':1e-4,
'batch_size':256,
'max_epochs':10,
'devices':0,
'num_workers':16,
'save_dir':experiment,
}
import grelu.lightning
model = grelu.lightning.LightningModel(model_params, train_params)
Train model#
# See the tutorial_3 folder for logs.
trainer = model.train_on_dataset(train_ds, val_ds)
Validation DataLoader 0: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:06<00:00, 0.49it/s]
Epoch 0: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 497/497 [04:43<00:00, 1.75it/s, v_num=0, train_loss_step=224.0]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: 0%| | 0/3 [00:00<?, ?it/s]
Validation DataLoader 0: 0%| | 0/3 [00:00<?, ?it/s]
Validation DataLoader 0: 33%|████████████████████████████████████▋ | 1/3 [00:00<00:00, 4.59it/s]
Validation DataLoader 0: 67%|█████████████████████████████████████████████████████████████████████████▎ | 2/3 [00:00<00:00, 4.57it/s]
Validation DataLoader 0: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 6.03it/s]
Epoch 1: 100%|██████████████████████████████████████████████████████████████████| 497/497 [04:58<00:00, 1.67it/s, v_num=0, train_loss_step=250.0, train_loss_epoch=172.0]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: 0%| | 0/3 [00:00<?, ?it/s]
Validation DataLoader 0: 0%| | 0/3 [00:00<?, ?it/s]
Validation DataLoader 0: 33%|████████████████████████████████████▋ | 1/3 [00:00<00:00, 4.63it/s]
Validation DataLoader 0: 67%|█████████████████████████████████████████████████████████████████████████▎ | 2/3 [00:00<00:00, 4.63it/s]
Validation DataLoader 0: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 6.10it/s]
Epoch 2: 100%|██████████████████████████████████████████████████████████████████| 497/497 [04:33<00:00, 1.82it/s, v_num=0, train_loss_step=129.0, train_loss_epoch=152.0]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: 0%| | 0/3 [00:00<?, ?it/s]
Validation DataLoader 0: 0%| | 0/3 [00:00<?, ?it/s]
Validation DataLoader 0: 33%|████████████████████████████████████▋ | 1/3 [00:00<00:00, 4.68it/s]
Validation DataLoader 0: 67%|█████████████████████████████████████████████████████████████████████████▎ | 2/3 [00:00<00:00, 4.63it/s]
Validation DataLoader 0: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 6.10it/s]
Epoch 3: 100%|██████████████████████████████████████████████████████████████████| 497/497 [04:57<00:00, 1.67it/s, v_num=0, train_loss_step=123.0, train_loss_epoch=137.0]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: 0%| | 0/3 [00:00<?, ?it/s]
Validation DataLoader 0: 0%| | 0/3 [00:00<?, ?it/s]
Validation DataLoader 0: 33%|████████████████████████████████████▋ | 1/3 [00:00<00:00, 4.62it/s]
Validation DataLoader 0: 67%|█████████████████████████████████████████████████████████████████████████▎ | 2/3 [00:00<00:00, 4.58it/s]
Validation DataLoader 0: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 6.08it/s]
Epoch 4: 100%|██████████████████████████████████████████████████████████████████| 497/497 [04:30<00:00, 1.84it/s, v_num=0, train_loss_step=65.30, train_loss_epoch=128.0]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: 0%| | 0/3 [00:00<?, ?it/s]
Validation DataLoader 0: 0%| | 0/3 [00:00<?, ?it/s]
Validation DataLoader 0: 33%|████████████████████████████████████▋ | 1/3 [00:00<00:00, 4.62it/s]
Validation DataLoader 0: 67%|█████████████████████████████████████████████████████████████████████████▎ | 2/3 [00:00<00:00, 4.58it/s]
Validation DataLoader 0: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 6.04it/s]
Epoch 5: 100%|██████████████████████████████████████████████████████████████████| 497/497 [04:37<00:00, 1.79it/s, v_num=0, train_loss_step=231.0, train_loss_epoch=125.0]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: 0%| | 0/3 [00:00<?, ?it/s]
Validation DataLoader 0: 0%| | 0/3 [00:00<?, ?it/s]
Validation DataLoader 0: 33%|████████████████████████████████████▋ | 1/3 [00:00<00:00, 9.26it/s]
Validation DataLoader 0: 67%|█████████████████████████████████████████████████████████████████████████▎ | 2/3 [00:00<00:00, 9.38it/s]
Validation DataLoader 0: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 12.41it/s]
Epoch 6: 100%|██████████████████████████████████████████████████████████████████| 497/497 [04:58<00:00, 1.67it/s, v_num=0, train_loss_step=68.40, train_loss_epoch=124.0]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: 0%| | 0/3 [00:00<?, ?it/s]
Validation DataLoader 0: 0%| | 0/3 [00:00<?, ?it/s]
Validation DataLoader 0: 33%|████████████████████████████████████▋ | 1/3 [00:00<00:00, 4.59it/s]
Validation DataLoader 0: 67%|█████████████████████████████████████████████████████████████████████████▎ | 2/3 [00:00<00:00, 4.59it/s]
Validation DataLoader 0: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 6.05it/s]
Epoch 7: 100%|██████████████████████████████████████████████████████████████████| 497/497 [04:24<00:00, 1.88it/s, v_num=0, train_loss_step=75.60, train_loss_epoch=123.0]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: 0%| | 0/3 [00:00<?, ?it/s]
Validation DataLoader 0: 0%| | 0/3 [00:00<?, ?it/s]
Validation DataLoader 0: 33%|████████████████████████████████████▋ | 1/3 [00:00<00:00, 4.61it/s]
Validation DataLoader 0: 67%|█████████████████████████████████████████████████████████████████████████▎ | 2/3 [00:00<00:00, 4.62it/s]
Validation DataLoader 0: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 6.08it/s]
Epoch 8: 100%|██████████████████████████████████████████████████████████████████| 497/497 [04:58<00:00, 1.66it/s, v_num=0, train_loss_step=140.0, train_loss_epoch=121.0]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: 0%| | 0/3 [00:00<?, ?it/s]
Validation DataLoader 0: 0%| | 0/3 [00:00<?, ?it/s]
Validation DataLoader 0: 33%|████████████████████████████████████▋ | 1/3 [00:00<00:00, 4.67it/s]
Validation DataLoader 0: 67%|█████████████████████████████████████████████████████████████████████████▎ | 2/3 [00:00<00:00, 4.63it/s]
Validation DataLoader 0: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 6.09it/s]
Epoch 9: 100%|██████████████████████████████████████████████████████████████████| 497/497 [04:20<00:00, 1.90it/s, v_num=0, train_loss_step=132.0, train_loss_epoch=119.0]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: 0%| | 0/3 [00:00<?, ?it/s]
Validation DataLoader 0: 0%| | 0/3 [00:00<?, ?it/s]
Validation DataLoader 0: 33%|████████████████████████████████████▋ | 1/3 [00:00<00:00, 4.58it/s]
Validation DataLoader 0: 67%|█████████████████████████████████████████████████████████████████████████▎ | 2/3 [00:00<00:00, 4.62it/s]
Validation DataLoader 0: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 6.09it/s]
Epoch 9: 100%|██████████████████████████████████████████████████████████████████| 497/497 [04:24<00:00, 1.88it/s, v_num=0, train_loss_step=132.0, train_loss_epoch=116.0]
Epoch 9: 100%|██████████████████████████████████████████████████████████████████| 497/497 [04:25<00:00, 1.87it/s, v_num=0, train_loss_step=132.0, train_loss_epoch=116.0]
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ ┃ Validate metric ┃ DataLoader 0 ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩ │ val_loss │ 1055.2911376953125 │ │ val_mse │ 239550.0 │ │ val_pearson │ 0.4029926061630249 │ └───────────────────────────┴───────────────────────────┘
Load the best model version#
best_checkpoint = trainer.checkpoint_callback.best_model_path
print(best_checkpoint)
tutorial_3/2024_28_05_06_44/version_0/checkpoints/epoch=9-step=4970.ckpt
model = grelu.lightning.LightningModel.load_from_checkpoint(best_checkpoint)
Evaluate model#
We now evaluate the model’s performance on the test dataset.
test_metrics = model.test_on_dataset(
test_ds,
batch_size=256,
devices=0,
num_workers=8
)
test_metrics
Testing DataLoader 0: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:02<00:00, 1.01it/s]
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ ┃ Test metric ┃ DataLoader 0 ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩ │ test_loss │ 114.12532806396484 │ │ test_mse │ 83015.90625 │ │ test_pearson │ 0.6458674669265747 │ └───────────────────────────┴───────────────────────────┘
test_mse | test_pearson | |
---|---|---|
Microglia_full | 83015.90625 | 0.645867 |
Run inference on held out sequences#
We can now get the predicted total coverage for each sequence in the test set.
preds = model.predict_on_dataset(
test_ds, devices=0, num_workers=8
)
preds.shape
Predicting DataLoader 0: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 5.29it/s]
(673, 1, 1)
We can visualize a scatter plot of predicted vs. true coverage values
Perform a marginalization experiment#
See tutorial 2 for an example of interpreting the model using ISM and TF-modisco, which discovers motifs that contribute to the model’s predictions. Here, we instead study the learned effect of a single specific motif.
To understand the effect of the AC0622:ELF_SPIB:Ets
motif, we perform a marginalization experiment. In this, we take the AC0622:ELF_SPIB:Ets
motif and insert it into shuffled background sequences, and compare the predictions of the model before and after inserting this motif.
First, we read this motif from the MEME file and extract the consensus sequence.
import grelu.io.meme
import grelu.interpret.motifs
motifs, _ = grelu.io.meme.read_meme_file("consensus", names=["AC0622:ELF_SPIB:Ets"])
patterns = grelu.interpret.motifs.motifs_to_strings(motifs)
print(patterns)
Read 1 motifs from file.
['AAGAGGAAGT']
We will select some peaks from the test set to shuffle and use as the background.
result = grelu.interpret.motifs.marginalize_patterns(
model=model,
patterns=patterns,
seqs = test.head(100), # First 100 peaks
genome = "hg38",
devices = 0,
num_workers = 8,
batch_size = 512,
n_shuffles = 1, # Each peak will be shuffled 1 time, conserving dinucleotide frequency
seed = 0,
compare_func = "subtract", # Return the change in the prediction after inserting the pattern
)
Predicting DataLoader 0: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:03<00:00, 0.33it/s]
result.squeeze().mean()
27.102833
We see that on average, inserting this motif into a shuffled peak sequence increases the coverage predicted by the model.