scimilarity.triplet_selector#

class scimilarity.triplet_selector.TripletLoss(margin, sample_across_studies=True, negative_selection='semihard', perturb_labels=False, perturb_labels_fraction=0.5)[source]#

Bases: TripletMarginLoss

Wrapper for pytorch TripletMarginLoss. Triplets are generated using TripletSelector object which take embeddings and labels then return triplets.

Parameters:
  • margin (float) – Triplet loss margin.

  • sample_across_studies (bool, default: True) – Whether to enforce anchor-positive pairs being from different studies.

  • negative_selection (str) – Method for negative selection: {“semihard”, “hardest”, “random”}

  • perturb_labels (bool, default: False) – Whether to perturb the ontology labels by coarse graining one level up.

  • perturb_labels_fraction (float, default: 0.5) – The fraction of labels to perturb

Examples

>>> triplet_loss = TripletLoss(margin=0.05)
forward(embeddings, labels, int2label, studies)[source]#

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Parameters:
  • embeddings (Tensor)

  • labels (Tensor)

  • int2label (dict)

  • studies (Tensor)

class scimilarity.triplet_selector.TripletSelector(margin, negative_selection='semihard', perturb_labels=False, perturb_labels_fraction=0.5)[source]#

Bases: object

For each anchor-positive pair, mine negative samples to create a triplet.

Parameters:
  • margin (float) – Triplet loss margin.

  • negative_selection (str, default: "semihard") – Method for negative selection: {“semihard”, “hardest”, “random”}.

  • perturb_labels (bool, default: False) – Whether to perturb the ontology labels by coarse graining one level up.

  • perturb_labels_fraction (float, default: 0.5) – The fraction of labels to perturb.

Examples

>>> triplet_selector = TripletSelector(margin=0.05, negative_selection="semihard")
get_asw(embeddings, labels, int2label, metric='cosine')[source]#
Get the average silhouette width of celltypes, being aware of cell ontology such that

ancestors are not considered inter-cluster and descendants are considered intra-cluster.

Parameters:
  • embeddings (numpy.ndarray, torch.Tensor) – Cell embeddings.

  • labels (List[str]) – Celltype names.

  • int2label (dict) – Dictionary to map labels in integer form to string

  • metric (str, default: "cosine") – The distance metric to use for scipy.spatial.distance.cdist().

Returns:

asw – The average silhouette width.

Return type:

float

Examples

>>> asw = ontology_silhouette_width(embeddings, labels, metric="cosine")
get_triplets_idx(embeddings, labels, int2label, studies=None)[source]#

Get triplets as anchor, positive, and negative cell indices.

Parameters:
  • embeddings (numpy.ndarray, torch.Tensor) – Cell embeddings.

  • labels (numpy.ndarray, torch.Tensor) – Cell labels in integer form.

  • int2label (dict) – Dictionary to map labels in integer form to string

  • studies (numpy.ndarray, torch.Tensor, optional, default: None) – Studies metadata for each cell.

Returns:

  • triplets (Tuple[List, List, List]) – A tuple of lists containing anchor, positive, and negative cell indices.

  • num_hard_triplets (int) – Number of hard triplets.

  • num_viable_triplets (int) – Number of viable triplets.

  • )

hardest_negative(loss_values)[source]#

Get hardest negative.

Parameters:

loss_values (numpy.ndarray) – Triplet loss of all negatives for given anchor positive pair.

Returns:

Index of selection.

Return type:

int

pdist(vectors)[source]#

Get pair-wise distance between all cell embeddings.

Parameters:

vectors (numpy.ndarray) – Cell embeddings.

Returns:

Distance matrix of cell embeddings.

Return type:

numpy.ndarray

random_negative(loss_values)[source]#

Get random negative.

Parameters:

loss_values (numpy.ndarray) – Triplet loss of all negatives for given anchor positive pair.

Returns:

Index of selection.

Return type:

int

semihard_negative(loss_values)[source]#

Get a random semihard negative.

Parameters:

loss_values (numpy.ndarray) – Triplet loss of all negatives for given anchor positive pair.

Returns:

Index of selection.

Return type:

int