Source code for oml.miners.inbatch_all_tri

from itertools import combinations, product
from operator import itemgetter
from random import sample
from sys import maxsize
from typing import List

import numpy as np
import torch

from oml.interfaces.miners import ITripletsMinerInBatch, TLabels, TTripletsIds
from oml.utils.misc import find_value_ids


[docs]class AllTripletsMiner(ITripletsMinerInBatch): """ This miner selects all the possible triplets for the given batch. """
[docs] def __init__(self, max_output_triplets: int = maxsize, device: str = "cpu"): """ Args: max_output_triplets: Number of all of the possible triplets in the batch can be very large, so we can limit them vis this parameter. device: the device where to perform computations. """ self._max_out_triplets = max_output_triplets self._device = device
def _sample(self, *_: torch.Tensor, labels: List[int]) -> TTripletsIds: # type: ignore """ Args: labels: Labels of the samples in the batch *_: Note, that we ignore features argument Returns: Indices of the triplets """ return get_available_triplets(labels, max_out_triplets=self._max_out_triplets, device=self._device)
def get_available_triplets(labels: TLabels, max_out_triplets: int = maxsize, device: str = "cpu") -> TTripletsIds: """ Generate a random subset of size ``max_out_triplets`` from the set of all possible triplets (anchor, positive, negative). Args: labels: Labels of the samples. max_out_triplets: maximal number of triplets to generate device: the desired device of returned tensor Returns: lists of indexes of the anchors, positives and negatives. """ labels = torch.tensor(labels, device=device, requires_grad=False) # First we need to get initial ingridients of triplets: # the arrays of all (anchor, positive) and (anchor, negavite) pairs mask_same_label: torch.Tensor = labels[:, None] == labels[None, :] idx_anch_neg = torch.nonzero(torch.logical_not(mask_same_label)) mask_same_label.fill_diagonal_(False) idx_anch_pos = torch.nonzero(mask_same_label) del mask_same_label # Next we need to group (anchor, negative) pairs # with respect to the anchor. As a result we get an array called neg_groups, # such that neg_groups[i] is all the negatives, that come along with # the anchor number i. # For simplicity, lets assume that # idx_anch_pos = [[0, 1], # [0, 2], # [1, 3], # [1, 4]] # and idx_anch_neg = [[0, 5], # [0, 6], # [1, 7]] # [1, 8]] # We want to get the following correspondence: # 0 -> [5, 6] # 1 -> [7, 8] # We know that idx_neg is sorted lexicographically. So, we can fint the number # of all unique anchors in idx_anch_neg, and use torch.split to get the desired # groups. _, neg_counts = torch.unique(idx_anch_neg[:, 0], return_counts=True, sorted=True) neg_groups = torch.split(idx_anch_neg[:, 1], neg_counts.tolist()) # Now for a pair (anchor, positive)we have # neg_group[i], where i is the number of the anchor, with all # the negatives that come along with that anchor. # In order to get all possible triplets that contains (anchor, positive) # we just need to repeat it as much as we have negative elements # in the neg_groups[i]. # For example above, we want to generate the following triplets: # # we need to repeat [0, 1] as much as there are | [0, 1, 5] | # negatives that come along with anchor 0 | [0, 1, 6] | and we repeat [5, 6] as much # [0, 2, 5] | as there are anchor 0 in idx_anch_pos # [0, 2, 6] | # [1, 3, 7] # [1, 3, 8] # [1, 4, 7] # [1, 4, 8] # Repeat the negatives as much as there are correspondent anchors. idx_neg = torch.cat(tuple(itemgetter(*idx_anch_pos[:, 0].cpu().tolist())(neg_groups))) # Repeat the positive pairs as much as there are correspondent netatives and get idx_anch, idx_pos. n_negs_in_group = [len(g) for g in neg_groups] n_repeats = torch.tensor(tuple(itemgetter(*idx_anch_pos[:, 0].cpu().numpy())(n_negs_in_group))).to(idx_anch_pos) idx_anch_pos = idx_anch_pos.repeat_interleave(n_repeats, dim=0) idx_anch, idx_pos = idx_anch_pos.T # Choose randomly the required number of triplets. random_idx = np.random.choice(len(idx_anch), min(len(idx_anch), max_out_triplets), replace=False) random_idx = torch.tensor(random_idx).to(device) idx_anch = idx_anch[random_idx] idx_pos = idx_pos[random_idx] idx_neg = idx_neg[random_idx] idx_anch = idx_anch.tolist() idx_pos = idx_pos.tolist() idx_neg = idx_neg.tolist() return idx_anch, idx_pos, idx_neg def get_available_triplets_naive(labels: List[int], max_out_triplets: int = maxsize) -> TTripletsIds: """ For each label, the function generates all possible positive and negative pairs of triplets. """ num_labels = len(labels) triplets = [] for label in set(labels): ids_pos_cur = find_value_ids(labels, label) ids_neg_cur = set(range(num_labels)) - set(ids_pos_cur) # (l0, l1, n) and (l1, l0, n) are 2 different triplets # and we want both of them pos_pairs = list(combinations(ids_pos_cur, r=2)) + list(combinations(ids_pos_cur[::-1], r=2)) tri = [(a, p, n) for (a, p), n in product(pos_pairs, ids_neg_cur)] triplets.extend(tri) triplets = sample(triplets, min(len(triplets), max_out_triplets)) ids_anchor, ids_pos, ids_neg = zip(*triplets) return list(ids_anchor), list(ids_pos), list(ids_neg) __all__ = ["AllTripletsMiner", "get_available_triplets"]