scimilarity.triplet_selector#
- class scimilarity.triplet_selector.TripletLoss(*args, **kwargs)[source]#
Bases:
TripletMarginLossWrapper for pytorch TripletMarginLoss. Triplets are generated using TripletSelector object which take embeddings and labels then return triplets.
- Parameters:
margin (float) – Triplet loss margin.
sample_across_studies (bool, default: True) – Whether to enforce anchor-positive pairs being from different studies.
negative_selection (str) – Method for negative selection: {“semihard”, “hardest”, “random”}
perturb_labels (bool, default: False) – Whether to perturb the ontology labels by coarse graining one level up.
perturb_labels_fraction (float, default: 0.5) – The fraction of labels to perturb
Examples
>>> triplet_loss = TripletLoss(margin=0.05)
- class scimilarity.triplet_selector.TripletSelector(margin, negative_selection='semihard', perturb_labels=False, perturb_labels_fraction=0.5)[source]#
Bases:
objectFor each anchor-positive pair, mine negative samples to create a triplet.
- Parameters:
margin (float) – Triplet loss margin.
negative_selection (str, default: "semihard") – Method for negative selection: {“semihard”, “hardest”, “random”}.
perturb_labels (bool, default: False) – Whether to perturb the ontology labels by coarse graining one level up.
perturb_labels_fraction (float, default: 0.5) – The fraction of labels to perturb.
Examples
>>> triplet_selector = TripletSelector(margin=0.05, negative_selection="semihard")
- get_asw(embeddings, labels, int2label, metric='cosine')[source]#
- Get the average silhouette width of celltypes, being aware of cell ontology such that
ancestors are not considered inter-cluster and descendants are considered intra-cluster.
- Parameters:
embeddings (numpy.ndarray, torch.Tensor) – Cell embeddings.
labels (List[str]) – Celltype names.
int2label (dict) – Dictionary to map labels in integer form to string
metric (str, default: "cosine") – The distance metric to use for scipy.spatial.distance.cdist().
- Returns:
asw – The average silhouette width.
- Return type:
float
Examples
>>> asw = ontology_silhouette_width(embeddings, labels, metric="cosine")
- get_triplets_idx(embeddings, labels, int2label, studies=None)[source]#
Get triplets as anchor, positive, and negative cell indices.
- Parameters:
embeddings (numpy.ndarray, torch.Tensor) – Cell embeddings.
labels (numpy.ndarray, torch.Tensor) – Cell labels in integer form.
int2label (dict) – Dictionary to map labels in integer form to string
studies (numpy.ndarray, torch.Tensor, optional, default: None) – Studies metadata for each cell.
- Returns:
triplets (Tuple[List, List, List]) – A tuple of lists containing anchor, positive, and negative cell indices.
num_hard_triplets (int) – Number of hard triplets.
num_viable_triplets (int) – Number of viable triplets.
)
- hardest_negative(loss_values)[source]#
Get hardest negative.
- Parameters:
loss_values (numpy.ndarray) – Triplet loss of all negatives for given anchor positive pair.
- Returns:
Index of selection.
- Return type:
int
- pdist(vectors)[source]#
Get pair-wise distance between all cell embeddings.
- Parameters:
vectors (numpy.ndarray) – Cell embeddings.
- Returns:
Distance matrix of cell embeddings.
- Return type:
numpy.ndarray