Source code for oml.miners.inbatch_hard_tri

from typing import List

import torch
from torch import Tensor

from oml.interfaces.miners import ITripletsMinerInBatch, TTripletsIds
from oml.utils.misc_torch import pairwise_dist


[docs]class HardTripletsMiner(ITripletsMinerInBatch): """ This miner selects the hardest triplets based on the distances between the features: - The hardest `positive` sample has the `maximal` distance to the anchor sample - The hardest `negative` sample has the `minimal` distance to the anchor sample """ def _sample(self, features: Tensor, labels: List[int]) -> TTripletsIds: """ This method samples the hardest triplets inside the batch. Args: features: Features with the shape of ``[batch_size, feature_size]`` labels: Labels with the size of ``batch_size`` Returns: The batch of the triplets in the order below: ``(anchor, positive, negative)`` """ assert features.shape[0] == len(labels) features = features.clone().detach() dist_mat = pairwise_dist(x1=features, x2=features, p=2) ids_anchor, ids_pos, ids_neg = self._sample_from_distmat(distmat=dist_mat, labels=labels) return ids_anchor, ids_pos, ids_neg @staticmethod def _sample_from_distmat(distmat: Tensor, labels: List[int]) -> TTripletsIds: """ This method samples the hardest triplets based on the given distances matrix. It chooses each sample in the batch as an anchor and then finds the hardest positive and negative pair. Args: distmat: Matrix with the shape of ``[batch_size, batch_size]`` labels: Labels with the size of ``batch_size`` Returns: The batch of the triplets with the size of ``batch_size``, order is the following: ``(anchor, positive, negative)`` """ labels = torch.tensor(labels) # Ensure labels are a torch tensor labels = labels.to(distmat.device) batch_size = labels.size(0) label_equal = labels.unsqueeze(0) == labels.unsqueeze(1) # Shape [batch_size, batch_size] label_not_equal = ~label_equal # Get the hardest positives: argmax over the distance matrix where labels match (i.e. hardest positive) dist_pos = distmat.clone() dist_pos[label_not_equal] = -float("inf") # Set non-positives to -inf hardest_pos_idx = torch.argmax(dist_pos, dim=1) # Get the hardest negatives: argmin over the distance matrix where labels don't match (i.e. hardest negative) dist_neg = distmat.clone() dist_neg[label_equal] = float("inf") # Set non-negatives to +inf hardest_neg_idx = torch.argmin(dist_neg, dim=1) # Return anchor indices, positive indices, and negative indices ids_anchor = list(range(batch_size)) ids_pos = hardest_pos_idx ids_neg = hardest_neg_idx return ids_anchor, ids_pos.cpu().tolist(), ids_neg.cpu().tolist()
__all__ = ["HardTripletsMiner"]