grelu.resources#

Attributes#

Classes#

LightningModel

Wrapper for predictive sequence models

Functions#

get_meme_file_path(→ str)

Return the path to a MEME file.

get_blacklist_file(→ str)

Return the path to a blacklist file

_check_wandb(→ None)

Check that the user is logged into Weights and Biases

projects(→ List[str])

List all projects in the model zoo

artifacts(→ List[str])

List all artifacts associated with a project in the model zoo

models(→ List[str])

List all models associated with a project in the model zoo

datasets(→ List[str])

List all datasets associated with a project in the model zoo

runs(→ List[str])

List attributes of all runs associated with a project in the model zoo

get_artifact(name, project[, host, alias])

Retrieve an artifact associated with a project in the model zoo

get_dataset_by_model(→ List[str])

List all datasets associated with a model in the model zoo

get_model_by_dataset(→ List[str])

List all models associated with a dataset in the model zoo

load_model(→ grelu.lightning.LightningModel)

Download and load a model from the model zoo

Package Contents#

class grelu.resources.LightningModel(model_params: dict, train_params: dict = {}, data_params: dict = {})[source]#

Bases: pytorch_lightning.LightningModule

Wrapper for predictive sequence models

Parameters:
  • model_params – Dictionary of parameters specifying model architecture

  • train_params – Dictionary specifying training parameters

  • data_params – Dictionary specifying parameters of the training data. This is empty by default and will be filled at the time of training.

model_params[source]#
train_params[source]#
data_params[source]#
build_model() None[source]#

Build a model from parameter dictionary

initialize_loss() None[source]#

Create the specified loss function.

initialize_activation() None[source]#

Add a task-specific activation function to the model.

initialize_metrics()[source]#

Initialize the appropriate metrics for the given task.

update_metrics(metrics: dict, y_hat: torch.Tensor, y: torch.Tensor) None[source]#

Update metrics after each pass

format_input(x: Tuple[torch.Tensor, torch.Tensor] | torch.Tensor) torch.Tensor[source]#

Extract the one-hot encoded sequence from the input

forward(x: Tuple[torch.Tensor, torch.Tensor] | torch.Tensor | str | List[str], logits: bool = False) torch.Tensor[source]#

Forward pass

training_step(batch: torch.Tensor, batch_idx: int) torch.Tensor[source]#
validation_step(batch: torch.Tensor, batch_idx: int) torch.Tensor[source]#
on_validation_epoch_end()[source]#

Calculate metrics for entire validation set

test_step(batch: torch.Tensor, batch_idx: int) torch.Tensor[source]#

Calculate metrics after a single test step

on_test_epoch_end() None[source]#

Calculate metrics for entire test set

configure_optimizers() None[source]#

Configure oprimizer for training

count_params() int[source]#

Number of gradient enabled parameters in the model

parse_devices(devices: str | int | List[int]) Tuple[str, str | List[int]][source]#

Parses the devices argument and returns a tuple of accelerator and devices.

Parameters:

devices – Either “cpu” or an integer or list of integers representing the indices of the GPUs for training.

Returns:

A tuple of accelerator and devices.

parse_logger() str[source]#

Parses the name of the logger supplied in train_params.

add_transform(prediction_transform: Callable) None[source]#

Add a prediction transform

reset_transform() None[source]#

Remove a prediction transform

make_train_loader(dataset: Callable, batch_size: int | None = None, num_workers: int | None = None) Callable[source]#

Make dataloader for training

make_test_loader(dataset: Callable, batch_size: int | None = None, num_workers: int | None = None) Callable[source]#

Make dataloader for validation and testing

make_predict_loader(dataset: Callable, batch_size: int | None = None, num_workers: int | None = None) Callable[source]#

Make dataloader for prediction

train_on_dataset(train_dataset: Callable, val_dataset: Callable, checkpoint_path: str | None = None)[source]#

Train model and optionally log metrics to wandb.

Parameters:
  • train_dataset (Dataset) – Dataset object that yields training examples

  • val_dataset (Dataset) – Dataset object that yields training examples

  • checkpoint_path (str) – Path to model checkpoint from which to resume training. The optimizer will be set to its checkpointed state.

Returns:

PyTorch Lightning Trainer

_get_dataset_attrs(dataset: Callable) None[source]#

Read data parameters from a dataset object

change_head(n_tasks: int, final_pool_func: str) None[source]#

Build a new head with the desired number of tasks

tune_on_dataset(train_dataset: Callable, val_dataset: Callable, final_act_func: str | None = None, final_pool_func: str | None = None, freeze_embedding: bool = False)[source]#

Fine-tune a pretrained model on a new dataset.

Parameters:
  • train_dataset – Dataset object that yields training examples

  • val_dataset – Dataset object that yields training examples

  • final_act_func – Name of the final activation layer

  • final_pool_func – Name of the final pooling layer

  • freeze_embedding – If True, all the embedding layers of the pretrained model will be frozen and only the head will be trained.

Returns:

PyTorch Lightning Trainer

on_save_checkpoint(checkpoint: dict) None[source]#
predict_on_seqs(x: str | List[str], device: str | int = 'cpu') numpy.ndarray[source]#

A simple function to return model predictions directly on a batch of a single batch of sequences in string format.

Parameters:
  • x – DNA sequences as a string or list of strings.

  • device – Index of the device to use

Returns:

A numpy array of predictions.

predict_on_dataset(dataset: Callable, devices: int | str | List[int] = 'cpu', num_workers: int = 1, batch_size: int = 256, augment_aggfunc: str | Callable = 'mean', compare_func: str | Callable | None = None, return_df: bool = False, precision: str | None = None)[source]#

Predict for a dataset of sequences or variants

Parameters:
  • dataset – Dataset object that yields one-hot encoded sequences

  • devices – Device IDs to use

  • num_workers – Number of workers for data loader

  • batch_size – Batch size for data loader

  • augment_aggfunc – Return the average prediction across all augmented versions of a sequence

  • compare_func – Return the alt/ref difference for variants

  • return_df – Return the predictions as a Pandas dataframe

  • precision – Precision of the trainer e.g. ‘32’ or ‘bf16-mixed’.

Returns:

Model predictions as a numpy array or dataframe

test_on_dataset(dataset: Callable, devices: str | int | List[int] = 'cpu', num_workers: int = 1, batch_size: int = 256, precision: str | None = None)[source]#

Run test loop for a dataset

Parameters:
  • dataset – Dataset object that yields one-hot encoded sequences

  • devices – Device IDs to use for inference

  • num_workers – Number of workers for data loader

  • batch_size – Batch size for data loader

  • precision – Precision of the trainer e.g. ‘32’ or ‘bf16-mixed’.

Returns:

Dataframe containing all calculated metrics on the test set.

embed_on_dataset(dataset: Callable, device: str | int = 'cpu', num_workers: int = 1, batch_size: int = 256)[source]#

Return embeddings for a dataset of sequences

Parameters:
  • dataset – Dataset object that yields one-hot encoded sequences

  • device – Device ID to use

  • num_workers – Number of workers for data loader

  • batch_size – Batch size for data loader

Returns:

Numpy array of shape (B, T, L) containing embeddings.

get_task_idxs(tasks: int | str | List[int] | List[str], key: str = 'name', invert: bool = False) int | List[int][source]#

Given a task name or metadata entry, get the task index If integers are provided, return them unchanged

Parameters:
  • tasks – A string corresponding to a task name or metadata entry, or an integer indicating the index of a task, or a list of strings/integers

  • key – key to model.data_params[“tasks”] in which the relevant task data is stored. “name” will be used by default.

  • invert – Get indices for all tasks except those listed in tasks

Returns:

The index or indices of the corresponding task(s) in the model’s output.

input_coord_to_output_bin(input_coord: int, start_pos: int = 0) int[source]#

Given the position of a base in the input, get the index of the corresponding bin in the model’s prediction.

Parameters:
  • input_coord – Genomic coordinate of the input position

  • start_pos – Genomic coordinate of the first base in the input sequence

Returns:

Index of the output bin containing the given position.

output_bin_to_input_coord(output_bin: int, return_pos: str = 'start', start_pos: int = 0) int[source]#

Given the index of a bin in the output, get its corresponding start or end coordinate.

Parameters:
  • output_bin – Index of the bin in the model’s output

  • return_pos – “start” or “end”

  • start_pos – Genomic coordinate of the first base in the input sequence

Returns:

Genomic coordinate corresponding to the start (if return_pos = start) or end (if return_pos=end) of the bin.

input_intervals_to_output_intervals(intervals: pandas.DataFrame) pandas.DataFrame[source]#

Given a dataframe containing intervals corresponding to the input sequences, return a dataframe containing intervals corresponding to the model output.

Parameters:

intervals – A dataframe of genomic intervals

Returns:

A dataframe containing the genomic intervals corresponding to the model output from each input interval.

input_intervals_to_output_bins(intervals: pandas.DataFrame, start_pos: int = 0) None[source]#

Given a dataframe of genomic intervals, add columns indicating the indices of output bins that overlap the start and end of each interval.

Parameters:
  • intervals – A dataframe of genomic intervals

  • start_pos – The start position of the sequence input to the model.

Returns:start and end indices of the output bins corresponding

to each input interval.

grelu.resources.DEFAULT_WANDB_ENTITY = 'grelu'[source]#
grelu.resources.DEFAULT_WANDB_HOST = 'https://api.wandb.ai'[source]#
grelu.resources.get_meme_file_path(meme_motif_db: str) str[source]#

Return the path to a MEME file.

Parameters:

meme_motif_db – Path to a MEME file or the name of a MEME file included with gReLU. Current name options are “jaspar” and “consensus”.

Returns:

Path to the specified MEME file.

grelu.resources.get_blacklist_file(genome: str) str[source]#

Return the path to a blacklist file

Parameters:

genome – Name of a genome whose blacklist file is included with gReLU. Current name options are “hg19”, “hg38” and “mm10”.

Returns:

Path to the specified blacklist file.

grelu.resources._check_wandb(host: str = DEFAULT_WANDB_HOST) None[source]#

Check that the user is logged into Weights and Biases

Parameters:

host – URL of the Weights & Biases host

grelu.resources.projects(host: str = DEFAULT_WANDB_HOST) List[str][source]#

List all projects in the model zoo

Parameters:

host – URL of the Weights & Biases host

Returns:

List of project names

grelu.resources.artifacts(project: str, host: str = DEFAULT_WANDB_HOST, type_is: str | None = None, type_contains: str | None = None) List[str][source]#

List all artifacts associated with a project in the model zoo

Parameters:
  • project – Name of the project to search

  • host – URL of the Weights & Biases host

  • type_is – Return only artifacts with this type

  • type_contains – Return only artifacts whose type contains this string

Returns:

List of artifact names

grelu.resources.models(project: str, host: str = DEFAULT_WANDB_HOST) List[str][source]#

List all models associated with a project in the model zoo

Parameters:
  • project – Name of the project to search

  • host – URL of the Weights & Biases host

Returns:

List of model names

grelu.resources.datasets(project: str, host: str = DEFAULT_WANDB_HOST) List[str][source]#

List all datasets associated with a project in the model zoo

Parameters:
  • project – Name of the project to search

  • host – URL of the Weights & Biases host

Returns:

List of dataset names

grelu.resources.runs(project: str, host: str = DEFAULT_WANDB_HOST, field: str = 'id', filters: Dict[str, Any] | None = None) List[str][source]#

List attributes of all runs associated with a project in the model zoo

Parameters:
  • project – Name of the project to search

  • host – URL of the Weights & Biases host

  • field – Field to return from the run metadata

  • filters – Dictionary of filters to pass to api.runs

Returns:

List of run attributes

grelu.resources.get_artifact(name: str, project: str, host: str = DEFAULT_WANDB_HOST, alias: str = 'latest')[source]#

Retrieve an artifact associated with a project in the model zoo

Parameters:
  • name – Name of the artifact

  • project – Name of the project containing the artifact

  • host – URL of the Weights & Biases host

  • alias – Alias of the artifact

Returns:

The specific artifact

grelu.resources.get_dataset_by_model(model_name: str, project: str, host: str = DEFAULT_WANDB_HOST, alias: str = 'latest') List[str][source]#

List all datasets associated with a model in the model zoo

Parameters:
  • model_name – Name of the model

  • project – Name of the project containing the model

  • host – URL of the Weights & Biases host

  • alias – Alias of the model artifact

Returns:

A list containing the names of all datasets linked to the model

grelu.resources.get_model_by_dataset(dataset_name: str, project: str, host: str = DEFAULT_WANDB_HOST, alias: str = 'latest') List[str][source]#

List all models associated with a dataset in the model zoo

Parameters:
  • dataset_name – Name of the dataset

  • project – Name of the project containing the dataset

  • host – URL of the Weights & Biases host

  • alias – Alias of the dataset artifact

Returns:

A list containing the names of all models linked to the dataset

grelu.resources.load_model(project: str, model_name: str, device: str | int = 'cpu', host: str = DEFAULT_WANDB_HOST, alias: str = 'latest', checkpoint_file: str = 'model.ckpt') grelu.lightning.LightningModel[source]#

Download and load a model from the model zoo

Parameters:
  • project – Name of the project containing the model

  • model_name – Name of the model

  • device – Device index on which to load the model.

  • host – URL of the Weights & Biases host

  • alias – Alias of the model artifact

  • checkpoint_file – Name of the checkpoint file contained in the model artifact

Returns:

A LightningModel object