scimilarity.anndata_data_models#
- class scimilarity.anndata_data_models.MetricLearningDataModule(*args, **kwargs)[source]#
Bases:
LightningDataModule
A class to encapsulate the anndata needed to train the model.
- Parameters:
train_path (str) – Path to the training h5ad file.
val_path (str, optional, default: None) – Path to the validataion h5ad file.
label_column (str, default: "celltype_name") – The column name containing ontology compliant cell type names.
study_column (str, default: "study") – The column name containing study identifiers.
gene_order_file (str, optional) – Use a given gene order as described in the specified file rather than using the training dataset’s gene order. One gene symbol per line.
batch_size (int, default: 1000) – Batch size.
num_workers (int, default: 1) – The number of worker threads for dataloaders.
sparse (bool, default: False) – Use sparse matrices.
remove_singleton_classes (bool, default: True) – Exclude cells with classes that exist in only one study.
pin_memory (bool, default: False) – If True, uses pin memory in the DataLoaders.
persistent_workers (bool, default: False) – If True, uses persistent workers in the DataLoaders. False if num_workers is 0.
multiprocessing_context (str, default: "fork") – Multiprocessing context for dataloaders: [“spawn”, “fork”].
Examples
>>> datamodule = MetricLearningDataModule( batch_size=1000, num_workers=1, label_column="celltype_name", train_path="train.h5ad", )
- get_sampler_weights(labels, studies=None, class_target_sum=10000.0, study_target_sum=1000000.0)[source]#
Get weighted random sampler.
- Parameters:
labels (list) – The list of labels.
studies (list, optional, default: None) – A list of studies to incorporate studies in sampler weights.
class_target_sum (float, default: 1e4) – Target sum for normalization of class counts.
study_target_sum (float, default: 1e6) – Target sum for normalization of study counts.
- Returns:
A WeightedRandomSampler object.
- Return type:
WeightedRandomSampler
- remove_singleton_label_ids(data, n_studies=2)[source]#
Ensure labels exist in at least a minimum number of studies.
- Parameters:
data (anndata.AnnData) – Annotated data to subset by valid ontology id.
n_studies (int, default: 2) – The number of studies a label must exist in to be valid.
- Return type:
anndata.AnnData
- subset_valid_terms(data)[source]#
Keep cells whose celltype labels have valid ontology id.
- Parameters:
data (anndata.AnnData) – Annotated data to subset by valid ontology id.
- Returns:
An object containing the data whose celltype labels have valid ontology id.
- Return type:
anndata.AnnData
- test_dataloader()[source]#
Load the test dataset.
- Returns:
A DataLoader object containing the test dataset.
- Return type:
DataLoader
- class scimilarity.anndata_data_models.scCollator(label2int, sparse=False)[source]#
Bases:
object
A class to collate batch data.
- Parameters:
label2int (dict) – A dictionary that maps string labels to class integers.
sparse (bool, default: False) – Use sparse matrices.
- class scimilarity.anndata_data_models.scDataset(X, Y, study=None)[source]#
Bases:
Dataset
A class that represents a single cell dataset.
- Parameters:
X (numpy.ndarray) – Gene expression vectors for every cell.
Y (numpy.ndarray) – Text labels for every cell.
study (numpy.ndarray) – The study identifier for every cell.