scimilarity.anndata_data_models#
- class scimilarity.anndata_data_models.MetricLearningDataModule(*args, **kwargs)[source]#
Bases:
LightningDataModule
A class to encapsulate the anndata needed to train the model.
- Parameters:
train_path (str) – Path to the training h5ad file.
val_path (str, optional, default: None) – Path to the validataion h5ad file.
obs_field (str, default: "celltype_name") – The obs key name containing celltype labels.
batch_size (int, default: 1000) – Batch size.
num_workers (int, default: 1) – The number of worker threads for dataloaders.
gene_order_file (str, optional) – Use a given gene order as described in the specified file rather than using the training dataset’s gene order. One gene symbol per line.
Examples
>>> datamodule = MetricLearningDataModule( batch_size=1000, num_workers=1, obs_field="celltype_name", train_path="train.h5ad", )
- collate(batch)[source]#
Collate tensors.
- Parameters:
batch – Batch to collate.
- Returns:
A Tuple[torch.Tensor, torch.Tensor, list] containing information on the collated tensors.
- Return type:
tuple
- get_sampler_weights(labels, studies=None)[source]#
Get weighted random sampler.
- Parameters:
dataset (scDataset) – Single cell dataset.
labels (list)
studies (list | None)
- Returns:
A WeightedRandomSampler object.
- Return type:
WeightedRandomSampler
- subset_valid_terms(data)[source]#
Keep cells whose celltype labels have valid ontology id.
- Parameters:
data (anndata.AnnData) – Annotated data to subset by valid ontology id.
- Returns:
An object containing the data whose celltype labels have valid ontology id.
- Return type:
anndata.AnnData
- test_dataloader()[source]#
Load the test dataset.
- Returns:
A DataLoader object containing the test dataset.
- Return type:
DataLoader
- class scimilarity.anndata_data_models.scDataset(X, Y, study=None)[source]#
Bases:
Dataset
A class that represents a single cell dataset.
- Parameters:
X (numpy.ndarray) – Gene expression vectors for every cell.
Y (numpy.ndarray) – Text labels for every cell.
study (numpy.ndarray) – The study identifier for every cell.