grelu.lightning#

The LightningModel class.

Submodules#

Attributes#

Classes#

ISMDataset

Dataset to perform In silico mutagenesis (ISM)

LabeledSeqDataset

A general Dataset class for DNA sequences and labels. All sequences and

MotifScanDataset

Dataset to perform in silico motif scanning by inserting a motif

PatternMarginalizeDataset

Dataset to marginalize the effect of given sequence patterns

VariantDataset

Dataset class to perform inference on sequence variants.

VariantMarginalizeDataset

Dataset to marginalize the effect of given variants

PoissonMultinomialLoss

Possion decomposition with multinomial specificity term.

MSE

Metric class to calculate the MSE for each task.

BestF1

Metric class to calculate the best F1 score for each task.

PearsonCorrCoef

Metric class to calculate the Pearson correlation coefficient for each task.

ConvHead

A 1x1 Conv layer that transforms the the number of channels in the input and then

LightningModel

Wrapper for predictive sequence models

LightningModelEnsemble

Combine multiple LightningModel objects into a single object.

Functions#

strings_to_one_hot(→ torch.Tensor)

Convert a list of DNA sequences to one-hot encoded format.

get_aggfunc(→ Callable)

Return a function to aggregate values.

get_compare_func(→ Callable)

Return a function to compare two values.

make_list(→ list)

Convert various kinds of inputs into a list

Package Contents#

class grelu.lightning.ISMDataset(seqs: str | Sequence | pandas.DataFrame | numpy.ndarray, genome: str | None = None, drop_ref: bool = False, positions: List[int] | None = None)[source]#

Bases: torch.utils.data.Dataset

Dataset to perform In silico mutagenesis (ISM)

Parameters:
  • seqs – DNA sequences as intervals, strings, indices or one-hot.

  • genome – The name of the genome from which to read sequences. This is only needed if genomic intervals are supplied in seqs.

  • drop_ref – If True, the base that already exists at each position will not be included in the returned sequences.

  • positions – List of positions to mutate. If None, all positions will be mutated.

positions[source]#
genome[source]#
drop_ref[source]#
n_alleles[source]#
n_seqs[source]#
seq_len[source]#
n_augmented[source]#
_load_seqs(seqs) None[source]#
__len__() int[source]#
__getitem__(idx: int, return_compressed=False) torch.Tensor[source]#
class grelu.lightning.LabeledSeqDataset(seqs: str | Sequence | pandas.DataFrame | numpy.ndarray, labels: numpy.ndarray, tasks: Sequence | pandas.DataFrame | None = None, seq_len: int | None = None, genome: str | None = None, end: str = 'both', rc: bool = False, max_seq_shift: int = 0, label_len: int | None = None, max_pair_shift: int = 0, label_aggfunc: str | Callable | None = None, bin_size: int | None = None, min_label_clip: int | None = None, max_label_clip: int | None = None, label_transform_func: str | Callable | None = None, seed: int | None = None, augment_mode: str = 'serial')[source]#

Bases: torch.utils.data.Dataset

A general Dataset class for DNA sequences and labels. All sequences and labels will be stored in memory.

Parameters:
  • seqs – DNA sequences as intervals, strings, indices or one-hot.

  • labels – A numpy array of shape (B, T, L) containing the labels.

  • tasks – A list of task names or a pandas dataframe containing task information. If a dataframe is supplied, the row indices should be the task names.

  • seq_len – Uniform expected length (in base pairs) for output sequences

  • genome – The name of the genome from which to read sequences. Only needed if genomic intervals are supplied.

  • end – Which end of the sequence to resize if necessary. Supported values are “left”, “right” and “both”.

  • rc – If True, sequences will be augmented by reverse complementation. If False, they will not be reverse complemented.

  • max_seq_shift – Maximum number of bases to shift the sequence for augmentation. This is normally a small value (< 10). If 0, sequences will not be augmented by shifting.

  • label_len – Uniform expected length (in base pairs) for output labels

  • max_pair_shift – Maximum number of bases to shift both the sequence and label for augmentation. If 0, sequence and label pairs will not be augmented by shifting.

  • label_aggfunc – Function to aggregate the labels over bin_size.

  • bin_size – Number of bases to aggregate in the label. Only used if label_aggfunc is not None. If None, it will be taken as equal to label_len.

  • min_label_clip – Minimum value for label

  • max_label_clip – Maximum value for label

  • label_transform_func – Function to transform label values.

  • seed – Random seed for reproducibility

  • augment_mode – “random” or “serial”

end[source]#
genome[source]#
min_label_clip[source]#
max_label_clip[source]#
label_transform_func[source]#
seq_len[source]#
label_len[source]#
label_aggfunc[source]#
bin_size[source]#
rc[source]#
max_seq_shift[source]#
max_pair_shift[source]#
padded_seq_len[source]#
padded_label_len[source]#
n_seqs[source]#
n_tasks[source]#
label_transform[source]#
augmenter[source]#
n_augmented[source]#
n_alleles = 1[source]#
predict = False[source]#
_load_seqs(seqs: str | Sequence | pandas.DataFrame | numpy.ndarray) None[source]#
_load_tasks(tasks: pandas.DataFrame | List) None[source]#
_load_labels(labels: numpy.ndarray) None[source]#
__len__() int[source]#
get_labels() numpy.ndarray[source]#

Return the labels as a numpy array of shape (B, T, L). This does not account for data augmentation.

__getitem__(idx: int) torch.Tensor | Tuple[torch.Tensor, torch.Tensor][source]#
class grelu.lightning.MotifScanDataset(seqs: str | Sequence | pandas.DataFrame | numpy.ndarray, motifs: List[str], genome: str | None = None, positions: List[int] | None = None)[source]#

Bases: torch.utils.data.Dataset

Dataset to perform in silico motif scanning by inserting a motif at each position of a sequence.

Parameters:
  • seqs – Background DNA sequences as intervals, strings, integer encoded or one-hot encoded.

  • motifs – A list of subsequences to insert into the background sequences.

  • genome – The name of the genome from which to read sequences. This is only needed if genomic intervals are supplied in seqs.

  • positions – List of positions at which to insert the motif. If None, all positions will be mutated.

positions[source]#
genome[source]#
motifs[source]#
max_motif_len[source]#
n_alleles[source]#
n_seqs[source]#
seq_len[source]#
n_augmented[source]#
_load_seqs(seqs)[source]#
__len__() int[source]#
__getitem__(idx: int, return_compressed=False) torch.Tensor[source]#
class grelu.lightning.PatternMarginalizeDataset(seqs: List[str] | pandas.DataFrame | numpy.ndarray, patterns: List[str], genome: str | None = None, seq_len: int | None = None, seed: int | None = None, rc: bool = False, n_shuffles: int = 1)[source]#

Bases: torch.utils.data.Dataset

Dataset to marginalize the effect of given sequence patterns across shuffled background sequences. All sequences are stored in memory.

Parameters:
  • seqs – DNA sequences as intervals, strings, integer encoded or one-hot encoded.

  • patterns – List of alleles or motif sequences to insert into the background sequences.

  • n_shuffles – Number of times to shuffle each background sequence to generate a background distribution.

  • genome – The name of the genome from which to read sequences. Only used if genomic intervals are supplied.

  • seed – Seed for random number generator

  • rc – If True, sequences will be augmented by reverse complementation. If False, they will not be reverse complemented.

genome[source]#
seed[source]#
seq_len[source]#
rc[source]#
n_shuffles[source]#
augmenter[source]#
n_augmented[source]#
bg = None[source]#
curr_seq_idx = None[source]#
_load_alleles(patterns: List[str]) None[source]#
_load_seqs(seqs: pandas.DataFrame | List[str] | numpy.ndarray) None[source]#

Make the background sequences

__update__(idx: int) None[source]#

Update the current background

__len__() int[source]#
__getitem__(idx: int) torch.Tensor[source]#
class grelu.lightning.VariantDataset(variants: pandas.DataFrame, seq_len: int, genome: str | None = None, rc: bool = False, max_seq_shift: int = 0, frac_mutation: float = 0.0, n_mutated_seqs: int = 1, protect: List[int] | None = None, seed: int | None = None, augment_mode: str = 'serial')[source]#

Bases: torch.utils.data.Dataset

Dataset class to perform inference on sequence variants.

Parameters:
  • variants – pd.DataFrame with columns “chrom”, “pos”, “ref”, “alt”.

  • seq_len – Uniform expected length (in base pairs) for output sequences

  • genome – The name of the genome from which to read sequences.

  • rc – If True, sequences will be augmented by reverse complementation. If False, they will not be reverse complemented.

  • max_seq_shift – Maximum number of bases to shift the sequence for augmentation. This is normally a small value (< 10). If 0, sequences will not be augmented by shifting.

  • frac_mutation – Fraction of bases to randomly mutate for data augmentation.

  • protect – A list of positions to protect from mutation.

  • n_mutated_seqs – Number of mutated sequences to generate from each input sequence for data augmentation.

genome[source]#
seq_len[source]#
rc[source]#
max_seq_shift[source]#
frac_mutated_bases[source]#
n_mutated_bases[source]#
n_mutated_seqs[source]#
n_alleles = 2[source]#
n_seqs[source]#
augmenter[source]#
n_augmented[source]#
_load_alleles(variants: pandas.DataFrame) None[source]#
_load_seqs(variants: pandas.DataFrame) None[source]#
__len__() int[source]#
__getitem__(idx: int) torch.Tensor[source]#
class grelu.lightning.VariantMarginalizeDataset(variants: pandas.DataFrame, genome: str, seq_len: int, seed: int | None = None, rc: bool = False, max_seq_shift: int = 0, n_shuffles: int = 100)[source]#

Bases: torch.utils.data.Dataset

Dataset to marginalize the effect of given variants across shuffled background sequences. All sequences are stored in memory.

Parameters:
  • variants – A dataframe of sequence variants

  • genome – The name of the genome from which to read sequences. Only used if genomic intervals are supplied.

  • seed – Seed for random number generator

  • rc – If True, sequences will be augmented by reverse complementation. If False, they will not be reverse complemented.

  • max_seq_shift – Maximum number of bases to shift the sequence for augmentation. This is normally a small value (< 10). If 0, sequences will not be augmented by shifting.

  • n_shuffles – Number of times to shuffle each background sequence to generate a background distribution.

genome[source]#
seed[source]#
seq_len[source]#
rc = False[source]#
max_seq_shift = 0[source]#
n_shuffles[source]#
augmenter[source]#
n_augmented[source]#
bg = None[source]#
curr_seq_idx = None[source]#
_load_alleles(variants: pandas.DataFrame) None[source]#

Load the alleles to substitute into the background

_load_seqs(variants: pandas.DataFrame) None[source]#

Load sequences surrounding the variant position

__update__(idx: int) None[source]#

Update the current background

__len__() int[source]#
__getitem__(idx: int) torch.Tensor[source]#
class grelu.lightning.PoissonMultinomialLoss(total_weight: float = 1, eps: float = 1e-07, log_input: bool = True, reduction: str = 'mean', multinomial_axis: str = 'length')[source]#

Bases: torch.nn.Module

Possion decomposition with multinomial specificity term.

Parameters:
  • total_weight – Weight of the Poisson total term.

  • eps – Added small value to avoid log(0). Only needed if log_input = False.

  • log_input – If True, the input is transformed with torch.exp to produce predicted counts. Otherwise, the input is assumed to already represent predicted counts.

  • multinomial_axis – Either “length” or “task”, representing the axis along which the multinomial distribution should be calculated.

  • reduction – “mean” or “none”.

eps[source]#
total_weight[source]#
log_input[source]#
reduction[source]#
forward(input: torch.Tensor, target: torch.Tensor) torch.Tensor[source]#

Loss computation

Parameters:
  • input – Tensor of shape (B, T, L)

  • target – Tensor of shape (B, T, L)

Returns:

Loss value

class grelu.lightning.MSE(num_outputs: int = 1, average: bool = True)[source]#

Bases: torchmetrics.Metric

Metric class to calculate the MSE for each task.

Parameters:
  • num_outputs – Number of tasks

  • average – If true, return the average metric across tasks. Otherwise, return a separate value for each task

As input to forward and update the metric accepts the following input:

preds: Predictions of shape (N, n_tasks, L) target: Ground truth labels (N, n_tasks, L)

As output of forward and compute the metric returns the following output:

output: A tensor with the MSE

average[source]#
update(preds: torch.Tensor, target: torch.Tensor) None[source]#
compute() torch.Tensor[source]#
class grelu.lightning.BestF1(num_labels: int = 1, average: bool = True)[source]#

Bases: torchmetrics.Metric

Metric class to calculate the best F1 score for each task.

Parameters:
  • num_labels – Number of tasks

  • average – If true, return the average metric across tasks. Otherwise, return a separate value for each task

As input to forward and update the metric accepts the following input:

preds: Probabilities of shape (N, n_tasks, L) target: Ground truth labels of shape (N, n_tasks, L)

As output of forward and compute the metric returns the following output:

output: A tensor with the best F1 score

average[source]#
update(preds: torch.Tensor, target: torch.Tensor) None[source]#
compute() torch.Tensor[source]#
class grelu.lightning.PearsonCorrCoef(num_outputs: int = 1, average: bool = True)[source]#

Bases: torchmetrics.Metric

Metric class to calculate the Pearson correlation coefficient for each task.

Parameters:
  • num_outputs – Number of tasks

  • average – If true, return the average metric across tasks. Otherwise, return a separate value for each task

As input to forward and update the metric accepts the following input:

preds: Predictions of shape (N, n_tasks, L) target: Ground truth labels of shape (N, n_tasks, L)

As output of forward and compute the metric returns the following output:

output: A tensor with the Pearson coefficient.

pearson[source]#
average[source]#
update(preds: torch.Tensor, target: torch.Tensor) None[source]#
compute() torch.Tensor[source]#
reset() None[source]#
class grelu.lightning.ConvHead(n_tasks: int, in_channels: int, act_func: str | None = None, pool_func: str | None = None, norm: bool = False, dtype=None, device=None)[source]#

Bases: torch.nn.Module

A 1x1 Conv layer that transforms the the number of channels in the input and then optionally pools along the length axis.

Parameters:
  • n_tasks – Number of tasks (output channels)

  • in_channels – Number of channels in the input

  • norm – If True, batch normalization will be included.

  • act_func – Activation function for the convolutional layer

  • pool_func – Pooling function.

  • dtype – Data type for the layers.

  • device – Device for the layers.

n_tasks[source]#
in_channels[source]#
act_func[source]#
pool_func[source]#
norm[source]#
channel_transform[source]#
pool[source]#
forward(x: torch.Tensor) torch.Tensor[source]#
Parameters:

x – Input data.

grelu.lightning.strings_to_one_hot(strings: str | List[str], add_batch_axis: bool = False) torch.Tensor[source]#

Convert a list of DNA sequences to one-hot encoded format.

Parameters:
  • seqs – A DNA sequence or a list of DNA sequences.

  • add_batch_axis – If True, a batch axis will be included in the output for single sequences. If False, the output for a single sequence will be a 2-dimensional tensor.

Returns:

The one-hot encoded DNA sequence(s).

Raises:
  • AssertionError – If the input sequences are not of the same length,

  • or if the input is not a string or a list of strings.

grelu.lightning.get_aggfunc(func: str | Callable | None, tensor: bool = False) Callable[source]#

Return a function to aggregate values.

Parameters:
  • func – A function or the name of a function. Supported names are “max”, “min”, “mean”, and “sum”. If a function is supplied, it will be returned unchanged.

  • tensor – If True, it is assumed that the inputs will be torch tensors. If False, it is assumed that the inputs will be numpy arrays.

Returns:

The desired function.

Raises:

NotImplementedError – If the input is neither a function nor a supported function name.

grelu.lightning.get_compare_func(func: str | Callable | None, tensor: bool = False) Callable[source]#

Return a function to compare two values.

Parameters:
  • func – A function or the name of a function. Supported names are “subtract”, “divide”, and “log2FC”. If a function is supplied, it will be returned unchanged. func cannot be None.

  • tensor – If True, it is assumed that the inputs will be torch tensors. If False, it is assumed that the inputs will be numpy arrays.

Returns:

The desired function.

Raises:

NotImplementedError – If the input is neither a function nor a supported function name.

grelu.lightning.make_list(x: pandas.Series | numpy.ndarray | torch.Tensor | Sequence | int | float | str | None) list[source]#

Convert various kinds of inputs into a list

Parameters:

x – An input value or sequence of values.

Returns:

The input values in list format.

grelu.lightning.default_train_params[source]#
class grelu.lightning.LightningModel(model_params: dict, train_params: dict = {}, data_params: dict = {})[source]#

Bases: pytorch_lightning.LightningModule

Wrapper for predictive sequence models

Parameters:
  • model_params – Dictionary of parameters specifying model architecture

  • train_params – Dictionary specifying training parameters

  • data_params – Dictionary specifying parameters of the training data. This is empty by default and will be filled at the time of training.

model_params[source]#
train_params[source]#
data_params[source]#
build_model() None[source]#

Build a model from parameter dictionary

initialize_loss() None[source]#

Create the specified loss function.

initialize_activation() None[source]#

Add a task-specific activation function to the model.

initialize_metrics()[source]#

Initialize the appropriate metrics for the given task.

update_metrics(metrics: dict, y_hat: torch.Tensor, y: torch.Tensor) None[source]#

Update metrics after each pass

format_input(x: Tuple[torch.Tensor, torch.Tensor] | torch.Tensor) torch.Tensor[source]#

Extract the one-hot encoded sequence from the input

forward(x: Tuple[torch.Tensor, torch.Tensor] | torch.Tensor | str | List[str], logits: bool = False) torch.Tensor[source]#

Forward pass

training_step(batch: torch.Tensor, batch_idx: int) torch.Tensor[source]#
validation_step(batch: torch.Tensor, batch_idx: int) torch.Tensor[source]#
on_validation_epoch_end()[source]#

Calculate metrics for entire validation set

test_step(batch: torch.Tensor, batch_idx: int) torch.Tensor[source]#

Calculate metrics after a single test step

on_test_epoch_end() None[source]#

Calculate metrics for entire test set

configure_optimizers() None[source]#

Configure oprimizer for training

count_params() int[source]#

Number of gradient enabled parameters in the model

parse_devices(devices: str | int | List[int]) Tuple[str, str | List[int]][source]#

Parses the devices argument and returns a tuple of accelerator and devices.

Parameters:

devices – Either “cpu” or an integer or list of integers representing the indices of the GPUs for training.

Returns:

A tuple of accelerator and devices.

parse_logger() str[source]#

Parses the name of the logger supplied in train_params.

add_transform(prediction_transform: Callable) None[source]#

Add a prediction transform

reset_transform() None[source]#

Remove a prediction transform

make_train_loader(dataset: Callable, batch_size: int | None = None, num_workers: int | None = None) Callable[source]#

Make dataloader for training

make_test_loader(dataset: Callable, batch_size: int | None = None, num_workers: int | None = None) Callable[source]#

Make dataloader for validation and testing

make_predict_loader(dataset: Callable, batch_size: int | None = None, num_workers: int | None = None) Callable[source]#

Make dataloader for prediction

train_on_dataset(train_dataset: Callable, val_dataset: Callable, checkpoint_path: str | None = None)[source]#

Train model and optionally log metrics to wandb.

Parameters:
  • train_dataset (Dataset) – Dataset object that yields training examples

  • val_dataset (Dataset) – Dataset object that yields training examples

  • checkpoint_path (str) – Path to model checkpoint from which to resume training. The optimizer will be set to its checkpointed state.

Returns:

PyTorch Lightning Trainer

_get_dataset_attrs(dataset: Callable) None[source]#

Read data parameters from a dataset object

change_head(n_tasks: int, final_pool_func: str) None[source]#

Build a new head with the desired number of tasks

tune_on_dataset(train_dataset: Callable, val_dataset: Callable, final_act_func: str | None = None, final_pool_func: str | None = None, freeze_embedding: bool = False)[source]#

Fine-tune a pretrained model on a new dataset.

Parameters:
  • train_dataset – Dataset object that yields training examples

  • val_dataset – Dataset object that yields training examples

  • final_act_func – Name of the final activation layer

  • final_pool_func – Name of the final pooling layer

  • freeze_embedding – If True, all the embedding layers of the pretrained model will be frozen and only the head will be trained.

Returns:

PyTorch Lightning Trainer

on_save_checkpoint(checkpoint: dict) None[source]#
predict_on_seqs(x: str | List[str], device: str | int = 'cpu') numpy.ndarray[source]#

A simple function to return model predictions directly on a batch of a single batch of sequences in string format.

Parameters:
  • x – DNA sequences as a string or list of strings.

  • device – Index of the device to use

Returns:

A numpy array of predictions.

predict_on_dataset(dataset: Callable, devices: int | str | List[int] = 'cpu', num_workers: int = 1, batch_size: int = 256, augment_aggfunc: str | Callable = 'mean', compare_func: str | Callable | None = None, return_df: bool = False, precision: str | None = None)[source]#

Predict for a dataset of sequences or variants

Parameters:
  • dataset – Dataset object that yields one-hot encoded sequences

  • devices – Device IDs to use

  • num_workers – Number of workers for data loader

  • batch_size – Batch size for data loader

  • augment_aggfunc – Return the average prediction across all augmented versions of a sequence

  • compare_func – Return the alt/ref difference for variants

  • return_df – Return the predictions as a Pandas dataframe

  • precision – Precision of the trainer e.g. ‘32’ or ‘bf16-mixed’.

Returns:

Model predictions as a numpy array or dataframe

test_on_dataset(dataset: Callable, devices: str | int | List[int] = 'cpu', num_workers: int = 1, batch_size: int = 256, precision: str | None = None)[source]#

Run test loop for a dataset

Parameters:
  • dataset – Dataset object that yields one-hot encoded sequences

  • devices – Device IDs to use for inference

  • num_workers – Number of workers for data loader

  • batch_size – Batch size for data loader

  • precision – Precision of the trainer e.g. ‘32’ or ‘bf16-mixed’.

Returns:

Dataframe containing all calculated metrics on the test set.

embed_on_dataset(dataset: Callable, device: str | int = 'cpu', num_workers: int = 1, batch_size: int = 256)[source]#

Return embeddings for a dataset of sequences

Parameters:
  • dataset – Dataset object that yields one-hot encoded sequences

  • device – Device ID to use

  • num_workers – Number of workers for data loader

  • batch_size – Batch size for data loader

Returns:

Numpy array of shape (B, T, L) containing embeddings.

get_task_idxs(tasks: int | str | List[int] | List[str], key: str = 'name', invert: bool = False) int | List[int][source]#

Given a task name or metadata entry, get the task index If integers are provided, return them unchanged

Parameters:
  • tasks – A string corresponding to a task name or metadata entry, or an integer indicating the index of a task, or a list of strings/integers

  • key – key to model.data_params[“tasks”] in which the relevant task data is stored. “name” will be used by default.

  • invert – Get indices for all tasks except those listed in tasks

Returns:

The index or indices of the corresponding task(s) in the model’s output.

input_coord_to_output_bin(input_coord: int, start_pos: int = 0) int[source]#

Given the position of a base in the input, get the index of the corresponding bin in the model’s prediction.

Parameters:
  • input_coord – Genomic coordinate of the input position

  • start_pos – Genomic coordinate of the first base in the input sequence

Returns:

Index of the output bin containing the given position.

output_bin_to_input_coord(output_bin: int, return_pos: str = 'start', start_pos: int = 0) int[source]#

Given the index of a bin in the output, get its corresponding start or end coordinate.

Parameters:
  • output_bin – Index of the bin in the model’s output

  • return_pos – “start” or “end”

  • start_pos – Genomic coordinate of the first base in the input sequence

Returns:

Genomic coordinate corresponding to the start (if return_pos = start) or end (if return_pos=end) of the bin.

input_intervals_to_output_intervals(intervals: pandas.DataFrame) pandas.DataFrame[source]#

Given a dataframe containing intervals corresponding to the input sequences, return a dataframe containing intervals corresponding to the model output.

Parameters:

intervals – A dataframe of genomic intervals

Returns:

A dataframe containing the genomic intervals corresponding to the model output from each input interval.

input_intervals_to_output_bins(intervals: pandas.DataFrame, start_pos: int = 0) None[source]#

Given a dataframe of genomic intervals, add columns indicating the indices of output bins that overlap the start and end of each interval.

Parameters:
  • intervals – A dataframe of genomic intervals

  • start_pos – The start position of the sequence input to the model.

Returns:start and end indices of the output bins corresponding

to each input interval.

class grelu.lightning.LightningModelEnsemble(models: list, model_names: List[str] | None = None)[source]#

Bases: pytorch_lightning.LightningModule

Combine multiple LightningModel objects into a single object. When predict_on_dataset is used, it will return the concatenated predictions from all the models in the order in which they were supplied.

Parameters:
  • models (list) – A list of multiple LightningModel objects

  • model_names (list) – A name for each model. This will be prefixed to the names of the individual tasks predicted by the model. If not supplied, the models will be named “model0”, “model1”, etc.

models[source]#
model_names[source]#
model_params[source]#
data_params[source]#
_combine_tasks() None[source]#

Combine the task metadata of all the sub-models into self.data_params[“tasks”]

forward(x: torch.Tensor) torch.Tensor[source]#

Forward Pass.

predict_on_dataset(dataset: Callable, **kwargs) numpy.ndarray[source]#

This will return the concatenated predictions from all the constituent models, in the order in which they were supplied. Predictions will be concatenated along the task axis.

get_task_idxs(tasks: str | int | List[str] | List[int], key: str = 'name') int | List[int][source]#

Return the task index given the name of the task. Note that task names should be supplied with a prefix indicating the model number, so for instance if you want the predictions from the second model on astrocytes, the task name would be “{name of second model}_astrocytes”. If model names were not supplied to __init__, the task name would be “model1_astrocytes”.

Parameters:
  • tasks – A string corresponding to a task name or metadata entry, or an integer indicating the index of a task, or a list of strings/integers

  • key – key to model.data_params[“tasks”] in which the relevant task data is stored. “name” will be used by default.

Returns: An integer or list of integers representing the indices of the

tasks in the model output.