grelu.resources#

grelu.resources contains additional data files that can be used by gReLU functions. It also contains functions to load these files as well as files stored externally, such as model checkpoints and datasets in the model zoo.

Attributes#

Classes#

LightningModel

Wrapper for predictive sequence models

Functions#

get_meme_file_path(→ str)

Return the path to a MEME file.

get_blacklist_file(→ str)

Return the path to a blacklist file

_check_wandb(→ None)

Check that the user is logged into Weights and Biases

projects(→ List[str])

List all projects in the model zoo

artifacts(→ List[str])

List all artifacts associated with a project in the model zoo

models(→ List[str])

List all models associated with a project in the model zoo

datasets(→ List[str])

List all datasets associated with a project in the model zoo

runs(→ List[str])

List attributes of all runs associated with a project in the model zoo

get_artifact(name, project[, host, alias])

Retrieve an artifact associated with a project in the model zoo

get_dataset_by_model(→ List[str])

List all datasets associated with a model in the model zoo

get_model_by_dataset(→ List[str])

List all models associated with a dataset in the model zoo

load_model(→ grelu.lightning.LightningModel)

Download and load a model from the model zoo

Package Contents#

class grelu.resources.LightningModel(model_params: dict, train_params: dict = {})[source]#

Bases: pytorch_lightning.LightningModule

Wrapper for predictive sequence models

Parameters:
  • model_params – Dictionary of parameters specifying model architecture

  • train_params – Dictionary specifying training parameters

model_params[source]#
train_params[source]#
data_params[source]#
performance[source]#
build_model() None[source]#

Build a model from parameter dictionary

initialize_loss() None[source]#

Create the specified loss function.

initialize_activation() None[source]#

Add a task-specific activation function to the model.

initialize_metrics()[source]#

Initialize the appropriate metrics for the given task.

update_metrics(metrics: dict, y_hat: torch.Tensor, y: torch.Tensor) None[source]#

Update metrics after each pass

format_input(x: Tuple[torch.Tensor, torch.Tensor] | torch.Tensor) torch.Tensor[source]#

Extract the one-hot encoded sequence from the input

forward(x: Tuple[torch.Tensor, torch.Tensor] | torch.Tensor | str | List[str], logits: bool = False) torch.Tensor[source]#

Forward pass

training_step(batch: torch.Tensor, batch_idx: int) torch.Tensor[source]#
validation_step(batch: torch.Tensor, batch_idx: int) torch.Tensor[source]#
on_validation_epoch_end()[source]#

Calculate metrics for entire validation set

test_step(batch: torch.Tensor, batch_idx: int) torch.Tensor[source]#

Calculate metrics after a single test step

on_test_epoch_end() None[source]#

Calculate metrics for entire test set

configure_optimizers() None[source]#

Configure oprimizer for training

count_params() int[source]#

Number of gradient enabled parameters in the model

parse_devices(devices: str | int | List[int]) Tuple[str, str | List[int]][source]#

Parses the devices argument and returns a tuple of accelerator and devices.

Parameters:

devices – Either “cpu” or an integer or list of integers representing the indices of the GPUs for training.

Returns:

A tuple of accelerator and devices.

parse_logger() str[source]#

Parses the name of the logger supplied in train_params.

add_transform(prediction_transform: Callable) None[source]#

Add a prediction transform

reset_transform() None[source]#

Remove a prediction transform

make_train_loader(dataset: Callable, batch_size: int | None = None, num_workers: int | None = None) Callable[source]#

Make dataloader for training

make_test_loader(dataset: Callable, batch_size: int | None = None, num_workers: int | None = None) Callable[source]#

Make dataloader for validation and testing

make_predict_loader(dataset: Callable, batch_size: int | None = None, num_workers: int | None = None) Callable[source]#

Make dataloader for prediction

train_on_dataset(train_dataset: Callable, val_dataset: Callable, checkpoint_path: str | None = None)[source]#

Train model and optionally log metrics to wandb.

Parameters:
  • train_dataset (Dataset) – Dataset object that yields training examples

  • val_dataset (Dataset) – Dataset object that yields training examples

  • checkpoint_path (str) – Path to model checkpoint from which to resume training. The optimizer will be set to its checkpointed state.

Returns:

PyTorch Lightning Trainer

_get_dataset_attrs(dataset: Callable) None[source]#

Read data parameters from a dataset object

change_head(n_tasks: int, final_pool_func: str) None[source]#

Build a new head with the desired number of tasks

tune_on_dataset(train_dataset: Callable, val_dataset: Callable, final_act_func: str | None = None, final_pool_func: str | None = None, freeze_embedding: bool = False)[source]#

Fine-tune a pretrained model on a new dataset.

Parameters:
  • train_dataset – Dataset object that yields training examples

  • val_dataset – Dataset object that yields training examples

  • final_act_func – Name of the final activation layer

  • final_pool_func – Name of the final pooling layer

  • freeze_embedding – If True, all the embedding layers of the pretrained model will be frozen and only the head will be trained.

Returns:

PyTorch Lightning Trainer

on_save_checkpoint(checkpoint: dict) None[source]#
on_load_checkpoint(checkpoint: dict) None[source]#
predict_on_seqs(x: str | List[str], device: str | int = 'cpu') numpy.ndarray[source]#

A simple function to return model predictions directly on a batch of a single batch of sequences in string format.

Parameters:
  • x – DNA sequences as a string or list of strings.

  • device – Index of the device to use

Returns:

A numpy array of predictions.

predict_on_dataset(dataset: Callable, devices: int | str | List[int] = 'cpu', num_workers: int = 1, batch_size: int = 256, augment_aggfunc: str | Callable = 'mean', compare_func: str | Callable | None = None, return_df: bool = False, precision: str | None = None)[source]#

Predict for a dataset of sequences or variants

Parameters:
  • dataset – Dataset object that yields one-hot encoded sequences

  • devices – Device IDs to use

  • num_workers – Number of workers for data loader

  • batch_size – Batch size for data loader

  • augment_aggfunc – Return the average prediction across all augmented versions of a sequence

  • compare_func – Return the alt/ref difference for variants

  • return_df – Return the predictions as a Pandas dataframe

  • precision – Precision of the trainer e.g. ‘32’ or ‘bf16-mixed’.

Returns:

Model predictions as a numpy array or dataframe

test_on_dataset(dataset: Callable, devices: str | int | List[int] = 'cpu', num_workers: int = 1, batch_size: int = 256, precision: str | None = None, write_path: str | None = None)[source]#

Run test loop for a dataset

Parameters:
  • dataset – Dataset object that yields one-hot encoded sequences

  • devices – Device IDs to use for inference

  • num_workers – Number of workers for data loader

  • batch_size – Batch size for data loader

  • precision – Precision of the trainer e.g. ‘32’ or ‘bf16-mixed’.

  • write_path – Path to write a new model checkpoint containing test data parameters and performance.

Returns:

Dataframe containing all calculated metrics on the test set.

embed_on_dataset(dataset: Callable, device: str | int = 'cpu', num_workers: int = 1, batch_size: int = 256)[source]#

Return embeddings for a dataset of sequences

Parameters:
  • dataset – Dataset object that yields one-hot encoded sequences

  • device – Device ID to use

  • num_workers – Number of workers for data loader

  • batch_size – Batch size for data loader

Returns:

Numpy array of shape (B, T, L) containing embeddings.

get_task_idxs(tasks: int | str | List[int] | List[str], key: str = 'name', invert: bool = False) int | List[int][source]#

Given a task name or metadata entry, get the task index If integers are provided, return them unchanged

Parameters:
  • tasks – A string corresponding to a task name or metadata entry, or an integer indicating the index of a task, or a list of strings/integers

  • key – key to model.data_params[“tasks”] in which the relevant task data is stored. “name” will be used by default.

  • invert – Get indices for all tasks except those listed in tasks

Returns:

The index or indices of the corresponding task(s) in the model’s output.

input_coord_to_output_bin(input_coord: int, start_pos: int = 0) int[source]#

Given the position of a base in the input, get the index of the corresponding bin in the model’s prediction.

Parameters:
  • input_coord – Genomic coordinate of the input position

  • start_pos – Genomic coordinate of the first base in the input sequence

Returns:

Index of the output bin containing the given position.

output_bin_to_input_coord(output_bin: int, return_pos: str = 'start', start_pos: int = 0) int[source]#

Given the index of a bin in the output, get its corresponding start or end coordinate.

Parameters:
  • output_bin – Index of the bin in the model’s output

  • return_pos – “start” or “end”

  • start_pos – Genomic coordinate of the first base in the input sequence

Returns:

Genomic coordinate corresponding to the start (if return_pos = start) or end (if return_pos=end) of the bin.

input_intervals_to_output_intervals(intervals: pandas.DataFrame) pandas.DataFrame[source]#

Given a dataframe containing intervals corresponding to the input sequences, return a dataframe containing intervals corresponding to the model output.

Parameters:

intervals – A dataframe of genomic intervals

Returns:

A dataframe containing the genomic intervals corresponding to the model output from each input interval.

input_intervals_to_output_bins(intervals: pandas.DataFrame, start_pos: int = 0) None[source]#

Given a dataframe of genomic intervals, add columns indicating the indices of output bins that overlap the start and end of each interval.

Parameters:
  • intervals – A dataframe of genomic intervals

  • start_pos – The start position of the sequence input to the model.

Returns:start and end indices of the output bins corresponding

to each input interval.

grelu.resources.DEFAULT_WANDB_ENTITY = 'grelu'[source]#
grelu.resources.DEFAULT_WANDB_HOST = 'https://api.wandb.ai'[source]#
grelu.resources.get_meme_file_path(meme_motif_db: str) str[source]#

Return the path to a MEME file.

Parameters:

meme_motif_db – Path to a MEME file or the name of a MEME file included with gReLU. Current name options are “jaspar” and “consensus”.

Returns:

Path to the specified MEME file.

grelu.resources.get_blacklist_file(genome: str) str[source]#

Return the path to a blacklist file

Parameters:

genome – Name of a genome whose blacklist file is included with gReLU. Current name options are “hg19”, “hg38” and “mm10”.

Returns:

Path to the specified blacklist file.

grelu.resources._check_wandb(host: str = DEFAULT_WANDB_HOST) None[source]#

Check that the user is logged into Weights and Biases

Parameters:

host – URL of the Weights & Biases host

grelu.resources.projects(host: str = DEFAULT_WANDB_HOST) List[str][source]#

List all projects in the model zoo

Parameters:

host – URL of the Weights & Biases host

Returns:

List of project names

grelu.resources.artifacts(project: str, host: str = DEFAULT_WANDB_HOST, type_is: str | None = None, type_contains: str | None = None) List[str][source]#

List all artifacts associated with a project in the model zoo

Parameters:
  • project – Name of the project to search

  • host – URL of the Weights & Biases host

  • type_is – Return only artifacts with this type

  • type_contains – Return only artifacts whose type contains this string

Returns:

List of artifact names

grelu.resources.models(project: str, host: str = DEFAULT_WANDB_HOST) List[str][source]#

List all models associated with a project in the model zoo

Parameters:
  • project – Name of the project to search

  • host – URL of the Weights & Biases host

Returns:

List of model names

grelu.resources.datasets(project: str, host: str = DEFAULT_WANDB_HOST) List[str][source]#

List all datasets associated with a project in the model zoo

Parameters:
  • project – Name of the project to search

  • host – URL of the Weights & Biases host

Returns:

List of dataset names

grelu.resources.runs(project: str, host: str = DEFAULT_WANDB_HOST, field: str = 'id', filters: Dict[str, Any] | None = None) List[str][source]#

List attributes of all runs associated with a project in the model zoo

Parameters:
  • project – Name of the project to search

  • host – URL of the Weights & Biases host

  • field – Field to return from the run metadata

  • filters – Dictionary of filters to pass to api.runs

Returns:

List of run attributes

grelu.resources.get_artifact(name: str, project: str, host: str = DEFAULT_WANDB_HOST, alias: str = 'latest')[source]#

Retrieve an artifact associated with a project in the model zoo

Parameters:
  • name – Name of the artifact

  • project – Name of the project containing the artifact

  • host – URL of the Weights & Biases host

  • alias – Alias of the artifact

Returns:

The specific artifact

grelu.resources.get_dataset_by_model(model_name: str, project: str, host: str = DEFAULT_WANDB_HOST, alias: str = 'latest') List[str][source]#

List all datasets associated with a model in the model zoo

Parameters:
  • model_name – Name of the model

  • project – Name of the project containing the model

  • host – URL of the Weights & Biases host

  • alias – Alias of the model artifact

Returns:

A list containing the names of all datasets linked to the model

grelu.resources.get_model_by_dataset(dataset_name: str, project: str, host: str = DEFAULT_WANDB_HOST, alias: str = 'latest') List[str][source]#

List all models associated with a dataset in the model zoo

Parameters:
  • dataset_name – Name of the dataset

  • project – Name of the project containing the dataset

  • host – URL of the Weights & Biases host

  • alias – Alias of the dataset artifact

Returns:

A list containing the names of all models linked to the dataset

grelu.resources.load_model(project: str, model_name: str, device: str | int = 'cpu', host: str = DEFAULT_WANDB_HOST, alias: str = 'latest', checkpoint_file: str = 'model.ckpt') grelu.lightning.LightningModel[source]#

Download and load a model from the model zoo

Parameters:
  • project – Name of the project containing the model

  • model_name – Name of the model

  • device – Device index on which to load the model.

  • host – URL of the Weights & Biases host

  • alias – Alias of the model artifact

  • checkpoint_file – Name of the checkpoint file contained in the model artifact

Returns:

A LightningModel object