Source code for reglm.regression
import numpy as np
import pandas as pd
import pytorch_lightning as pl
import torch
from enformer_pytorch import Enformer
from enformer_pytorch.data import str_to_one_hot
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import CSVLogger
from torch import nn, optim
from torch.utils.data import DataLoader, Dataset, WeightedRandomSampler
[docs]class SeqDataset(Dataset):
"""
PyTorch dataset class for training enformer-based regression models
Args:
seqs (list, pd.DataFrame): either a list of DNA sequences, or a dataframe whose
first column is DNA sequences and remaining columns are labels.
seq_len (int): Length of sequences to return. Sequences will be padded with Ns
on the right to reach this length.
"""
def __init__(self, seqs, seq_len=None):
super().__init__()
self.is_labeled = False
# Add seqs
if isinstance(seqs, pd.DataFrame):
self.seqs = seqs.iloc[:, 0]
self.labels = torch.Tensor(
seqs.values[:, 1:].astype(np.float32)
) # N, n_tasks
self.is_labeled = True
else:
self.seqs = seqs
# Add max sequence length
self.seq_len = seq_len or np.max([len(seq) for seq in self.seqs])
def __len__(self):
return len(self.seqs)
def __getitem__(self, idx):
# Get sequence
seq = self.seqs[idx]
# Get label and augment sequence
if self.is_labeled:
label = self.labels[idx]
# Pad sequence
if len(seq) < self.seq_len:
seq = seq + "N" * (self.seq_len - len(seq))
# One-hot encode
seq = str_to_one_hot(seq)
if self.is_labeled:
return seq, label
else:
return seq
[docs]class EnformerModel(pl.LightningModule):
"""
Enformer-based single-task regression models that can be
trained from scratch or finetuned.
Args:
lr (float): learning rate
loss (str): "poisson" or "mse"
pretrained (bool): If true, initialize from the pretrained enformer model
dim (int): Number of conv layer filters
depth (int): Number of transformer layers
n_downsamples (int): Number of conv/pool blocks
"""
def __init__(
self,
lr=1e-4,
loss="poisson",
pretrained=False,
dim=1536,
depth=11,
n_downsamples=7,
):
super().__init__()
self.n_tasks = 1
self.save_hyperparameters(ignore=["model"])
# Build model
if pretrained:
self.trunk = Enformer.from_pretrained(
"EleutherAI/enformer-official-rough", target_length=-1
)._trunk
else:
self.trunk = Enformer.from_hparams(
dim=dim,
depth=depth,
heads=8,
num_downsamples=n_downsamples,
target_length=-1,
)._trunk
self.head = nn.Linear(dim * 2, self.n_tasks, bias=True)
# Training params
self.lr = lr
self.loss_type = loss
if loss == "poisson":
self.loss = nn.PoissonNLLLoss(log_input=True, full=True)
else:
self.loss = nn.MSELoss()
[docs] def forward(self, x, return_logits=False):
if (isinstance(x, list)) or (isinstance(x, tuple)):
if isinstance(x[0], str):
# If x is a list of strings, convert it into a one-hot encoded tensor
x = str_to_one_hot(x)
else:
# If x is a list (input, target) use only the input
x = x[0]
x = self.trunk(x) # N, L, dim*2
x = self.head(x) # N, L, 1
x = x.mean(1) # N, 1
if (self.loss_type == "poisson") and (not return_logits):
x = torch.exp(x)
return x
[docs] def training_step(self, batch, batch_idx):
x, y = batch
logits = self.forward(x, return_logits=True)
loss = self.loss(logits, y)
self.log(
"train_loss", loss, logger=True, on_step=True, on_epoch=True, prog_bar=True
)
return loss
[docs] def validation_step(self, batch, batch_idx):
x, y = batch
logits = self.forward(x, return_logits=True)
loss = self.loss(logits, y)
self.log("val_loss", loss, logger=True, on_step=False, on_epoch=True)
return loss
[docs] def validation_epoch_end(self, output):
print("\nval_loss", torch.mean(torch.Tensor(output)).item())
[docs] def train_on_dataset(
self,
train_dataset,
val_dataset,
device=0,
batch_size=512,
num_workers=1,
save_dir=".",
max_epochs=10,
weights=None,
):
torch.set_float32_matmul_precision("medium")
# Make trainer
trainer = pl.Trainer(
max_epochs=max_epochs,
accelerator="gpu",
devices=[device],
logger=CSVLogger(save_dir),
callbacks=[ModelCheckpoint(monitor="val_loss", mode="min")],
)
# Make dataloaders
if weights is None:
train_dl = DataLoader(
train_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=num_workers,
)
else:
sampler = WeightedRandomSampler(
weights=weights,
num_samples=len(weights),
replacement=True,
)
train_dl = DataLoader(
train_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=num_workers,
sampler=sampler,
)
val_dl = DataLoader(
val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers
)
# Fit model
trainer.fit(model=self, train_dataloaders=train_dl, val_dataloaders=val_dl)
return trainer
[docs] def predict_on_dataset(
self,
dataset,
device=0,
num_workers=1,
batch_size=512,
):
torch.set_float32_matmul_precision("medium")
# Make dataloader
dataloader = DataLoader(
dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers
)
# Make trainer
trainer = pl.Trainer(accelerator="gpu", devices=[device], logger=None)
# Run inference
return (
torch.concat(trainer.predict(self, dataloader)).cpu().detach().numpy()
) # N, 1
[docs]class MultiTaskEnformerModel(nn.Module):
"""
Combine multiple single-task enformer models into a single object.
Args:
models (list): List of multiple EnformerModel objects
device (int): GPU index
"""
def __init__(self, model1, model2, model3=None, mean=False, specificity=None):
super().__init__()
self.model1 = model1
self.model2 = model2
self.model3 = model3
self.n_tasks = 2 if model3 is None else 3
self.mean = mean
self.specificity = specificity
[docs] def forward(self, x):
if self.model3 is None:
x = [self.model1(x), self.model2(x)]
else:
x = [self.model1(x), self.model2(x), self.model3(x)]
x = torch.cat(x, axis=1) # N, n_models
return x
[docs] def predict_on_dataset(self, ds, **kwargs):
if self.model3 is None:
out = np.concatenate(
[
self.model1.predict_on_dataset(ds, **kwargs),
self.model2.predict_on_dataset(ds, **kwargs),
],
axis=1,
)
else:
out = np.concatenate(
[
self.model1.predict_on_dataset(ds, **kwargs),
self.model2.predict_on_dataset(ds, **kwargs),
self.model3.predict_on_dataset(ds, **kwargs),
],
axis=1,
)
if self.mean:
out = out.mean(1)
elif self.specificity is not None:
non_specific = [x for x in range(self.n_tasks) if x != self.specificity]
out = out[:, self.specificity] - np.max(out[:, non_specific], axis=1)
return out