grelu.lightning#
The LightningModel class.
Submodules#
Attributes#
Classes#
Dataset to perform In silico mutagenesis (ISM) |
|
A general Dataset class for DNA sequences and labels. All sequences and |
|
Dataset to perform in silico motif scanning by inserting a motif |
|
Dataset to marginalize the effect of given sequence patterns |
|
Dataset class to perform inference on sequence variants. |
|
Dataset to marginalize the effect of given variants |
|
Possion decomposition with multinomial specificity term. |
|
Metric class to calculate the MSE for each task. |
|
Metric class to calculate the best F1 score for each task. |
|
Metric class to calculate the Pearson correlation coefficient for each task. |
|
A 1x1 Conv layer that transforms the the number of channels in the input and then |
|
Wrapper for predictive sequence models |
|
Combine multiple LightningModel objects into a single object. |
Functions#
|
Convert a list of DNA sequences to one-hot encoded format. |
|
Return a function to aggregate values. |
|
Return a function to compare two values. |
|
Convert various kinds of inputs into a list |
Package Contents#
- class grelu.lightning.ISMDataset(seqs: str | Sequence | pandas.DataFrame | numpy.ndarray, genome: str | None = None, drop_ref: bool = False, positions: List[int] | None = None)[source]#
Bases:
torch.utils.data.Dataset
Dataset to perform In silico mutagenesis (ISM)
- Parameters:
seqs – DNA sequences as intervals, strings, indices or one-hot.
genome – The name of the genome from which to read sequences. This is only needed if genomic intervals are supplied in seqs.
drop_ref – If True, the base that already exists at each position will not be included in the returned sequences.
positions – List of positions to mutate. If None, all positions will be mutated.
- class grelu.lightning.LabeledSeqDataset(seqs: str | Sequence | pandas.DataFrame | numpy.ndarray, labels: numpy.ndarray, tasks: Sequence | pandas.DataFrame | None = None, seq_len: int | None = None, genome: str | None = None, end: str = 'both', rc: bool = False, max_seq_shift: int = 0, label_len: int | None = None, max_pair_shift: int = 0, label_aggfunc: str | Callable | None = None, bin_size: int | None = None, min_label_clip: int | None = None, max_label_clip: int | None = None, label_transform_func: str | Callable | None = None, seed: int | None = None, augment_mode: str = 'serial')[source]#
Bases:
torch.utils.data.Dataset
A general Dataset class for DNA sequences and labels. All sequences and labels will be stored in memory.
- Parameters:
seqs – DNA sequences as intervals, strings, indices or one-hot.
labels – A numpy array of shape (B, T, L) containing the labels.
tasks – A list of task names or a pandas dataframe containing task information. If a dataframe is supplied, the row indices should be the task names.
seq_len – Uniform expected length (in base pairs) for output sequences
genome – The name of the genome from which to read sequences. Only needed if genomic intervals are supplied.
end – Which end of the sequence to resize if necessary. Supported values are “left”, “right” and “both”.
rc – If True, sequences will be augmented by reverse complementation. If False, they will not be reverse complemented.
max_seq_shift – Maximum number of bases to shift the sequence for augmentation. This is normally a small value (< 10). If 0, sequences will not be augmented by shifting.
label_len – Uniform expected length (in base pairs) for output labels
max_pair_shift – Maximum number of bases to shift both the sequence and label for augmentation. If 0, sequence and label pairs will not be augmented by shifting.
label_aggfunc – Function to aggregate the labels over bin_size.
bin_size – Number of bases to aggregate in the label. Only used if label_aggfunc is not None. If None, it will be taken as equal to label_len.
min_label_clip – Minimum value for label
max_label_clip – Maximum value for label
label_transform_func – Function to transform label values.
seed – Random seed for reproducibility
augment_mode – “random” or “serial”
- _load_seqs(seqs: str | Sequence | pandas.DataFrame | numpy.ndarray) None [source]#
- _load_tasks(tasks: pandas.DataFrame | List) None [source]#
- _load_labels(labels: numpy.ndarray) None [source]#
- get_labels() numpy.ndarray [source]#
Return the labels as a numpy array of shape (B, T, L). This does not account for data augmentation.
- class grelu.lightning.MotifScanDataset(seqs: str | Sequence | pandas.DataFrame | numpy.ndarray, motifs: List[str], genome: str | None = None, positions: List[int] | None = None)[source]#
Bases:
torch.utils.data.Dataset
Dataset to perform in silico motif scanning by inserting a motif at each position of a sequence.
- Parameters:
seqs – Background DNA sequences as intervals, strings, integer encoded or one-hot encoded.
motifs – A list of subsequences to insert into the background sequences.
genome – The name of the genome from which to read sequences. This is only needed if genomic intervals are supplied in seqs.
positions – List of positions at which to insert the motif. If None, all positions will be mutated.
- class grelu.lightning.PatternMarginalizeDataset(seqs: List[str] | pandas.DataFrame | numpy.ndarray, patterns: List[str], genome: str | None = None, seq_len: int | None = None, seed: int | None = None, rc: bool = False, n_shuffles: int = 1)[source]#
Bases:
torch.utils.data.Dataset
Dataset to marginalize the effect of given sequence patterns across shuffled background sequences. All sequences are stored in memory.
- Parameters:
seqs – DNA sequences as intervals, strings, integer encoded or one-hot encoded.
patterns – List of alleles or motif sequences to insert into the background sequences.
n_shuffles – Number of times to shuffle each background sequence to generate a background distribution.
genome – The name of the genome from which to read sequences. Only used if genomic intervals are supplied.
seed – Seed for random number generator
rc – If True, sequences will be augmented by reverse complementation. If False, they will not be reverse complemented.
- _load_seqs(seqs: pandas.DataFrame | List[str] | numpy.ndarray) None [source]#
Make the background sequences
- class grelu.lightning.VariantDataset(variants: pandas.DataFrame, seq_len: int, genome: str | None = None, rc: bool = False, max_seq_shift: int = 0, frac_mutation: float = 0.0, n_mutated_seqs: int = 1, protect: List[int] | None = None, seed: int | None = None, augment_mode: str = 'serial')[source]#
Bases:
torch.utils.data.Dataset
Dataset class to perform inference on sequence variants.
- Parameters:
variants – pd.DataFrame with columns “chrom”, “pos”, “ref”, “alt”.
seq_len – Uniform expected length (in base pairs) for output sequences
genome – The name of the genome from which to read sequences.
rc – If True, sequences will be augmented by reverse complementation. If False, they will not be reverse complemented.
max_seq_shift – Maximum number of bases to shift the sequence for augmentation. This is normally a small value (< 10). If 0, sequences will not be augmented by shifting.
frac_mutation – Fraction of bases to randomly mutate for data augmentation.
protect – A list of positions to protect from mutation.
n_mutated_seqs – Number of mutated sequences to generate from each input sequence for data augmentation.
- _load_alleles(variants: pandas.DataFrame) None [source]#
- _load_seqs(variants: pandas.DataFrame) None [source]#
- class grelu.lightning.VariantMarginalizeDataset(variants: pandas.DataFrame, genome: str, seq_len: int, seed: int | None = None, rc: bool = False, max_seq_shift: int = 0, n_shuffles: int = 100)[source]#
Bases:
torch.utils.data.Dataset
Dataset to marginalize the effect of given variants across shuffled background sequences. All sequences are stored in memory.
- Parameters:
variants – A dataframe of sequence variants
genome – The name of the genome from which to read sequences. Only used if genomic intervals are supplied.
seed – Seed for random number generator
rc – If True, sequences will be augmented by reverse complementation. If False, they will not be reverse complemented.
max_seq_shift – Maximum number of bases to shift the sequence for augmentation. This is normally a small value (< 10). If 0, sequences will not be augmented by shifting.
n_shuffles – Number of times to shuffle each background sequence to generate a background distribution.
- _load_alleles(variants: pandas.DataFrame) None [source]#
Load the alleles to substitute into the background
- _load_seqs(variants: pandas.DataFrame) None [source]#
Load sequences surrounding the variant position
- class grelu.lightning.PoissonMultinomialLoss(total_weight: float = 1, eps: float = 1e-07, log_input: bool = True, reduction: str = 'mean', multinomial_axis: str = 'length')[source]#
Bases:
torch.nn.Module
Possion decomposition with multinomial specificity term.
- Parameters:
total_weight – Weight of the Poisson total term.
eps – Added small value to avoid log(0). Only needed if log_input = False.
log_input – If True, the input is transformed with torch.exp to produce predicted counts. Otherwise, the input is assumed to already represent predicted counts.
multinomial_axis – Either “length” or “task”, representing the axis along which the multinomial distribution should be calculated.
reduction – “mean” or “none”.
- class grelu.lightning.MSE(num_outputs: int = 1, average: bool = True)[source]#
Bases:
torchmetrics.Metric
Metric class to calculate the MSE for each task.
- Parameters:
num_outputs – Number of tasks
average – If true, return the average metric across tasks. Otherwise, return a separate value for each task
- As input to forward and update the metric accepts the following input:
preds: Predictions of shape (N, n_tasks, L) target: Ground truth labels (N, n_tasks, L)
- As output of forward and compute the metric returns the following output:
output: A tensor with the MSE
- class grelu.lightning.BestF1(num_labels: int = 1, average: bool = True)[source]#
Bases:
torchmetrics.Metric
Metric class to calculate the best F1 score for each task.
- Parameters:
num_labels – Number of tasks
average – If true, return the average metric across tasks. Otherwise, return a separate value for each task
- As input to forward and update the metric accepts the following input:
preds: Probabilities of shape (N, n_tasks, L) target: Ground truth labels of shape (N, n_tasks, L)
- As output of forward and compute the metric returns the following output:
output: A tensor with the best F1 score
- class grelu.lightning.PearsonCorrCoef(num_outputs: int = 1, average: bool = True)[source]#
Bases:
torchmetrics.Metric
Metric class to calculate the Pearson correlation coefficient for each task.
- Parameters:
num_outputs – Number of tasks
average – If true, return the average metric across tasks. Otherwise, return a separate value for each task
- As input to forward and update the metric accepts the following input:
preds: Predictions of shape (N, n_tasks, L) target: Ground truth labels of shape (N, n_tasks, L)
- As output of forward and compute the metric returns the following output:
output: A tensor with the Pearson coefficient.
- class grelu.lightning.ConvHead(n_tasks: int, in_channels: int, act_func: str | None = None, pool_func: str | None = None, norm: bool = False, dtype=None, device=None)[source]#
Bases:
torch.nn.Module
A 1x1 Conv layer that transforms the the number of channels in the input and then optionally pools along the length axis.
- Parameters:
n_tasks – Number of tasks (output channels)
in_channels – Number of channels in the input
norm – If True, batch normalization will be included.
act_func – Activation function for the convolutional layer
pool_func – Pooling function.
dtype – Data type for the layers.
device – Device for the layers.
- grelu.lightning.strings_to_one_hot(strings: str | List[str], add_batch_axis: bool = False) torch.Tensor [source]#
Convert a list of DNA sequences to one-hot encoded format.
- Parameters:
seqs – A DNA sequence or a list of DNA sequences.
add_batch_axis – If True, a batch axis will be included in the output for single sequences. If False, the output for a single sequence will be a 2-dimensional tensor.
- Returns:
The one-hot encoded DNA sequence(s).
- Raises:
AssertionError – If the input sequences are not of the same length,
or if the input is not a string or a list of strings. –
- grelu.lightning.get_aggfunc(func: str | Callable | None, tensor: bool = False) Callable [source]#
Return a function to aggregate values.
- Parameters:
func – A function or the name of a function. Supported names are “max”, “min”, “mean”, and “sum”. If a function is supplied, it will be returned unchanged.
tensor – If True, it is assumed that the inputs will be torch tensors. If False, it is assumed that the inputs will be numpy arrays.
- Returns:
The desired function.
- Raises:
NotImplementedError – If the input is neither a function nor a supported function name.
- grelu.lightning.get_compare_func(func: str | Callable | None, tensor: bool = False) Callable [source]#
Return a function to compare two values.
- Parameters:
func – A function or the name of a function. Supported names are “subtract”, “divide”, and “log2FC”. If a function is supplied, it will be returned unchanged. func cannot be None.
tensor – If True, it is assumed that the inputs will be torch tensors. If False, it is assumed that the inputs will be numpy arrays.
- Returns:
The desired function.
- Raises:
NotImplementedError – If the input is neither a function nor a supported function name.
- grelu.lightning.make_list(x: pandas.Series | numpy.ndarray | torch.Tensor | Sequence | int | float | str | None) list [source]#
Convert various kinds of inputs into a list
- Parameters:
x – An input value or sequence of values.
- Returns:
The input values in list format.
- class grelu.lightning.LightningModel(model_params: dict, train_params: dict = {}, data_params: dict = {})[source]#
Bases:
pytorch_lightning.LightningModule
Wrapper for predictive sequence models
- Parameters:
model_params – Dictionary of parameters specifying model architecture
train_params – Dictionary specifying training parameters
data_params – Dictionary specifying parameters of the training data. This is empty by default and will be filled at the time of training.
- update_metrics(metrics: dict, y_hat: torch.Tensor, y: torch.Tensor) None [source]#
Update metrics after each pass
- format_input(x: Tuple[torch.Tensor, torch.Tensor] | torch.Tensor) torch.Tensor [source]#
Extract the one-hot encoded sequence from the input
- forward(x: Tuple[torch.Tensor, torch.Tensor] | torch.Tensor | str | List[str], logits: bool = False) torch.Tensor [source]#
Forward pass
- test_step(batch: torch.Tensor, batch_idx: int) torch.Tensor [source]#
Calculate metrics after a single test step
- parse_devices(devices: str | int | List[int]) Tuple[str, str | List[int]] [source]#
Parses the devices argument and returns a tuple of accelerator and devices.
- Parameters:
devices – Either “cpu” or an integer or list of integers representing the indices of the GPUs for training.
- Returns:
A tuple of accelerator and devices.
- make_train_loader(dataset: Callable, batch_size: int | None = None, num_workers: int | None = None) Callable [source]#
Make dataloader for training
- make_test_loader(dataset: Callable, batch_size: int | None = None, num_workers: int | None = None) Callable [source]#
Make dataloader for validation and testing
- make_predict_loader(dataset: Callable, batch_size: int | None = None, num_workers: int | None = None) Callable [source]#
Make dataloader for prediction
- train_on_dataset(train_dataset: Callable, val_dataset: Callable, checkpoint_path: str | None = None)[source]#
Train model and optionally log metrics to wandb.
- Parameters:
train_dataset (Dataset) – Dataset object that yields training examples
val_dataset (Dataset) – Dataset object that yields training examples
checkpoint_path (str) – Path to model checkpoint from which to resume training. The optimizer will be set to its checkpointed state.
- Returns:
PyTorch Lightning Trainer
- change_head(n_tasks: int, final_pool_func: str) None [source]#
Build a new head with the desired number of tasks
- tune_on_dataset(train_dataset: Callable, val_dataset: Callable, final_act_func: str | None = None, final_pool_func: str | None = None, freeze_embedding: bool = False)[source]#
Fine-tune a pretrained model on a new dataset.
- Parameters:
train_dataset – Dataset object that yields training examples
val_dataset – Dataset object that yields training examples
final_act_func – Name of the final activation layer
final_pool_func – Name of the final pooling layer
freeze_embedding – If True, all the embedding layers of the pretrained model will be frozen and only the head will be trained.
- Returns:
PyTorch Lightning Trainer
- predict_on_seqs(x: str | List[str], device: str | int = 'cpu') numpy.ndarray [source]#
A simple function to return model predictions directly on a batch of a single batch of sequences in string format.
- Parameters:
x – DNA sequences as a string or list of strings.
device – Index of the device to use
- Returns:
A numpy array of predictions.
- predict_on_dataset(dataset: Callable, devices: int | str | List[int] = 'cpu', num_workers: int = 1, batch_size: int = 256, augment_aggfunc: str | Callable = 'mean', compare_func: str | Callable | None = None, return_df: bool = False, precision: str | None = None)[source]#
Predict for a dataset of sequences or variants
- Parameters:
dataset – Dataset object that yields one-hot encoded sequences
devices – Device IDs to use
num_workers – Number of workers for data loader
batch_size – Batch size for data loader
augment_aggfunc – Return the average prediction across all augmented versions of a sequence
compare_func – Return the alt/ref difference for variants
return_df – Return the predictions as a Pandas dataframe
precision – Precision of the trainer e.g. ‘32’ or ‘bf16-mixed’.
- Returns:
Model predictions as a numpy array or dataframe
- test_on_dataset(dataset: Callable, devices: str | int | List[int] = 'cpu', num_workers: int = 1, batch_size: int = 256, precision: str | None = None)[source]#
Run test loop for a dataset
- Parameters:
dataset – Dataset object that yields one-hot encoded sequences
devices – Device IDs to use for inference
num_workers – Number of workers for data loader
batch_size – Batch size for data loader
precision – Precision of the trainer e.g. ‘32’ or ‘bf16-mixed’.
- Returns:
Dataframe containing all calculated metrics on the test set.
- embed_on_dataset(dataset: Callable, device: str | int = 'cpu', num_workers: int = 1, batch_size: int = 256)[source]#
Return embeddings for a dataset of sequences
- Parameters:
dataset – Dataset object that yields one-hot encoded sequences
device – Device ID to use
num_workers – Number of workers for data loader
batch_size – Batch size for data loader
- Returns:
Numpy array of shape (B, T, L) containing embeddings.
- get_task_idxs(tasks: int | str | List[int] | List[str], key: str = 'name', invert: bool = False) int | List[int] [source]#
Given a task name or metadata entry, get the task index If integers are provided, return them unchanged
- Parameters:
tasks – A string corresponding to a task name or metadata entry, or an integer indicating the index of a task, or a list of strings/integers
key – key to model.data_params[“tasks”] in which the relevant task data is stored. “name” will be used by default.
invert – Get indices for all tasks except those listed in tasks
- Returns:
The index or indices of the corresponding task(s) in the model’s output.
- input_coord_to_output_bin(input_coord: int, start_pos: int = 0) int [source]#
Given the position of a base in the input, get the index of the corresponding bin in the model’s prediction.
- Parameters:
input_coord – Genomic coordinate of the input position
start_pos – Genomic coordinate of the first base in the input sequence
- Returns:
Index of the output bin containing the given position.
- output_bin_to_input_coord(output_bin: int, return_pos: str = 'start', start_pos: int = 0) int [source]#
Given the index of a bin in the output, get its corresponding start or end coordinate.
- Parameters:
output_bin – Index of the bin in the model’s output
return_pos – “start” or “end”
start_pos – Genomic coordinate of the first base in the input sequence
- Returns:
Genomic coordinate corresponding to the start (if return_pos = start) or end (if return_pos=end) of the bin.
- input_intervals_to_output_intervals(intervals: pandas.DataFrame) pandas.DataFrame [source]#
Given a dataframe containing intervals corresponding to the input sequences, return a dataframe containing intervals corresponding to the model output.
- Parameters:
intervals – A dataframe of genomic intervals
- Returns:
A dataframe containing the genomic intervals corresponding to the model output from each input interval.
- input_intervals_to_output_bins(intervals: pandas.DataFrame, start_pos: int = 0) None [source]#
Given a dataframe of genomic intervals, add columns indicating the indices of output bins that overlap the start and end of each interval.
- Parameters:
intervals – A dataframe of genomic intervals
start_pos – The start position of the sequence input to the model.
- Returns:start and end indices of the output bins corresponding
to each input interval.
- class grelu.lightning.LightningModelEnsemble(models: list, model_names: List[str] | None = None)[source]#
Bases:
pytorch_lightning.LightningModule
Combine multiple LightningModel objects into a single object. When predict_on_dataset is used, it will return the concatenated predictions from all the models in the order in which they were supplied.
- Parameters:
- _combine_tasks() None [source]#
Combine the task metadata of all the sub-models into self.data_params[“tasks”]
- predict_on_dataset(dataset: Callable, **kwargs) numpy.ndarray [source]#
This will return the concatenated predictions from all the constituent models, in the order in which they were supplied. Predictions will be concatenated along the task axis.
- get_task_idxs(tasks: str | int | List[str] | List[int], key: str = 'name') int | List[int] [source]#
Return the task index given the name of the task. Note that task names should be supplied with a prefix indicating the model number, so for instance if you want the predictions from the second model on astrocytes, the task name would be “{name of second model}_astrocytes”. If model names were not supplied to __init__, the task name would be “model1_astrocytes”.
- Parameters:
tasks – A string corresponding to a task name or metadata entry, or an integer indicating the index of a task, or a list of strings/integers
key – key to model.data_params[“tasks”] in which the relevant task data is stored. “name” will be used by default.
- Returns: An integer or list of integers representing the indices of the
tasks in the model output.