scimilarity.training_models#
- class scimilarity.training_models.MetricLearning(*args, **kwargs)[source]#
Bases:
LightningModule
A class encapsulating the metric learning.
- Parameters:
n_genes (int) –
latent_dim (int) –
hidden_dim (List[int]) –
dropout (float) –
input_dropout (float) –
triplet_loss_weight (float) –
margin (float) –
negative_selection (str) –
sample_across_studies (bool) –
perturb_labels (bool) –
perturb_labels_fraction (float) –
lr (float) –
l1 (float) –
l2 (float) –
max_epochs (int) –
residual (bool) –
- forward(x)[source]#
Forward.
- Parameters:
x (torch.Tensor) – Input tensor corresponding to input layer.
- Returns:
z (torch.Tensor) – Output tensor corresponding to the last encoder layer.
x_hat (torch.Tensor) – Output tensor corresponding to the last decoder layer.
- get_losses(batch, use_studies=True)[source]#
Calculate the triplet and reconstruction loss.
- Parameters:
batch – A batch as defined by a pytorch DataLoader.
use_studies (bool, default: True) – Whether to use studies metadata in mining triplets and calculating triplet loss
- Returns:
triplet_loss (torch.Tensor) – Triplet loss.
mse (torch.Tensor) – MSE reconstruction loss
num_hard_triplets (torch.Tensor) – Number of hard triplets.
num_viable_triplets (torch.Tensor) – Number of viable triplets.
- get_mixed_loss(triplet_loss, mse)[source]#
Calculate the mixed loss.
- Parameters:
triplet_loss (torch.Tensor) – Triplet loss.
mse (torch.Tensor) – MSE reconstruction loss
- Returns:
Mixed loss.
- Return type:
torch.Tensor
- load_state(encoder_filename, decoder_filename, use_gpu=False, freeze=False)[source]#
Load model state.
- Parameters:
encoder_filename (str) – Filename containing the encoder model state.
decoder_filename (str) – Filename containing the decoder model state.
use_gpu (bool, default: False) – Boolean indicating whether or not to use GPUs.
freeze (bool, default: False) – Freeze all but bottleneck layer, used if pretraining the encoder.
- test_step(batch, batch_idx)[source]#
Pytorch-lightning test step.
- Parameters:
batch – A batch as defined by a pytorch DataLoader.
batch_idx – A batch index as defined by a pytorch-lightning.
- class scimilarity.training_models.TripletLoss(margin, sample_across_studies=True, negative_selection='semihard', perturb_labels=True, perturb_labels_fraction=0.5)[source]#
Bases:
TripletMarginLoss
Wrapper for pytorch TripletMarginLoss. Triplets are generated using TripletSelector object which take embeddings and labels then return triplets.
- Parameters:
margin (float) –
sample_across_studies (bool) –
negative_selection (str) –
perturb_labels (bool) –
perturb_labels_fraction (float) –
- forward(embeddings, labels, int2label, studies)[source]#
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.- Parameters:
embeddings (Tensor) –
labels (Tensor) –
int2label (dict) –
studies (Tensor) –