scimilarity.zarr_data_models#

class scimilarity.zarr_data_models.MetricLearningDataModule(*args, **kwargs)[source]#

Bases: LightningDataModule

A class to encapsulate a collection of zarr datasets to train the model.

Parameters:
  • train_path (str) –

  • gene_order (str) –

  • val_path (Optional[str]) –

  • obs_field (str) –

  • batch_size (int) –

  • num_workers (int) –

collate(batch)[source]#

Collate tensors.

Parameters:

batch – Batch to collate.

Returns:

A Tuple[torch.Tensor, torch.Tensor, list] containing information on the collated tensors.

Return type:

tuple

get_sampler_weights(labels, studies=None)[source]#

Get weighted random sampler.

Parameters:
  • dataset (scDataset) – Single cell dataset.

  • labels (list) –

  • studies (Optional[list]) –

Returns:

A WeightedRandomSampler object.

Return type:

WeightedRandomSampler

test_dataloader()[source]#

Load the test dataset.

Returns:

A DataLoader object containing the test dataset.

Return type:

DataLoader

train_dataloader()[source]#

Load the training dataset.

Returns:

A DataLoader object containing the training dataset.

Return type:

DataLoader

val_dataloader()[source]#

Load the validation dataset.

Returns:

A DataLoader object containing the validation dataset.

Return type:

DataLoader

class scimilarity.zarr_data_models.scDataset(data_list, obs_celltype='celltype_name', obs_study='study')[source]#

Bases: Dataset

A class that represent a collection of single cell datasets in zarr format.