Source code for oml.miners.inbatch_nhard_tri

from typing import Any, List, Optional, Tuple, Union

import numpy as np
import torch
from torch import Tensor

from oml.interfaces.miners import ITripletsMinerInBatch, TTripletsIds
from oml.utils.misc_torch import pairwise_dist, take_2d


[docs]class NHardTripletsMiner(ITripletsMinerInBatch): """ This miner selects hard triplets based on distances between features: - hard `positive` samples have large distance to the anchor sample - hard `negative` samples have small distance to the anchor sample Toward the end of the training, annotation errors can affect final metric. If you are not sure about the quality of your dataset, you can use range instead of integer value for parameters and exclude combinations with the largest distances. For example instead picking 5 positive examples, you can use examples from the 2nd hardest to the 5th one. """
[docs] def __init__( self, n_positive: Union[Tuple[int, int], List[int], int] = 1, n_negative: Union[Tuple[int, int], List[int], int] = 1, ): """ Args: n_positive: keep ``n_positive`` positive samples with large distances. If the value is a range, minimal value has to be less than the available amount of labels in batches n_negative: keep ``n_negative`` negative pipelines with small distances Note: If both parameters are 1, the miner is equivalent to ``HardTripletsMiner``. If both parameters are large enough, the miner can be equivalent to ``AllTripletsMiner`` """ self.positive_slice = slice(*self._parse_input_arg(n_positive)) self.negative_slice = slice(*self._parse_input_arg(n_negative))
@staticmethod def _parse_input_arg(n: Union[Tuple[int, int], List[int], int]) -> List[int]: if isinstance(n, int): n = [0, n] elif isinstance(n, (list, tuple)): n = list(n) n[0] -= 1 else: raise TypeError("Unsupported type of argument. Must be int, tuple or list") assert n[1] > n[0] assert n[0] >= 0 return n def _sample( self, features: Tensor, labels: List[int], *_: Any, ignore_anchor_mask: Optional[Union[List[int], Tensor, np.ndarray]] = None, ) -> TTripletsIds: """ This method samples the hard triplets inside the batch based on distance between features. Args: features: Features with the shape of ``[batch_size, feature_size]`` labels: Labels with the size of ``batch_size`` *_: Anything here will be ignored ignore_anchor_mask: Parameter allows you to specify the ids of features that cannot be used as anchors. Can be useful with memory banks to create triplets in which at least one vector will have gradients. Returns: The batch of the triplets in the order below: ``(anchor, positive, negative)`` """ assert features.shape[0] == len(labels) dist_mat = pairwise_dist(x1=features, x2=features, p=2) ids_anchor, ids_pos, ids_neg = self._sample_from_distmat( distmat=dist_mat, labels=labels, ignore_anchor_mask=ignore_anchor_mask ) return ids_anchor, ids_pos, ids_neg def _sample_from_distmat( self, distmat: Tensor, labels: List[int], *_: Any, ignore_anchor_mask: Optional[Union[List[int], Tensor, np.ndarray]] = None, ) -> TTripletsIds: if isinstance(labels, (list, tuple)): labels = torch.tensor(labels, device=distmat.device).long() if ignore_anchor_mask is None: ignore_anchor_mask = torch.zeros(len(distmat), dtype=torch.bool, device=distmat.device) all_ids_reduced = torch.arange(len(labels), device=distmat.device)[torch.logical_not(ignore_anchor_mask)] distmat_reduced = distmat[torch.logical_not(ignore_anchor_mask)] ids_highest_distance = torch.argsort(distmat_reduced, descending=True, dim=1) mask_same_label = labels[:, None] == labels[None, :] # type: ignore mask_same_label.fill_diagonal_(False) mask_same_label = mask_same_label[torch.logical_not(ignore_anchor_mask)] mask_same_label_sorted_by_dist = take_2d(mask_same_label, ids_highest_distance) idx_anch_pos_reduced, idx_pos_sorted_by_dist = torch.nonzero(mask_same_label_sorted_by_dist, as_tuple=True) hardest_positive = ids_highest_distance[idx_anch_pos_reduced, idx_pos_sorted_by_dist] idx_anch_pos = all_ids_reduced[idx_anch_pos_reduced] ids_smallest_distance = torch.flip(ids_highest_distance, dims=(1,)) mask_diff_label = labels[:, None] != labels[None, :] # type: ignore mask_diff_label = mask_diff_label[torch.logical_not(ignore_anchor_mask)] diff_label_sorted_by_dist = take_2d(mask_diff_label, ids_smallest_distance) idx_anch_neg_reduced, idx_neg_sorted_by_dist = torch.nonzero(diff_label_sorted_by_dist, as_tuple=True) hardest_negative = ids_smallest_distance[idx_anch_neg_reduced, idx_neg_sorted_by_dist] idx_anch_neg = all_ids_reduced[idx_anch_neg_reduced] ids_a: List[int] = [] ids_p: List[int] = [] ids_n: List[int] = [] for idx_anch in torch.arange(len(labels), device=distmat.device)[torch.logical_not(ignore_anchor_mask)]: positives = hardest_positive[idx_anch_pos == idx_anch][self.positive_slice] negatives = hardest_negative[idx_anch_neg == idx_anch][self.negative_slice] i_pos, i_neg = list(zip(*torch.cartesian_prod(positives, negatives).tolist())) ids_a.extend([idx_anch.item()] * len(i_pos)) ids_p.extend(i_pos) ids_n.extend(i_neg) return ids_a, ids_p, ids_n
__all__ = ["NHardTripletsMiner"]