grelu.visualize#

Functions#

_collect_preds_and_labels(→ pandas.DataFrame)

Function to collect predictions and labels for a region of interest into a dataframe

plot_distribution(values[, title, method, figsize])

Given a 1-D sequence of values, plot a histogram or density plot of their distribution.

plot_pred_distribution(preds, labels[, tasks, bins, ...])

Plot the density of predictions and regression labels for a given task.

plot_pred_scatter(preds, labels[, tasks, bins, ...])

Plot a scatterplot of predictions and regression labels for a given task.

plot_binary_preds(preds, labels[, tasks, bins, figsize])

Plot a box plot of predictions for each classification label

plot_calibration_curve(probs, labels[, tasks, bins, ...])

Plots a calibration curve for a classification model

add_highlights(→ None)

Add highlights to a matplotlib axis

plot_evolution(df[, figsize])

Plot change in scores and predictions over multiple rounds of directed evolution

plot_gc_match(positives, negatives[, binwidth, ...])

Plot a histogram comparing GC content distribution in positive and negative regions.

plot_attributions(attrs[, start_pos, end_pos, ...])

Plot base-level importance scores across a sequence.

plot_ISM(ism_preds[, start_pos, end_pos, figsize, method])

Return in silico mutagenesis plot

plot_tracks(tracks[, start_pos, end_pos, titles, ...])

Plot genomic coverage tracks

plot_attention_matrix(attn[, start_pos, end_pos, ...])

Plot a bin x bin matrix of attentiomn weights derived

Module Contents#

grelu.visualize._collect_preds_and_labels(preds: numpy.ndarray | pandas.DataFrame, labels: numpy.ndarray | anndata.AnnData, tasks: List[int] | List[str] | None = None, bins: List[int] | None = None) pandas.DataFrame[source]#

Function to collect predictions and labels for a region of interest into a dataframe

grelu.visualize.plot_distribution(values: List | numpy.ndarray | pandas.Series, title: str = 'metric', method: str = 'histogram', figsize: Tuple[int, int] = (4, 3), **kwargs)[source]#

Given a 1-D sequence of values, plot a histogram or density plot of their distribution.

Parameters:
  • values – 1-D sequence of numbers to plot

  • title – Plot title

  • method – Either “histogram” or “density”

  • figsize – Tuple containing (width, height)

  • **kwargs – Additional arguments to pass to geom_histogram (if method == “histogram”) or geom_density (if method == “density”).

Returns:

histogram or density plot

grelu.visualize.plot_pred_distribution(preds: numpy.ndarray | pandas.DataFrame, labels: numpy.ndarray | anndata.AnnData, tasks: List[int] | List[str] | None = None, bins: List[int] | None = None, figsize: Tuple[int, int] = (4, 3), **kwargs)[source]#

Plot the density of predictions and regression labels for a given task.

Parameters:
  • preds – Model predictions

  • labels – True labels

  • tasks – List of task names or indices. If None, all tasks will be used.

  • bins – List of relevant bins in the predictions and labels. If None, all bins will be used.

  • figsize – Tuple containing (width, height)

  • **kwargs – Additional arguments to pass to geom_density()

Returns:

Density plots

grelu.visualize.plot_pred_scatter(preds: numpy.ndarray | pandas.DataFrame, labels: numpy.ndarray | anndata.AnnData, tasks: List[int] | List[str] | None = None, bins: List[int] | None = None, density: bool = False, figsize: Tuple[int, int] = (4, 3), **kwargs)[source]#

Plot a scatterplot of predictions and regression labels for a given task.

Parameters:
  • preds – Model predictions

  • labels – True labels

  • tasks – List of task names or indices. If None, all tasks will be used.

  • bins – List of relevant bins in the predictions and labels. If None, all bins will be used.

  • density – If true, color the points by local density.

  • figsize – Tuple containing (width, height)

  • **kwargs – Additional arguments to pass to geom_point (if density = False) or geom_pointdensity (if density = True).

Returns:

Scatter plots

grelu.visualize.plot_binary_preds(preds: numpy.ndarray | pandas.DataFrame, labels: numpy.ndarray | anndata.AnnData, tasks: List[int] | List[str] | None = None, bins: List[int] | None = None, figsize: Tuple[int, int] = (4, 3), **kwargs)[source]#

Plot a box plot of predictions for each classification label

Parameters:
  • preds – Model predictions

  • labels – True labels

  • tasks – List of task names or indices. If None, all tasks will be used.

  • bins – List of relevant bins in the predictions and labels. If None, all bins will be used.

  • figsize – Tuple containing (width, height)

  • **kwargs – Additional arguments to pass to geom_boxplot

Returns:

Box plots

grelu.visualize.plot_calibration_curve(probs: numpy.ndarray | pandas.DataFrame, labels: numpy.ndarray | anndata.AnnData, tasks: List[int] | List[str] | None = None, bins: List[int] | None = None, aggregate: bool = True, figsize: Tuple[int, int] = (4, 3), show_legend: bool = True)[source]#

Plots a calibration curve for a classification model

Parameters:
  • probs – Model predictions

  • labels – True classification labels

  • tasks – List of task names or indices. If None, all tasks will be used.

  • bins – List of relevant bins in the predictions and labels. If None, all bins will be used.

  • figsize – Tuple containing (width, height)

  • show_legend – If True, the legend is displayed. If False, no legend is displayed.

Returns:

Line plots showing the calibration between true and predicted probabilities for each task (if aggregate=False) or for all tasks combined (if aggregate=True)

grelu.visualize.add_highlights(ax, centers: int | List[int] | None = None, width: int | None = None, starts: int | List[int] | None = None, ends: int | List[int] | None = None, positions: int | List[int] | None = None, ymin: float = -10, ymax: float = 20, facecolor: str | None = 'yellow', alpha: float | None = 0.15, edgecolor: str | None = None) None[source]#

Add highlights to a matplotlib axis

grelu.visualize.plot_evolution(df: pandas.DataFrame, figsize: Tuple[float, float] = (4, 3), **kwargs)[source]#

Plot change in scores and predictions over multiple rounds of directed evolution

Parameters:
  • df – Dataframe produced by grelu.design.evolve

  • figsize – Tuple containing (width, height)

  • **kwargs – Additional arguments to pass to geom_boxplot.

grelu.visualize.plot_gc_match(positives: pandas.DataFrame, negatives: pandas.DataFrame, binwidth: float = 0.1, genome: str = 'hg38', figsize: Tuple[int, int] = (4, 3), **kwargs)[source]#

Plot a histogram comparing GC content distribution in positive and negative regions.

Parameters:
  • positives – Genomic intervals

  • negatives – Genomic intervals

  • binwidth – Resolution at which to bin GC content

  • genome – Name of the genome

  • figsize – Tuple containing (width, height)

  • **kwargs – Additional arguments to pass to geom_bar

Returns: Bar plot

grelu.visualize.plot_attributions(attrs: numpy.ndarray, start_pos: int = 0, end_pos: int = -1, figsize: Tuple[int] = (20, 2), ticks: int = 10, highlight_centers: List[int] | None = None, highlight_width: int | List[int] = 5, highlight_positions: List[int] | None = None, ylim: Tuple[float, float] | None = None, facecolor: str | None = 'yellow', edgecolor: str | None = None, alpha: float | None = 0.15)[source]#

Plot base-level importance scores across a sequence.

Parameters:
  • attrs – A numpy array of shape (4, L)

  • start_pos – Start position along the sequence

  • end_pos – End position along the sequence.

  • figsize – Tuple containing (width, height)

  • ticks – Frequency of ticks on the x-axis

  • highlight_centers – List of positions where highlights are centered

  • highlight_width – Width of each highlighted region

  • highlight_positions – List of individual positions to highlight.

  • ylim – Axis limits for the y-axis

  • facecolor – Face color for highlight box

  • edgecolor – Edge color for highlight box

  • alpha – Opacity of highlight box

grelu.visualize.plot_ISM(ism_preds: pandas.DataFrame, start_pos: int | None = None, end_pos: int | None = None, figsize: Tuple[float, float] = (8, 1.5), method: str = 'heatmap', **kwargs)[source]#

Return in silico mutagenesis plot

Parameters:
  • ism_preds – ISM dataframe produced by grelu.model.interpret.ISM_predict

  • start_pos – Start position of region to plot

  • end_pos – End position of region to plot

  • figsize – Tuple containing (width, height)

  • method – ‘heatmap’ or ‘logo’

  • **kwargs – Additional arguments to be passed to sns.heatmap (in case type=’heatmap’)

  • 'logo' (or plot_attributions (in case type =) –

Returns:

Heatmap or sequence logo for the specified region.

grelu.visualize.plot_tracks(tracks: numpy.ndarray, start_pos: int = 0, end_pos: int = None, titles: List[str] | None = None, figsize: Tuple[float, float] = (20, 1.5), highlight_intervals: pandas.DataFrame | None = None, facecolor: str | None = 'yellow', edgecolor: str | None = None, alpha: float | None = 0.15, annotations: Dict[str, pandas.DataFrame] = {})[source]#

Plot genomic coverage tracks

Parameters:
  • tracks – Numpy array of shape (T, L)

  • start_pos – Coordinate at which the tracks start

  • end_pos – Coordinate at which the tracks end

  • titles – List containing a title for each track

  • figsize – Tuple of (width, height)

  • highlight_intervals – A pandas dataframe containing genomic intervals to highlight

  • facecolor – Face color for highlight box

  • edgecolor – Edge color for highlight box

  • alpha – Opacity of highlight box

  • annotations – Dictionary of (key, value) pairs where the keys are strings and the values are pandas dataframes containing annotated genomic intervals

grelu.visualize.plot_attention_matrix(attn: numpy.ndarray, start_pos: int = 0, end_pos: int | None = None, highlight_intervals: pandas.DataFrame | None = None, figsize: Tuple[int, int] = (5, 4), **kwargs)[source]#

Plot a bin x bin matrix of attentiomn weights derived from transformer layers in a model.

Parameters:
  • attn – A square numpy array containing attention weights.

  • start_pos – The start coordinate of the genomic region

  • end_pos – The end coordinate of the genomic region

  • highlight_intervals – A pandas dataframe containing genomic intervals to highlight

  • figsize – A tuple containing (width, height)

  • **kwargs – Additional arguments to pass to sns.heatmap