[docs]classDecimaModel(BaseModel):""" Decima model. Args: n_tasks: Number of tasks. mask: Whether to use a mask. borzoi_kwargs: Keyword arguments for the Borzoi model. """
[docs]def__init__(self,n_tasks:int,mask=True,borzoi_kwargs:dict=None,init_borzoi=False,replicate=0):borzoi_kwargs={"crop_len":5120,"n_tasks":7611,"stem_channels":512,"stem_kernel_size":15,"init_channels":608,"n_conv":7,"kernel_size":5,"n_transformers":8,"key_len":64,"value_len":192,"pos_dropout":0.0,"attn_dropout":0.0,"n_heads":8,"n_pos_features":32,# backward compatibility with grelu<1.0.7"norm_kwargs":{"eps":1e-5},"act_func":"gelu","final_act_func":None,"final_pool_func":None,**(borzoi_kwargsordict()),}model=BorzoiModel(**borzoi_kwargs)ifmodelin["0","1","2","3"]:# replicate indexmodel=int(model)ifinit_borzoi:# Load state dictifPath(str(replicate)).exists():ifreplicate.endswith(".h5")orreplicate.endswith(".pth")orreplicate.endswith(".pt"):state_dict=torch.load(replicate)elifreplicate.endswith(".ckpt"):state_dict=torch.load(replicate)["state_dict"]else:raiseValueError(f"Invalid replicate path: {replicate}")else:wandb.login(host="https://api.wandb.ai/",anonymous="must")api=wandb.Api(overrides={"base_url":"https://api.wandb.ai/"})art=api.artifact(f"grelu/borzoi/human_state_dict_fold{replicate}:latest")withTemporaryDirectory()asd:art.download(d)state_dict=torch.load(Path(d)/f"fold{replicate}.h5")model.load_state_dict(state_dict)head=ConvHead(n_tasks=n_tasks,in_channels=1920,pool_func="avg")super().__init__(embedding=model.embedding,head=head)# Add a channel for the gene maskself.mask=maskifself.mask:weight=self.embedding.conv_tower.blocks[0].conv.weightnew_layer=nn.Conv1d(5,512,kernel_size=(15,),stride=(1,),padding="same")new_weight=nn.Parameter(torch.cat([weight,new_layer.weight[:,[-1],:]],axis=1))self.embedding.conv_tower.blocks[0].conv.weight=new_weight