scimilarity.training_models#

class scimilarity.training_models.MetricLearning(*args, **kwargs)[source]#

Bases: LightningModule

A class encapsulating the metric learning.

Parameters:
  • n_genes (int) – The number of genes in the gene space, representing the input dimensions.

  • latent_dim (int, default: 128) – The latent space dimensions

  • hidden_dim (List[int], default: [1024, 1024]) – A list of hidden layer dimensions, describing the number of layers and their dimensions. Hidden layers are constructed in the order of the list for the encoder and in reverse for the decoder.

  • dropout (float, default: 0.5) – The dropout rate for hidden layers

  • input_dropout (float, default: 0.4) – The dropout rate for the input layer

  • triplet_loss_weight (float, default 0.001) – The weighting for triplet loss vs reconstruction loss. This weighting sums to 1 such that triplet loss weight is triplet_loss_weight and reconstruction loss weight is (1 - triplet_loss_weight).

  • margin (float, default: 0.05) – The margin parameter in triplet loss.

  • negative_selection ({"semihard", "hardest", "random"}, default: "semihard") – The negative selection function.

  • sample_across_studies (bool, default: True) – Whether to enforce anchor-positive pairs being from different studies.

  • perturb_labels (bool, default: False) – Whether to perturb celltype labels by coarse graining the label based on cell ontology.

  • perturb_labels_fraction (float, default: 0.5) – The fraction of cells per batch to perform label perturbation.

  • lr (float, default: 5e-3) – The initial learning rate

  • l1 (float, default: 1e-4) – The l1 penalty lambda. A value of 0 will disable l1 penalty.

  • l2 (float, default: 1e-2) – The l2 penalty lambda (weight decay). A value of 0 will disable l2 penalty.

  • max_epochs (int, default: 500) – The max epochs, used by the scheduler to determine lr annealing rate.

  • cosine_annealing_tmax (int, optional, default: None) – The number of epochs for T_max in cosine LR annealing. If None, use the max_epochs.

  • track_triplets (str, optional, default: None) – Track the triplet composition used in triplet loss and store the files in this directory.

  • track_triplets_above_step (int, default: -1,) – When tracking triplet composition, only track for global step above the given value.

Examples

>>> datamodule = MetricLearningZarrDataModule(
        batch_size=1000,
        num_workers=1,
        obs_field="celltype_name",
        train_path="train",
        gene_order="gene_order.tsv",
    )
>>> model = MetricLearning(datamodule.n_genes)
configure_optimizers()[source]#

Configure optimizers.

forward(x)[source]#

Forward.

Parameters:

x (torch.Tensor) – Input tensor corresponding to input layer.

Returns:

  • z (torch.Tensor) – Output tensor corresponding to the last encoder layer.

  • x_hat (torch.Tensor) – Output tensor corresponding to the last decoder layer.

get_losses(batch, use_studies=True, val_metrics=False)[source]#

Calculate the triplet and reconstruction loss.

Parameters:
  • batch – A batch as defined by a pytorch DataLoader.

  • use_studies (bool, default: True) – Whether to use studies metadata in mining triplets and calculating triplet loss

  • val_metrics (bool, default: False) – Whether to include extra validation metrics

Returns:

  • triplet_loss (torch.Tensor) – Triplet loss.

  • mse (torch.Tensor) – MSE reconstruction loss

  • num_hard_triplets (torch.Tensor) – Number of hard triplets.

  • num_viable_triplets (torch.Tensor) – Number of viable triplets.

get_mixed_loss(triplet_loss, mse)[source]#

Calculate the mixed loss.

Parameters:
  • triplet_loss (torch.Tensor) – Triplet loss.

  • mse (torch.Tensor) – MSE reconstruction loss

Returns:

Mixed loss.

Return type:

torch.Tensor

load_state(encoder_filename, decoder_filename, use_gpu=False, freeze=False)[source]#

Load model state.

Parameters:
  • encoder_filename (str) – Filename containing the encoder model state.

  • decoder_filename (str) – Filename containing the decoder model state.

  • use_gpu (bool, default: False) – Boolean indicating whether or not to use GPUs.

  • freeze (bool, default: False) – Freeze all but bottleneck layer, used if pretraining the encoder.

on_test_epoch_end()[source]#

Pytorch-lightning test epoch end evaluation.

on_test_epoch_start()[source]#

Pytorch-lightning test epoch start.

on_validation_epoch_end()[source]#

Pytorch-lightning validation epoch end evaluation.

on_validation_epoch_start()[source]#

Pytorch-lightning validation epoch start.

test_step(batch, batch_idx)[source]#

Pytorch-lightning test step.

Parameters:
  • batch – A batch as defined by a pytorch DataLoader.

  • batch_idx – A batch index as defined by a pytorch-lightning.

training_step(batch, batch_idx)[source]#

Pytorch-lightning training step.

Parameters:
  • batch – A batch as defined by a pytorch DataLoader.

  • batch_idx – A batch index as defined by a pytorch-lightning.

validation_step(batch, batch_idx)[source]#

Pytorch-lightning validation step.

Parameters:
  • batch – A batch as defined by a pytorch DataLoader.

  • batch_idx – A batch index as defined by a pytorch-lightning.