grelu.transforms.prediction_transforms#
Classes to perform transformations on the output of a predictive model.
The input to the forward method of these classes will be a tensor of shape (B, T, L). The output should also be a 3-D tensor.
Classes#
A class to filter and aggregate the model output over desired tasks and/or positions. |
|
Filter to calculate cell type specificity |
Module Contents#
- class grelu.transforms.prediction_transforms.Aggregate(tasks: List[int] | List[str] | None = None, except_tasks: List[int] | List[str] | None = None, positions: List[int] | None = None, length_aggfunc: Callable | None = None, task_aggfunc: Callable | None = None, model: Callable | None = None, weight: float | None = None)[source]#
Bases:
torch.nn.Module
A class to filter and aggregate the model output over desired tasks and/or positions.
- Parameters:
tasks – A list of task names or indices to include. If task names are supplied, “model” should not be None. If tasks and except_tasks are both None, all tasks will be considered.
except_tasks – A list of task names or indices to exclude if tasks is None. If task names are supplied, “model” should not be None. If tasks and except_tasks are both None, all tasks will be considered.
positions – A list of positions to include along the length axis. If None, all positions will be included.
length_aggfunc – A function or name of a function to apply along the length axis. Accepted values are “sum”, “mean”, “min” or “max”.
task_aggfunc – A function or name of a function to apply along the task axis. Accepted values are “sum”, “mean”, “min” or “max”.
model – A trained LightningModel object. Needed only if task names are supplied.
weight – A weight by which to multiply the aggregated prediction.
- filter(x: torch.Tensor | numpy.ndarray) torch.Tensor | numpy.ndarray [source]#
Filter the relevant tasks and positions in the predictions.
- torch_aggregate(x: torch.Tensor) torch.Tensor [source]#
Aggregate predictions in the form of a tensor.
- numpy_aggregate(x: numpy.ndarray) numpy.ndarray [source]#
Aggregate predictions in the form of a numpy array.
- forward(x: torch.Tensor) torch.Tensor [source]#
Forward pass
- Parameters:
x – Output of the model forward pass
- compute(x: numpy.ndarray) numpy.ndarray [source]#
Compute the output score on a numpy array.
- class grelu.transforms.prediction_transforms.Specificity(on_tasks: List[int] | List[str], off_tasks: List[int] | List[str] | None = None, on_aggfunc: str | Callable = 'mean', off_aggfunc: str | Callable = 'mean', off_weight: float | None = 1.0, off_thresh: float | None = None, positions: List[int] = None, length_aggfunc: str | Callable = 'sum', compare_func: str | Callable = 'divide', model: Callable | None = None)[source]#
Bases:
torch.nn.Module
Filter to calculate cell type specificity
- Parameters:
on_tasks – A list of task names or indices for foreground tasks.
off_tasks – A list of task names or indices for background tasks. If None, all tasks other than on_tasks will be considered part of the background.
on_aggfunc – A function or name of a function to aggregate predictions for the foreground tasks. Accepted values are “sum”, “mean”, “min” or “max”.
off_aggfunc – A function or name of a function to aggregate predictions for the background tasks. Accepted values are “sum”, “mean”, “min” or “max”.
off_weight – Relative weight of the background tasks. If this is equal to 1, the background and foreground predictions will be equally weighted. If off_thresh if provided, the weight will be applied only to off- target predictions exceeding off_thresh.
off_thresh – A maximum threshold for the prediction in off_tasks.
positions – A list of positions to include along the length axis. If None, all positions will be included.
length_aggfunc – A function or name of a function to apply along the length axis. Accepted values are “sum”, “mean”, “min” or “max”.
func (compare) – A function or name of a function to calculate specificity. Accepted values are “subtract” or “divide”.
model – A trained LightningModel object. Needed if task names are supplied.
- weight_off(x: numpy.ndarray | torch.Tensor) None [source]#
Apply a weight to the off-target predictions.
- forward(x: torch.Tensor) torch.Tensor [source]#
Forward pass
- Parameters:
x – Output of the model forward pass
- compute(x: numpy.ndarray) numpy.ndarray [source]#
Compute the output score on a numpy array.