scimilarity.training_models#

class scimilarity.training_models.MetricLearning(*args, **kwargs)[source]#

Bases: LightningModule

A class encapsulating the metric learning.

Parameters:
  • n_genes (int) –

  • latent_dim (int) –

  • hidden_dim (List[int]) –

  • dropout (float) –

  • input_dropout (float) –

  • triplet_loss_weight (float) –

  • margin (float) –

  • negative_selection (str) –

  • sample_across_studies (bool) –

  • perturb_labels (bool) –

  • perturb_labels_fraction (float) –

  • lr (float) –

  • l1 (float) –

  • l2 (float) –

  • max_epochs (int) –

  • residual (bool) –

configure_optimizers()[source]#

Configure optimizers.

forward(x)[source]#

Forward.

Parameters:

x (torch.Tensor) – Input tensor corresponding to input layer.

Returns:

  • z (torch.Tensor) – Output tensor corresponding to the last encoder layer.

  • x_hat (torch.Tensor) – Output tensor corresponding to the last decoder layer.

get_losses(batch, use_studies=True)[source]#

Calculate the triplet and reconstruction loss.

Parameters:
  • batch – A batch as defined by a pytorch DataLoader.

  • use_studies (bool, default: True) – Whether to use studies metadata in mining triplets and calculating triplet loss

Returns:

  • triplet_loss (torch.Tensor) – Triplet loss.

  • mse (torch.Tensor) – MSE reconstruction loss

  • num_hard_triplets (torch.Tensor) – Number of hard triplets.

  • num_viable_triplets (torch.Tensor) – Number of viable triplets.

get_mixed_loss(triplet_loss, mse)[source]#

Calculate the mixed loss.

Parameters:
  • triplet_loss (torch.Tensor) – Triplet loss.

  • mse (torch.Tensor) – MSE reconstruction loss

Returns:

Mixed loss.

Return type:

torch.Tensor

load_state(encoder_filename, decoder_filename, use_gpu=False, freeze=False)[source]#

Load model state.

Parameters:
  • encoder_filename (str) – Filename containing the encoder model state.

  • decoder_filename (str) – Filename containing the decoder model state.

  • use_gpu (bool, default: False) – Boolean indicating whether or not to use GPUs.

  • freeze (bool, default: False) – Freeze all but bottleneck layer, used if pretraining the encoder.

on_test_epoch_end()[source]#

Pytorch-lightning test epoch end evaluation.

on_test_epoch_start()[source]#

Pytorch-lightning test epoch start.

on_validation_epoch_end()[source]#

Pytorch-lightning validation epoch end evaluation.

on_validation_epoch_start()[source]#

Pytorch-lightning validation epoch start.

test_step(batch, batch_idx)[source]#

Pytorch-lightning test step.

Parameters:
  • batch – A batch as defined by a pytorch DataLoader.

  • batch_idx – A batch index as defined by a pytorch-lightning.

training_step(batch, batch_idx)[source]#

Pytorch-lightning training step.

Parameters:
  • batch – A batch as defined by a pytorch DataLoader.

  • batch_idx – A batch index as defined by a pytorch-lightning.

validation_step(batch, batch_idx)[source]#

Pytorch-lightning validation step.

Parameters:
  • batch – A batch as defined by a pytorch DataLoader.

  • batch_idx – A batch index as defined by a pytorch-lightning.

class scimilarity.training_models.TripletLoss(margin, sample_across_studies=True, negative_selection='semihard', perturb_labels=True, perturb_labels_fraction=0.5)[source]#

Bases: TripletMarginLoss

Wrapper for pytorch TripletMarginLoss. Triplets are generated using TripletSelector object which take embeddings and labels then return triplets.

Parameters:
  • margin (float) –

  • sample_across_studies (bool) –

  • negative_selection (str) –

  • perturb_labels (bool) –

  • perturb_labels_fraction (float) –

forward(embeddings, labels, int2label, studies)[source]#

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Parameters:
  • embeddings (Tensor) –

  • labels (Tensor) –

  • int2label (dict) –

  • studies (Tensor) –