scimilarity.triplet_selector#
- class scimilarity.triplet_selector.TripletLoss(margin, sample_across_studies=True, negative_selection='semihard', perturb_labels=False, perturb_labels_fraction=0.5)[source]#
Bases:
TripletMarginLoss
Wrapper for pytorch TripletMarginLoss. Triplets are generated using TripletSelector object which take embeddings and labels then return triplets.
- Parameters:
margin (float) – Triplet loss margin.
sample_across_studies (bool, default: True) – Whether to enforce anchor-positive pairs being from different studies.
negative_selection (str) – Method for negative selection: {“semihard”, “hardest”, “random”}
perturb_labels (bool, default: False) – Whether to perturb the ontology labels by coarse graining one level up.
perturb_labels_fraction (float, default: 0.5) – The fraction of labels to perturb
Examples
>>> triplet_loss = TripletLoss(margin=0.05)
- forward(embeddings, labels, int2label, studies)[source]#
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.- Parameters:
embeddings (Tensor)
labels (Tensor)
int2label (dict)
studies (Tensor)
- class scimilarity.triplet_selector.TripletSelector(margin, negative_selection='semihard', perturb_labels=False, perturb_labels_fraction=0.5)[source]#
Bases:
object
For each anchor-positive pair, mine negative samples to create a triplet.
- Parameters:
margin (float) – Triplet loss margin.
negative_selection (str, default: "semihard") – Method for negative selection: {“semihard”, “hardest”, “random”}.
perturb_labels (bool, default: False) – Whether to perturb the ontology labels by coarse graining one level up.
perturb_labels_fraction (float, default: 0.5) – The fraction of labels to perturb.
Examples
>>> triplet_selector = TripletSelector(margin=0.05, negative_selection="semihard")
- get_asw(embeddings, labels, int2label, metric='cosine')[source]#
- Get the average silhouette width of celltypes, being aware of cell ontology such that
ancestors are not considered inter-cluster and descendants are considered intra-cluster.
- Parameters:
embeddings (numpy.ndarray, torch.Tensor) – Cell embeddings.
labels (List[str]) – Celltype names.
int2label (dict) – Dictionary to map labels in integer form to string
metric (str, default: "cosine") – The distance metric to use for scipy.spatial.distance.cdist().
- Returns:
asw – The average silhouette width.
- Return type:
float
Examples
>>> asw = ontology_silhouette_width(embeddings, labels, metric="cosine")
- get_triplets_idx(embeddings, labels, int2label, studies=None)[source]#
Get triplets as anchor, positive, and negative cell indices.
- Parameters:
embeddings (numpy.ndarray, torch.Tensor) – Cell embeddings.
labels (numpy.ndarray, torch.Tensor) – Cell labels in integer form.
int2label (dict) – Dictionary to map labels in integer form to string
studies (numpy.ndarray, torch.Tensor, optional, default: None) – Studies metadata for each cell.
- Returns:
triplets (Tuple[List, List, List]) – A tuple of lists containing anchor, positive, and negative cell indices.
num_hard_triplets (int) – Number of hard triplets.
num_viable_triplets (int) – Number of viable triplets.
)
- hardest_negative(loss_values)[source]#
Get hardest negative.
- Parameters:
loss_values (numpy.ndarray) – Triplet loss of all negatives for given anchor positive pair.
- Returns:
Index of selection.
- Return type:
int
- pdist(vectors)[source]#
Get pair-wise distance between all cell embeddings.
- Parameters:
vectors (numpy.ndarray) – Cell embeddings.
- Returns:
Distance matrix of cell embeddings.
- Return type:
numpy.ndarray