"""Decima pytorch model class."""
import logging
from pathlib import Path
from tempfile import TemporaryDirectory
import torch
import wandb
from grelu.model.heads import ConvHead
from grelu.model.models import BaseModel, BorzoiModel
from torch import nn
[docs]
class DecimaModel(BaseModel):
"""
Decima model.
Args:
n_tasks: Number of tasks.
mask: Whether to use a mask.
borzoi_kwargs: Keyword arguments for the Borzoi model.
init_borzoi: Whether to initialize the Borzoi model.
replicate: Replicate to initialize the model from.
Examples:
>>> model = DecimaModel(
... n_tasks=10,
... mask=True,
... borzoi_kwargs=None,
... init_borzoi=False,
... replicate=0,
... )
>>> model.load_state_dict(
... torch.load(
... "model.pth"
... )
... )
"""
[docs]
def __init__(self, n_tasks: int, mask=True, borzoi_kwargs: dict = None, init_borzoi=False, replicate=0):
borzoi_kwargs = {
"crop_len": 5120,
"n_tasks": 7611,
"stem_channels": 512,
"stem_kernel_size": 15,
"init_channels": 608,
"n_conv": 7,
"kernel_size": 5,
"n_transformers": 8,
"key_len": 64,
"value_len": 192,
"pos_dropout": 0.0,
"attn_dropout": 0.0,
"n_heads": 8,
"n_pos_features": 32,
# backward compatibility with grelu<1.0.7
"norm_kwargs": {"eps": 1e-5},
"act_func": "gelu",
"final_act_func": None,
"final_pool_func": None,
**(borzoi_kwargs or dict()),
}
model = BorzoiModel(**borzoi_kwargs)
if model in ["0", "1", "2", "3"]: # replicate index
model = int(model)
if init_borzoi:
logger = logging.getLogger("decima")
# Load state dict
if Path(str(replicate)).exists():
logger.info(f"Initializing weights from Borzoi model using file: {replicate}")
if replicate.endswith(".h5") or replicate.endswith(".pth") or replicate.endswith(".pt"):
state_dict = torch.load(replicate)
elif replicate.endswith(".ckpt"):
state_dict = torch.load(replicate)["state_dict"]
else:
raise ValueError(f"Invalid replicate path: {replicate}")
else:
logger.info(f"Initializing weights from Borzoi model using wandb for replicate: {replicate}")
wandb.login(host="https://api.wandb.ai/", anonymous="must")
api = wandb.Api(overrides={"base_url": "https://api.wandb.ai/"})
art = api.artifact(f"grelu/borzoi/human_state_dict_fold{replicate}:latest")
with TemporaryDirectory() as d:
art.download(d)
state_dict = torch.load(Path(d) / f"fold{replicate}.h5")
model.load_state_dict(state_dict)
head = ConvHead(n_tasks=n_tasks, in_channels=1920, pool_func="avg")
super().__init__(embedding=model.embedding, head=head)
# Add a channel for the gene mask
self.mask = mask
if self.mask:
weight = self.embedding.conv_tower.blocks[0].conv.weight
new_layer = nn.Conv1d(5, 512, kernel_size=(15,), stride=(1,), padding="same")
new_weight = nn.Parameter(torch.cat([weight, new_layer.weight[:, [-1], :]], axis=1))
self.embedding.conv_tower.blocks[0].conv.weight = new_weight