Source code for oml.samplers.distinct_category_balance

from collections import Counter, defaultdict
from copy import deepcopy
from typing import Dict, Iterator, List, Set, Union

import numpy as np

from oml.interfaces.samplers import IBatchSampler
from oml.utils.misc import smart_sample


[docs]class DistinctCategoryBalanceSampler(IBatchSampler): """ This sampler takes ``n_instances`` for each of the ``n_labels`` for each of the ``n_categories`` to form the batches. Thus, the batch size is ``n_instances x n_labels x n_categories``. The strategy for the dataset with ``L`` unique labels and ``C`` unique categories is the following: - Select ``n_categories`` of ``C`` for the 1st batch - Select ``n_labels`` for each of the chosen categories for the 1st batch - Select ``n_instances`` for each of the chosen labels for the 1st batch - Define the set of available for the 2nd batch labels ``L^``: these are all the labels ``L`` except the ones chosen for the 1st batch - Define set of available categories ``C^``: these are all the categories corresponding to labels from ``L^`` - Select ``n_categories`` from ``C^`` for the 2nd batch - Select ``n_labels`` for each category from ``L^`` for the 2nd batch - Select ``n_instances`` for each label for the 2nd batch - ... - Epoch ends after ``epoch_size`` steps Behavior in corner cases: - If all the categories were chosen before ``epoch_size`` steps, the sampler resets its state and goes on sampling from the first step. - If some class does not contain ``n_instances``, a choice will be made with repetition. - If the chosen category does not contain unused ``n_labels``, all the unused labels will be added to a batch and the missing ones will be sampled from the used labels without repetition. - If ``L % n_labels == 1`` then one of the labels must be dropped because we always want to have more than 1 label in a batch to be able to form positive pairs later on. """
[docs] def __init__( self, labels: Union[List[int], np.ndarray], label2category: Dict[int, Union[str, int]], n_categories: int, n_labels: int, n_instances: int, epoch_size: int, ): """ Args: labels: Labels to sample from label2category: Mapping from label to category n_categories: The desired number of categories to sample for each batch n_labels: The desired number of labels to sample for each category in batch n_instances: The desired number of samples to sample for each label in batch epoch_size: The desired number of batches in epoch """ unique_labels = set(labels) unique_categories = set(label2category.values()) category2labels = { category: {label for label, cat in label2category.items() if category == cat} for category in sorted(list(unique_categories)) } for param in [n_categories, n_labels, n_instances]: if not isinstance(param, int): raise TypeError(f"{param.__name__} must be int, {type(param)} given") if not 1 <= n_categories <= len(unique_categories): raise ValueError(f"must be 1 <= n_categories <= {len(unique_categories)}, {n_categories} given") if not 1 < n_labels <= len(unique_labels): raise ValueError(f"must be 1 < n_labels <= {len(unique_labels)}, {n_labels} given") if n_instances <= 1: raise ValueError(f"must be not less than 1, {n_instances} given") if any(label not in label2category.keys() for label in unique_labels): raise ValueError("All the labels must have category") if any(label not in unique_labels for label in label2category.keys()): raise ValueError("All the labels from label2category mapping must be in the labels") if any(n <= 1 for n in Counter(labels).values()): raise ValueError("Each class must contain at least 2 instances to fit") if any(len(list(labs)) < n_labels for labs in category2labels.values()): raise ValueError(f"All the categories must have at least {n_labels} unique labels") self._labels = np.array(labels) self._label2category = label2category self.n_categories = n_categories self.n_labels = n_labels self.n_instances = n_instances self._epoch_size = epoch_size self._batch_size = self.n_categories * self.n_labels * self.n_instances self._label2index = { label: np.arange(len(self._labels))[self._labels == label].tolist() for label in sorted(list(unique_labels)) } self._category2labels = category2labels
@property def batch_size(self) -> int: return self._batch_size def __len__(self) -> int: return self._epoch_size def __iter__(self) -> Iterator[List[int]]: category2labels = deepcopy(self._category2labels) used_labels: Dict[int, Set[int]] = defaultdict(set) epoch_indices = [] for _ in range(self._epoch_size): if len(category2labels) < self.n_categories: category2labels = deepcopy(self._category2labels) used_labels = defaultdict(set) categories_available = list(category2labels.keys()) categories = np.random.choice( categories_available, size=min(self.n_categories, len(categories_available)), replace=False ) batch_indices = [] for category in categories: labels_available = list(category2labels[category]) labels_available_number = len(labels_available) if self.n_labels <= labels_available_number: labels = np.random.choice(labels_available, size=self.n_labels, replace=False).tolist() else: labels = ( labels_available + np.random.choice( list(used_labels[category]), size=self.n_labels - labels_available_number, replace=False ).tolist() ) for label in labels: indices = self._label2index[label] samples_indices = smart_sample(array=indices, k=self.n_instances) batch_indices.extend(samples_indices) category2labels[category] -= set(labels) used_labels[category].update(labels) if not category2labels[category]: category2labels.pop(category) epoch_indices.append(batch_indices) return iter(epoch_indices)
__all__ = ["DistinctCategoryBalanceSampler"]