Simulation-based model interpretation with gReLU#

This tutorial demonstrates some of gReLU’s functions for making predictions on simulated genomic sequences with specific properties. This allows us to decipher the regulatory grammar and syntax learned by a model.

import os
import numpy as np
import pandas as pd
import scipy.stats

from plotnine import *
import seaborn as sns

%matplotlib inline
os.environ["WANDB_SILENT"] = "true"

Part 1: CATlas ATAC-seq binary classification model#

First we demonstrate some simple functions using the CATLAS binary classification model. We begin by loading the model from the model zoo:

import grelu.resources

catlas = grelu.resources.load_model(repo_id="Genentech/human-atac-catlas-model", filename="model.ckpt")

This object contains the model metadata, so we can extract a dataframe describing the cell types for which the model makes predictions.

tasks = pd.DataFrame(catlas.data_params['tasks'])
print(len(tasks))
tasks.head(3)
204
name cell type
0 Follicular Follicular
1 Fibro General Fibro General
2 Acinar Acinar

We can also extract genomic intervals that were in the test set of this model:

test_intervals = pd.DataFrame(catlas.data_params['test']['intervals'])
test_intervals.head()
chrom start end cre_class in_fetal in_adult cre_module width cre_idx enformer_split split
0 chr1 143497510 143497710 Promoter Proximal no yes 113 400 53530 test test
1 chr1 143498052 143498252 Promoter Proximal no yes 4 400 53531 test test
2 chr1 143498633 143498833 Promoter yes no 46 400 53532 test test
3 chr1 143563536 143563736 Distal yes no 63 400 53533 test test
4 chr1 143575231 143575431 Distal no yes 28 400 53534 test test

Let us extract 100 intervals from the test set:

test_intervals = test_intervals.sample(100, random_state=0)

1.1 Motif marginalization#

In order to understand the function of a specific motif, we can insert the motif into shuffled sequences. By making predictions on the shuffled sequences before and after motif insertion, we can compute the extent to which the motif alone increases predicted activity.

To demonstrate, we extract a few selected motifs from the HOCOMOCO database.

import grelu.io.motifs
import grelu.interpret.motifs

motif_names = ['CTCF.H12CORE.0.P.B', 'IRF4.H12CORE.0.P.B', 'SPI1.H12CORE.0.P.B',
               'HNF1A.H12CORE.0.PS.A']
motifs = grelu.io.motifs.read_meme_file("hocomoco_v12", names=motif_names)

This is a dictionary containing the motifs. We can extract the consensus sequence for each motif:

patterns = grelu.interpret.motifs.motifs_to_strings(motifs)
print(patterns)
['TGGCCACTAGGGGGCGCTAG', 'AAAGAGGAACTGAAACT', 'AAAAGAGGAAGTGA', 'GGTTAATGATTAACC']

Now, we insert each of these motifs into shuffled sequences and compare the predictions before and after motif insertion. gReLU provides an easy function marginalize_patterns to perform this analysis.

import grelu.interpret.simulate

preds = grelu.interpret.simulate.marginalize_patterns(
    model=catlas,
    patterns=patterns,
    seqs=test_intervals,
    genome='hg38',
    devices=0,
    num_workers=16,
    batch_size=512,
    n_shuffles=3, # Each test sequence will be shuffled 3 times 
    seed=0,
    compare_func='subtract', # Return the difference in prediction with and without each motif. If not supplied, predictions with and without are returned separately.
).squeeze()
preds.shape
(100, 3, 4, 204)

The preds object has the shape (100, 3, 4, 204) where 100 is the number of background sequences, 3 is the number of times each sequence is shuffled, 4 is the number of motifs, and 204 is the number of cell types. Let’s reshape this so that each motif has 300 background sequences:

preds = preds.reshape(300, 4, 204)

We can now subset the results for our cell type of interest. For example, microglia:

microglia_preds = preds[:, :, tasks.name=='Microglia'].squeeze()
microglia_preds.shape
(300, 4)
microglia_preds = pd.DataFrame(microglia_preds, columns=motif_names)
microglia_preds.head()
CTCF.H12CORE.0.P.B IRF4.H12CORE.0.P.B SPI1.H12CORE.0.P.B HNF1A.H12CORE.0.PS.A
0 0.028803 0.602386 0.532793 -0.021205
1 0.014246 0.054271 0.016262 -0.002340
2 0.047286 0.107553 0.008252 -0.001463
3 0.015101 0.176377 0.167955 -0.013501
4 -0.053546 0.719574 0.607628 -0.095223

We can now visualize the effect of each of these motifs in microglia:

microglia_preds.plot.box(rot=45, figsize=(6,2),
    title='Increase in predicted probability of\naccessibility in microglia after inserting motif')
<Axes: title={'center': 'Increase in predicted probability of\naccessibility in microglia after inserting motif'}>
../_images/a95a27e04d282c4782fbb57b1b9853516973191ba0acb8576762607a9c4e0100.png

We can compute the significance of the observed changes using Wilcoxon’s signed-rank test:

scipy.stats.wilcoxon(microglia_preds['SPI1.H12CORE.0.P.B'], alternative='greater')
WilcoxonResult(statistic=45138.0, pvalue=3.4306111433259653e-51)

We can also compare the average effect of each motifs across all 204 cell types. For this, we first average the effect across all 500 peaks:

print(preds.shape)
mean_motif_effect = preds.mean(0)
print(mean_motif_effect.shape)
(300, 4, 204)
(4, 204)

This gives us an array of size (motifs x cell types). We can represent this as a pandas dataframe:

mean_motif_effect = pd.DataFrame(mean_motif_effect, index=motif_names, columns=tasks.name)
mean_motif_effect.iloc[:2, :5]
name Follicular Fibro General Acinar T Lymphocyte 1 (CD8+) T lymphocyte 2 (CD4+)
CTCF.H12CORE.0.P.B 0.388231 0.394323 0.432490 0.305734 0.294881
IRF4.H12CORE.0.P.B -0.036349 -0.012824 0.002279 0.056303 0.036172

For visualization, let us select the 20 cell types with highest variance:

variable_cell_types = mean_motif_effect.var(0).sort_values().tail(20).index
variable_cell_types
Index(['Fetal Photoreceptor', 'Fetal T Lymphocyte 1 (CD4+)',
       'Fetal Enterocyte 3', 'Fetal Enterocyte 1',
       'Fetal Retinal Progenitor 1', 'Memory B', 'Fetal Acinar 2',
       'Macrophage Gen or Alv', 'Hepatocyte', 'Plasma B',
       'Fetal Adrenal Neuron', 'Alpha 1', 'Fetal Adrenal Cortical', 'Acinar',
       'Fetal Thymocyte', 'Macrophage General',
       'Fetal Syncitio+Cytotrophoblast', 'Fetal B Lymphocyte 1 (SPIB+)',
       'Fetal Erythroblast 2', 'Fetal Hepatoblast'],
      dtype='object', name='name')
sns.clustermap(mean_motif_effect.loc[:, variable_cell_types].T, figsize=(6,6))
<seaborn.matrix.ClusterGrid at 0x15044d414050>
../_images/8629c31c1b915a82c9717d97eeb139f02282bf69854a4d99e1edd22823adf36e.png

1.2 Motif spacing#

Given two motifs, we may want to know whether the spacing between them is important for regulatory function. Here, we insert a TEAD motif in the center of the shuffled peak and then insert a FOS motif at different distances from the TEAD motif. We compute the increase in predictions in fibroblasts before and after inserting the TEAD motif. We marginalize the effect of the FOS motif across all 500 shuffled peak sequences.

First, we load the motifs from the HOCOMOCO database:

motifs = grelu.io.motifs.read_meme_file(
    "hocomoco_v12", names=['TEAD1.H12CORE.0.PSM.A', 'FOS.H12CORE.0.P.B']
)

And extract the consensus sequence for each motif:

patterns = grelu.interpret.motifs.motifs_to_strings(motifs)
tead_pattern, fos_pattern = patterns
print(tead_pattern)
print(fos_pattern)
CCACATTCCAGGC
GTGAGTCACCA

gReLU provides another handy function marginalize_pattern_spacing to compute the effect of different inter-motif distances. This function produces two outputs: preds (model predictions) and distances (distances at which the second motif was inserted). Note that distances are calculated between the start positions of the two motifs.

preds, distances = grelu.interpret.simulate.marginalize_pattern_spacing(
    model = catlas,
    seqs = test_intervals,
    fixed_pattern = tead_pattern, # this motif will be fixed at the center of the sequence
    moving_pattern = fos_pattern, # this motif will be inserted at different distances
    genome = 'hg38',
    n_shuffles = 1, # Each test sequence will be shuffled once
    seed = 0,
    devices = 0,
    num_workers = 16,
    batch_size = 512,
    compare_func = 'subtract', # Return the difference in prediction after inserting FOS
)

Note that this produces two outputs: preds (the model’s predictions) and distances (the distance between the two motifs, corresponding to the predictions.

preds = preds.squeeze()
preds.shape
(100, 167, 204)

preds has shape (number of sequences, number of positions, number of tasks). We are going to focus on one cell type (fibroblasts). So we subset the predictions for that cell type:

preds = preds[:, :, tasks.name=='Fibro General'].squeeze()
preds.shape
(100, 167)

Let’s also take a look at distances:

distances[:5]
[-93, -92, -91, -90, -89]

A convenient way to analyze these is to plot predicted effect of the FOS motif vs. distance from the TEAD motif.

import grelu.visualize
grelu.visualize.plot_position_effect(
    preds, distances,
    title="Increase in predicted probability of accessibility\nin fibroblasts after inserting FOS motif",
    xlab="Distance from TEAD motif",
    figsize=(6,2)
)
../_images/e3baef3a2d309477f916d34d46846ec0a980eab732945e1eaa6ef6aa0177dc64.png
pd.DataFrame({'mean_effect':preds.mean(0), 'dist':distances}).sort_values('mean_effect', ascending=False).head()
mean_effect dist
85 0.314871 15
86 0.199106 16
84 0.192458 14
87 0.149564 17
83 0.143681 13

gReLU shows us a clear distance dependence: a distance of 15 nucleotides between the start positions of the two motifs creates a composite motif that drives higher predicted accessibility. Since the TEAD motif is 15 bp long, this means two nucleotides between the end of the TEAD motif and the start of the FOS motif (CCACATTCCAGGC _ _ GTGAGTCACCA)

del catlas

Part 2: Enformer#

Next we will examine regulatory relationships between gene TSSs and their genomic context using the Enformer model.

First, we import the Enformer model from the model zoo.

import grelu.resources
grelu.resources.get_model_info("Genentech/enformer-model")
{'id': 'Genentech/enformer-model',
 'tags': ['pytorch-lightning',
  'biology',
  'genomics',
  'tabular-regression',
  'dataset:Genentech/enformer-data',
  'license:mit',
  'region:us'],
 'card_data': {'base_model': None,
  'datasets': ['Genentech/enformer-data'],
  'eval_results': None,
  'language': None,
  'library_name': 'pytorch-lightning',
  'license': 'mit',
  'license_name': None,
  'license_link': None,
  'metrics': None,
  'model_name': None,
  'pipeline_tag': 'tabular-regression',
  'tags': ['biology', 'genomics']},
 'downloads': 0,
 'last_modified': datetime.datetime(2026, 2, 23, 22, 0, 16, tzinfo=datetime.timezone.utc),
 'files': ['.gitattributes',
  'README.md',
  'human.ckpt',
  'human_state_dict.h5',
  'mouse.ckpt',
  'mouse_state_dict.h5',
  'save_wandb_enformer_human.ipynb',
  'save_wandb_enformer_mouse.ipynb']}

We see that there are two checkpoints available, human.ckpt and mouse.ckpt for the human and mouse genomes respectively.

enformer = grelu.resources.load_model(repo_id="Genentech/enformer-model", filename="human.ckpt")

This object includes the task metadata. For the current analyses, we will focus on a single CAGE track in K562 cells, which is track 5111.

k562_cage_track_idx=5111
tasks = pd.DataFrame(enformer.data_params['tasks'])
tasks.iloc[[k562_cage_track_idx]]
name file clip scale sum_stat description assay sample
5111 CNhs12336 /home/drk/tillage/datasets/human/cage/fantom/C... 384 1 sum CAGE:chronic myelogenous leukemia cell line:K5... CAGE chronic myelogenous leukemia cell line:K562 EN...

We are going to analyze CAGE activity at the TSS of the ABCA8 gene. This gene is located on chromosome 17 and its TSS is at position 68955392. We define a region of 5000 bp centered on this position as the TSS region.

tss_interval = pd.DataFrame({
    'chrom':['chr17'], 'start':[68955392-2500], 'end':[68955392+2500]
})
print(tss_interval)
   chrom     start       end
0  chr17  68952892  68957892

And extract its sequence from the genome:

import grelu.sequence.format
tss_seq = grelu.sequence.format.convert_input_type(tss_interval, "strings", genome="hg38")[0] # Don't try to print this!

We will construct an input interval for Enformer centered on this TSS position.

import grelu.sequence.utils
input_interval = grelu.sequence.utils.resize(tss_interval, enformer.data_params['train']['seq_len'])
input_interval
chrom start end
0 chr17 68857088 69053696

Let’s examine Enformer’s predictions for this interval:

input_seq = grelu.sequence.format.convert_input_type(input_interval, "strings", genome='hg38')[0]
input_pred = enformer.predict_on_seqs(input_seq, device=0)
input_pred.shape
(1, 5313, 896)

Enformer predicts 5313 output tracks, but we are only interested in the predictions for the K562 CAGE track - let’s plot those, and highlight the TSS region that we defined.

fig = grelu.visualize.plot_tracks(
    input_pred[:, k562_cage_track_idx],
    start_pos=input_interval.start[0], end_pos=input_interval.end[0],
    highlight_intervals=tss_interval,
)
../_images/ba75b91dce636e16f2d23861a1e19359993fb1cbd710066325718f3dac11d2ce.png

We see that Enformer predicts transcriptional activity at the ABCA8 TSS.

Instead of making predictions for the entire sequence across all tracks and then subsetting the TSS and the K562 CAGE track each time, gReLU allows us to use a transform object that will compute only the relevant predictions.

Here, we can create an Aggregate transform that sums the predicted values in the bins overlapping the TSS region, in the K562 CAGE track. This transform, when applied to Enformer’s predictions, computes the statistic we’re interested in.

To do this, we first need the indices of the output bins that overlap with the TSS region:

enformer.input_intervals_to_output_bins(tss_interval, start_pos=input_interval.start[0])
start end
0 428 468

Now we can create the transform:

from grelu.transforms.prediction_transforms import Aggregate
tss_activity_transform = Aggregate(
    tasks=k562_cage_track_idx,
    positions=list(range(428, 468)), # these are the indices of the central bins overlapping with the TSS
    length_aggfunc='sum' # Sum the output over all bins
)

So we can now compute the TSS activity in the wild-type sequence:

WT = tss_activity_transform.compute(input_pred).squeeze()
WT
array(9.004404, dtype=float32)

Next, we will simulate different changes to this locus and see how this statistic changes.

2.1 Context dependence test#

The context dependence test introduced in CREME (https://doi.org/10.1101/2023.07.03.547592) measures the extent to which a prediction of a given TSS activity is influenced by its context.

It includes the following steps:

  • Predict TSS activity for the wild type sequence (denoted as WT).

  • Dinucleotide shuffle the sequence (except the 5kb tile centered at the TSS).

  • Predict TSS activity for the shuffled sequence (denoted as MUTANT).

  • Normalization: compute context effect on TSS using WT as control: (WT - MUTANT) / WT

  • Repeat steps 2-4 10 times and average across different random dinucleotide shuffles.

A positive effect size means that the central TSS prediction for the mutated sequence is lower than WT, which indicates that we have perturbed an enhancing context. Negative effect size means that the central TSS prediction for the mutated sequence is higher than WT, which suggests that we have perturbed a silencing context.

We have already done the first step, which is to predict TSS activity in the wild-type (WT). So, we move on to the subsequent steps.

Using gReLU, we can shuffle the input sequence 10 times.

shuffled_seqs = grelu.sequence.utils.dinuc_shuffle(input_seq, n_shuffles=10, seed=0)

And re-insert the TSS region into the center of the shuffled sequences:

from grelu.sequence.mutate import mutate
shuf_seqs_w_tss = [mutate(s, tss_seq) for s in shuffled_seqs]

We now make predictions for the TSS in the shuffled sequences:

from grelu.data.dataset import SeqDataset
ds = SeqDataset(shuf_seqs_w_tss)

# Make predictions
MUTANT = enformer.predict_on_dataset(ds, devices=0, batch_size=8)

# Apply the transform to compute predicted TSS activity
MUTANT = tss_activity_transform.compute(MUTANT).squeeze()

MUTANT
array([1.7107706, 1.7943791, 1.6819942, 1.7473049, 1.6672367, 1.6260251,
       1.6094068, 1.7582744, 1.6140951, 1.6499794], dtype=float32)

And compute the final metric:

result = np.mean((WT - MUTANT) / WT)
result
0.81276417

The positive value shows that the TSS is in an enhancing context.

2.2 Enhancer Discovery by in silico tiling#

Which parts of the input sequence are responsible for its enhancing context? We can discover putative regulatory elements in silico by shuffling successive tiles in the sequence and computing the change in predicted TSS activity. gReLU provides a shuffle_tiles function for this analysis:

Which parts of the input sequence are responsible for its enhancing context? We can discover putative regulatory elements in silico by shuffling successive tiles in the sequence and computing the change in predicted TSS activity. gReLU provides a shuffle_tiles function for this analysis:

preds, tiles = grelu.interpret.simulate.shuffle_tiles(
    model=enformer,
    seqs=[input_seq],
    tile_len=5000, # Length of region to shuffle
    stride=2500, # Tiles will overlap by this much
    protect_center= 5000, # Do not disrupt the central TSS region
    n_shuffles = 5, # Number of times to shuffle each tile
    seed = 0,
    devices=0,
    num_workers=16,
    batch_size=8,
    prediction_transform=tss_activity_transform,
    compare_func='subtract',
)

Note that this produces two outputs: preds (the model’s predictions) and tiles (the regions that were shuffled, corresponding to the predictions).

print(tiles.shape)
tiles.head()
(73, 2)
start end
0 0 5000
1 2500 7500
2 5000 10000
3 7500 12500
4 10000 15000
preds = preds.squeeze()
preds.shape
(73, 5)
grelu.visualize.plot_position_effect(
    preds.T, tiles.start.tolist(),
    title="Change in predicted TSS activity\nafter shuffling tile",
    xlab="Start position of tile",
    figsize=(6,2)
)
../_images/c5e79d211175f9f641d0445b4e657e6c50ae99e1d9cd348de114a9f0e5be497c.png

The tiles with the strongest negative value are the putative enhancers. For a tabular representation, we can add the predictions to the dataframe, averaging across all shuffles of the same sequence, then sort the tiles by effect:

tiles['effect'] = preds.mean(1)
tiles.sort_values('effect').head(3)
start end effect
20 50000 55000 -7.333127
19 47500 52500 -7.235798
37 102500 107500 -6.969648

It seems that the tile from position 50,000 - 55,000 of the input sequence has the strongest enhancer activity. We will define this as our putative CRE and extract its sequence.

cre_pos = 50000
cre_seq = input_seq[cre_pos: cre_pos+5000]

2.3 Sufficiency test#

The sufficiency test (also from CREME) consists of the following steps:

  1. Predict TSS activity for the wild type sequence (WT).

  2. Dinucleotide shuffle the sequence.

  3. Add the TSS 5 kb tile and predict TSS activity (CONTROL).

  4. Add the CRE and the TSS tiles to the sequence and predict TSS activity (MUTANT).

  5. in case of enhancing context sequences, we compute the normalized CRE effect as (MUTANT - CONTROL) / WT.

  6. Repeat each shuffle 10 times and average the normalized CRE effect per sequence.

We insert the CRE into our 10 shuffled sequences containing the TSS.

shuffled_seqs_w_tss_and_cre = [mutate(s, cre_seq, pos=cre_pos) for s in shuf_seqs_w_tss]

Note that what was previously called MUTANT (shuffled sequence with the TSS tile) is now called CONTROL, so we rename it:

CONTROL = MUTANT
del MUTANT

And then make predictions:

ds = SeqDataset(shuffled_seqs_w_tss_and_cre)
MUTANT = enformer.predict_on_dataset(ds, devices=0, batch_size=8).squeeze()
MUTANT = tss_activity_transform.compute(MUTANT)
results = np.mean((MUTANT - CONTROL) / WT)
results
0.028983079

This suggests that although shuffling the CRE has a large effect on TSS activity, it is not sufficient to increase TSS activity much on its own.

2.4 TSS-CRE Distance Test#

CREME defined a ‘TSS-CRE Distance test’ in which the TSS is inserted into a dinucleotide-shuffled sequence, the putative CRE is inserted at different distances from the TSS, and its effect is measured at each distance. We can do a similar experiment using the previously seen marginalize_pattern_spacing function.

Note that CREME normalizes effects using the maximum effect of the CRE at any difference, which we do not do here, but can be done manually.

preds, distances = grelu.interpret.simulate.marginalize_pattern_spacing(
    model = enformer,
    seqs = input_seq,
    fixed_pattern = tss_seq, # The TSS will be fixed in the center of the sequence
    moving_pattern = cre_seq, # The CRE will be inserted at different distances
    stride=5000, # The CRE insertion position will shift by 5000 bp
    n_shuffles = 10, # The input sequence will be shuffled 10 times and results averaged
    seed = 0,
    devices=0,
    num_workers=16,
    batch_size=8,
    prediction_transform=tss_activity_transform,
    compare_func='subtract', # Compute the change in prediction after inserting the CRE. 
)
preds = preds.squeeze()
preds.shape
(10, 37)
len(distances)
37
grelu.visualize.plot_position_effect(
    preds, distances,
    title="Change in predicted TSS activity\nafter inserting enhancer",
    xlab="Start position of tile",
    figsize=(6,2)
)
../_images/576928b2421bf9586e00b1808fe9f2eafb140541e0fc322972b1f962a573947b.png
df = pd.DataFrame({'distance':distances, 'effect': preds.mean(0)})
df.sort_values('effect', ascending=False).head()
distance effect
15 -20804 0.590919
16 -15804 0.557747
21 19196 0.536221
20 14196 0.508207
17 -10804 0.492683