grelu.resources#

Attributes#

Classes#

LightningModel

Wrapper for predictive sequence models

Functions#

get_meme_file_path(meme_motif_db)

Return the path to a MEME file.

get_default_config_file()

get_blacklist_file(genome)

_check_wandb([host])

projects([host])

artifacts(project[, host, type_is, type_contains])

models(project[, host])

datasets(project[, host])

runs(project[, host, field, filters])

get_artifact(name, project[, alias])

get_dataset_by_model(model_name, project[, alias])

get_model_by_dataset(dataset_name, project[, alias])

load_model(project, model_name[, alias, checkpoint_file])

Package Contents#

class grelu.resources.LightningModel(model_params: dict, train_params: dict = {}, data_params: dict = {})[source]#

Bases: pytorch_lightning.LightningModule

Wrapper for predictive sequence models

Parameters:
  • model_params – Dictionary of parameters specifying model architecture

  • train_params – Dictionary specifying training parameters

  • data_params – Dictionary specifying parameters of the training data. This is empty by default and will be filled at the time of training.

build_model() None[source]#

Build a model from parameter dictionary

initialize_loss() None[source]#

Create the specified loss function.

initialize_activation() None[source]#

Add a task-specific activation function to the model.

initialize_metrics()[source]#

Initialize the appropriate metrics for the given task.

update_metrics(metrics: dict, y_hat: torch.Tensor, y: torch.Tensor) None[source]#

Update metrics after each pass

format_input(x: Tuple[torch.Tensor, torch.Tensor] | torch.Tensor) torch.Tensor[source]#

Extract the one-hot encoded sequence from the input

forward(x: Tuple[torch.Tensor, torch.Tensor] | torch.Tensor | str | List[str], logits: bool = False) torch.Tensor[source]#

Forward pass

training_step(batch: torch.Tensor, batch_idx: int) torch.Tensor[source]#
validation_step(batch: torch.Tensor, batch_idx: int) torch.Tensor[source]#
on_validation_epoch_end()[source]#

Calculate metrics for entire validation set

test_step(batch: torch.Tensor, batch_idx: int) torch.Tensor[source]#

Calculate metrics after a single test step

on_test_epoch_end() None[source]#

Calculate metrics for entire test set

configure_optimizers() None[source]#

Configure oprimizer for training

count_params() int[source]#

Number of gradient enabled parameters in the model

parse_devices(devices: str | int | List[int]) Tuple[str, str | List[int]][source]#

Parses the devices argument and returns a tuple of accelerator and devices.

Parameters:

devices – Either “cpu” or an integer or list of integers representing the indices of the GPUs for training.

Returns:

A tuple of accelerator and devices.

parse_logger() str[source]#

Parses the name of the logger supplied in train_params.

add_transform(prediction_transform: Callable) None[source]#

Add a prediction transform

reset_transform() None[source]#

Remove a prediction transform

make_train_loader(dataset: Callable, batch_size: int | None = None, num_workers: int | None = None) Callable[source]#

Make dataloader for training

make_test_loader(dataset: Callable, batch_size: int | None = None, num_workers: int | None = None) Callable[source]#

Make dataloader for validation and testing

make_predict_loader(dataset: Callable, batch_size: int | None = None, num_workers: int | None = None) Callable[source]#

Make dataloader for prediction

train_on_dataset(train_dataset: Callable, val_dataset: Callable, checkpoint_path: str | None = None)[source]#

Train model and optionally log metrics to wandb.

Parameters:
  • train_dataset (Dataset) – Dataset object that yields training examples

  • val_dataset (Dataset) – Dataset object that yields training examples

  • checkpoint_path (str) – Path to model checkpoint from which to resume training. The optimizer will be set to its checkpointed state.

Returns:

PyTorch Lightning Trainer

_get_dataset_attrs(dataset: Callable) None[source]#

Read data parameters from a dataset object

change_head(n_tasks: int, final_pool_func: str) None[source]#

Build a new head with the desired number of tasks

tune_on_dataset(train_dataset: Callable, val_dataset: Callable, final_act_func: str | None = None, final_pool_func: str | None = None, freeze_embedding: bool = False)[source]#

Fine-tune a pretrained model on a new dataset.

Parameters:
  • train_dataset – Dataset object that yields training examples

  • val_dataset – Dataset object that yields training examples

  • final_act_func – Name of the final activation layer

  • final_pool_func – Name of the final pooling layer

  • freeze_embedding – If True, all the embedding layers of the pretrained model will be frozen and only the head will be trained.

Returns:

PyTorch Lightning Trainer

on_save_checkpoint(checkpoint: dict) None[source]#
predict_on_seqs(x: str | List[str], device: str | int = 'cpu') numpy.ndarray[source]#

A simple function to return model predictions directly on a batch of a single batch of sequences in string format.

Parameters:
  • x – DNA sequences as a string or list of strings.

  • device – Index of the device to use

Returns:

A numpy array of predictions.

predict_on_dataset(dataset: Callable, devices: int | str | List[int] = 'cpu', num_workers: int = 1, batch_size: int = 256, augment_aggfunc: str | Callable = 'mean', compare_func: str | Callable | None = None, return_df: bool = False)[source]#

Predict for a dataset of sequences or variants

Parameters:
  • dataset – Dataset object that yields one-hot encoded sequences

  • devices – Device IDs to use

  • num_workers – Number of workers for data loader

  • batch_size – Batch size for data loader

  • augment_aggfunc – Return the average prediction across all augmented versions of a sequence

  • compare_func – Return the alt/ref difference for variants

  • return_df – Return the predictions as a Pandas dataframe

Returns:

Model predictions as a numpy array or dataframe

test_on_dataset(dataset: Callable, devices: str | int | List[int] = 'cpu', num_workers: int = 1, batch_size: int = 256)[source]#

Run test loop for a dataset

Parameters:
  • dataset – Dataset object that yields one-hot encoded sequences

  • devices – Device IDs to use for inference

  • num_workers – Number of workers for data loader

  • batch_size – Batch size for data loader

Returns:

Dataframe containing all calculated metrics on the test set.

embed_on_dataset(dataset: Callable, devices: str | int | List[int] = 'cpu', num_workers: int = 1, batch_size: int = 256)[source]#

Return embeddings for a dataset of sequences

Parameters:
  • dataset – Dataset object that yields one-hot encoded sequences

  • devices – Device IDs to use

  • num_workers – Number of workers for data loader

  • batch_size – Batch size for data loader

Returns:

Numpy array of shape (B, T, L) containing embeddings.

get_task_idxs(tasks: int | str | List[int] | List[str], key: str = 'name', invert: bool = False) int | List[int][source]#

Given a task name or metadata entry, get the task index If integers are provided, return them unchanged

Parameters:
  • tasks – A string corresponding to a task name or metadata entry, or an integer indicating the index of a task, or a list of strings/integers

  • key – key to model.data_params[“tasks”] in which the relevant task data is stored. “name” will be used by default.

  • invert – Get indices for all tasks except those listed in tasks

Returns:

The index or indices of the corresponding task(s) in the model’s output.

input_coord_to_output_bin(input_coord: int, start_pos: int = 0) int[source]#

Given the position of a base in the input, get the index of the corresponding bin in the model’s prediction.

Parameters:
  • input_coord – Genomic coordinate of the input position

  • start_pos – Genomic coordinate of the first base in the input sequence

Returns:

Index of the output bin containing the given position.

output_bin_to_input_coord(output_bin: int, return_pos: str = 'start', start_pos: int = 0) int[source]#

Given the index of a bin in the output, get its corresponding start or end coordinate.

Parameters:
  • output_bin – Index of the bin in the model’s output

  • return_pos – “start” or “end”

  • start_pos – Genomic coordinate of the first base in the input sequence

Returns:

Genomic coordinate corresponding to the start (if return_pos = start) or end (if return_pos=end) of the bin.

input_intervals_to_output_intervals(intervals: pandas.DataFrame) pandas.DataFrame[source]#

Given a dataframe containing intervals corresponding to the input sequences, return a dataframe containing intervals corresponding to the model output.

Parameters:

intervals – A dataframe of genomic intervals

Returns:

A dataframe containing the genomic intervals corresponding to the model output from each input interval.

input_intervals_to_output_bins(intervals: pandas.DataFrame, start_pos: int = 0) None[source]#

Given a dataframe of genomic intervals, add columns indicating the indices of output bins that overlap the start and end of each interval.

Parameters:
  • intervals – A dataframe of genomic intervals

  • start_pos – The start position of the sequence input to the model.

Returns:start and end indices of the output bins corresponding

to each input interval.

grelu.resources.DEFAULT_WANDB_ENTITY = 'grelu'[source]#
grelu.resources.DEFAULT_WANDB_HOST = 'https://api.wandb.ai'[source]#
grelu.resources.get_meme_file_path(meme_motif_db)[source]#

Return the path to a MEME file.

Parameters:

meme_motif_db (str) – Path to a MEME file or the name of a MEME file included with gReLU. Current name options are “jaspar” and “consensus”.

Returns:

Path to the specified MEME file.

Return type:

(str)

grelu.resources.get_default_config_file()[source]#
grelu.resources.get_blacklist_file(genome)[source]#
grelu.resources._check_wandb(host=DEFAULT_WANDB_HOST)[source]#
grelu.resources.projects(host=DEFAULT_WANDB_HOST)[source]#
grelu.resources.artifacts(project, host=DEFAULT_WANDB_HOST, type_is=None, type_contains=None)[source]#
grelu.resources.models(project, host=DEFAULT_WANDB_HOST)[source]#
grelu.resources.datasets(project, host=DEFAULT_WANDB_HOST)[source]#
grelu.resources.runs(project, host=DEFAULT_WANDB_HOST, field='id', filters=None)[source]#
grelu.resources.get_artifact(name, project, alias='latest')[source]#
grelu.resources.get_dataset_by_model(model_name, project, alias='latest')[source]#
grelu.resources.get_model_by_dataset(dataset_name, project, alias='latest')[source]#
grelu.resources.load_model(project, model_name, alias='latest', checkpoint_file='model.ckpt')[source]#