grelu.data.dataset#
Pytorch dataset classes to load sequence data
All dataset classes produce either one-hot encoded sequences of shape (4, L) or sequence-label pairs of shape (4, L) and (T, L).
Classes#
A general Dataset class for DNA sequences and labels. All sequences and |
|
LabeledSeqDataset derived class for a dataframe containing sequences |
|
LabeledSeqDataset derived class for an AnnData object. |
|
LabeledSeqDataset derived class for genomic intervals and BigWig files. |
|
Dataset to cycle through unlabeled sequences for inference. All sequences |
|
Dataset class to perform inference on sequence variants. |
|
Dataset to marginalize the effect of given variants |
|
Dataset to marginalize the effect of given sequence patterns |
|
Dataset to perform In silico mutagenesis (ISM) |
|
Dataset to perform in silico motif scanning by inserting a motif |
Module Contents#
- class grelu.data.dataset.LabeledSeqDataset(seqs: str | Sequence | pandas.DataFrame | numpy.ndarray, labels: numpy.ndarray, tasks: Sequence | pandas.DataFrame | None = None, seq_len: int | None = None, genome: str | None = None, end: str = 'both', rc: bool = False, max_seq_shift: int = 0, label_len: int | None = None, max_pair_shift: int = 0, label_aggfunc: str | Callable | None = None, bin_size: int | None = None, min_label_clip: int | None = None, max_label_clip: int | None = None, label_transform_func: str | Callable | None = None, seed: int | None = None, augment_mode: str = 'serial')[source]#
Bases:
torch.utils.data.Dataset
A general Dataset class for DNA sequences and labels. All sequences and labels will be stored in memory.
- Parameters:
seqs – DNA sequences as intervals, strings, indices or one-hot.
labels – A numpy array of shape (B, T, L) containing the labels.
tasks – A list of task names or a pandas dataframe containing task information. If a dataframe is supplied, the row indices should be the task names.
seq_len – Uniform expected length (in base pairs) for output sequences
genome – The name of the genome from which to read sequences. Only needed if genomic intervals are supplied.
end – Which end of the sequence to resize if necessary. Supported values are “left”, “right” and “both”.
rc – If True, sequences will be augmented by reverse complementation. If False, they will not be reverse complemented.
max_seq_shift – Maximum number of bases to shift the sequence for augmentation. This is normally a small value (< 10). If 0, sequences will not be augmented by shifting.
label_len – Uniform expected length (in base pairs) for output labels
max_pair_shift – Maximum number of bases to shift both the sequence and label for augmentation. If 0, sequence and label pairs will not be augmented by shifting.
label_aggfunc – Function to aggregate the labels over bin_size.
bin_size – Number of bases to aggregate in the label. Only used if label_aggfunc is not None. If None, it will be taken as equal to label_len.
min_label_clip – Minimum value for label
max_label_clip – Maximum value for label
label_transform_func – Function to transform label values.
seed – Random seed for reproducibility
augment_mode – “random” or “serial”
- _load_seqs(seqs: str | Sequence | pandas.DataFrame | numpy.ndarray) None [source]#
- _load_tasks(tasks: pandas.DataFrame | List) None [source]#
- _load_labels(labels: numpy.ndarray) None [source]#
- get_labels() numpy.ndarray [source]#
Return the labels as a numpy array of shape (B, T, L). This does not account for data augmentation.
- class grelu.data.dataset.DFSeqDataset(df: pandas.DataFrame, tasks: pandas.DataFrame | None = None, seq_len: int | None = None, genome: str | None = None, end: str = 'both', rc: bool = False, max_seq_shift: int = 0, seed: int | None = None, augment_mode: str = 'serial')[source]#
Bases:
LabeledSeqDataset
LabeledSeqDataset derived class for a dataframe containing sequences (or genomic intervals) and labels.
- Parameters:
df – DataFrame containing either DNA sequences in the first column or genomic intervals in the first 3 columns. All remaining columns are assumed to be labels.
tasks – A list of task names or a pandas dataframe containing task information. If a dataframe is supplied, the row indices should be the task names.
seq_len – Uniform expected length (in base pairs) for output sequences
genome – The name of the genome from which to read sequences. Only needed if genomic intervals are supplied.
end – Which end of the sequence to resize if necessary. Supported values are “left”, “right” and “both”.
rc – If True, sequences will be augmented by reverse complementation. If False, they will not be reverse complemented.
max_seq_shift – Maximum number of bases to shift the sequence for augmentation. This is normally a small value (< 10). If 0, sequences will not be augmented by shifting.
- class grelu.data.dataset.AnnDataSeqDataset(adata, label_key: str | None = None, seq_len: int | None = None, genome: str | None = None, end: str = 'both', rc: bool = False, max_seq_shift: int = 0, seed: int | None = None, augment_mode: str = 'serial')[source]#
Bases:
LabeledSeqDataset
LabeledSeqDataset derived class for an AnnData object.
- Parameters:
adata – AnnData object containing genomic intervals in .var
label_key – If labels are stored in .varm, the key under which they are stored.
seq_len – Uniform expected length (in base pairs) for output sequences
genome – The name of the genome from which to read sequences. Only needed if genomic intervals are supplied.
end – Which end of the sequence to resize if necessary. Supported values are “left”, “right” and “both”.
rc – If True, sequences will be augmented by reverse complementation. If False, they will not be reverse complemented.
max_seq_shift – Maximum number of bases to shift the sequence for augmentation. This is normally a small value (< 10). If 0, sequences will not be augmented by shifting.
- class grelu.data.dataset.BigWigSeqDataset(intervals: pandas.DataFrame, bw_files: str | List[str], tasks: List[str] | pandas.DataFrame | None = None, seq_len: int | None = None, genome: str | None = None, end: str = 'both', rc: bool = False, max_seq_shift: int = 0, label_len: int | None = None, max_pair_shift: int = 0, label_aggfunc: str | Callable | None = np.sum, bin_size: int | None = None, min_label_clip: int | None = None, max_label_clip: int | None = None, label_transform_func: str | Callable | None = None, seed: int | None = None, augment_mode: str = 'serial')[source]#
Bases:
LabeledSeqDataset
LabeledSeqDataset derived class for genomic intervals and BigWig files. Labels are read into memory.
- Parameters:
intervals – A Pandas dataframe containing genomic intervals
bw_files – List of bigWig files
tasks – A list of task names or a pandas dataframe containing task information. If a dataframe is supplied, the row indices should be the task names.
seq_len – Uniform expected length (in base pairs) for output sequences
genome – The name of the genome from which to read sequences. Only needed if genomic intervals are supplied.
end – Which end of the sequence to resize. Supported values are “left”, “right” and “both”.
rc – If True, sequences will be augmented by reverse complementation. If False, they will not be reverse complemented.
max_seq_shift – Maximum number of bases to shift the sequence for augmentation. This is normally a small value (< 10). If 0, sequences will not be augmented by shifting.
max_pair_shift – Maximum number of bases to shift both the sequence and label for augmentation. If 0, sequence and label pairs will not be augmented by shifting.
label_aggfunc – Function to aggregate the labels over bin_size.
bin_size – Number of bases to aggregate in the label.
min_label_clip – Minimum value for label
max_label_clip – Maximum value for label
label_transform_func – Function to transform label values.
- class grelu.data.dataset.SeqDataset(seqs: str | Sequence | pandas.DataFrame | numpy.ndarray, seq_len: int | None = None, genome: str | None = None, end: str = 'both', rc: bool = False, max_seq_shift: int = 0, seed: int | None = None, augment_mode: str = 'serial')[source]#
Bases:
torch.utils.data.Dataset
Dataset to cycle through unlabeled sequences for inference. All sequences are stored in memory.
- Parameters:
seqs – DNA sequences
seq_len – Uniform expected length (in base pairs) for output sequences
genome – The name of the genome from which to read sequences. Only needed if genomic intervals are supplied.
end – Which end of the sequence to resize if necessary. Supported values are “left”, “right” and “both”.
rc – If True, sequences will be augmented by reverse complementation. If False, they will not be reverse complemented.
max_seq_shift – Maximum number of bases to shift the sequence for augmentation. This is normally a small value (< 10). If 0, sequences will not be augmented by shifting.
- _load_seqs(seqs: str | Sequence | pandas.DataFrame | numpy.ndarray) None [source]#
- class grelu.data.dataset.VariantDataset(variants: pandas.DataFrame, seq_len: int, genome: str | None = None, rc: bool = False, max_seq_shift: int = 0, frac_mutation: float = 0.0, n_mutated_seqs: int = 1, protect: List[int] | None = None, seed: int | None = None, augment_mode: str = 'serial')[source]#
Bases:
torch.utils.data.Dataset
Dataset class to perform inference on sequence variants.
- Parameters:
variants – pd.DataFrame with columns “chrom”, “pos”, “ref”, “alt”.
seq_len – Uniform expected length (in base pairs) for output sequences
genome – The name of the genome from which to read sequences.
rc – If True, sequences will be augmented by reverse complementation. If False, they will not be reverse complemented.
max_seq_shift – Maximum number of bases to shift the sequence for augmentation. This is normally a small value (< 10). If 0, sequences will not be augmented by shifting.
frac_mutation – Fraction of bases to randomly mutate for data augmentation.
protect – A list of positions to protect from mutation.
n_mutated_seqs – Number of mutated sequences to generate from each input sequence for data augmentation.
- _load_alleles(variants: pandas.DataFrame) None [source]#
- _load_seqs(variants: pandas.DataFrame) None [source]#
- class grelu.data.dataset.VariantMarginalizeDataset(variants: pandas.DataFrame, genome: str, seq_len: int, seed: int | None = None, rc: bool = False, max_seq_shift: int = 0, n_shuffles: int = 100)[source]#
Bases:
torch.utils.data.Dataset
Dataset to marginalize the effect of given variants across shuffled background sequences. All sequences are stored in memory.
- Parameters:
variants – A dataframe of sequence variants
genome – The name of the genome from which to read sequences. Only used if genomic intervals are supplied.
seed – Seed for random number generator
rc – If True, sequences will be augmented by reverse complementation. If False, they will not be reverse complemented.
max_seq_shift – Maximum number of bases to shift the sequence for augmentation. This is normally a small value (< 10). If 0, sequences will not be augmented by shifting.
n_shuffles – Number of times to shuffle each background sequence to generate a background distribution.
- _load_alleles(variants: pandas.DataFrame) None [source]#
Load the alleles to substitute into the background
- _load_seqs(variants: pandas.DataFrame) None [source]#
Load sequences surrounding the variant position
- class grelu.data.dataset.PatternMarginalizeDataset(seqs: List[str] | pandas.DataFrame | numpy.ndarray, patterns: List[str], genome: str | None = None, seq_len: int | None = None, seed: int | None = None, rc: bool = False, n_shuffles: int = 1)[source]#
Bases:
torch.utils.data.Dataset
Dataset to marginalize the effect of given sequence patterns across shuffled background sequences. All sequences are stored in memory.
- Parameters:
seqs – DNA sequences as intervals, strings, integer encoded or one-hot encoded.
patterns – List of alleles or motif sequences to insert into the background sequences.
n_shuffles – Number of times to shuffle each background sequence to generate a background distribution.
genome – The name of the genome from which to read sequences. Only used if genomic intervals are supplied.
seed – Seed for random number generator
rc – If True, sequences will be augmented by reverse complementation. If False, they will not be reverse complemented.
- _load_seqs(seqs: pandas.DataFrame | List[str] | numpy.ndarray) None [source]#
Make the background sequences
- class grelu.data.dataset.ISMDataset(seqs: str | Sequence | pandas.DataFrame | numpy.ndarray, genome: str | None = None, drop_ref: bool = False, positions: List[int] | None = None)[source]#
Bases:
torch.utils.data.Dataset
Dataset to perform In silico mutagenesis (ISM)
- Parameters:
seqs – DNA sequences as intervals, strings, indices or one-hot.
genome – The name of the genome from which to read sequences. This is only needed if genomic intervals are supplied in seqs.
drop_ref – If True, the base that already exists at each position will not be included in the returned sequences.
positions – List of positions to mutate. If None, all positions will be mutated.
- class grelu.data.dataset.MotifScanDataset(seqs: str | Sequence | pandas.DataFrame | numpy.ndarray, motifs: List[str], genome: str | None = None, positions: List[int] | None = None)[source]#
Bases:
torch.utils.data.Dataset
Dataset to perform in silico motif scanning by inserting a motif at each position of a sequence.
- Parameters:
seqs – Background DNA sequences as intervals, strings, integer encoded or one-hot encoded.
motifs – A list of subsequences to insert into the background sequences.
genome – The name of the genome from which to read sequences. This is only needed if genomic intervals are supplied in seqs.
positions – List of positions at which to insert the motif. If None, all positions will be mutated.