Source code for decima.interpret.attributions

from typing import Optional, Union
import warnings
import numpy as np
import pandas as pd
import torch
import pyBigWig
from pyfaidx import Faidx
import genomepy
from captum.attr import InputXGradient, Saliency, IntegratedGradients
from grelu.interpret.motifs import scan_sequences
from grelu.sequence.format import convert_input_type, strings_to_one_hot
from grelu.transforms.prediction_transforms import Aggregate, Specificity
from tangermeme.seqlet import recursive_seqlets

from decima.constants import DECIMA_CONTEXT_SIZE
from decima.core.result import DecimaResult
from decima.hub import load_decima_model
from decima.model.lightning import LightningModel
from decima.utils import get_compute_device
from decima.plot.visualize import plot_peaks
from grelu.visualize import plot_attributions


[docs] def get_attribution_method(method: str): """Get attribution method from string. Args: method: Method to use for attribution analysis Returns: Attribution: Attribution analysis results for the gene and tasks """ if method == "saliency": return Saliency elif method == "inputxgradient": return InputXGradient elif method == "integratedgradients": return IntegratedGradients return method
[docs] def attributions( inputs, tasks, off_tasks=None, model=0, transform="specificity", method="inputxgradient", device=None, **kwargs, ): """Compute attributions for a gene. Args: gene: Gene symbol or ID to analyze tasks: List of cell types to analyze attributions for off_tasks: List of cell types to contrast against model: Model to use for attribution analysis device: Device to use for attribution analysis inputs: One-hot encoded sequence transform: Transformation to apply to attributions method: Method to use for attribution analysis Returns: Attribution: Attribution analysis results for the gene and tasks """ assert inputs.shape[1] == 5, "`inputs` must be 5-dimensional with shape (batch_size, 5, 524288)" assert inputs.shape[2] == DECIMA_CONTEXT_SIZE, "`inputs` must have shape (batch_size, 5, 524288)" if not isinstance(model, LightningModel): model = load_decima_model(model, device) if transform == "specificity": model.add_transform( Specificity( on_tasks=tasks, off_tasks=off_tasks, model=model, compare_func="subtract", ) ) elif transform == "aggregate": model.add_transform(Aggregate(tasks=tasks, task_aggfunc="mean", model=model)) model = model.eval() device = get_compute_device(device) inputs.requires_grad = True attribution_method = get_attribution_method(method) attributer = attribution_method(model.to(device)) if method == "saliency": kwargs = {**kwargs, "abs": False} with torch.no_grad(): attrs = attributer.attribute(inputs.to(device), **kwargs) attrs = attrs.cpu().numpy()[:, :4] model.reset_transform() return attrs
[docs] class Attribution: """ Attribution analysis results for a gene. Args: gene: Gene symbol or ID to analyze inputs: One-hot encoded sequence attrs: Attribution scores gene: Gene name chrom: Chromosome name start: Start position end: End position strand: Strand threshold: Threshold for peak finding min_seqlet_len: Minimum sequence length for peak finding max_seqlet_len: Maximum sequence length for peak finding additional_flanks: Additional flanks to add to the gene Returns: Attribution: Attribution analysis results for the gene and tasks Examples: >>> attribution = Attribution( gene="A1BG", inputs=inputs, attrs=attrs, chrom="chr1", start=100, end=200, strand="+", threshold=5e-4, min_seqlet_len=4, max_seqlet_len=25, additional_flanks=0, ) >>> attribution.plot_peaks() >>> attribution.scan_motifs() >>> attribution.save_bigwig( ... "attributions.bigwig" ... ) >>> attribution.peaks_to_bed() """
[docs] def __init__( self, inputs: torch.Tensor, attrs: np.ndarray, gene: Optional[str] = "", chrom: Optional[str] = None, start: Optional[int] = None, end: Optional[int] = None, strand: Optional[str] = None, threshold: Optional[float] = 5e-4, min_seqlet_len: Optional[int] = 4, max_seqlet_len: Optional[int] = 25, additional_flanks: Optional[int] = 0, ): """Initialize Attribution. Args: inputs: One-hot encoded sequence attrs: Attribution scores gene: Gene name chrom: Chromosome name start: Start position end: End position strand: Strand threshold: Threshold for peak finding min_seqlet_len: Minimum sequence length for peak finding max_seqlet_len: Maximum sequence length for peak finding additional_flanks: Additional flanks to add to the gene """ assert ( inputs.shape[0] == 5 ), "`inputs` must be 5-dimensional with shape (5, seq_len) where the last dimension is a binary mask." assert attrs.shape[0] == 4, "`attrs` must be 4-dimensional" assert inputs.shape[1] == attrs.shape[1], "`inputs` and `attrs` must have the same length" self.inputs = inputs self.attrs = attrs self.gene = gene self._chrom = chrom self._start = start self._end = end self._strand = strand assert self.end - self.start == self.inputs.shape[1], "`end` - `start` must be equal to the length of `inputs`" self.gene_mask_start = np.where(inputs[-1] == 1)[0][0] self.gene_mask_end = np.where(inputs[-1] == 1)[-1][0] self.peaks = self._find_peaks( threshold=threshold, min_seqlet_len=min_seqlet_len, max_seqlet_len=max_seqlet_len, additional_flanks=additional_flanks, )
@property def chrom(self) -> str: """Get the chromosome name.""" if self._chrom is None: return "custom" return self._chrom @property def start(self) -> int: """Get the start position.""" if self._start is None: return 0 return self._start @property def end(self) -> int: """Get the end position.""" if self._end is None: return self.inputs.shape[1] return self._end @property def strand(self) -> str: """Get the strand.""" if self._strand is None: return "+" return self._strand @property def gene_start(self) -> int: """Get the gene start position.""" if self.strand == "-": return self.end - self.gene_mask_end return self.start + self.gene_mask_start @property def gene_end(self) -> int: """Get the gene end position.""" if self.strand == "-": return self.end - self.gene_mask_start return self.start + self.gene_mask_end
[docs] @classmethod def from_seq( cls, inputs: Union[str, torch.Tensor, np.ndarray], tasks: Optional[list] = None, off_tasks: Optional[list] = None, model: Optional[Union[str, int]] = 0, transform: str = "specificity", method: str = "inputxgradient", device: Optional[str] = None, result: Optional[DecimaResult] = None, gene: Optional[str] = "", chrom: Optional[str] = None, start: Optional[int] = None, end: Optional[int] = None, strand: Optional[str] = None, gene_mask_start: Optional[int] = None, gene_mask_end: Optional[int] = None, threshold: Optional[float] = 5e-4, min_seqlet_len: Optional[int] = 4, max_seqlet_len: Optional[int] = 25, additional_flanks: Optional[int] = 0, ): """Initialize Attribution from sequence. Args: inputs: Sequence to analyze either string of sequence, torch.Tensor or np.ndarray with shape (4, 524288) or (5, 524288) where the last dimension is a binary mask. If 4-dimensional, gene_mask_start and gene_mask_end must be provided. tasks: List of cell types to analyze attributions for off_tasks: List of cell types to contrast against model: Model to use for attribution analysis transform: Transformation to apply to attributions device: Device to use for attribution analysis gene: Gene name chrom: Chromosome name start: Start position end: End position strand: Strand gene_start: Gene start position gene_end: Gene end position threshold: Threshold for peak finding min_seqlet_len: Minimum sequence length for peak finding max_seqlet_len: Maximum sequence length for peak finding additional_flanks: Additional flanks to add to the gene """ if isinstance(inputs, np.ndarray): inputs = torch.from_numpy(inputs).float() if isinstance(inputs, torch.Tensor): if (inputs.shape[0] == 4) and (gene_mask_start is not None) and (gene_mask_end is not None): mask = torch.zeros(1, DECIMA_CONTEXT_SIZE) mask[0, gene_mask_start:gene_mask_end] = 1.0 inputs = torch.vstack([inputs, mask]) elif inputs.shape[0] == 5: if (gene_mask_start is not None) or (gene_mask_end is not None): warnings.warn("Gene mask will be ignored as sequence is 5-dimensional.") pass else: raise ValueError( "Sequence must be 4-dimensional with shape (4, seq_len) " "and gene start and end must be provided, or 5-dimensional " "with shape (5, seq_len) where the last dimension is a binary mask." ) elif isinstance(inputs, str): inputs = strings_to_one_hot(inputs) assert (gene_mask_start is not None) and ( gene_mask_end is not None ), "Gene start and end must be provided when seq is a string." mask = torch.zeros(1, DECIMA_CONTEXT_SIZE) mask[0, gene_mask_start:gene_mask_end] = 1.0 inputs = torch.vstack([inputs, mask]) else: raise ValueError("`inputs` must be a string, torch.Tensor, or np.ndarray") if result is None: result = DecimaResult.load() tasks, off_tasks = result.query_tasks(tasks, off_tasks) attrs = attributions( inputs=inputs.unsqueeze(0), tasks=tasks, off_tasks=off_tasks, model=model, transform=transform, method=method, device=device, ).squeeze(0) return cls( inputs=inputs, attrs=attrs, gene=gene, chrom=chrom, start=start, end=end, strand=strand, threshold=threshold, min_seqlet_len=min_seqlet_len, max_seqlet_len=max_seqlet_len, additional_flanks=additional_flanks, )
[docs] @staticmethod def find_peaks(attrs, threshold=5e-4, min_seqlet_len=4, max_seqlet_len=25, additional_flanks=0): return recursive_seqlets( attrs.sum(0, keepdims=True), threshold=threshold, min_seqlet_len=min_seqlet_len, max_seqlet_len=max_seqlet_len, additional_flanks=additional_flanks, ).reset_index(drop=True)
def _find_peaks(self, threshold=5e-4, min_seqlet_len=4, max_seqlet_len=25, additional_flanks=0): df = self.find_peaks(self.attrs, threshold, min_seqlet_len, max_seqlet_len, additional_flanks) del df["example_idx"] df["from_tss"] = df["start"] - self.gene_mask_start df["peak"] = self.gene + "@" + df["from_tss"].astype(str) return df[["peak", "start", "end", "attribution", "p-value", "from_tss"]]
[docs] def scan_motifs(self, motifs: str = "hocomoco_v12", window: int = 18, pthresh: float = 5e-4) -> pd.DataFrame: """Scan for motifs in peak regions. Args: motifs: Motif database to use window: Window size around peaks pthresh: P-value threshold for motif matches Returns: pd.DataFrame: Motif scan results """ mid = (self.peaks["start"] + self.peaks["end"]) // 2 peak_attrs = np.stack([self.attrs[:, i - window : i + window] for i in mid]) peak_seqs = torch.stack([self.inputs[:4, i - window : i + window] for i in mid]) df = scan_sequences( seqs=convert_input_type(peak_seqs, "strings"), seq_ids=self.peaks["peak"].tolist(), motifs=motifs, pthresh=pthresh, rc=True, attrs=peak_attrs, ).rename(columns={"sequence": "peak"}) df = df.merge( self.peaks[["peak", "from_tss", "start"]].assign(mid=mid).reset_index(drop=True), on="peak", suffixes=("", "_peak"), ) df["start"] += df["mid"] - window df["end"] += df["mid"] - window df["from_tss"] = df["start"] - self.gene_mask_start del df["start_peak"] del df["seq_idx"] del df["mid"] return df.sort_values("p-value")
[docs] def plot_peaks(self, overlapping_min_dist=1000, figsize=(10, 2)): """Plot attribution scores and highlight peaks. Args: overlapping_min_dist: Minimum distance between peaks to consider them overlapping figsize: Figure size in inches (width, height) Returns: plotnine.ggplot: The plotted figure showing attribution scores with highlighted peaks """ return plot_peaks( self.attrs, self.gene_mask_start, self.peaks, overlapping_min_dist=overlapping_min_dist, figsize=figsize, )
[docs] def __repr__(self): return f"Attribution(gene={self.gene})"
def __str__(self): return f"Attribution(gene={self.gene})"
[docs] def save_bigwig(self, bigwig_path: str): """ Save attribution scores as a bigwig file. Args: bigwig_path: Path to save bigwig file. """ attrs = self.attrs.sum(axis=0) if self.strand == "-": attrs = attrs[::-1] bw = pyBigWig.open(bigwig_path, "w") if self._chrom is not None: name = self.chrom sizes = genomepy.Genome("hg38").sizes bw.addHeader([(chrom, size) for chrom, size in sizes.items()]) else: name = self.gene or "custom" bw.addHeader([(name, self.end - self.start)]) bw.addEntries(name, self.start, values=attrs, span=1, step=1) bw.close()
[docs] def fasta_str(self): """ Get attribution scores as a fasta string. """ seq = convert_input_type(self.inputs[:4], "strings") name = self.gene or "custom" return f">{name}\n{seq}\n"
[docs] def save_fasta(self, fasta_path: str): """ Save attribution scores as a fasta file. """ with open(fasta_path, "w") as f: f.write(self.fasta_str()) Faidx(fasta_path, build_index=True)
[docs] def peaks_to_bed(self): """ Convert peaks to bed format. Returns: pd.DataFrame: Peaks in bed format where columns are: - chrom: Chromosome name - start: Start position in genome - end: End position in genome - name: Peak name in format "gene@from_tss" - score: Score (-log10(p-value)) clipped to 0-100 based on the seqlet calling - strand: Strand == '.' """ df = self.peaks.copy().rename(columns={"peak": "name"}) df["chrom"] = self.chrom if self.strand == "+": df["start"], df["end"] = self.start + df["start"], self.start + df["end"] else: df["start"], df["end"] = self.end - df["end"], self.end - df["start"] df["strand"] = "." df["score"] = -np.log10(df["p-value"] + 1e-50) df["score"] = df["score"].astype(int).clip(lower=0, upper=50) return df[["chrom", "start", "end", "name", "score", "strand"]]
[docs] def save_peaks(self, bed_path: str): """ Save peaks to bed file. Args: bed_path: Path to save bed file. """ self.peaks_to_bed().to_csv(bed_path, sep="\t", header=False, index=False)