grelu.visualize#
Functions#
|
Function to collect predictions and labels for a region of interest into a dataframe |
|
Given a 1-D sequence of values, plot a histogram or density plot of their distribution. |
|
Plot the density of predictions and regression labels for a given task. |
|
Plot a scatterplot of predictions and regression labels for a given task. |
|
Plot a box plot of predictions for each classification label |
|
Plots a calibration curve for a classification model |
|
Add highlights to a matplotlib axis |
|
Plot change in scores and predictions over multiple rounds of directed evolution |
|
Plot a histogram comparing GC content distribution in positive and negative regions. |
|
Plot base-level importance scores across a sequence. |
|
Return in silico mutagenesis plot |
|
Plot genomic coverage tracks |
|
Plot a bin x bin matrix of attentiomn weights derived |
|
Module Contents#
- grelu.visualize._collect_preds_and_labels(preds: numpy.ndarray | pandas.DataFrame, labels: numpy.ndarray | anndata.AnnData, tasks: List[int] | List[str] | None = None, bins: List[int] | None = None) pandas.DataFrame [source]#
Function to collect predictions and labels for a region of interest into a dataframe
- grelu.visualize.plot_distribution(values: List | numpy.ndarray | pandas.Series, title: str = 'metric', method: str = 'histogram', figsize: Tuple[int, int] = (4, 3), **kwargs)[source]#
Given a 1-D sequence of values, plot a histogram or density plot of their distribution.
- Parameters:
values – 1-D sequence of numbers to plot
title – Plot title
method – Either “histogram” or “density”
figsize – Tuple containing (width, height)
**kwargs – Additional arguments to pass to geom_histogram (if method == “histogram”) or geom_density (if method == “density”).
- Returns:
histogram or density plot
- grelu.visualize.plot_pred_distribution(preds: numpy.ndarray | pandas.DataFrame, labels: numpy.ndarray | anndata.AnnData, tasks: List[int] | List[str] | None = None, bins: List[int] | None = None, figsize: Tuple[int, int] = (4, 3), **kwargs)[source]#
Plot the density of predictions and regression labels for a given task.
- Parameters:
preds – Model predictions
labels – True labels
tasks – List of task names or indices. If None, all tasks will be used.
bins – List of relevant bins in the predictions and labels. If None, all bins will be used.
figsize – Tuple containing (width, height)
**kwargs – Additional arguments to pass to geom_density()
- Returns:
Density plots
- grelu.visualize.plot_pred_scatter(preds: numpy.ndarray | pandas.DataFrame, labels: numpy.ndarray | anndata.AnnData, tasks: List[int] | List[str] | None = None, bins: List[int] | None = None, density: bool = False, figsize: Tuple[int, int] = (4, 3), **kwargs)[source]#
Plot a scatterplot of predictions and regression labels for a given task.
- Parameters:
preds – Model predictions
labels – True labels
tasks – List of task names or indices. If None, all tasks will be used.
bins – List of relevant bins in the predictions and labels. If None, all bins will be used.
density – If true, color the points by local density.
figsize – Tuple containing (width, height)
**kwargs – Additional arguments to pass to geom_point (if density = False) or geom_pointdensity (if density = True).
- Returns:
Scatter plots
- grelu.visualize.plot_binary_preds(preds: numpy.ndarray | pandas.DataFrame, labels: numpy.ndarray | anndata.AnnData, tasks: List[int] | List[str] | None = None, bins: List[int] | None = None, figsize: Tuple[int, int] = (4, 3), **kwargs)[source]#
Plot a box plot of predictions for each classification label
- Parameters:
preds – Model predictions
labels – True labels
tasks – List of task names or indices. If None, all tasks will be used.
bins – List of relevant bins in the predictions and labels. If None, all bins will be used.
figsize – Tuple containing (width, height)
**kwargs – Additional arguments to pass to geom_boxplot
- Returns:
Box plots
- grelu.visualize.plot_calibration_curve(probs: numpy.ndarray | pandas.DataFrame, labels: numpy.ndarray | anndata.AnnData, tasks: List[int] | List[str] | None = None, bins: List[int] | None = None, aggregate: bool = True, figsize: Tuple[int, int] = (4, 3), show_legend: bool = True)[source]#
Plots a calibration curve for a classification model
- Parameters:
probs – Model predictions
labels – True classification labels
tasks – List of task names or indices. If None, all tasks will be used.
bins – List of relevant bins in the predictions and labels. If None, all bins will be used.
figsize – Tuple containing (width, height)
show_legend – If True, the legend is displayed. If False, no legend is displayed.
- Returns:
Line plots showing the calibration between true and predicted probabilities for each task (if aggregate=False) or for all tasks combined (if aggregate=True)
- grelu.visualize.add_highlights(ax, centers: int | List[int] | None = None, width: int | None = None, starts: int | List[int] | None = None, ends: int | List[int] | None = None, positions: int | List[int] | None = None, ymin: float = -10, ymax: float = 20, facecolor: str | None = 'yellow', alpha: float | None = 0.15, edgecolor: str | None = None) None [source]#
Add highlights to a matplotlib axis
- grelu.visualize.plot_evolution(df: pandas.DataFrame, figsize: Tuple[float, float] = (4, 3), **kwargs)[source]#
Plot change in scores and predictions over multiple rounds of directed evolution
- Parameters:
df – Dataframe produced by grelu.design.evolve
figsize – Tuple containing (width, height)
**kwargs – Additional arguments to pass to geom_boxplot.
- grelu.visualize.plot_gc_match(positives: pandas.DataFrame, negatives: pandas.DataFrame, binwidth: float = 0.1, genome: str = 'hg38', figsize: Tuple[int, int] = (4, 3), **kwargs)[source]#
Plot a histogram comparing GC content distribution in positive and negative regions.
- Parameters:
positives – Genomic intervals
negatives – Genomic intervals
binwidth – Resolution at which to bin GC content
genome – Name of the genome
figsize – Tuple containing (width, height)
**kwargs – Additional arguments to pass to geom_bar
Returns: Bar plot
- grelu.visualize.plot_attributions(attrs: numpy.ndarray, start_pos: int = 0, end_pos: int = -1, figsize: Tuple[int] = (20, 2), ticks: int = 10, highlight_centers: List[int] | None = None, highlight_width: int | List[int] = 5, highlight_positions: List[int] | None = None, ylim: Tuple[float, float] | None = None, facecolor: str | None = 'yellow', edgecolor: str | None = None, alpha: float | None = 0.15)[source]#
Plot base-level importance scores across a sequence.
- Parameters:
attrs – A numpy array of shape (4, L)
start_pos – Start position along the sequence
end_pos – End position along the sequence.
figsize – Tuple containing (width, height)
ticks – Frequency of ticks on the x-axis
highlight_centers – List of positions where highlights are centered
highlight_width – Width of each highlighted region
highlight_positions – List of individual positions to highlight.
ylim – Axis limits for the y-axis
facecolor – Face color for highlight box
edgecolor – Edge color for highlight box
alpha – Opacity of highlight box
- grelu.visualize.plot_ISM(ism_preds: pandas.DataFrame, start_pos: int | None = None, end_pos: int | None = None, figsize: Tuple[float, float] = (8, 1.5), method: str = 'heatmap', **kwargs)[source]#
Return in silico mutagenesis plot
- Parameters:
ism_preds – ISM dataframe produced by grelu.model.interpret.ISM_predict
start_pos – Start position of region to plot
end_pos – End position of region to plot
figsize – Tuple containing (width, height)
method – ‘heatmap’ or ‘logo’
**kwargs – Additional arguments to be passed to sns.heatmap (in case type=’heatmap’)
'logo' (or plot_attributions (in case type =) –
- Returns:
Heatmap or sequence logo for the specified region.
- grelu.visualize.plot_tracks(tracks: numpy.ndarray, start_pos: int = 0, end_pos: int = None, titles: List[str] | None = None, figsize: Tuple[float, float] = (20, 1.5), highlight_intervals: pandas.DataFrame | None = None, facecolor: str | None = 'yellow', edgecolor: str | None = None, alpha: float | None = 0.15, annotations: Dict[str, pandas.DataFrame] = {}, annot_height_ratio: float = 1.0)[source]#
Plot genomic coverage tracks
- Parameters:
tracks – Numpy array of shape (T, L)
start_pos – Coordinate at which the tracks start
end_pos – Coordinate at which the tracks end
titles – List containing a title for each track
figsize – Tuple of (width, height)
highlight_intervals – A pandas dataframe containing genomic intervals to highlight
facecolor – Face color for highlight box
edgecolor – Edge color for highlight box
alpha – Opacity of highlight box
annotations – Dictionary of (key, value) pairs where the keys are strings and the values are pandas dataframes containing annotated genomic intervals
annot_height_ratio – Ratio between the height of an annotation and the height of a track. By default, both are of equal height.
- grelu.visualize.plot_attention_matrix(attn: numpy.ndarray, start_pos: int = 0, end_pos: int | None = None, highlight_intervals: pandas.DataFrame | None = None, figsize: Tuple[int, int] = (5, 4), **kwargs)[source]#
Plot a bin x bin matrix of attentiomn weights derived from transformer layers in a model.
- Parameters:
attn – A square numpy array containing attention weights.
start_pos – The start coordinate of the genomic region
end_pos – The end coordinate of the genomic region
highlight_intervals – A pandas dataframe containing genomic intervals to highlight
figsize – A tuple containing (width, height)
**kwargs – Additional arguments to pass to sns.heatmap