decima.core package

Submodules

decima.core.metadata module

class decima.core.metadata.CellMetadata(name, cell_type, tissue, organ, disease, study, dataset, region, subregion, celltype_coarse, n_cells, total_counts, n_genes, size_factor, train_pearson, val_pearson, test_pearson)[source]

Bases: object

Metadata for a cell in the dataset.

name

Cell identifier

cell_type

Detailed cell type

tissue

Tissue identifier

organ

Organ name

disease

Disease state

study

Study identifier

dataset

Dataset identifier

region

Anatomical region

subregion

Anatomical subregion

celltype_coarse

Coarse cell type classification

n_cells

Number of cells

total_counts

Total count of transcripts

n_genes

Number of genes detected

size_factor

Size normalization factor

train_pearson

Pearson correlation in training set

val_pearson

Pearson correlation in validation set

test_pearson

Pearson correlation in test set

__annotations__ = {'cell_type': <class 'str'>, 'celltype_coarse': typing.Optional[str], 'dataset': <class 'str'>, 'disease': <class 'str'>, 'n_cells': <class 'int'>, 'n_genes': <class 'int'>, 'name': <class 'str'>, 'organ': <class 'str'>, 'region': typing.Optional[str], 'size_factor': <class 'float'>, 'study': <class 'str'>, 'subregion': typing.Optional[str], 'test_pearson': <class 'float'>, 'tissue': <class 'str'>, 'total_counts': <class 'float'>, 'train_pearson': <class 'float'>, 'val_pearson': <class 'float'>}
__dataclass_fields__ = {'cell_type': Field(name='cell_type',type=<class 'str'>,default=<dataclasses._MISSING_TYPE object>,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({}),kw_only=False,_field_type=_FIELD), 'celltype_coarse': Field(name='celltype_coarse',type=typing.Optional[str],default=<dataclasses._MISSING_TYPE object>,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({}),kw_only=False,_field_type=_FIELD), 'dataset': Field(name='dataset',type=<class 'str'>,default=<dataclasses._MISSING_TYPE object>,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({}),kw_only=False,_field_type=_FIELD), 'disease': Field(name='disease',type=<class 'str'>,default=<dataclasses._MISSING_TYPE object>,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({}),kw_only=False,_field_type=_FIELD), 'n_cells': Field(name='n_cells',type=<class 'int'>,default=<dataclasses._MISSING_TYPE object>,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({}),kw_only=False,_field_type=_FIELD), 'n_genes': Field(name='n_genes',type=<class 'int'>,default=<dataclasses._MISSING_TYPE object>,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({}),kw_only=False,_field_type=_FIELD), 'name': Field(name='name',type=<class 'str'>,default=<dataclasses._MISSING_TYPE object>,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({}),kw_only=False,_field_type=_FIELD), 'organ': Field(name='organ',type=<class 'str'>,default=<dataclasses._MISSING_TYPE object>,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({}),kw_only=False,_field_type=_FIELD), 'region': Field(name='region',type=typing.Optional[str],default=<dataclasses._MISSING_TYPE object>,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({}),kw_only=False,_field_type=_FIELD), 'size_factor': Field(name='size_factor',type=<class 'float'>,default=<dataclasses._MISSING_TYPE object>,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({}),kw_only=False,_field_type=_FIELD), 'study': Field(name='study',type=<class 'str'>,default=<dataclasses._MISSING_TYPE object>,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({}),kw_only=False,_field_type=_FIELD), 'subregion': Field(name='subregion',type=typing.Optional[str],default=<dataclasses._MISSING_TYPE object>,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({}),kw_only=False,_field_type=_FIELD), 'test_pearson': Field(name='test_pearson',type=<class 'float'>,default=<dataclasses._MISSING_TYPE object>,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({}),kw_only=False,_field_type=_FIELD), 'tissue': Field(name='tissue',type=<class 'str'>,default=<dataclasses._MISSING_TYPE object>,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({}),kw_only=False,_field_type=_FIELD), 'total_counts': Field(name='total_counts',type=<class 'float'>,default=<dataclasses._MISSING_TYPE object>,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({}),kw_only=False,_field_type=_FIELD), 'train_pearson': Field(name='train_pearson',type=<class 'float'>,default=<dataclasses._MISSING_TYPE object>,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({}),kw_only=False,_field_type=_FIELD), 'val_pearson': Field(name='val_pearson',type=<class 'float'>,default=<dataclasses._MISSING_TYPE object>,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({}),kw_only=False,_field_type=_FIELD)}
__dataclass_params__ = _DataclassParams(init=True,repr=True,eq=True,order=False,unsafe_hash=False,frozen=False)
__eq__(other)

Return self==value.

__hash__ = None
__init__(name, cell_type, tissue, organ, disease, study, dataset, region, subregion, celltype_coarse, n_cells, total_counts, n_genes, size_factor, train_pearson, val_pearson, test_pearson)
__match_args__ = ('name', 'cell_type', 'tissue', 'organ', 'disease', 'study', 'dataset', 'region', 'subregion', 'celltype_coarse', 'n_cells', 'total_counts', 'n_genes', 'size_factor', 'train_pearson', 'val_pearson', 'test_pearson')
__repr__()

Return repr(self).

cell_type: str
celltype_coarse: Optional[str]
dataset: str
disease: str
classmethod from_series(name, series)[source]

Create CellMetadata from a pandas Series.

Return type:

CellMetadata

n_cells: int
n_genes: int
name: str
organ: str
region: Optional[str]
size_factor: float
study: str
subregion: Optional[str]
test_pearson: float
tissue: str
total_counts: float
train_pearson: float
val_pearson: float
class decima.core.metadata.GeneMetadata(name, chrom, start, end, strand, gene_type, frac_nan, mean_counts, n_tracks, gene_start, gene_end, gene_length, gene_mask_start, gene_mask_end, frac_N, fold, dataset, gene_id, pearson, size_factor_pearson)[source]

Bases: object

Metadata for a gene in the dataset.

name

Gene name

chrom

Chromosome where the gene is located

start

Start position of the region around the gene to perform predictions in the chromosome

end

End position of the region around the gene to perform predictions in the chromosome

strand

Strand orientation (+ or -)

gene_type

Type of gene (e.g., protein_coding)

frac_nan

Fraction of NaN values

mean_counts

Mean count across samples

n_tracks

Number of tracks

gene_start

Gene start position

gene_end

Gene end position

gene_length

Length of the gene

gene_mask_start

Start position of the gene mask

gene_mask_end

End position of the gene mask

frac_N

Fraction of N bases

fold

Cross-validation fold

dataset

Dataset identifier

gene_id

Ensembl gene ID

pearson

Pearson correlation

size_factor_pearson

Size factor Pearson correlation

__annotations__ = {'chrom': <class 'str'>, 'dataset': <class 'str'>, 'end': <class 'int'>, 'fold': typing.List[str], 'frac_N': <class 'float'>, 'frac_nan': <class 'float'>, 'gene_end': <class 'int'>, 'gene_id': <class 'str'>, 'gene_length': <class 'int'>, 'gene_mask_end': <class 'int'>, 'gene_mask_start': <class 'int'>, 'gene_start': <class 'int'>, 'gene_type': <class 'str'>, 'mean_counts': <class 'float'>, 'n_tracks': <class 'int'>, 'name': <class 'str'>, 'pearson': <class 'float'>, 'size_factor_pearson': <class 'float'>, 'start': <class 'int'>, 'strand': <class 'str'>}
__dataclass_fields__ = {'chrom': Field(name='chrom',type=<class 'str'>,default=<dataclasses._MISSING_TYPE object>,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({}),kw_only=False,_field_type=_FIELD), 'dataset': Field(name='dataset',type=<class 'str'>,default=<dataclasses._MISSING_TYPE object>,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({}),kw_only=False,_field_type=_FIELD), 'end': Field(name='end',type=<class 'int'>,default=<dataclasses._MISSING_TYPE object>,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({}),kw_only=False,_field_type=_FIELD), 'fold': Field(name='fold',type=typing.List[str],default=<dataclasses._MISSING_TYPE object>,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({}),kw_only=False,_field_type=_FIELD), 'frac_N': Field(name='frac_N',type=<class 'float'>,default=<dataclasses._MISSING_TYPE object>,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({}),kw_only=False,_field_type=_FIELD), 'frac_nan': Field(name='frac_nan',type=<class 'float'>,default=<dataclasses._MISSING_TYPE object>,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({}),kw_only=False,_field_type=_FIELD), 'gene_end': Field(name='gene_end',type=<class 'int'>,default=<dataclasses._MISSING_TYPE object>,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({}),kw_only=False,_field_type=_FIELD), 'gene_id': Field(name='gene_id',type=<class 'str'>,default=<dataclasses._MISSING_TYPE object>,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({}),kw_only=False,_field_type=_FIELD), 'gene_length': Field(name='gene_length',type=<class 'int'>,default=<dataclasses._MISSING_TYPE object>,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({}),kw_only=False,_field_type=_FIELD), 'gene_mask_end': Field(name='gene_mask_end',type=<class 'int'>,default=<dataclasses._MISSING_TYPE object>,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({}),kw_only=False,_field_type=_FIELD), 'gene_mask_start': Field(name='gene_mask_start',type=<class 'int'>,default=<dataclasses._MISSING_TYPE object>,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({}),kw_only=False,_field_type=_FIELD), 'gene_start': Field(name='gene_start',type=<class 'int'>,default=<dataclasses._MISSING_TYPE object>,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({}),kw_only=False,_field_type=_FIELD), 'gene_type': Field(name='gene_type',type=<class 'str'>,default=<dataclasses._MISSING_TYPE object>,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({}),kw_only=False,_field_type=_FIELD), 'mean_counts': Field(name='mean_counts',type=<class 'float'>,default=<dataclasses._MISSING_TYPE object>,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({}),kw_only=False,_field_type=_FIELD), 'n_tracks': Field(name='n_tracks',type=<class 'int'>,default=<dataclasses._MISSING_TYPE object>,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({}),kw_only=False,_field_type=_FIELD), 'name': Field(name='name',type=<class 'str'>,default=<dataclasses._MISSING_TYPE object>,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({}),kw_only=False,_field_type=_FIELD), 'pearson': Field(name='pearson',type=<class 'float'>,default=<dataclasses._MISSING_TYPE object>,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({}),kw_only=False,_field_type=_FIELD), 'size_factor_pearson': Field(name='size_factor_pearson',type=<class 'float'>,default=<dataclasses._MISSING_TYPE object>,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({}),kw_only=False,_field_type=_FIELD), 'start': Field(name='start',type=<class 'int'>,default=<dataclasses._MISSING_TYPE object>,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({}),kw_only=False,_field_type=_FIELD), 'strand': Field(name='strand',type=<class 'str'>,default=<dataclasses._MISSING_TYPE object>,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({}),kw_only=False,_field_type=_FIELD)}
__dataclass_params__ = _DataclassParams(init=True,repr=True,eq=True,order=False,unsafe_hash=False,frozen=False)
__eq__(other)

Return self==value.

__hash__ = None
__init__(name, chrom, start, end, strand, gene_type, frac_nan, mean_counts, n_tracks, gene_start, gene_end, gene_length, gene_mask_start, gene_mask_end, frac_N, fold, dataset, gene_id, pearson, size_factor_pearson)
__match_args__ = ('name', 'chrom', 'start', 'end', 'strand', 'gene_type', 'frac_nan', 'mean_counts', 'n_tracks', 'gene_start', 'gene_end', 'gene_length', 'gene_mask_start', 'gene_mask_end', 'frac_N', 'fold', 'dataset', 'gene_id', 'pearson', 'size_factor_pearson')
__repr__()

Return repr(self).

chrom: str
dataset: str
end: int
fold: List[str]
frac_N: float
frac_nan: float
classmethod from_series(name, series)[source]

Create GeneMetadata from a pandas Series.

Return type:

GeneMetadata

gene_end: int
gene_id: str
gene_length: int
gene_mask_end: int
gene_mask_start: int
gene_start: int
gene_type: str
mean_counts: float
n_tracks: int
name: str
pearson: float
size_factor_pearson: float
start: int
strand: str

decima.core.result module

class decima.core.result.DecimaResult(anndata)[source]

Bases: object

Container for Decima results and model predictions.

This class provides a unified interface for loading pre-trained Decima models and associated metadata, making predictions, and performing attribution analyses.

The DecimaResult object contains:
  • An AnnData object with gene expression and metadata

  • A trained model for making predictions

  • Methods for attribution analysis and interpretation

Parameters:

anndata – AnnData object containing gene expression data and metadata

Examples

>>> # Load default pre-trained model and metadata
>>> result = DecimaResult.load()
>>> result.load_model(
...     rep=0
... )
>>> # Perform attribution analysis
>>> attributions = result.attributions(
...     output_dir="attrs_SP1I_classical_monoctypes",
...     gene="SPI1",
...     tasks='cell_type == "classical monocyte"',
... )
Properties:

model: Decima model genes: List of gene names cells: List of cell names cell_metadata: Cell metadata gene_metadata: Gene metadata shape: Shape of the expression matrix attributions: Attributions for a gene

__annotations__ = {}
__init__(anndata)[source]
__repr__()[source]

Return repr(self).

attributions(gene, tasks=None, off_tasks=None, transform='specificity', method='inputxgradient', threshold=0.0005, min_seqlet_len=4, max_seqlet_len=25, additional_flanks=0)[source]

Get attributions for a specific gene.

Parameters:
  • gene (str) – Gene name

  • tasks (Optional[List[str]]) – List of cells to use as on task

  • off_tasks (Optional[List[str]]) – List of cells to use as off task

  • transform (str) – Attribution transform method

  • method (str) – Attribution method

  • n_peaks – Number of peaks to find

  • min_dist – Minimum distance between peaks

Returns:

Container with inputs, predictions, attribution scores and TSS position

Return type:

Attribution

property cell_metadata: DataFrame

Cell metadata including annotations, metrics, etc.

property cells: List[str]

List of cell identifiers in the dataset.

property gene_metadata: DataFrame

Gene metadata.

gene_sequence(gene, stranded=True)[source]

Get sequence for a gene.

Return type:

str

property genes: List[str]

List of gene names in the dataset.

get_cell_metadata(cell)[source]

Get metadata for a specific cell.

Return type:

CellMetadata

get_gene_metadata(gene)[source]

Get metadata for a specific gene.

Return type:

GeneMetadata

classmethod load(anndata_path=None)[source]

Load a DecimaResult object from an anndata file or a path to an anndata file.

Parameters:

anndata_path (Union[str, AnnData, None]) – Path to anndata file or anndata object

Returns:

DecimaResult object

Examples

>>> result = DecimaResult.load()  # Load default decima metadata
>>> result = DecimaResult.load(
...     "path/to/anndata.h5ad"
... )  # Load custom anndata object from file
load_model(model=0, device='cpu')[source]

Load the trained model from a checkpoint path.

Parameters:
  • model (Union[str, int, None]) – Path to model checkpoint or replicate number (0-3) for pre-trained models

  • device (str) – Device to load model on

Returns:

self

Examples

>>> result = DecimaResult.load()
>>> result.load_model()  # Load default model (rep0)
>>> result.load_model(
...     model="path/to/checkpoint.ckpt"
... )
>>> result.load_model(
...     model=2
... )
property model

Decima model.

predicted_expression_matrix(genes=None)[source]

Get predicted expression matrix for all or specific genes.

Parameters:

genes (Optional[List[str]]) – Optional list of genes to get predictions for. If None, returns all genes.

Returns:

Predicted expression matrix (cells x genes)

Return type:

pd.DataFrame

prepare_one_hot(gene, variants=None)[source]

Prepare one-hot encoding for a gene.

Parameters:

gene (str) – Gene name

Returns:

One-hot encoding of the gene

Return type:

torch.Tensor

query_cells(query)[source]
query_tasks(tasks=None, off_tasks=None)[source]
property shape: tuple

Shape of the expression matrix (n_cells, n_genes).

Module contents