Train a single-task regression model from scratch#

In this tutorial, we train a single-task convolutional regression model to predict total coverage over ATAC-seq peaks, starting from an ATAC-seq fragment file.

import os
import numpy as np
import pandas as pd
import torch

#os.environ["CUDA_VISIBLE_DEVICES"] = "4"

Set experiment parameters#

experiment='tutorial_3'
if not os.path.exists(experiment):
    os.makedirs(experiment)

Peak and fragment files#

We downloaded pseudobulk scATAC data for human microglia from Corces et al. (2020): https://www.nature.com/articles/s41588-020-00721-x. Here, we use the grelu.resources module to download the fragment file and peak file from the model zoo:

import grelu.resources
# Download these datasets into local directories

fragment_file_dir = grelu.resources.get_artifact(
    project='microglia-scatac-tutorial', name='fragment_file').download()

peak_file_dir = grelu.resources.get_artifact(
    project='microglia-scatac-tutorial', name='peak_file').download()
# Paths to files
frag_file = os.path.join(fragment_file_dir, "Microglia_full.bed")
peak_file = os.path.join(peak_file_dir, "Microglia_full_peaks.narrowPeak")

Set parameters#

seq_len=2114 # Length of the input sequence
label_len=1000 # Length over which we calculate total coverage
val_chroms=["chr10"]
test_chroms=["chr11"]
genome="hg38"

Read peak file#

We read peak coordinates from the narrowPeak file.

import grelu.io.bed

peaks = grelu.io.bed.read_narrowpeak(peak_file)
peaks.tail(3)
chrom start end name score strand signal pvalue qvalue summit
83316 chrY 56870777 56870983 Microglia_full_peak_83318 94 . 3.83352 11.58170 9.41381 116
83317 chrY 56873629 56873811 Microglia_full_peak_83319 49 . 3.01171 7.02098 4.99754 105
83318 chrY 56874075 56874225 Microglia_full_peak_83320 42 . 2.86533 6.19767 4.20704 17

Summit-center peaks#

We extract the genomic coordinates for the 2114 bases surrounding the summit of each peak.

import grelu.data.preprocess

peaks = grelu.data.preprocess.extend_from_coord(
    peaks,
    seq_len=seq_len,
    center_col="summit"
)
peaks.tail(3)
chrom start end
83316 chrY 56869836 56871950
83317 chrY 56872677 56874791
83318 chrY 56873035 56875149

Filter peaks#

We filter the peaks to include only those within autosomes. You can also use “autosomesX” or “autosomesXY” to include sex chromosomes.

peaks = grelu.data.preprocess.filter_chromosomes(peaks, 'autosomes')
Keeping 80823 intervals

We drop peaks that are close to ENCODE hg38 blacklist regions.

peaks = grelu.data.preprocess.filter_blacklist(
    peaks,
    genome=genome,
    window=50 # Remove peaks if they are within 50 bp of a blacklist region
)
Keeping 80028 intervals

Get GC matched negative regions#

To ensure that the model also learns to recognize regions that are not peaks, we will include a set of “negative” (non-peak) regions with similar GC content to the peaks.

negatives = grelu.data.preprocess.get_gc_matched_intervals(
    peaks,
    binwidth=0.02, # resolution of measuring GC content
    genome=genome,
    chroms="autosomes", # negative regions will also be chosen from autosomes
    #gc_bw_file='gc_hg38_2114.bw',
    blacklist=genome, # negative regions overlapping the blacklist will be dropped
    seed=0,
)
negatives.head(3)
Calculating GC content genomewide and saving to gc_hg38_2114.bw
Extracting matching intervals
GC paired t-test: 0.027, 9.33e-26
Filtering blacklist
Keeping 77564 intervals
chrom start end
0 chr3 69786703 69788817
1 chr10 111915184 111917298
2 chr17 4597230 4599344

We can visualize a histogram of GC content in the peaks and negative regions to verify that they are similar.

import grelu.visualize
grelu.visualize.plot_gc_match(
    positives=peaks, negatives=negatives, binwidth=0.02, genome="hg38", figsize=(4, 3)
)
../_images/04794fdfa25154ce4597ed72ccee16bf5baf704e4de03d3d9e7b0451fd03089b.png

Combine peaks and negative regions#

The full dataset includes both peaks and negative regions.

regions = pd.concat([peaks, negatives])
len(regions)
157592

Make BigWig file#

We convert the ATAC-seq fragment file into a bigWig file which contains the number of Tn5 insertions at each position.

bw_file = grelu.data.preprocess.make_insertion_bigwig(
    frag_file = frag_file,
    plus_shift=0,
    minus_shift=1, # This corrects the +4/-5 Tn5 shift to a +4/-4 shift
    genome=genome,
    chroms="autosomes", # The output bigWig file contains coverage over autosomes.
)
Making bedgraph file
cat /code/gReLU/docs/tutorials/artifacts/fragment_file:v0/Microglia_full.bed | awk -v OFS="\t" '{print $1,$2+0,$3,1000,0,"+";
    print $1,$2,$3+1,1000,0,"-"}' | sort -k1,1 | grep -e ^chr1 -e ^chr2 -e ^chr3 -e ^chr4 -e ^chr5 -e ^chr6 -e ^chr7 -e ^chr8 -e ^chr9 -e ^chr10 -e ^chr11 -e ^chr12 -e ^chr13 -e ^chr14 -e ^chr15 -e ^chr16 -e ^chr17 -e ^chr18 -e ^chr19 -e ^chr20 -e ^chr21 -e ^chr22  | bedtools genomecov -bg -5 -i stdin -g /root/.local/share/genomes/hg38/hg38.fa.sizes | bedtools sort -i stdin > ./Microglia_full.bedGraph
Making bigWig file
bedGraphToBigWig ./Microglia_full.bedGraph /root/.local/share/genomes/hg38/hg38.fa.sizes ./Microglia_full.bw
Deleting temporary files

Split data by chromosome#

We now split the dataset by chromosome to create separate datasets for training, validation and testing.

train, val, test = grelu.data.preprocess.split(
    regions, val_chroms=val_chroms, test_chroms=test_chroms)
Selecting training samples
Keeping 141228 intervals


Selecting validation samples
Keeping 7921 intervals


Selecting test samples
Keeping 8443 intervals
Final sizes: train: (126990, 3), val: (577, 3), test: (673, 3)

Make labeled datasets#

We now make pytorch dataset objects in order to load paired sequences and coverage values from the genome and the bigWig file. We use the BigWigSeqDataset class.

We first make the training dataset. To increase model robustness we use several forms of data augmentation here: rc=True (reverse complementing the input sequence), max_seq_shift=2 (shifting the coordinates of the input sequence by upto 2 bp in either direction; also known as jitter), and max_pair_shift=20 (shifting both the input sequence and the region for which to calculate coverage by upto 20 bp in either direction).

Further, we use label_aggfunc="sum" which means that the label will be the summed coverage over the central region of the interval.

import grelu.data.dataset

train_ds = grelu.data.dataset.BigWigSeqDataset(
    intervals = train,
    bw_files=[bw_file],
    label_len=label_len,
    label_aggfunc="sum",
    rc=True, # reverse complement
    max_seq_shift=2, # Shift the sequence
    max_pair_shift=20, # Shift both sequence and label
    augment_mode="random",
    seed=0,
    genome=genome,
)

We do not apply any augmentations to the validation and test datasets (although it is possible to do so).

val_ds = grelu.data.dataset.BigWigSeqDataset(
    intervals = val,
    bw_files=[bw_file],
    label_len=label_len,
    label_aggfunc="sum", 
    genome=genome,
)

test_ds = grelu.data.dataset.BigWigSeqDataset(
    intervals = test,
    bw_files=[bw_file],
    label_len=label_len,
    label_aggfunc="sum",
    genome=genome,
)

len(train_ds), len(val_ds), len(test_ds)
(126990, 577, 673)

Build model#

model_params = {
    'model_type':'DilatedConvModel',
    'crop_len':(seq_len-label_len)//2,
    'n_tasks':1,
    'channels':512,
    'n_conv':8,
}

train_params = {
    'task':'regression',
    'loss': 'poisson', # Poisson loss. Other regression loss functions are "mse" and "poisson_multinomial"
    'logger':'csv',
    'lr':1e-4,
    'batch_size':256,
    'max_epochs':10,
    'devices':0,
    'num_workers':16,
    'save_dir':experiment,
}

import grelu.lightning
model = grelu.lightning.LightningModel(model_params, train_params)

Train model#

# See the tutorial_3 folder for logs.
trainer = model.train_on_dataset(train_ds, val_ds)
Validation DataLoader 0: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:06<00:00,  0.49it/s]
Epoch 0: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 497/497 [04:43<00:00,  1.75it/s, v_num=0, train_loss_step=224.0]
Validation: |                                                                                                                                       | 0/? [00:00<?, ?it/s]
Validation:   0%|                                                                                                                                   | 0/3 [00:00<?, ?it/s]
Validation DataLoader 0:   0%|                                                                                                                      | 0/3 [00:00<?, ?it/s]
Validation DataLoader 0:  33%|████████████████████████████████████▋                                                                         | 1/3 [00:00<00:00,  4.59it/s]
Validation DataLoader 0:  67%|█████████████████████████████████████████████████████████████████████████▎                                    | 2/3 [00:00<00:00,  4.57it/s]
Validation DataLoader 0: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00,  6.03it/s]
Epoch 1: 100%|██████████████████████████████████████████████████████████████████| 497/497 [04:58<00:00,  1.67it/s, v_num=0, train_loss_step=250.0, train_loss_epoch=172.0]
Validation: |                                                                                                                                       | 0/? [00:00<?, ?it/s]
Validation:   0%|                                                                                                                                   | 0/3 [00:00<?, ?it/s]
Validation DataLoader 0:   0%|                                                                                                                      | 0/3 [00:00<?, ?it/s]
Validation DataLoader 0:  33%|████████████████████████████████████▋                                                                         | 1/3 [00:00<00:00,  4.63it/s]
Validation DataLoader 0:  67%|█████████████████████████████████████████████████████████████████████████▎                                    | 2/3 [00:00<00:00,  4.63it/s]
Validation DataLoader 0: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00,  6.10it/s]
Epoch 2: 100%|██████████████████████████████████████████████████████████████████| 497/497 [04:33<00:00,  1.82it/s, v_num=0, train_loss_step=129.0, train_loss_epoch=152.0]
Validation: |                                                                                                                                       | 0/? [00:00<?, ?it/s]
Validation:   0%|                                                                                                                                   | 0/3 [00:00<?, ?it/s]
Validation DataLoader 0:   0%|                                                                                                                      | 0/3 [00:00<?, ?it/s]
Validation DataLoader 0:  33%|████████████████████████████████████▋                                                                         | 1/3 [00:00<00:00,  4.68it/s]
Validation DataLoader 0:  67%|█████████████████████████████████████████████████████████████████████████▎                                    | 2/3 [00:00<00:00,  4.63it/s]
Validation DataLoader 0: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00,  6.10it/s]
Epoch 3: 100%|██████████████████████████████████████████████████████████████████| 497/497 [04:57<00:00,  1.67it/s, v_num=0, train_loss_step=123.0, train_loss_epoch=137.0]
Validation: |                                                                                                                                       | 0/? [00:00<?, ?it/s]
Validation:   0%|                                                                                                                                   | 0/3 [00:00<?, ?it/s]
Validation DataLoader 0:   0%|                                                                                                                      | 0/3 [00:00<?, ?it/s]
Validation DataLoader 0:  33%|████████████████████████████████████▋                                                                         | 1/3 [00:00<00:00,  4.62it/s]
Validation DataLoader 0:  67%|█████████████████████████████████████████████████████████████████████████▎                                    | 2/3 [00:00<00:00,  4.58it/s]
Validation DataLoader 0: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00,  6.08it/s]
Epoch 4: 100%|██████████████████████████████████████████████████████████████████| 497/497 [04:30<00:00,  1.84it/s, v_num=0, train_loss_step=65.30, train_loss_epoch=128.0]
Validation: |                                                                                                                                       | 0/? [00:00<?, ?it/s]
Validation:   0%|                                                                                                                                   | 0/3 [00:00<?, ?it/s]
Validation DataLoader 0:   0%|                                                                                                                      | 0/3 [00:00<?, ?it/s]
Validation DataLoader 0:  33%|████████████████████████████████████▋                                                                         | 1/3 [00:00<00:00,  4.62it/s]
Validation DataLoader 0:  67%|█████████████████████████████████████████████████████████████████████████▎                                    | 2/3 [00:00<00:00,  4.58it/s]
Validation DataLoader 0: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00,  6.04it/s]
Epoch 5: 100%|██████████████████████████████████████████████████████████████████| 497/497 [04:37<00:00,  1.79it/s, v_num=0, train_loss_step=231.0, train_loss_epoch=125.0]
Validation: |                                                                                                                                       | 0/? [00:00<?, ?it/s]
Validation:   0%|                                                                                                                                   | 0/3 [00:00<?, ?it/s]
Validation DataLoader 0:   0%|                                                                                                                      | 0/3 [00:00<?, ?it/s]
Validation DataLoader 0:  33%|████████████████████████████████████▋                                                                         | 1/3 [00:00<00:00,  9.26it/s]
Validation DataLoader 0:  67%|█████████████████████████████████████████████████████████████████████████▎                                    | 2/3 [00:00<00:00,  9.38it/s]
Validation DataLoader 0: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 12.41it/s]
Epoch 6: 100%|██████████████████████████████████████████████████████████████████| 497/497 [04:58<00:00,  1.67it/s, v_num=0, train_loss_step=68.40, train_loss_epoch=124.0]
Validation: |                                                                                                                                       | 0/? [00:00<?, ?it/s]
Validation:   0%|                                                                                                                                   | 0/3 [00:00<?, ?it/s]
Validation DataLoader 0:   0%|                                                                                                                      | 0/3 [00:00<?, ?it/s]
Validation DataLoader 0:  33%|████████████████████████████████████▋                                                                         | 1/3 [00:00<00:00,  4.59it/s]
Validation DataLoader 0:  67%|█████████████████████████████████████████████████████████████████████████▎                                    | 2/3 [00:00<00:00,  4.59it/s]
Validation DataLoader 0: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00,  6.05it/s]
Epoch 7: 100%|██████████████████████████████████████████████████████████████████| 497/497 [04:24<00:00,  1.88it/s, v_num=0, train_loss_step=75.60, train_loss_epoch=123.0]
Validation: |                                                                                                                                       | 0/? [00:00<?, ?it/s]
Validation:   0%|                                                                                                                                   | 0/3 [00:00<?, ?it/s]
Validation DataLoader 0:   0%|                                                                                                                      | 0/3 [00:00<?, ?it/s]
Validation DataLoader 0:  33%|████████████████████████████████████▋                                                                         | 1/3 [00:00<00:00,  4.61it/s]
Validation DataLoader 0:  67%|█████████████████████████████████████████████████████████████████████████▎                                    | 2/3 [00:00<00:00,  4.62it/s]
Validation DataLoader 0: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00,  6.08it/s]
Epoch 8: 100%|██████████████████████████████████████████████████████████████████| 497/497 [04:58<00:00,  1.66it/s, v_num=0, train_loss_step=140.0, train_loss_epoch=121.0]
Validation: |                                                                                                                                       | 0/? [00:00<?, ?it/s]
Validation:   0%|                                                                                                                                   | 0/3 [00:00<?, ?it/s]
Validation DataLoader 0:   0%|                                                                                                                      | 0/3 [00:00<?, ?it/s]
Validation DataLoader 0:  33%|████████████████████████████████████▋                                                                         | 1/3 [00:00<00:00,  4.67it/s]
Validation DataLoader 0:  67%|█████████████████████████████████████████████████████████████████████████▎                                    | 2/3 [00:00<00:00,  4.63it/s]
Validation DataLoader 0: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00,  6.09it/s]
Epoch 9: 100%|██████████████████████████████████████████████████████████████████| 497/497 [04:20<00:00,  1.90it/s, v_num=0, train_loss_step=132.0, train_loss_epoch=119.0]
Validation: |                                                                                                                                       | 0/? [00:00<?, ?it/s]
Validation:   0%|                                                                                                                                   | 0/3 [00:00<?, ?it/s]
Validation DataLoader 0:   0%|                                                                                                                      | 0/3 [00:00<?, ?it/s]
Validation DataLoader 0:  33%|████████████████████████████████████▋                                                                         | 1/3 [00:00<00:00,  4.58it/s]
Validation DataLoader 0:  67%|█████████████████████████████████████████████████████████████████████████▎                                    | 2/3 [00:00<00:00,  4.62it/s]
Validation DataLoader 0: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00,  6.09it/s]
Epoch 9: 100%|██████████████████████████████████████████████████████████████████| 497/497 [04:24<00:00,  1.88it/s, v_num=0, train_loss_step=132.0, train_loss_epoch=116.0]Epoch 9: 100%|██████████████████████████████████████████████████████████████████| 497/497 [04:25<00:00,  1.87it/s, v_num=0, train_loss_step=132.0, train_loss_epoch=116.0]
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃      Validate metric             DataLoader 0        ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│         val_loss              1055.2911376953125     │
│          val_mse                   239550.0          │
│        val_pearson            0.4029926061630249     │
└───────────────────────────┴───────────────────────────┘

Load the best model version#

best_checkpoint = trainer.checkpoint_callback.best_model_path
print(best_checkpoint)
tutorial_3/2024_28_05_06_44/version_0/checkpoints/epoch=9-step=4970.ckpt
model = grelu.lightning.LightningModel.load_from_checkpoint(best_checkpoint)

Evaluate model#

We now evaluate the model’s performance on the test dataset.

test_metrics = model.test_on_dataset(
    test_ds,
    batch_size=256,
    devices=0,
    num_workers=8
)

test_metrics
Testing DataLoader 0: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:02<00:00,  1.01it/s]
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃        Test metric               DataLoader 0        ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│         test_loss             114.12532806396484     │
│         test_mse                  83015.90625        │
│       test_pearson            0.6458674669265747     │
└───────────────────────────┴───────────────────────────┘
test_mse test_pearson
Microglia_full 83015.90625 0.645867

Run inference on held out sequences#

We can now get the predicted total coverage for each sequence in the test set.

preds = model.predict_on_dataset(
    test_ds, devices=0, num_workers=8
    )
preds.shape
Predicting DataLoader 0: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00,  5.29it/s]
(673, 1, 1)

We can visualize a scatter plot of predicted vs. true coverage values

grelu.visualize.plot_pred_scatter(
    preds=np.log(preds),
    labels=np.log(test_ds.get_labels()),
    density=True, # Color points by local density
    figsize=(3, 2.5), # width, height
    size=.5
)
../_images/100e6ddd697f6bb7d211fabc7e219f5e9921e2e9b3ddea33884988e219d4a2b1.png

Perform a marginalization experiment#

See tutorial 2 for an example of interpreting the model using ISM and TF-modisco, which discovers motifs that contribute to the model’s predictions. Here, we instead study the learned effect of a single specific motif.

To understand the effect of the AC0622:ELF_SPIB:Ets motif, we perform a marginalization experiment. In this, we take the AC0622:ELF_SPIB:Ets motif and insert it into shuffled background sequences, and compare the predictions of the model before and after inserting this motif.

First, we read this motif from the MEME file and extract the consensus sequence.

import grelu.io.meme
import grelu.interpret.motifs

motifs, _ = grelu.io.meme.read_meme_file("consensus", names=["AC0622:ELF_SPIB:Ets"])
patterns = grelu.interpret.motifs.motifs_to_strings(motifs)

print(patterns)
Read 1 motifs from file.
['AAGAGGAAGT']

We will select some peaks from the test set to shuffle and use as the background.

result = grelu.interpret.motifs.marginalize_patterns(
    model=model,
    patterns=patterns,
    seqs = test.head(100), # First 100 peaks
    genome = "hg38",
    devices = 0,
    num_workers = 8,
    batch_size = 512,
    n_shuffles = 1, # Each peak will be shuffled 1 time, conserving dinucleotide frequency
    seed = 0,
    compare_func = "subtract", # Return the change in the prediction after inserting the pattern
)
Predicting DataLoader 0: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:03<00:00,  0.33it/s]
result.squeeze().mean()
27.102833

We see that on average, inserting this motif into a shuffled peak sequence increases the coverage predicted by the model.