decima.data package

Submodules

decima.data.dataset module

pytorch Datasets for Decima.

This module contains the datasets for Decima, including: - HDF5Dataset: Dataset for HDF5 files. - GeneDataset: Dataset for gene expression prediction. - SeqDataset: Dataset for sequence prediction. - VariantDataset: Dataset for variant effect prediction.

class decima.data.dataset.GeneDataset(genes=None, metadata_anndata=None, max_seq_shift=0, seed=0, augment_mode='random', genome='hg38')[source]

Bases: Dataset

Dataset for gene expression prediction.

Parameters:
  • genes – List of genes to include in the dataset.

  • metadata_anndata (Optional[str]) – AnnData object to use for extracting gene metadata.

  • max_seq_shift (int) – Maximum sequence shift.

  • seed (int) – Seed for the random number generator.

  • augment_mode (str) – Augmentation mode.

  • genome (str) – Name of the genome

Returns:

Dataset for gene expression prediction.

Return type:

Dataset

Examples

>>> genes = [
...     "SPI1",
...     "SPI2",
... ]
>>> dataset = (
...     GeneDataset(
...         genes=genes
...     )
... )
>>> dl = torch.data.DataLoader(
...     dataset,
...     batch_size=1,
...     shuffle=True,
...     collate_fn=dataset.collate_fn,
... )
>>> for batch in dl:
    print(batch)
... (2, 524288, 5)
__getitem__(idx)[source]
__init__(genes=None, metadata_anndata=None, max_seq_shift=0, seed=0, augment_mode='random', genome='hg38')[source]
__len__()[source]
__parameters__ = ()
collate_fn(batch)[source]
class decima.data.dataset.HDF5Dataset(key, h5_file, ad=None, seq_len=524288, max_seq_shift=0, seed=0, augment_mode='random')[source]

Bases: Dataset

Dataset for HDF5 files.

Parameters:
  • key – Key to use to access the data in the HDF5 file.

  • h5_file – Path to the HDF5 file.

  • ad – AnnData object to use for extracting tasks.

  • seq_len – Length of the sequence.

  • max_seq_shift – Maximum sequence shift.

  • seed – Seed for the random number generator.

  • augment_mode – Augmentation mode.

Returns:

Dataset for HDF5 files.

Return type:

Dataset

__annotations__ = {}
__getitem__(idx)[source]
__init__(key, h5_file, ad=None, seq_len=524288, max_seq_shift=0, seed=0, augment_mode='random')[source]
__len__()[source]
__parameters__ = ()
close()[source]
extract_label(idx)[source]
extract_seq(idx)[source]
extract_tasks(ad=None)[source]
class decima.data.dataset.SeqDataset(seqs, gene_mask_starts, gene_mask_ends, genes=None, max_seq_shift=0, seed=0, augment_mode='random')[source]

Bases: Dataset

Dataset for sequence prediction with the masked gene sequence.

Parameters:
  • seqs (List[str]) – List of sequences as strings.

  • gene_mask_starts (List[int]) – List of gene mask starts.

  • gene_mask_ends (List[int]) – List of gene mask ends.

  • genes (List[str]) – List of gene names.

  • max_seq_shift (int) – Maximum sequence shift.

  • seed (int) – Seed for the random number generator.

  • augment_mode (str) – Augmentation mode.

Returns:

Dataset for sequence prediction with the masked gene sequence.

Return type:

Dataset

Examples

>>> seqs = [
...     "ATCG...",
...     "ATCG..",
...     "ATCG...",
... ]
>>> gene_mask_starts = [
...     0,
...     0,
...     0,
... ]
>>> gene_mask_ends = [
...     4,
...     4,
...     4,
... ]
>>> dataset = SeqDataset(
...     seqs=seqs,
...     gene_mask_starts=gene_mask_starts,
...     gene_mask_ends=gene_mask_ends,
... )
>>> dl = torch.data.DataLoader(
...     dataset,
...     batch_size=1,
...     shuffle=True,
...     collate_fn=dataset.collate_fn,
... )
>>> for batch in dl:
    print(batch)
... (2, 524288, 5)
>>> dataset = SeqDataset.from_fasta(
...     fasta_file="example/seqs.fasta"
... )
>>> df = pd.DataFrame(
...     {
...         "seq": [
...             "ATCG..",
...             "ATCG...",
...             "ATCG......",
...         ],
...         "gene_mask_start": [
...             0,
...             0,
...             0,
...         ],
...         "gene_mask_end": [
...             4,
...             4,
...             4,
...         ],
...     }
... )
>>> dataset = SeqDataset.from_dataframe(
...     df
... )
__annotations__ = {}
__getitem__(idx)[source]
__init__(seqs, gene_mask_starts, gene_mask_ends, genes=None, max_seq_shift=0, seed=0, augment_mode='random')[source]
__len__()[source]
__parameters__ = ()
collate_fn(batch)[source]
classmethod from_dataframe(df, max_seq_shift=0, seed=0, augment_mode='random')[source]

Create a SeqDataset from a pandas DataFrame.

Parameters:
  • df (DataFrame) – pandas DataFrame containing seq, gene_mask_start, and gene_mask_end columns.

  • max_seq_shift (int) – Maximum sequence shift.

  • seed (int) – Seed for the random number generator.

  • augment_mode (str) – Augmentation mode.

Returns:

SeqDataset object.

Return type:

SeqDataset

Examples

>>> df = pd.DataFrame(
...     {
...         "seq": [
...             "ATCG..",
...             "ATCG...",
...             "ATCG......",
...         ],
...         "gene_mask_start": [
...             0,
...             0,
...             0,
...         ],
...         "gene_mask_end": [
...             4,
...             4,
...             4,
...         ],
...     }
... )
>>> dataset = SeqDataset.from_dataframe(
...     df
... )
>>> dl = torch.data.DataLoader(
...     dataset,
...     batch_size=1,
...     shuffle=True,
...     collate_fn=dataset.collate_fn,
... )
>>> for batch in dl:
    print(batch)
... (2, 524288, 5)
classmethod from_fasta(fasta_file, max_seq_shift=0, seed=0, augment_mode='random')[source]

Create a SeqDataset from a FASTA file.

Args:

fasta_file: Path to the FASTA file with header as gene name, maks and sequence as the sequence: “>gene_name|gene_mask_start=10000|gene_mask_end=10000

ATACG…”.

max_seq_shift: Maximum sequence shift. seed: Seed for the random number generator. augment_mode: Augmentation mode.

Returns:

SeqDataset: SeqDataset object.

Examples:
>>> dataset = SeqDataset.from_fasta(
...     fasta_file="example/seqs.fasta"
... )
>>> dl = torch.data.DataLoader(
...     dataset,
...     batch_size=1,
...     shuffle=True,
...     collate_fn=dataset.collate_fn,
... )
>>> for batch in dl:
    print(batch)
... (2, 524288, 5)
classmethod from_one_hot(one_hot, gene_mask_starts=None, gene_mask_ends=None, max_seq_shift=0, seed=0, augment_mode='random')[source]

Create a SeqDataset from a one-hot encoded tensor.

Parameters:
  • one_hot (Tensor) – One-hot encoded tensor with shape (batch_size, 4 or 5, seq_len).

  • gene_mask_starts (List[int]) – List of gene mask starts.

  • gene_mask_ends (List[int]) – List of gene mask ends.

  • max_seq_shift (int) – Maximum sequence shift.

  • seed (int) – Seed for the random number generator.

  • augment_mode (str) – Augmentation mode.

Returns:

SeqDataset object.

Return type:

SeqDataset

class decima.data.dataset.VariantDataset(variants, metadata_anndata=None, max_seq_shift=0, seed=0, include_cols=None, gene_col=None, min_from_end=0, distance_type='tss', min_distance=0, max_distance=inf, model_name=None, reference_cache=True, genome='hg38')[source]

Bases: Dataset

Dataset for variant effect prediction

Parameters:
  • variants (pd.DataFrame) – DataFrame with variants

  • anndata (AnnData) – AnnData object with gene metadata

  • seq_len (int) – Length of the sequence

  • max_seq_shift (int) – Maximum sequence shift

  • include_cols (list) – List of columns to include in the output

  • gene_col (str) – Column name for gene names

  • min_from_end (int) – Minimum distance from the end of the gene

  • distance_type (str) – Type of distance

  • min_distance (int) – Minimum distance from the TSS

  • max_distance (int) – Maximum distance from the TSS

Returns:

Dataset for variant effect prediction

Return type:

Dataset

Examples

>>> import pandas as pd
>>> import anndata as ad
>>> from decima.data.dataset import (
...     VariantDataset,
... )
>>> variants = pd.read_csv(
...     "variants.csv"
... )
>>> dataset = (
...     VariantDataset(
...         variants
...     )
... )
>>> dataset[0]
{'seq': tensor([[1.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 1.0000],
                [0.0000, 1.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
                [0.0000, 0.0000, 1.0000,  ..., 0.0000, 1.0000, 0.0000],
                [0.0000, 0.0000, 0.0000,  ..., 1.0000, 0.0000, 0.0000],
                [0.0000, 0.0000, 1.0000,  ..., 1.0000, 0.0000, 0.0000]]), 'warning': []}
>>> dl = torch.data.DataLoader(
...     dataset,
...     batch_size=1,
...     shuffle=True,
...     collate_fn=dataset.collate_fn,
... )
>>> for batch in dl:
    print(batch)
... {
    'seq': tensor([[1.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 1.0000],
                    [0.0000, 1.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
                    [0.0000, 0.0000, 1.0000,  ..., 0.0000, 1.0000, 0.0000],
                    [0.0000, 0.0000, 1.0000,  ..., 1.0000, 0.0000, 0.0000]]),
    'warning': [],
    'pred_expression': tensor([[0.0000, 0.0000],
                    [0.0000, 0.0000],
                    [0.0000, 0.0000],
                    [0.0000, 0.0000]]),
}
DEFAULT_COLUMNS = ['chrom', 'pos', 'ref', 'alt', 'gene', 'start', 'end', 'strand', 'gene_mask_start', 'gene_mask_end', 'rel_pos', 'ref_tx', 'alt_tx', 'tss_dist']
__annotations__ = {}
__getitem__(idx)[source]
__init__(variants, metadata_anndata=None, max_seq_shift=0, seed=0, include_cols=None, gene_col=None, min_from_end=0, distance_type='tss', min_distance=0, max_distance=inf, model_name=None, reference_cache=True, genome='hg38')[source]
__len__()[source]
__parameters__ = ()
__repr__()[source]

Return repr(self).

collate_fn(batch)[source]
static overlap_genes(df_variants, df_genes, gene_col=None, include_cols=None, min_from_end=0, distance_type='tss', min_distance=0, max_distance=inf)[source]

Overlap genes with variants.

Parameters:
  • df_variants – pandas DataFrame containing variants.

  • df_genes – pandas DataFrame containing genes.

  • gene_col – Column name for gene names.

  • include_cols – List of columns to include in the output.

  • min_from_end – Minimum distance from the end of the gene.

  • distance_type – Type of distance.

  • min_distance – Minimum distance from the TSS.

  • max_distance – Maximum distance from the TSS.

Returns:

pandas DataFrame containing the overlap between genes and variants.

Examples

>>> df_variants = pd.DataFrame(
...     {
...         "chrom": [
...             "1",
...             "1",
...             "1",
...         ],
...         "pos": [
...             10000,
...             10000,
...             10000,
...         ],
...         "ref": [
...             "A",
...             "A",
...             "A",
...         ],
...         "alt": [
...             "G",
...             "G",
...             "G",
...         ],
...         "gene": [
...             "SPI1",
...             "SPI2",
...             "SPI3",
...         ],
...     }
... )
>>> df_genes = pd.DataFrame(
...     {
...         "gene": [
...             "SPI1",
...             "SPI2",
...             "SPI3",
...         ],
...         "start": [
...             10000,
...             10000,
...             10000,
...         ],
...         "end": [
...             10000,
...             10000,
...             10000,
...         ],
...         "strand": [
...             "+",
...             "+",
...             "+",
...         ],
...         "gene_mask_start": [
...             0,
...             0,
...             0,
...         ],
...         "gene_mask_end": [
...             4,
...             4,
...             4,
...         ],
...     }
... )
>>> df = VariantDataset.overlap_genes(
...     df_variants,
...     df_genes,
... )
>>> print(df)
...    chrom  pos ref alt gene  start end strand  gene_mask_start  gene_mask_end  rel_pos ref_tx alt_tx  tss_dist
... 0     1  10000   A   G  SPI1  10000 10000      +              0              4      0      A     G        0
... 1     1  10000   A   G  SPI2  10000 10000      +              0              4      0      A     G        0
... 2     1  10000   A   G  SPI3  10000 10000      +              0              4      0      A     G        0
predicted_expression_cache(gene)[source]

Get predicted expression for a gene.

Parameters:

gene – Gene name.

Returns:

Dictionary of predicted expression for each model.

Return type:

dict

validate_allele_seq(gene, variant)[source]

decima.data.preprocess module

decima.data.preprocess.aggregate_anndata(ad, by_cols=['cell_type', 'tissue', 'organ', 'disease', 'study', 'dataset', 'region', 'subregion', 'celltype_coarse'], sum_cols=['n_cells'])[source]
decima.data.preprocess.assign_borzoi_folds(ad, splits)[source]
decima.data.preprocess.change_values(df, col, value_dict)[source]
decima.data.preprocess.get_frac_N(interval, genome='hg38')[source]
decima.data.preprocess.load_ncbi_string(string, allow_dups=False, verbose=False)[source]
decima.data.preprocess.make_inputs(gene, ad)[source]
decima.data.preprocess.match_cellranger_2024(ad, genes24)[source]
decima.data.preprocess.match_ncbi(ad, ncbi)[source]
decima.data.preprocess.match_ref_ad(ad, ref_ad)[source]
decima.data.preprocess.merge_transcripts(gtf)[source]
decima.data.preprocess.return_ensembl(queries, on='gene_id')[source]
decima.data.preprocess.var_to_intervals(ad, chr_end_pad=10000, genome='hg38', seq_len=524288, crop_coords=163840)[source]

decima.data.read_hdf5 module

decima.data.read_hdf5.count_genes(h5_file, key=None)[source]
decima.data.read_hdf5.extract_gene_data(h5_file, gene, seq_len=524288, merge=True)[source]
decima.data.read_hdf5.get_gene_idx(h5_file, gene, key=None)[source]
decima.data.read_hdf5.index_genes(h5_file, key=None)[source]
decima.data.read_hdf5.list_genes(h5_file, key=None)[source]
decima.data.read_hdf5.mutate(seq, allele, pos)[source]

decima.data.write_hdf5 module

decima.data.write_hdf5.write_hdf5(file, ad, pad=0, genome='hg38')[source]

Write AnnData object to HDF5 file.

Parameters:
  • file – Path to the HDF5 file to write

  • ad – AnnData object containing the data

  • pad – Amount of padding to add. Defaults to 0

  • genome – Genome name or path to the genome fasta file. Defaults to “hg38”

Module contents