scimilarity.training_models#
- class scimilarity.training_models.MetricLearning(*args, **kwargs)[source]#
Bases:
LightningModule
A class encapsulating the metric learning.
- Parameters:
n_genes (int) – The number of genes in the gene space, representing the input dimensions.
latent_dim (int, default: 128) – The latent space dimensions
hidden_dim (List[int], default: [1024, 1024]) – A list of hidden layer dimensions, describing the number of layers and their dimensions. Hidden layers are constructed in the order of the list for the encoder and in reverse for the decoder.
dropout (float, default: 0.5) – The dropout rate for hidden layers
input_dropout (float, default: 0.4) – The dropout rate for the input layer
triplet_loss_weight (float, default 0.001) – The weighting for triplet loss vs reconstruction loss. This weighting sums to 1 such that triplet loss weight is triplet_loss_weight and reconstruction loss weight is (1 - triplet_loss_weight).
margin (float, default: 0.05) – The margin parameter in triplet loss.
negative_selection ({"semihard", "hardest", "random"}, default: "semihard") – The negative selection function.
sample_across_studies (bool, default: True) – Whether to enforce anchor-positive pairs being from different studies.
perturb_labels (bool, default: False) – Whether to perturb celltype labels by coarse graining the label based on cell ontology.
perturb_labels_fraction (float, default: 0.5) – The fraction of cells per batch to perform label perturbation.
lr (float, default: 5e-3) – The initial learning rate
l1 (float, default: 1e-4) – The l1 penalty lambda. A value of 0 will disable l1 penalty.
l2 (float, default: 1e-2) – The l2 penalty lambda (weight decay). A value of 0 will disable l2 penalty.
max_epochs (int, default: 500) – The max epochs, used by the scheduler to determine lr annealing rate.
cosine_annealing_tmax (int, optional, default: None) – The number of epochs for T_max in cosine LR annealing. If None, use the max_epochs.
track_triplets (str, optional, default: None) – Track the triplet composition used in triplet loss and store the files in this directory.
track_triplets_above_step (int, default: -1,) – When tracking triplet composition, only track for global step above the given value.
Examples
>>> datamodule = MetricLearningZarrDataModule( batch_size=1000, num_workers=1, obs_field="celltype_name", train_path="train", gene_order="gene_order.tsv", ) >>> model = MetricLearning(datamodule.n_genes)
- forward(x)[source]#
Forward.
- Parameters:
x (torch.Tensor) – Input tensor corresponding to input layer.
- Returns:
z (torch.Tensor) – Output tensor corresponding to the last encoder layer.
x_hat (torch.Tensor) – Output tensor corresponding to the last decoder layer.
- get_losses(batch, use_studies=True, val_metrics=False)[source]#
Calculate the triplet and reconstruction loss.
- Parameters:
batch – A batch as defined by a pytorch DataLoader.
use_studies (bool, default: True) – Whether to use studies metadata in mining triplets and calculating triplet loss
val_metrics (bool, default: False) – Whether to include extra validation metrics
- Returns:
triplet_loss (torch.Tensor) – Triplet loss.
mse (torch.Tensor) – MSE reconstruction loss
num_hard_triplets (torch.Tensor) – Number of hard triplets.
num_viable_triplets (torch.Tensor) – Number of viable triplets.
- get_mixed_loss(triplet_loss, mse)[source]#
Calculate the mixed loss.
- Parameters:
triplet_loss (torch.Tensor) – Triplet loss.
mse (torch.Tensor) – MSE reconstruction loss
- Returns:
Mixed loss.
- Return type:
torch.Tensor
- load_state(encoder_filename, decoder_filename, use_gpu=False, freeze=False)[source]#
Load model state.
- Parameters:
encoder_filename (str) – Filename containing the encoder model state.
decoder_filename (str) – Filename containing the decoder model state.
use_gpu (bool, default: False) – Boolean indicating whether or not to use GPUs.
freeze (bool, default: False) – Freeze all but bottleneck layer, used if pretraining the encoder.
- test_step(batch, batch_idx)[source]#
Pytorch-lightning test step.
- Parameters:
batch – A batch as defined by a pytorch DataLoader.
batch_idx – A batch index as defined by a pytorch-lightning.