scimilarity.zarr_data_models#

class scimilarity.zarr_data_models.MetricLearningDataModule(*args, **kwargs)[source]#

Bases: LightningDataModule

A class to encapsulate a collection of zarr datasets to train the model.

Parameters:
  • train_path (str) – Path to folder containing all training datasets. All datasets should be in zarr format, aligned to a known gene space, and cleaned to only contain valid cell ontology terms.

  • gene_order (str) – Use a given gene order as described in the specified file. One gene symbol per line. IMPORTANT: the zarr datasets should already be in this gene order after preprocessing.

  • val_path (str, optional, default: None) – Path to folder containing all validation datasets.

  • obs_field (str, default: "celltype_name") – The obs key name containing celltype labels.

  • batch_size (int, default: 1000) – Batch size.

  • num_workers (int, default: 1) – The number of worker threads for dataloaders

Examples

>>> datamodule = MetricLearningZarrDataModule(
        batch_size=1000,
        num_workers=1,
        obs_field="celltype_name",
        train_path="train",
        gene_order="gene_order.tsv",
    )
collate(batch)[source]#

Collate tensors.

Parameters:

batch – Batch to collate.

Returns:

A Tuple[torch.Tensor, torch.Tensor, list] containing information on the collated tensors.

Return type:

tuple

get_sampler_weights(labels, studies=None)[source]#

Get weighted random sampler.

Parameters:
  • dataset (scDataset) – Single cell dataset.

  • labels (list)

  • studies (list | None)

Returns:

A WeightedRandomSampler object.

Return type:

WeightedRandomSampler

test_dataloader()[source]#

Load the test dataset.

Returns:

A DataLoader object containing the test dataset.

Return type:

DataLoader

train_dataloader()[source]#

Load the training dataset.

Returns:

A DataLoader object containing the training dataset.

Return type:

DataLoader

val_dataloader()[source]#

Load the validation dataset.

Returns:

A DataLoader object containing the validation dataset.

Return type:

DataLoader

class scimilarity.zarr_data_models.scDataset(data_list, obs_celltype='celltype_name', obs_study='study')[source]#

Bases: Dataset

A class that represent a collection of single cell datasets in zarr format.

Parameters:
  • data_list (list) – List of single-cell datasets.

  • obs_celltype (str, default: "celltype_name") – Cell type name.

  • obs_study (str, default: "study") – Study name.