scimilarity.anndata_data_models#

class scimilarity.anndata_data_models.MetricLearningDataModule(*args, **kwargs)[source]#

Bases: LightningDataModule

A class to encapsulate the anndata needed to train the model.

Parameters:
  • train_path (str) – Path to the training h5ad file.

  • val_path (str, optional, default: None) – Path to the validataion h5ad file.

  • label_column (str, default: "celltype_name") – The column name containing ontology compliant cell type names.

  • study_column (str, default: "study") – The column name containing study identifiers.

  • gene_order_file (str, optional) – Use a given gene order as described in the specified file rather than using the training dataset’s gene order. One gene symbol per line.

  • batch_size (int, default: 1000) – Batch size.

  • num_workers (int, default: 1) – The number of worker threads for dataloaders.

  • sparse (bool, default: False) – Use sparse matrices.

  • remove_singleton_classes (bool, default: True) – Exclude cells with classes that exist in only one study.

  • pin_memory (bool, default: False) – If True, uses pin memory in the DataLoaders.

  • persistent_workers (bool, default: False) – If True, uses persistent workers in the DataLoaders. False if num_workers is 0.

  • multiprocessing_context (str, default: "fork") – Multiprocessing context for dataloaders: [“spawn”, “fork”].

Examples

>>> datamodule = MetricLearningDataModule(
        batch_size=1000,
        num_workers=1,
        label_column="celltype_name",
        train_path="train.h5ad",
    )
get_sampler_weights(labels, studies=None, class_target_sum=10000.0, study_target_sum=1000000.0)[source]#

Get weighted random sampler.

Parameters:
  • labels (list) – The list of labels.

  • studies (list, optional, default: None) – A list of studies to incorporate studies in sampler weights.

  • class_target_sum (float, default: 1e4) – Target sum for normalization of class counts.

  • study_target_sum (float, default: 1e6) – Target sum for normalization of study counts.

Returns:

A WeightedRandomSampler object.

Return type:

WeightedRandomSampler

remove_singleton_label_ids(data, n_studies=2)[source]#

Ensure labels exist in at least a minimum number of studies.

Parameters:
  • data (anndata.AnnData) – Annotated data to subset by valid ontology id.

  • n_studies (int, default: 2) – The number of studies a label must exist in to be valid.

Return type:

anndata.AnnData

subset_valid_terms(data)[source]#

Keep cells whose celltype labels have valid ontology id.

Parameters:

data (anndata.AnnData) – Annotated data to subset by valid ontology id.

Returns:

An object containing the data whose celltype labels have valid ontology id.

Return type:

anndata.AnnData

test_dataloader()[source]#

Load the test dataset.

Returns:

A DataLoader object containing the test dataset.

Return type:

DataLoader

train_dataloader()[source]#

Load the training dataset.

Returns:

A DataLoader object containing the training dataset.

Return type:

DataLoader

val_dataloader()[source]#

Load the validation dataset.

Returns:

A DataLoader object containing the validation dataset.

Return type:

DataLoader

class scimilarity.anndata_data_models.scCollator(label2int, sparse=False)[source]#

Bases: object

A class to collate batch data.

Parameters:
  • label2int (dict) – A dictionary that maps string labels to class integers.

  • sparse (bool, default: False) – Use sparse matrices.

class scimilarity.anndata_data_models.scDataset(X, Y, study=None)[source]#

Bases: Dataset

A class that represents a single cell dataset.

Parameters:
  • X (numpy.ndarray) – Gene expression vectors for every cell.

  • Y (numpy.ndarray) – Text labels for every cell.

  • study (numpy.ndarray) – The study identifier for every cell.