grelu.data.dataset#

Pytorch dataset classes to load sequence data

All dataset classes produce either one-hot encoded sequences of shape (4, L) or sequence-label pairs of shape (4, L) and (T, L).

Classes#

LabeledSeqDataset

A general Dataset class for DNA sequences and labels. All sequences and

DFSeqDataset

LabeledSeqDataset derived class for a dataframe containing sequences

AnnDataSeqDataset

LabeledSeqDataset derived class for an AnnData object.

BigWigSeqDataset

LabeledSeqDataset derived class for genomic intervals and BigWig files.

SeqDataset

Dataset to cycle through unlabeled sequences for inference. All sequences

VariantDataset

Dataset class to perform inference on sequence variants.

VariantMarginalizeDataset

Dataset to marginalize the effect of given variants

PatternMarginalizeDataset

Dataset to marginalize the effect of given sequence patterns

ISMDataset

Dataset to perform In silico mutagenesis (ISM)

MotifScanDataset

Dataset to perform in silico motif scanning by inserting a motif

Module Contents#

class grelu.data.dataset.LabeledSeqDataset(seqs: str | Sequence | pandas.DataFrame | numpy.ndarray, labels: numpy.ndarray, tasks: Sequence | pandas.DataFrame | None = None, seq_len: int | None = None, genome: str | None = None, end: str = 'both', rc: bool = False, max_seq_shift: int = 0, label_len: int | None = None, max_pair_shift: int = 0, label_aggfunc: str | Callable | None = None, bin_size: int | None = None, min_label_clip: int | None = None, max_label_clip: int | None = None, label_transform_func: str | Callable | None = None, seed: int | None = None, augment_mode: str = 'serial')[source]#

Bases: torch.utils.data.Dataset

A general Dataset class for DNA sequences and labels. All sequences and labels will be stored in memory.

Parameters:
  • seqs – DNA sequences as intervals, strings, indices or one-hot.

  • labels – A numpy array of shape (B, T, L) containing the labels.

  • tasks – A list of task names or a pandas dataframe containing task information. If a dataframe is supplied, the row indices should be the task names.

  • seq_len – Uniform expected length (in base pairs) for output sequences

  • genome – The name of the genome from which to read sequences. Only needed if genomic intervals are supplied.

  • end – Which end of the sequence to resize if necessary. Supported values are “left”, “right” and “both”.

  • rc – If True, sequences will be augmented by reverse complementation. If False, they will not be reverse complemented.

  • max_seq_shift – Maximum number of bases to shift the sequence for augmentation. This is normally a small value (< 10). If 0, sequences will not be augmented by shifting.

  • label_len – Uniform expected length (in base pairs) for output labels

  • max_pair_shift – Maximum number of bases to shift both the sequence and label for augmentation. If 0, sequence and label pairs will not be augmented by shifting.

  • label_aggfunc – Function to aggregate the labels over bin_size.

  • bin_size – Number of bases to aggregate in the label. Only used if label_aggfunc is not None. If None, it will be taken as equal to label_len.

  • min_label_clip – Minimum value for label

  • max_label_clip – Maximum value for label

  • label_transform_func – Function to transform label values.

  • seed – Random seed for reproducibility

  • augment_mode – “random” or “serial”

end[source]#
genome[source]#
min_label_clip[source]#
max_label_clip[source]#
label_transform_func[source]#
seq_len[source]#
label_len[source]#
label_aggfunc[source]#
bin_size[source]#
rc[source]#
max_seq_shift[source]#
max_pair_shift[source]#
padded_seq_len[source]#
padded_label_len[source]#
n_seqs[source]#
n_tasks[source]#
label_transform[source]#
augmenter[source]#
n_augmented[source]#
n_alleles = 1[source]#
predict = False[source]#
_load_seqs(seqs: str | Sequence | pandas.DataFrame | numpy.ndarray) None[source]#
_load_tasks(tasks: pandas.DataFrame | List) None[source]#
_load_labels(labels: numpy.ndarray) None[source]#
__len__() int[source]#
get_labels() numpy.ndarray[source]#

Return the labels as a numpy array of shape (B, T, L). This does not account for data augmentation.

__getitem__(idx: int) torch.Tensor | Tuple[torch.Tensor, torch.Tensor][source]#
class grelu.data.dataset.DFSeqDataset(df: pandas.DataFrame, tasks: pandas.DataFrame | None = None, seq_len: int | None = None, genome: str | None = None, end: str = 'both', rc: bool = False, max_seq_shift: int = 0, seed: int | None = None, augment_mode: str = 'serial')[source]#

Bases: LabeledSeqDataset

LabeledSeqDataset derived class for a dataframe containing sequences (or genomic intervals) and labels.

Parameters:
  • df – DataFrame containing either DNA sequences in the first column or genomic intervals in the first 3 columns. All remaining columns are assumed to be labels.

  • tasks – A list of task names or a pandas dataframe containing task information. If a dataframe is supplied, the row indices should be the task names.

  • seq_len – Uniform expected length (in base pairs) for output sequences

  • genome – The name of the genome from which to read sequences. Only needed if genomic intervals are supplied.

  • end – Which end of the sequence to resize if necessary. Supported values are “left”, “right” and “both”.

  • rc – If True, sequences will be augmented by reverse complementation. If False, they will not be reverse complemented.

  • max_seq_shift – Maximum number of bases to shift the sequence for augmentation. This is normally a small value (< 10). If 0, sequences will not be augmented by shifting.

class grelu.data.dataset.AnnDataSeqDataset(adata, label_key: str | None = None, seq_len: int | None = None, genome: str | None = None, end: str = 'both', rc: bool = False, max_seq_shift: int = 0, seed: int | None = None, augment_mode: str = 'serial')[source]#

Bases: LabeledSeqDataset

LabeledSeqDataset derived class for an AnnData object.

Parameters:
  • adata – AnnData object containing genomic intervals in .var

  • label_key – If labels are stored in .varm, the key under which they are stored.

  • seq_len – Uniform expected length (in base pairs) for output sequences

  • genome – The name of the genome from which to read sequences. Only needed if genomic intervals are supplied.

  • end – Which end of the sequence to resize if necessary. Supported values are “left”, “right” and “both”.

  • rc – If True, sequences will be augmented by reverse complementation. If False, they will not be reverse complemented.

  • max_seq_shift – Maximum number of bases to shift the sequence for augmentation. This is normally a small value (< 10). If 0, sequences will not be augmented by shifting.

class grelu.data.dataset.BigWigSeqDataset(intervals: pandas.DataFrame, bw_files: str | List[str], tasks: List[str] | pandas.DataFrame | None = None, seq_len: int | None = None, genome: str | None = None, end: str = 'both', rc: bool = False, max_seq_shift: int = 0, label_len: int | None = None, max_pair_shift: int = 0, label_aggfunc: str | Callable | None = np.sum, bin_size: int | None = None, min_label_clip: int | None = None, max_label_clip: int | None = None, label_transform_func: str | Callable | None = None, seed: int | None = None, augment_mode: str = 'serial')[source]#

Bases: LabeledSeqDataset

LabeledSeqDataset derived class for genomic intervals and BigWig files. Labels are read into memory.

Parameters:
  • intervals – A Pandas dataframe containing genomic intervals

  • bw_files – List of bigWig files

  • tasks – A list of task names or a pandas dataframe containing task information. If a dataframe is supplied, the row indices should be the task names.

  • seq_len – Uniform expected length (in base pairs) for output sequences

  • genome – The name of the genome from which to read sequences. Only needed if genomic intervals are supplied.

  • end – Which end of the sequence to resize. Supported values are “left”, “right” and “both”.

  • rc – If True, sequences will be augmented by reverse complementation. If False, they will not be reverse complemented.

  • max_seq_shift – Maximum number of bases to shift the sequence for augmentation. This is normally a small value (< 10). If 0, sequences will not be augmented by shifting.

  • max_pair_shift – Maximum number of bases to shift both the sequence and label for augmentation. If 0, sequence and label pairs will not be augmented by shifting.

  • label_aggfunc – Function to aggregate the labels over bin_size.

  • bin_size – Number of bases to aggregate in the label.

  • min_label_clip – Minimum value for label

  • max_label_clip – Maximum value for label

  • label_transform_func – Function to transform label values.

_load_labels(bw_files: str | List[str]) None[source]#

Load the labels from the provided bigWig files.

class grelu.data.dataset.SeqDataset(seqs: str | Sequence | pandas.DataFrame | numpy.ndarray, seq_len: int | None = None, genome: str | None = None, end: str = 'both', rc: bool = False, max_seq_shift: int = 0, seed: int | None = None, augment_mode: str = 'serial')[source]#

Bases: torch.utils.data.Dataset

Dataset to cycle through unlabeled sequences for inference. All sequences are stored in memory.

Parameters:
  • seqs – DNA sequences

  • seq_len – Uniform expected length (in base pairs) for output sequences

  • genome – The name of the genome from which to read sequences. Only needed if genomic intervals are supplied.

  • end – Which end of the sequence to resize if necessary. Supported values are “left”, “right” and “both”.

  • rc – If True, sequences will be augmented by reverse complementation. If False, they will not be reverse complemented.

  • max_seq_shift – Maximum number of bases to shift the sequence for augmentation. This is normally a small value (< 10). If 0, sequences will not be augmented by shifting.

end[source]#
genome[source]#
seq_len[source]#
rc[source]#
max_seq_shift[source]#
n_seqs[source]#
augmenter[source]#
n_augmented[source]#
n_alleles = 1[source]#
_load_seqs(seqs: str | Sequence | pandas.DataFrame | numpy.ndarray) None[source]#
__len__() int[source]#
__getitem__(idx: int) torch.Tensor[source]#
class grelu.data.dataset.VariantDataset(variants: pandas.DataFrame, seq_len: int, genome: str | None = None, rc: bool = False, max_seq_shift: int = 0, frac_mutation: float = 0.0, n_mutated_seqs: int = 1, protect: List[int] | None = None, seed: int | None = None, augment_mode: str = 'serial')[source]#

Bases: torch.utils.data.Dataset

Dataset class to perform inference on sequence variants.

Parameters:
  • variants – pd.DataFrame with columns “chrom”, “pos”, “ref”, “alt”.

  • seq_len – Uniform expected length (in base pairs) for output sequences

  • genome – The name of the genome from which to read sequences.

  • rc – If True, sequences will be augmented by reverse complementation. If False, they will not be reverse complemented.

  • max_seq_shift – Maximum number of bases to shift the sequence for augmentation. This is normally a small value (< 10). If 0, sequences will not be augmented by shifting.

  • frac_mutation – Fraction of bases to randomly mutate for data augmentation.

  • protect – A list of positions to protect from mutation.

  • n_mutated_seqs – Number of mutated sequences to generate from each input sequence for data augmentation.

genome[source]#
seq_len[source]#
rc[source]#
max_seq_shift[source]#
frac_mutated_bases[source]#
n_mutated_bases[source]#
n_mutated_seqs[source]#
n_alleles = 2[source]#
n_seqs[source]#
augmenter[source]#
n_augmented[source]#
_load_alleles(variants: pandas.DataFrame) None[source]#
_load_seqs(variants: pandas.DataFrame) None[source]#
__len__() int[source]#
__getitem__(idx: int) torch.Tensor[source]#
class grelu.data.dataset.VariantMarginalizeDataset(variants: pandas.DataFrame, genome: str, seq_len: int, seed: int | None = None, rc: bool = False, max_seq_shift: int = 0, n_shuffles: int = 100)[source]#

Bases: torch.utils.data.Dataset

Dataset to marginalize the effect of given variants across shuffled background sequences. All sequences are stored in memory.

Parameters:
  • variants – A dataframe of sequence variants

  • genome – The name of the genome from which to read sequences. Only used if genomic intervals are supplied.

  • seed – Seed for random number generator

  • rc – If True, sequences will be augmented by reverse complementation. If False, they will not be reverse complemented.

  • max_seq_shift – Maximum number of bases to shift the sequence for augmentation. This is normally a small value (< 10). If 0, sequences will not be augmented by shifting.

  • n_shuffles – Number of times to shuffle each background sequence to generate a background distribution.

genome[source]#
seed[source]#
seq_len[source]#
rc = False[source]#
max_seq_shift = 0[source]#
n_shuffles[source]#
augmenter[source]#
n_augmented[source]#
bg = None[source]#
curr_seq_idx = None[source]#
_load_alleles(variants: pandas.DataFrame) None[source]#

Load the alleles to substitute into the background

_load_seqs(variants: pandas.DataFrame) None[source]#

Load sequences surrounding the variant position

__update__(idx: int) None[source]#

Update the current background

__len__() int[source]#
__getitem__(idx: int) torch.Tensor[source]#
class grelu.data.dataset.PatternMarginalizeDataset(seqs: List[str] | pandas.DataFrame | numpy.ndarray, patterns: List[str], genome: str | None = None, seq_len: int | None = None, seed: int | None = None, rc: bool = False, n_shuffles: int = 1)[source]#

Bases: torch.utils.data.Dataset

Dataset to marginalize the effect of given sequence patterns across shuffled background sequences. All sequences are stored in memory.

Parameters:
  • seqs – DNA sequences as intervals, strings, integer encoded or one-hot encoded.

  • patterns – List of alleles or motif sequences to insert into the background sequences.

  • n_shuffles – Number of times to shuffle each background sequence to generate a background distribution.

  • genome – The name of the genome from which to read sequences. Only used if genomic intervals are supplied.

  • seed – Seed for random number generator

  • rc – If True, sequences will be augmented by reverse complementation. If False, they will not be reverse complemented.

genome[source]#
seed[source]#
seq_len[source]#
rc[source]#
n_shuffles[source]#
augmenter[source]#
n_augmented[source]#
bg = None[source]#
curr_seq_idx = None[source]#
_load_alleles(patterns: List[str]) None[source]#
_load_seqs(seqs: pandas.DataFrame | List[str] | numpy.ndarray) None[source]#

Make the background sequences

__update__(idx: int) None[source]#

Update the current background

__len__() int[source]#
__getitem__(idx: int) torch.Tensor[source]#
class grelu.data.dataset.ISMDataset(seqs: str | Sequence | pandas.DataFrame | numpy.ndarray, genome: str | None = None, drop_ref: bool = False, positions: List[int] | None = None)[source]#

Bases: torch.utils.data.Dataset

Dataset to perform In silico mutagenesis (ISM)

Parameters:
  • seqs – DNA sequences as intervals, strings, indices or one-hot.

  • genome – The name of the genome from which to read sequences. This is only needed if genomic intervals are supplied in seqs.

  • drop_ref – If True, the base that already exists at each position will not be included in the returned sequences.

  • positions – List of positions to mutate. If None, all positions will be mutated.

positions[source]#
genome[source]#
drop_ref[source]#
n_alleles[source]#
n_seqs[source]#
seq_len[source]#
n_augmented[source]#
_load_seqs(seqs) None[source]#
__len__() int[source]#
__getitem__(idx: int, return_compressed=False) torch.Tensor[source]#
class grelu.data.dataset.MotifScanDataset(seqs: str | Sequence | pandas.DataFrame | numpy.ndarray, motifs: List[str], genome: str | None = None, positions: List[int] | None = None)[source]#

Bases: torch.utils.data.Dataset

Dataset to perform in silico motif scanning by inserting a motif at each position of a sequence.

Parameters:
  • seqs – Background DNA sequences as intervals, strings, integer encoded or one-hot encoded.

  • motifs – A list of subsequences to insert into the background sequences.

  • genome – The name of the genome from which to read sequences. This is only needed if genomic intervals are supplied in seqs.

  • positions – List of positions at which to insert the motif. If None, all positions will be mutated.

positions[source]#
genome[source]#
motifs[source]#
max_motif_len[source]#
n_alleles[source]#
n_seqs[source]#
seq_len[source]#
n_augmented[source]#
_load_seqs(seqs)[source]#
__len__() int[source]#
__getitem__(idx: int, return_compressed=False) torch.Tensor[source]#