grelu.resources#
grelu.resources contains additional data files that can be used by gReLU functions. It also contains functions to load these files as well as files stored externally, such as model checkpoints and datasets in the model zoo.
Attributes#
Classes#
Wrapper for predictive sequence models |
Functions#
|
Return the path to a MEME file. |
|
Return the path to a blacklist file |
|
Check that the user is logged into Weights and Biases |
|
List all projects in the model zoo |
|
List all artifacts associated with a project in the model zoo |
|
List all models associated with a project in the model zoo |
|
List all datasets associated with a project in the model zoo |
|
List attributes of all runs associated with a project in the model zoo |
|
Retrieve an artifact associated with a project in the model zoo |
|
List all datasets associated with a model in the model zoo |
|
List all models associated with a dataset in the model zoo |
|
Download and load a model from the model zoo |
Package Contents#
- class grelu.resources.LightningModel(model_params: dict, train_params: dict = {})[source]#
Bases:
pytorch_lightning.LightningModule
Wrapper for predictive sequence models
- Parameters:
model_params – Dictionary of parameters specifying model architecture
train_params – Dictionary specifying training parameters
- update_metrics(metrics: dict, y_hat: torch.Tensor, y: torch.Tensor) None [source]#
Update metrics after each pass
- format_input(x: Tuple[torch.Tensor, torch.Tensor] | torch.Tensor) torch.Tensor [source]#
Extract the one-hot encoded sequence from the input
- forward(x: Tuple[torch.Tensor, torch.Tensor] | torch.Tensor | str | List[str], logits: bool = False) torch.Tensor [source]#
Forward pass
- test_step(batch: torch.Tensor, batch_idx: int) torch.Tensor [source]#
Calculate metrics after a single test step
- parse_devices(devices: str | int | List[int]) Tuple[str, str | List[int]] [source]#
Parses the devices argument and returns a tuple of accelerator and devices.
- Parameters:
devices – Either “cpu” or an integer or list of integers representing the indices of the GPUs for training.
- Returns:
A tuple of accelerator and devices.
- make_train_loader(dataset: Callable, batch_size: int | None = None, num_workers: int | None = None) Callable [source]#
Make dataloader for training
- make_test_loader(dataset: Callable, batch_size: int | None = None, num_workers: int | None = None) Callable [source]#
Make dataloader for validation and testing
- make_predict_loader(dataset: Callable, batch_size: int | None = None, num_workers: int | None = None) Callable [source]#
Make dataloader for prediction
- train_on_dataset(train_dataset: Callable, val_dataset: Callable, checkpoint_path: str | None = None)[source]#
Train model and optionally log metrics to wandb.
- Parameters:
train_dataset (Dataset) – Dataset object that yields training examples
val_dataset (Dataset) – Dataset object that yields training examples
checkpoint_path (str) – Path to model checkpoint from which to resume training. The optimizer will be set to its checkpointed state.
- Returns:
PyTorch Lightning Trainer
- change_head(n_tasks: int, final_pool_func: str) None [source]#
Build a new head with the desired number of tasks
- tune_on_dataset(train_dataset: Callable, val_dataset: Callable, final_act_func: str | None = None, final_pool_func: str | None = None, freeze_embedding: bool = False)[source]#
Fine-tune a pretrained model on a new dataset.
- Parameters:
train_dataset – Dataset object that yields training examples
val_dataset – Dataset object that yields training examples
final_act_func – Name of the final activation layer
final_pool_func – Name of the final pooling layer
freeze_embedding – If True, all the embedding layers of the pretrained model will be frozen and only the head will be trained.
- Returns:
PyTorch Lightning Trainer
- predict_on_seqs(x: str | List[str], device: str | int = 'cpu') numpy.ndarray [source]#
A simple function to return model predictions directly on a batch of a single batch of sequences in string format.
- Parameters:
x – DNA sequences as a string or list of strings.
device – Index of the device to use
- Returns:
A numpy array of predictions.
- predict_on_dataset(dataset: Callable, devices: int | str | List[int] = 'cpu', num_workers: int = 1, batch_size: int = 256, augment_aggfunc: str | Callable = 'mean', compare_func: str | Callable | None = None, return_df: bool = False, precision: str | None = None)[source]#
Predict for a dataset of sequences or variants
- Parameters:
dataset – Dataset object that yields one-hot encoded sequences
devices – Device IDs to use
num_workers – Number of workers for data loader
batch_size – Batch size for data loader
augment_aggfunc – Return the average prediction across all augmented versions of a sequence
compare_func – Return the alt/ref difference for variants
return_df – Return the predictions as a Pandas dataframe
precision – Precision of the trainer e.g. ‘32’ or ‘bf16-mixed’.
- Returns:
Model predictions as a numpy array or dataframe
- test_on_dataset(dataset: Callable, devices: str | int | List[int] = 'cpu', num_workers: int = 1, batch_size: int = 256, precision: str | None = None, write_path: str | None = None)[source]#
Run test loop for a dataset
- Parameters:
dataset – Dataset object that yields one-hot encoded sequences
devices – Device IDs to use for inference
num_workers – Number of workers for data loader
batch_size – Batch size for data loader
precision – Precision of the trainer e.g. ‘32’ or ‘bf16-mixed’.
write_path – Path to write a new model checkpoint containing test data parameters and performance.
- Returns:
Dataframe containing all calculated metrics on the test set.
- embed_on_dataset(dataset: Callable, device: str | int = 'cpu', num_workers: int = 1, batch_size: int = 256)[source]#
Return embeddings for a dataset of sequences
- Parameters:
dataset – Dataset object that yields one-hot encoded sequences
device – Device ID to use
num_workers – Number of workers for data loader
batch_size – Batch size for data loader
- Returns:
Numpy array of shape (B, T, L) containing embeddings.
- get_task_idxs(tasks: int | str | List[int] | List[str], key: str = 'name', invert: bool = False) int | List[int] [source]#
Given a task name or metadata entry, get the task index If integers are provided, return them unchanged
- Parameters:
tasks – A string corresponding to a task name or metadata entry, or an integer indicating the index of a task, or a list of strings/integers
key – key to model.data_params[“tasks”] in which the relevant task data is stored. “name” will be used by default.
invert – Get indices for all tasks except those listed in tasks
- Returns:
The index or indices of the corresponding task(s) in the model’s output.
- input_coord_to_output_bin(input_coord: int, start_pos: int = 0) int [source]#
Given the position of a base in the input, get the index of the corresponding bin in the model’s prediction.
- Parameters:
input_coord – Genomic coordinate of the input position
start_pos – Genomic coordinate of the first base in the input sequence
- Returns:
Index of the output bin containing the given position.
- output_bin_to_input_coord(output_bin: int, return_pos: str = 'start', start_pos: int = 0) int [source]#
Given the index of a bin in the output, get its corresponding start or end coordinate.
- Parameters:
output_bin – Index of the bin in the model’s output
return_pos – “start” or “end”
start_pos – Genomic coordinate of the first base in the input sequence
- Returns:
Genomic coordinate corresponding to the start (if return_pos = start) or end (if return_pos=end) of the bin.
- input_intervals_to_output_intervals(intervals: pandas.DataFrame) pandas.DataFrame [source]#
Given a dataframe containing intervals corresponding to the input sequences, return a dataframe containing intervals corresponding to the model output.
- Parameters:
intervals – A dataframe of genomic intervals
- Returns:
A dataframe containing the genomic intervals corresponding to the model output from each input interval.
- input_intervals_to_output_bins(intervals: pandas.DataFrame, start_pos: int = 0) None [source]#
Given a dataframe of genomic intervals, add columns indicating the indices of output bins that overlap the start and end of each interval.
- Parameters:
intervals – A dataframe of genomic intervals
start_pos – The start position of the sequence input to the model.
- Returns:start and end indices of the output bins corresponding
to each input interval.
- grelu.resources.get_meme_file_path(meme_motif_db: str) str [source]#
Return the path to a MEME file.
- Parameters:
meme_motif_db – Path to a MEME file or the name of a MEME file included with gReLU. Current name options are “jaspar” and “consensus”.
- Returns:
Path to the specified MEME file.
- grelu.resources.get_blacklist_file(genome: str) str [source]#
Return the path to a blacklist file
- Parameters:
genome – Name of a genome whose blacklist file is included with gReLU. Current name options are “hg19”, “hg38” and “mm10”.
- Returns:
Path to the specified blacklist file.
- grelu.resources._check_wandb(host: str = DEFAULT_WANDB_HOST) None [source]#
Check that the user is logged into Weights and Biases
- Parameters:
host – URL of the Weights & Biases host
- grelu.resources.projects(host: str = DEFAULT_WANDB_HOST) List[str] [source]#
List all projects in the model zoo
- Parameters:
host – URL of the Weights & Biases host
- Returns:
List of project names
- grelu.resources.artifacts(project: str, host: str = DEFAULT_WANDB_HOST, type_is: str | None = None, type_contains: str | None = None) List[str] [source]#
List all artifacts associated with a project in the model zoo
- Parameters:
project – Name of the project to search
host – URL of the Weights & Biases host
type_is – Return only artifacts with this type
type_contains – Return only artifacts whose type contains this string
- Returns:
List of artifact names
- grelu.resources.models(project: str, host: str = DEFAULT_WANDB_HOST) List[str] [source]#
List all models associated with a project in the model zoo
- Parameters:
project – Name of the project to search
host – URL of the Weights & Biases host
- Returns:
List of model names
- grelu.resources.datasets(project: str, host: str = DEFAULT_WANDB_HOST) List[str] [source]#
List all datasets associated with a project in the model zoo
- Parameters:
project – Name of the project to search
host – URL of the Weights & Biases host
- Returns:
List of dataset names
- grelu.resources.runs(project: str, host: str = DEFAULT_WANDB_HOST, field: str = 'id', filters: Dict[str, Any] | None = None) List[str] [source]#
List attributes of all runs associated with a project in the model zoo
- Parameters:
project – Name of the project to search
host – URL of the Weights & Biases host
field – Field to return from the run metadata
filters – Dictionary of filters to pass to api.runs
- Returns:
List of run attributes
- grelu.resources.get_artifact(name: str, project: str, host: str = DEFAULT_WANDB_HOST, alias: str = 'latest')[source]#
Retrieve an artifact associated with a project in the model zoo
- Parameters:
name – Name of the artifact
project – Name of the project containing the artifact
host – URL of the Weights & Biases host
alias – Alias of the artifact
- Returns:
The specific artifact
- grelu.resources.get_dataset_by_model(model_name: str, project: str, host: str = DEFAULT_WANDB_HOST, alias: str = 'latest') List[str] [source]#
List all datasets associated with a model in the model zoo
- Parameters:
model_name – Name of the model
project – Name of the project containing the model
host – URL of the Weights & Biases host
alias – Alias of the model artifact
- Returns:
A list containing the names of all datasets linked to the model
- grelu.resources.get_model_by_dataset(dataset_name: str, project: str, host: str = DEFAULT_WANDB_HOST, alias: str = 'latest') List[str] [source]#
List all models associated with a dataset in the model zoo
- Parameters:
dataset_name – Name of the dataset
project – Name of the project containing the dataset
host – URL of the Weights & Biases host
alias – Alias of the dataset artifact
- Returns:
A list containing the names of all models linked to the dataset
- grelu.resources.load_model(project: str, model_name: str, device: str | int = 'cpu', host: str = DEFAULT_WANDB_HOST, alias: str = 'latest', checkpoint_file: str = 'model.ckpt') grelu.lightning.LightningModel [source]#
Download and load a model from the model zoo
- Parameters:
project – Name of the project containing the model
model_name – Name of the model
device – Device index on which to load the model.
host – URL of the Weights & Biases host
alias – Alias of the model artifact
checkpoint_file – Name of the checkpoint file contained in the model artifact
- Returns:
A LightningModel object