decima.data package¶
Submodules¶
decima.data.dataset module¶
pytorch Datasets for Decima.
This module contains the datasets for Decima, including: - HDF5Dataset: Dataset for HDF5 files. - GeneDataset: Dataset for gene expression prediction. - SeqDataset: Dataset for sequence prediction. - VariantDataset: Dataset for variant effect prediction.
- class decima.data.dataset.GeneDataset(genes=None, metadata_anndata=None, max_seq_shift=0, seed=0, augment_mode='random', genome='hg38')[source]¶
Bases:
DatasetDataset for gene expression prediction.
- Parameters:
genes – List of genes to include in the dataset.
metadata_anndata (
Optional[str]) – AnnData object to use for extracting gene metadata.max_seq_shift (
int) – Maximum sequence shift.seed (
int) – Seed for the random number generator.augment_mode (
str) – Augmentation mode.genome (
str) – Name of the genome
- Returns:
Dataset for gene expression prediction.
- Return type:
Dataset
Examples
>>> genes = [ ... "SPI1", ... "SPI2", ... ] >>> dataset = ( ... GeneDataset( ... genes=genes ... ) ... ) >>> dl = torch.data.DataLoader( ... dataset, ... batch_size=1, ... shuffle=True, ... collate_fn=dataset.collate_fn, ... ) >>> for batch in dl: print(batch) ... (2, 524288, 5)
- __init__(genes=None, metadata_anndata=None, max_seq_shift=0, seed=0, augment_mode='random', genome='hg38')[source]¶
- __parameters__ = ()¶
- class decima.data.dataset.HDF5Dataset(key, h5_file, ad=None, seq_len=524288, max_seq_shift=0, seed=0, augment_mode='random')[source]¶
Bases:
DatasetDataset for HDF5 files.
- Parameters:
key – Key to use to access the data in the HDF5 file.
h5_file – Path to the HDF5 file.
ad – AnnData object to use for extracting tasks.
seq_len – Length of the sequence.
max_seq_shift – Maximum sequence shift.
seed – Seed for the random number generator.
augment_mode – Augmentation mode.
- Returns:
Dataset for HDF5 files.
- Return type:
Dataset
- __annotations__ = {}¶
- __init__(key, h5_file, ad=None, seq_len=524288, max_seq_shift=0, seed=0, augment_mode='random')[source]¶
- __parameters__ = ()¶
- class decima.data.dataset.SeqDataset(seqs, gene_mask_starts, gene_mask_ends, genes=None, max_seq_shift=0, seed=0, augment_mode='random')[source]¶
Bases:
DatasetDataset for sequence prediction with the masked gene sequence.
- Parameters:
- Returns:
Dataset for sequence prediction with the masked gene sequence.
- Return type:
Dataset
Examples
>>> seqs = [ ... "ATCG...", ... "ATCG..", ... "ATCG...", ... ] >>> gene_mask_starts = [ ... 0, ... 0, ... 0, ... ] >>> gene_mask_ends = [ ... 4, ... 4, ... 4, ... ] >>> dataset = SeqDataset( ... seqs=seqs, ... gene_mask_starts=gene_mask_starts, ... gene_mask_ends=gene_mask_ends, ... ) >>> dl = torch.data.DataLoader( ... dataset, ... batch_size=1, ... shuffle=True, ... collate_fn=dataset.collate_fn, ... ) >>> for batch in dl: print(batch) ... (2, 524288, 5)
>>> dataset = SeqDataset.from_fasta( ... fasta_file="example/seqs.fasta" ... )
>>> df = pd.DataFrame( ... { ... "seq": [ ... "ATCG..", ... "ATCG...", ... "ATCG......", ... ], ... "gene_mask_start": [ ... 0, ... 0, ... 0, ... ], ... "gene_mask_end": [ ... 4, ... 4, ... 4, ... ], ... } ... ) >>> dataset = SeqDataset.from_dataframe( ... df ... )
- __annotations__ = {}¶
- __init__(seqs, gene_mask_starts, gene_mask_ends, genes=None, max_seq_shift=0, seed=0, augment_mode='random')[source]¶
- __parameters__ = ()¶
- classmethod from_dataframe(df, max_seq_shift=0, seed=0, augment_mode='random')[source]¶
Create a SeqDataset from a pandas DataFrame.
- Parameters:
- Returns:
SeqDataset object.
- Return type:
Examples
>>> df = pd.DataFrame( ... { ... "seq": [ ... "ATCG..", ... "ATCG...", ... "ATCG......", ... ], ... "gene_mask_start": [ ... 0, ... 0, ... 0, ... ], ... "gene_mask_end": [ ... 4, ... 4, ... 4, ... ], ... } ... ) >>> dataset = SeqDataset.from_dataframe( ... df ... ) >>> dl = torch.data.DataLoader( ... dataset, ... batch_size=1, ... shuffle=True, ... collate_fn=dataset.collate_fn, ... ) >>> for batch in dl: print(batch) ... (2, 524288, 5)
- classmethod from_fasta(fasta_file, max_seq_shift=0, seed=0, augment_mode='random')[source]¶
Create a SeqDataset from a FASTA file.
- Args:
fasta_file: Path to the FASTA file with header as gene name, maks and sequence as the sequence: “>gene_name|gene_mask_start=10000|gene_mask_end=10000
- ATACG…”.
max_seq_shift: Maximum sequence shift. seed: Seed for the random number generator. augment_mode: Augmentation mode.
- Returns:
SeqDataset: SeqDataset object.
- Examples:
>>> dataset = SeqDataset.from_fasta( ... fasta_file="example/seqs.fasta" ... ) >>> dl = torch.data.DataLoader( ... dataset, ... batch_size=1, ... shuffle=True, ... collate_fn=dataset.collate_fn, ... ) >>> for batch in dl: print(batch) ... (2, 524288, 5)
- classmethod from_one_hot(one_hot, gene_mask_starts=None, gene_mask_ends=None, max_seq_shift=0, seed=0, augment_mode='random')[source]¶
Create a SeqDataset from a one-hot encoded tensor.
- class decima.data.dataset.VariantDataset(variants, metadata_anndata=None, max_seq_shift=0, seed=0, include_cols=None, gene_col=None, min_from_end=0, distance_type='tss', min_distance=0, max_distance=inf, model_name=None, reference_cache=True, genome='hg38')[source]¶
Bases:
DatasetDataset for variant effect prediction
- Parameters:
variants (pd.DataFrame) – DataFrame with variants
anndata (AnnData) – AnnData object with gene metadata
seq_len (int) – Length of the sequence
max_seq_shift (int) – Maximum sequence shift
include_cols (list) – List of columns to include in the output
gene_col (str) – Column name for gene names
min_from_end (int) – Minimum distance from the end of the gene
distance_type (str) – Type of distance
min_distance (int) – Minimum distance from the TSS
max_distance (int) – Maximum distance from the TSS
- Returns:
Dataset for variant effect prediction
- Return type:
Dataset
Examples
>>> import pandas as pd >>> import anndata as ad >>> from decima.data.dataset import ( ... VariantDataset, ... ) >>> variants = pd.read_csv( ... "variants.csv" ... ) >>> dataset = ( ... VariantDataset( ... variants ... ) ... ) >>> dataset[0] {'seq': tensor([[1.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 1.0000], [0.0000, 1.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 1.0000, ..., 0.0000, 1.0000, 0.0000], [0.0000, 0.0000, 0.0000, ..., 1.0000, 0.0000, 0.0000], [0.0000, 0.0000, 1.0000, ..., 1.0000, 0.0000, 0.0000]]), 'warning': []} >>> dl = torch.data.DataLoader( ... dataset, ... batch_size=1, ... shuffle=True, ... collate_fn=dataset.collate_fn, ... ) >>> for batch in dl: print(batch) ... { 'seq': tensor([[1.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 1.0000], [0.0000, 1.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 1.0000, ..., 0.0000, 1.0000, 0.0000], [0.0000, 0.0000, 1.0000, ..., 1.0000, 0.0000, 0.0000]]), 'warning': [], 'pred_expression': tensor([[0.0000, 0.0000], [0.0000, 0.0000], [0.0000, 0.0000], [0.0000, 0.0000]]), }
- DEFAULT_COLUMNS = ['chrom', 'pos', 'ref', 'alt', 'gene', 'start', 'end', 'strand', 'gene_mask_start', 'gene_mask_end', 'rel_pos', 'ref_tx', 'alt_tx', 'tss_dist']¶
- __annotations__ = {}¶
- __init__(variants, metadata_anndata=None, max_seq_shift=0, seed=0, include_cols=None, gene_col=None, min_from_end=0, distance_type='tss', min_distance=0, max_distance=inf, model_name=None, reference_cache=True, genome='hg38')[source]¶
- __parameters__ = ()¶
- static overlap_genes(df_variants, df_genes, gene_col=None, include_cols=None, min_from_end=0, distance_type='tss', min_distance=0, max_distance=inf)[source]¶
Overlap genes with variants.
- Parameters:
df_variants – pandas DataFrame containing variants.
df_genes – pandas DataFrame containing genes.
gene_col – Column name for gene names.
include_cols – List of columns to include in the output.
min_from_end – Minimum distance from the end of the gene.
distance_type – Type of distance.
min_distance – Minimum distance from the TSS.
max_distance – Maximum distance from the TSS.
- Returns:
pandas DataFrame containing the overlap between genes and variants.
Examples
>>> df_variants = pd.DataFrame( ... { ... "chrom": [ ... "1", ... "1", ... "1", ... ], ... "pos": [ ... 10000, ... 10000, ... 10000, ... ], ... "ref": [ ... "A", ... "A", ... "A", ... ], ... "alt": [ ... "G", ... "G", ... "G", ... ], ... "gene": [ ... "SPI1", ... "SPI2", ... "SPI3", ... ], ... } ... ) >>> df_genes = pd.DataFrame( ... { ... "gene": [ ... "SPI1", ... "SPI2", ... "SPI3", ... ], ... "start": [ ... 10000, ... 10000, ... 10000, ... ], ... "end": [ ... 10000, ... 10000, ... 10000, ... ], ... "strand": [ ... "+", ... "+", ... "+", ... ], ... "gene_mask_start": [ ... 0, ... 0, ... 0, ... ], ... "gene_mask_end": [ ... 4, ... 4, ... 4, ... ], ... } ... )
>>> df = VariantDataset.overlap_genes( ... df_variants, ... df_genes, ... ) >>> print(df) ... chrom pos ref alt gene start end strand gene_mask_start gene_mask_end rel_pos ref_tx alt_tx tss_dist ... 0 1 10000 A G SPI1 10000 10000 + 0 4 0 A G 0 ... 1 1 10000 A G SPI2 10000 10000 + 0 4 0 A G 0 ... 2 1 10000 A G SPI3 10000 10000 + 0 4 0 A G 0
decima.data.preprocess module¶
decima.data.read_hdf5 module¶
decima.data.write_hdf5 module¶
- decima.data.write_hdf5.write_hdf5(file, ad, pad=0, genome='hg38')[source]¶
Write AnnData object to HDF5 file.
- Parameters:
file – Path to the HDF5 file to write
ad – AnnData object containing the data
pad – Amount of padding to add. Defaults to 0
genome – Genome name or path to the genome fasta file. Defaults to “hg38”