grelu.resources#
Attributes#
Classes#
Wrapper for predictive sequence models |
Functions#
|
Return the path to a MEME file. |
|
Return the path to a blacklist file |
|
Check that the user is logged into Weights and Biases |
|
List all projects in the model zoo |
|
List all artifacts associated with a project in the model zoo |
|
List all models associated with a project in the model zoo |
|
List all datasets associated with a project in the model zoo |
|
List attributes of all runs associated with a project in the model zoo |
|
Retrieve an artifact associated with a project in the model zoo |
|
List all datasets associated with a model in the model zoo |
|
List all models associated with a dataset in the model zoo |
|
Download and load a model from the model zoo |
Package Contents#
- class grelu.resources.LightningModel(model_params: dict, train_params: dict = {}, data_params: dict = {})[source]#
Bases:
pytorch_lightning.LightningModule
Wrapper for predictive sequence models
- Parameters:
model_params – Dictionary of parameters specifying model architecture
train_params – Dictionary specifying training parameters
data_params – Dictionary specifying parameters of the training data. This is empty by default and will be filled at the time of training.
- update_metrics(metrics: dict, y_hat: torch.Tensor, y: torch.Tensor) None [source]#
Update metrics after each pass
- format_input(x: Tuple[torch.Tensor, torch.Tensor] | torch.Tensor) torch.Tensor [source]#
Extract the one-hot encoded sequence from the input
- forward(x: Tuple[torch.Tensor, torch.Tensor] | torch.Tensor | str | List[str], logits: bool = False) torch.Tensor [source]#
Forward pass
- test_step(batch: torch.Tensor, batch_idx: int) torch.Tensor [source]#
Calculate metrics after a single test step
- parse_devices(devices: str | int | List[int]) Tuple[str, str | List[int]] [source]#
Parses the devices argument and returns a tuple of accelerator and devices.
- Parameters:
devices – Either “cpu” or an integer or list of integers representing the indices of the GPUs for training.
- Returns:
A tuple of accelerator and devices.
- make_train_loader(dataset: Callable, batch_size: int | None = None, num_workers: int | None = None) Callable [source]#
Make dataloader for training
- make_test_loader(dataset: Callable, batch_size: int | None = None, num_workers: int | None = None) Callable [source]#
Make dataloader for validation and testing
- make_predict_loader(dataset: Callable, batch_size: int | None = None, num_workers: int | None = None) Callable [source]#
Make dataloader for prediction
- train_on_dataset(train_dataset: Callable, val_dataset: Callable, checkpoint_path: str | None = None)[source]#
Train model and optionally log metrics to wandb.
- Parameters:
train_dataset (Dataset) – Dataset object that yields training examples
val_dataset (Dataset) – Dataset object that yields training examples
checkpoint_path (str) – Path to model checkpoint from which to resume training. The optimizer will be set to its checkpointed state.
- Returns:
PyTorch Lightning Trainer
- change_head(n_tasks: int, final_pool_func: str) None [source]#
Build a new head with the desired number of tasks
- tune_on_dataset(train_dataset: Callable, val_dataset: Callable, final_act_func: str | None = None, final_pool_func: str | None = None, freeze_embedding: bool = False)[source]#
Fine-tune a pretrained model on a new dataset.
- Parameters:
train_dataset – Dataset object that yields training examples
val_dataset – Dataset object that yields training examples
final_act_func – Name of the final activation layer
final_pool_func – Name of the final pooling layer
freeze_embedding – If True, all the embedding layers of the pretrained model will be frozen and only the head will be trained.
- Returns:
PyTorch Lightning Trainer
- predict_on_seqs(x: str | List[str], device: str | int = 'cpu') numpy.ndarray [source]#
A simple function to return model predictions directly on a batch of a single batch of sequences in string format.
- Parameters:
x – DNA sequences as a string or list of strings.
device – Index of the device to use
- Returns:
A numpy array of predictions.
- predict_on_dataset(dataset: Callable, devices: int | str | List[int] = 'cpu', num_workers: int = 1, batch_size: int = 256, augment_aggfunc: str | Callable = 'mean', compare_func: str | Callable | None = None, return_df: bool = False, precision: str | None = None)[source]#
Predict for a dataset of sequences or variants
- Parameters:
dataset – Dataset object that yields one-hot encoded sequences
devices – Device IDs to use
num_workers – Number of workers for data loader
batch_size – Batch size for data loader
augment_aggfunc – Return the average prediction across all augmented versions of a sequence
compare_func – Return the alt/ref difference for variants
return_df – Return the predictions as a Pandas dataframe
precision – Precision of the trainer e.g. ‘32’ or ‘bf16-mixed’.
- Returns:
Model predictions as a numpy array or dataframe
- test_on_dataset(dataset: Callable, devices: str | int | List[int] = 'cpu', num_workers: int = 1, batch_size: int = 256, precision: str | None = None)[source]#
Run test loop for a dataset
- Parameters:
dataset – Dataset object that yields one-hot encoded sequences
devices – Device IDs to use for inference
num_workers – Number of workers for data loader
batch_size – Batch size for data loader
precision – Precision of the trainer e.g. ‘32’ or ‘bf16-mixed’.
- Returns:
Dataframe containing all calculated metrics on the test set.
- embed_on_dataset(dataset: Callable, device: str | int = 'cpu', num_workers: int = 1, batch_size: int = 256)[source]#
Return embeddings for a dataset of sequences
- Parameters:
dataset – Dataset object that yields one-hot encoded sequences
device – Device ID to use
num_workers – Number of workers for data loader
batch_size – Batch size for data loader
- Returns:
Numpy array of shape (B, T, L) containing embeddings.
- get_task_idxs(tasks: int | str | List[int] | List[str], key: str = 'name', invert: bool = False) int | List[int] [source]#
Given a task name or metadata entry, get the task index If integers are provided, return them unchanged
- Parameters:
tasks – A string corresponding to a task name or metadata entry, or an integer indicating the index of a task, or a list of strings/integers
key – key to model.data_params[“tasks”] in which the relevant task data is stored. “name” will be used by default.
invert – Get indices for all tasks except those listed in tasks
- Returns:
The index or indices of the corresponding task(s) in the model’s output.
- input_coord_to_output_bin(input_coord: int, start_pos: int = 0) int [source]#
Given the position of a base in the input, get the index of the corresponding bin in the model’s prediction.
- Parameters:
input_coord – Genomic coordinate of the input position
start_pos – Genomic coordinate of the first base in the input sequence
- Returns:
Index of the output bin containing the given position.
- output_bin_to_input_coord(output_bin: int, return_pos: str = 'start', start_pos: int = 0) int [source]#
Given the index of a bin in the output, get its corresponding start or end coordinate.
- Parameters:
output_bin – Index of the bin in the model’s output
return_pos – “start” or “end”
start_pos – Genomic coordinate of the first base in the input sequence
- Returns:
Genomic coordinate corresponding to the start (if return_pos = start) or end (if return_pos=end) of the bin.
- input_intervals_to_output_intervals(intervals: pandas.DataFrame) pandas.DataFrame [source]#
Given a dataframe containing intervals corresponding to the input sequences, return a dataframe containing intervals corresponding to the model output.
- Parameters:
intervals – A dataframe of genomic intervals
- Returns:
A dataframe containing the genomic intervals corresponding to the model output from each input interval.
- input_intervals_to_output_bins(intervals: pandas.DataFrame, start_pos: int = 0) None [source]#
Given a dataframe of genomic intervals, add columns indicating the indices of output bins that overlap the start and end of each interval.
- Parameters:
intervals – A dataframe of genomic intervals
start_pos – The start position of the sequence input to the model.
- Returns:start and end indices of the output bins corresponding
to each input interval.
- grelu.resources.get_meme_file_path(meme_motif_db: str) str [source]#
Return the path to a MEME file.
- Parameters:
meme_motif_db – Path to a MEME file or the name of a MEME file included with gReLU. Current name options are “jaspar” and “consensus”.
- Returns:
Path to the specified MEME file.
- grelu.resources.get_blacklist_file(genome: str) str [source]#
Return the path to a blacklist file
- Parameters:
genome – Name of a genome whose blacklist file is included with gReLU. Current name options are “hg19”, “hg38” and “mm10”.
- Returns:
Path to the specified blacklist file.
- grelu.resources._check_wandb(host: str = DEFAULT_WANDB_HOST) None [source]#
Check that the user is logged into Weights and Biases
- Parameters:
host – URL of the Weights & Biases host
- grelu.resources.projects(host: str = DEFAULT_WANDB_HOST) List[str] [source]#
List all projects in the model zoo
- Parameters:
host – URL of the Weights & Biases host
- Returns:
List of project names
- grelu.resources.artifacts(project: str, host: str = DEFAULT_WANDB_HOST, type_is: str | None = None, type_contains: str | None = None) List[str] [source]#
List all artifacts associated with a project in the model zoo
- Parameters:
project – Name of the project to search
host – URL of the Weights & Biases host
type_is – Return only artifacts with this type
type_contains – Return only artifacts whose type contains this string
- Returns:
List of artifact names
- grelu.resources.models(project: str, host: str = DEFAULT_WANDB_HOST) List[str] [source]#
List all models associated with a project in the model zoo
- Parameters:
project – Name of the project to search
host – URL of the Weights & Biases host
- Returns:
List of model names
- grelu.resources.datasets(project: str, host: str = DEFAULT_WANDB_HOST) List[str] [source]#
List all datasets associated with a project in the model zoo
- Parameters:
project – Name of the project to search
host – URL of the Weights & Biases host
- Returns:
List of dataset names
- grelu.resources.runs(project: str, host: str = DEFAULT_WANDB_HOST, field: str = 'id', filters: Dict[str, Any] | None = None) List[str] [source]#
List attributes of all runs associated with a project in the model zoo
- Parameters:
project – Name of the project to search
host – URL of the Weights & Biases host
field – Field to return from the run metadata
filters – Dictionary of filters to pass to api.runs
- Returns:
List of run attributes
- grelu.resources.get_artifact(name: str, project: str, host: str = DEFAULT_WANDB_HOST, alias: str = 'latest')[source]#
Retrieve an artifact associated with a project in the model zoo
- Parameters:
name – Name of the artifact
project – Name of the project containing the artifact
host – URL of the Weights & Biases host
alias – Alias of the artifact
- Returns:
The specific artifact
- grelu.resources.get_dataset_by_model(model_name: str, project: str, host: str = DEFAULT_WANDB_HOST, alias: str = 'latest') List[str] [source]#
List all datasets associated with a model in the model zoo
- Parameters:
model_name – Name of the model
project – Name of the project containing the model
host – URL of the Weights & Biases host
alias – Alias of the model artifact
- Returns:
A list containing the names of all datasets linked to the model
- grelu.resources.get_model_by_dataset(dataset_name: str, project: str, host: str = DEFAULT_WANDB_HOST, alias: str = 'latest') List[str] [source]#
List all models associated with a dataset in the model zoo
- Parameters:
dataset_name – Name of the dataset
project – Name of the project containing the dataset
host – URL of the Weights & Biases host
alias – Alias of the dataset artifact
- Returns:
A list containing the names of all models linked to the dataset
- grelu.resources.load_model(project: str, model_name: str, device: str | int = 'cpu', host: str = DEFAULT_WANDB_HOST, alias: str = 'latest', checkpoint_file: str = 'model.ckpt') grelu.lightning.LightningModel [source]#
Download and load a model from the model zoo
- Parameters:
project – Name of the project containing the model
model_name – Name of the model
device – Device index on which to load the model.
host – URL of the Weights & Biases host
alias – Alias of the model artifact
checkpoint_file – Name of the checkpoint file contained in the model artifact
- Returns:
A LightningModel object