Source code for oml.samplers.category_balance

import math
from collections import Counter
from typing import Dict, Iterator, List, Union

import numpy as np

from oml.interfaces.samplers import IBatchSampler
from oml.utils.misc import smart_sample


[docs]class CategoryBalanceSampler(IBatchSampler): """ This sampler takes ``n_instances`` for each of the ``n_labels`` for each of the ``n_categories`` to form the batches. Thus, the batch size is ``n_instances x n_labels x n_categories``. Note, to form an epoch of batches we simply sample ``L / n_labels`` batches with repetition. """
[docs] def __init__( self, labels: Union[List[int], np.ndarray], label2category: Dict[int, Union[str, int]], n_categories: int, n_labels: int, n_instances: int, resample_labels: bool = False, weight_categories: bool = True, ): """ Args: labels: Labels to sample from label2category: Mapping from label to category n_categories: The desired number of categories to sample for each batch n_labels: The desired number of labels to sample for each category in batch n_instances: The desired number of samples to sample for each label in batch resample_labels: If ``True`` sample with repetition otherwise, otherwise raise an error in case of the labels lack in any category weight_categories: If ``True`` sample categories for each batch with weights proportional to the number of unique labels in the categories """ unique_labels = set(labels) unique_categories = set(label2category.values()) category2labels = { category: {label for label, cat in label2category.items() if category == cat} for category in unique_categories } for param in [n_categories, n_labels, n_instances]: if not isinstance(param, int): raise TypeError(f"{param.__name__} must be int, {type(param)} given") if not 1 <= n_categories <= len(unique_categories): raise ValueError(f"must be 1 <= n_categories <= {len(unique_categories)}, {n_categories} given") if not 1 < n_labels <= len(unique_labels): raise ValueError(f"must be 1 < n_labels <= {len(unique_labels)}, {n_labels} given") if n_instances <= 1: raise ValueError(f"must be not less than 1, {n_instances} given") if any(label not in label2category.keys() for label in unique_labels): raise ValueError("All the labels must have category") if any(label not in unique_labels for label in label2category.keys()): raise ValueError("All the labels from label2category mapping must be in the labels") if any(n <= 1 for n in Counter(labels).values()): raise ValueError("Each class must contain at least 2 instances to fit") if not resample_labels: if any(len(list(labs)) < n_labels for labs in category2labels.values()): raise ValueError(f"All the categories must have at least {n_labels} unique labels") unique_categories = sorted(list(unique_categories)) self._labels = np.array(labels) self._label2category = label2category self.n_categories = n_categories self.n_labels = n_labels self.n_instances = n_instances self._batch_size = self.n_categories * self.n_labels * self.n_instances self._weight_categories = weight_categories self._label2index = { label: np.arange(len(self._labels))[self._labels == label].tolist() for label in sorted(list(unique_labels)) } self._category2labels = { category: {label for label, cat in self._label2category.items() if category == cat} for category in unique_categories } category_weights = {cat: len(labels) / len(unique_labels) for cat, labels in self._category2labels.items()} self._category_weights = ( [category_weights[cat] for cat in unique_categories] if self._weight_categories else None ) self._batch_number = math.ceil(len(unique_labels) / self.n_labels)
@property def batch_size(self) -> int: return self._batch_size def __len__(self) -> int: return self._batch_number def __iter__(self) -> Iterator[List[int]]: epoch_indices = [] for _ in range(self._batch_number): categories = np.random.choice( list(self._category2labels.keys()), size=self.n_categories, replace=False, p=self._category_weights, ) batch_indices = [] for category in categories: labels_available = list(self._category2labels[category]) labels = smart_sample(array=labels_available, k=self.n_labels) for label in labels: indices = self._label2index[label] samples_indices = smart_sample(array=indices, k=self.n_instances) batch_indices.extend(samples_indices) epoch_indices.append(batch_indices) return iter(epoch_indices)
__all__ = ["CategoryBalanceSampler"]