Source code for reglm.dataset

import numpy as np
import torch
from torch.utils.data import Dataset


[docs]class CharDataset(Dataset): def __init__(self, seqs, labels, seq_len=None): """ A dataset class to produce tokenized sequences for training regLM. Each sequence will be represented as 0<LABEL><SEQ>1; hence 0 is the start token and 1 is the end token. Args: seqs (list): List of sequences. labels (list): List of labels as strings seq_len (int): Maximum sequence length """ # Check assert len(seqs) == len(labels), "seqs and labels should have equal length" assert ( len(set([len(x) for x in labels])) == 1 ), "All labels should be of equal length" # Store data self.seqs = seqs self.labels = labels # maximum sequence length self.seq_len = seq_len or np.max([len(seq) for seq in self.seqs]) self.label_len = len(self.labels[0]) self.unique_labels = set( np.concatenate([[tok for tok in lab] for lab in self.labels]) ) assert ( len(self.unique_labels) <= 10 ), ">10 label classes are currently not supported" # Encoding self.label_stoi = { "0": 2, "1": 3, "2": 4, "3": 5, "4": 6, "5": 7, "6": 8, "7": 9, "8": 10, "9": 11, } self.base_stoi = { "A": 7, "C": 8, "G": 9, "T": 10, "N": 11, } self.label_itos = {v: k for k, v in self.label_stoi.items()} self.base_itos = {v: k for k, v in self.base_stoi.items()} def __len__(self): return len(self.seqs)
[docs] def encode_seq(self, seq): """ Encode a sequence as a torch tensor of tokens Args: seq (str): DNA sequence Returns: torch.LongTensor of shape (seq_len,) """ return torch.LongTensor([self.base_stoi[tok] for tok in seq])
[docs] def encode_label(self, label): """ Encode a label as a torch tensor of tokens Args: label (str): label token sequence Returns: torch.LongTensor of shape (label_len,) """ return torch.tensor([self.label_stoi[tok] for tok in label])
[docs] def decode(self, idxs, is_labeled=False): """ Given a torch tensor of tokens, return the decoded sequence as a string. Args: idxs (list, torch.LongTensor): list or 1-D tensor is_labeled (bool): Whether labels are included Returns: labeled sequence as a string """ if isinstance(idxs, torch.Tensor): idxs = idxs.detach().cpu().tolist() if is_labeled: # Split the input into sequence and label label = idxs[: self.label_len] seq = idxs[self.label_len :] # Decode them separately and rejoin return "".join( [self.label_itos[i] for i in label] + [self.base_itos[i] for i in seq] ) else: # Only a sequence is provided return "".join([self.base_itos[i] for i in idxs])
def __getitem__(self, idx): """ Return a single labeled example as a tensor of tokens x = 0<LABEL><SEQ> y = <SEQ>1 Args: idx: Index of example to return Returns: x (torch.LongTensor): tensor of shape (1 + self.label_len + self.seq_len) y (torch.LongTensor): tensor of shape (self.seq_len + 1, ) """ # Get sequence seq = self.seqs[idx] # Encode sequence seq = self.encode_seq(seq) # Get label label = self.labels[idx] # Encode label label = self.encode_label(label) # Generate empty tensors x = torch.zeros(self.seq_len + self.label_len + 1, dtype=torch.long) y = torch.zeros(self.seq_len + 1, dtype=torch.long) # Input: START(0) + label + sequence + trailing zeros (will be ignored) x[1 : 1 + self.label_len] = label x[1 + self.label_len : 1 + self.label_len + len(seq)] = seq # Output: sequence + END (1) + trailing zeros (will be ignored) y[: len(seq)] = seq y[len(seq)] = 1 return x, y