import numpy as np
import pandas as pd
from enformer_pytorch.data import str_to_one_hot
from reglm.regression import SeqDataset
[docs]def ISM_at_pos(seq, pos, drop_ref=True):
"""
Perform in-silico mutagenesis at a single position in the sequence.
Args:
seq (str): DNA sequence
pos (int): Position to mutate
drop_ref (bool): If True, the original base at the mutation position is dropped.
Returns:
List of mutated DNA sequences, of length 3 or 4
"""
alt_bases = ["A", "C", "G", "T"]
if drop_ref:
alt_bases.remove(seq[pos])
return [seq[:pos] + base + seq[pos + 1 :] for base in alt_bases]
[docs]def ISM(seq, drop_ref=True):
"""
Perform in-silico mutagenesis of a DNA sequence.
Args:
seq (str): DNA sequence
drop_ref (bool): If True, the original base at the mutation position is dropped.
Returns:
List of mutated DNA sequences, of length 3*len(seq) or 4*len(seq)
"""
return list(
np.concatenate(
[ISM_at_pos(seq, pos, drop_ref=drop_ref) for pos in range(len(seq))]
)
)
[docs]def ISM_predict(seqs, model, seq_len=None, batch_size=512, device=0, num_workers=8):
"""
Perform in-silico mutagenesis of DNA sequences and make predictions with a
regression model to get per-base importance scores
Args:
seqs (list): List of DNA sequences of equal length
model (pl.LightningModule): regression model
seq_len (int): Maximum sequence length for regression model
batch_size (int): Batch size for prediction
num_workers (int): Number of workers for prediction
device (int): GPU index for prediction
Returns:
preds (np.array): Array of shape (number of sequences x length of sequences x 4)
"""
# Get sequence length
actual_seq_lens = [len(seq) for seq in seqs]
assert (
len(set(actual_seq_lens)) == 1
), "This function currently requires all sequences to have equal length"
actual_seq_len = actual_seq_lens[0]
# Perform ISM
mutated_seqs = np.concatenate([ISM(seq, drop_ref=False) for seq in seqs]) # N*4
# Get predictions for all mutated sequences
dataset = SeqDataset(mutated_seqs, seq_len=seq_len)
preds = model.predict_on_dataset(
dataset,
device=device,
batch_size=batch_size,
num_workers=num_workers,
).squeeze() # Nx4xseq_len
# Reshape the predictions
assert preds.shape[0] == len(seqs) * 4 * actual_seq_len, preds.shape
preds = preds.reshape(len(seqs), 4 * actual_seq_len) # N, 4*seq_len
preds = preds.reshape(len(seqs), actual_seq_len, 4) # N, seq_len, 4
return preds
[docs]def ISM_score(seqs, preds):
"""
Calculate a per-base importance score from ISM predictions
Args:
seqs (list): List of sequences
preds (np.array): ISM predictions from seqs
Returns:
scores (np.array): Array of shape (N x seq_len), containing
per-base importance scores
"""
# Convert original sequences to one-hot
one_hot = str_to_one_hot(seqs).numpy() # N, seq_len, 4
# Get the predictions for reference bases
ref_preds = np.sum(preds * one_hot, axis=2, keepdims=True) # N, seq_len, 1
# Take the negative log-ratio of the predicted value
# relative to the original sequence
scores = -np.log2(preds / ref_preds) # N, seq_len, 4
# Make mask
mask = ~one_hot.astype(bool)
# Calculate the average effect of mutation
mean_scores = np.zeros_like(scores[:, :, 0])
for i in range(scores.shape[0]):
for j in range(scores.shape[1]):
mean_scores[i, j] = scores[i, j][mask[i, j]].mean()
return mean_scores
[docs]def generate_random_sequences(n=1, seq_len=1024, seed=None):
"""
Generate random DNA sequences.
Args:
n (int): Number of sequences to generate (default 1).
seq_len (int): Length of each sequence (default 1024).
seed (int): Seed value for random number generator (default 0).
Returns:
Generated sequences as a list of strings.
"""
# Set random seed
rng = np.random.RandomState(seed)
# Generate sequences
seqs = rng.choice(["A", "C", "G", "T"], size=n * seq_len, replace=True).reshape(
[n, seq_len]
)
return ["".join(seq) for seq in seqs]
[docs]def motif_likelihood(seqs, motif, label, model):
"""
Return the log-likelihood of a motif occurring at the end of
each of the given sequences.
Args:
seqs (list): Sequences
motif (seq): Motif sequence
label (list): Label for the regLM model
model (pl.LightningModule): regLM model
Returns:
(list): log-likelihoods
"""
log_likelihood_per_pos = model.P_seqs_given_labels(
seqs=[seq + motif for seq in seqs],
labels=[label] * len(seqs),
per_pos=True,
log=True,
)
motif_likelihood = log_likelihood_per_pos[:, -len(motif) :]
assert motif_likelihood.shape == (len(seqs), len(motif)), motif_likelihood.shape
return motif_likelihood.sum(1)
[docs]def motif_insert(motif_dict, model, label, ref_label, seq_len, n=100):
"""
Insert motifs into random sequences and calculate log-likelihood ratio
of each motif given label vs. reference label.
Args:
motif_dict (dict): Dictionary with key-value pairs such as
motif ID: consensus sequence
model (pl.LightningModule): regLM model
label (list): Label for the regLM model
ref_label (str):
seq_len (int): Length of random sequences preceding the motif
n (int): number of random sequences to insert the motif in
Returns:
(pd.DataFrame): Dataframe containing log likelihood ratios of motif-containing
sequences
"""
out = pd.DataFrame()
random_seqs = generate_random_sequences(n=n, seq_len=seq_len)
for motif_id, consensus in motif_dict.items():
# Compute log-likelihood with token 00 and 44
LL_with_label = motif_likelihood(random_seqs, consensus, label, model)
LL_with_ref = motif_likelihood(random_seqs, consensus, ref_label, model)
# Compute log-likelihood ratio
ratio = LL_with_label - LL_with_ref
curr_out = pd.DataFrame(
{
"Sequence": random_seqs,
"Motif": motif_id,
"LL_ratio": ratio,
}
)
out = pd.concat([out, curr_out])
return out