Source code for oml.samplers.balance

from collections import Counter, defaultdict
from typing import Iterator, List, Union

import numpy as np

from oml.interfaces.samplers import IBatchSampler
from oml.utils.misc import smart_sample


[docs]class BalanceSampler(IBatchSampler): """ This sampler takes ``n_instances`` for each of the ``n_labels`` to form the batches. Thus, the batch size is ``n_instances x n_labels``. This type of sampling can be found in the classical Person Re-Id paper - `In Defense of the Triplet Loss for Person Re-Identification`_. .. _In Defense of the Triplet Loss for Person Re-Identification: https://arxiv.org/abs/1703.07737 The strategy for the dataset with ``L`` unique labels is the following: - Select ``n_labels`` of ``L`` labels for the 1st batch - Select ``n_instances`` for each label for the 1st batch - Select ``n_labels`` of ``L - n_labels`` remaining labels for 2nd batch - Select ``n_instances`` instances for each label for the 2nd batch - ... - The epoch ends after ``L // n_labels``. Thus, in each epoch, all the labels will be selected once, but this does not mean that all the instances will be picked. Behavior in corner cases: - If some label does not contain ``n_instances``, a choice will be made with repetition. - If ``L % n_labels != 0`` then we drop the last batch. """
[docs] def __init__(self, labels: Union[List[int], np.ndarray], n_labels: int, n_instances: int): """ Args: labels: List of the labels for each element in the dataset n_labels: The desired number of labels in a batch, should be > 1 n_instances: The desired number of instances of each label in a batch, should be > 1 """ unq_labels = set(labels) assert isinstance(n_labels, int) and isinstance(n_instances, int) assert (1 < n_labels <= len(unq_labels)) and (1 < n_instances) assert all(n > 1 for n in Counter(labels).values()), "Each label should contain at least 2 samples" self._labels = np.array(labels) self.n_labels = n_labels self.n_instances = n_instances self._batch_size = self.n_labels * self.n_instances self._unq_labels = unq_labels labels = np.array(labels) lbl2idx = defaultdict(list) for idx, label in enumerate(labels): lbl2idx[label].append(idx) self.lbl2idx = dict(lbl2idx) self._batches_in_epoch = len(self._unq_labels) // self.n_labels
@property def batch_size(self) -> int: return self._batch_size def __len__(self) -> int: return self._batches_in_epoch def __iter__(self) -> Iterator[List[int]]: inds_epoch = [] labels_rest = self._unq_labels.copy() for _ in range(len(self)): ids_batch = [] labels_for_batch = set( np.random.choice(list(labels_rest), size=min(self.n_labels, len(labels_rest)), replace=False) ) labels_rest -= labels_for_batch for cls in labels_for_batch: cls_ids = self.lbl2idx[cls] selected_inds = smart_sample(cls_ids, self.n_instances) ids_batch.extend(selected_inds) inds_epoch.append(ids_batch) return iter(inds_epoch) # type: ignore
__all__ = ["BalanceSampler"]