grelu.transforms.prediction_transforms#

Classes to perform transformations on the output of a predictive model.

The input to the forward method of these classes will be a tensor of shape (B, T, L). The output should also be a 3-D tensor.

Classes#

Aggregate

A class to filter and aggregate the model output over desired tasks and/or positions.

Specificity

Filter to calculate cell type specificity

Module Contents#

class grelu.transforms.prediction_transforms.Aggregate(tasks: List[int] | List[str] | None = None, except_tasks: List[int] | List[str] | None = None, positions: List[int] | None = None, length_aggfunc: Callable | None = None, task_aggfunc: Callable | None = None, model: Callable | None = None, weight: float | None = None)[source]#

Bases: torch.nn.Module

A class to filter and aggregate the model output over desired tasks and/or positions.

Parameters:
  • tasks – A list of task names or indices to include. If task names are supplied, “model” should not be None. If tasks and except_tasks are both None, all tasks will be considered.

  • except_tasks – A list of task names or indices to exclude if tasks is None. If task names are supplied, “model” should not be None. If tasks and except_tasks are both None, all tasks will be considered.

  • positions – A list of positions to include along the length axis. If None, all positions will be included.

  • length_aggfunc – A function or name of a function to apply along the length axis. Accepted values are “sum”, “mean”, “min” or “max”.

  • task_aggfunc – A function or name of a function to apply along the task axis. Accepted values are “sum”, “mean”, “min” or “max”.

  • model – A trained LightningModel object. Needed only if task names are supplied.

  • weight – A weight by which to multiply the aggregated prediction.

tasks[source]#
except_tasks[source]#
positions[source]#
task_aggfunc[source]#
length_aggfunc[source]#
task_aggfunc_numpy[source]#
length_aggfunc_numpy[source]#
filter(x: torch.Tensor | numpy.ndarray) torch.Tensor | numpy.ndarray[source]#

Filter the relevant tasks and positions in the predictions.

torch_aggregate(x: torch.Tensor) torch.Tensor[source]#

Aggregate predictions in the form of a tensor.

numpy_aggregate(x: numpy.ndarray) numpy.ndarray[source]#

Aggregate predictions in the form of a numpy array.

forward(x: torch.Tensor) torch.Tensor[source]#

Forward pass

Parameters:

x – Output of the model forward pass

compute(x: numpy.ndarray) numpy.ndarray[source]#

Compute the output score on a numpy array.

class grelu.transforms.prediction_transforms.Specificity(on_tasks: List[int] | List[str], off_tasks: List[int] | List[str] | None = None, on_aggfunc: str | Callable = 'mean', off_aggfunc: str | Callable = 'mean', off_weight: float | None = 1.0, off_thresh: float | None = None, positions: List[int] = None, length_aggfunc: str | Callable = 'sum', compare_func: str | Callable = 'divide', model: Callable | None = None)[source]#

Bases: torch.nn.Module

Filter to calculate cell type specificity

Parameters:
  • on_tasks – A list of task names or indices for foreground tasks.

  • off_tasks – A list of task names or indices for background tasks. If None, all tasks other than on_tasks will be considered part of the background.

  • on_aggfunc – A function or name of a function to aggregate predictions for the foreground tasks. Accepted values are “sum”, “mean”, “min” or “max”.

  • off_aggfunc – A function or name of a function to aggregate predictions for the background tasks. Accepted values are “sum”, “mean”, “min” or “max”.

  • off_weight – Relative weight of the background tasks. If this is equal to 1, the background and foreground predictions will be equally weighted. If off_thresh if provided, the weight will be applied only to off- target predictions exceeding off_thresh.

  • off_thresh – A maximum threshold for the prediction in off_tasks.

  • positions – A list of positions to include along the length axis. If None, all positions will be included.

  • length_aggfunc – A function or name of a function to apply along the length axis. Accepted values are “sum”, “mean”, “min” or “max”.

  • func (compare) – A function or name of a function to calculate specificity. Accepted values are “subtract” or “divide”.

  • model – A trained LightningModel object. Needed if task names are supplied.

on_transform[source]#
off_transform[source]#
tasks[source]#
compare_func[source]#
compare_func_numpy[source]#
length_aggfunc[source]#
length_aggfunc_numpy[source]#
off_weight[source]#
off_thresh[source]#
weight_off(x: numpy.ndarray | torch.Tensor) None[source]#

Apply a weight to the off-target predictions.

forward(x: torch.Tensor) torch.Tensor[source]#

Forward pass

Parameters:

x – Output of the model forward pass

compute(x: numpy.ndarray) numpy.ndarray[source]#

Compute the output score on a numpy array.