from collections import defaultdict
from typing import Any, Dict, List, Optional, Tuple
import torch
from torch import Tensor, no_grad
from oml.interfaces.miners import ITripletsMiner
from oml.miners.inbatch_nhard_tri import NHardTripletsMiner
from oml.utils.misc_torch import OnlineAvgDict
[docs]class MinerWithBank(ITripletsMiner):
"""
This is a class for cross-batch memory. This implementation uses only samples from the current batch as anchors
and finds positive and negative pairs from the bank and the current batch using miner.
"""
[docs] def __init__(
self,
bank_size_in_batches: int,
miner: NHardTripletsMiner,
need_logs: bool = True,
):
"""
Args:
bank_size_in_batches: Size of the bank.
miner: Miner, for now we only support ``NHardTripletsMiner``
need_logs: Set ``True`` to store some information to track in ``self.last_logs`` property.
"""
assert isinstance(bank_size_in_batches, int)
assert bank_size_in_batches >= 1
assert isinstance(miner, NHardTripletsMiner)
self.miner = miner
self.bank_size_in_batches = bank_size_in_batches
self.bank_features: Optional[Tensor] = None
self.bank_labels: Optional[Tensor] = None
self.bank_size = -1 # bank size in the number of vectors
self.is_accumulated = False # indicates if we filled out the bank
# We maintain a queue of feature vectors and ptr indicates the place to insert another batch.
# After the batch has been inserted, we move the pointer.
self.ptr = 0
self.need_logs = need_logs
self._last_logs: Dict[str, float] = {}
@no_grad()
def __allocate_if_needed(self, features: Tensor, labels: Tensor) -> None:
if self.bank_features is None:
assert len(features) == len(labels)
bs = features.shape[0]
self.feat_dim = features.shape[-1]
self.bank_size = self.bank_size_in_batches * bs
self.bank_labels = torch.empty(self.bank_size, dtype=torch.long, device=features.device)
self.bank_features = torch.empty(
(self.bank_size, self.feat_dim), dtype=features.dtype, device=features.device
)
@no_grad()
def _update_bank(self, features: Tensor, labels: Tensor) -> None:
bs = features.shape[0]
if not self.is_accumulated:
self.is_accumulated = self.ptr + bs >= self.bank_size
self.bank_features[self.ptr : self.ptr + bs] = features.clone().detach()
self.bank_labels[self.ptr : self.ptr + bs] = labels.clone()
self.ptr = (self.ptr + bs) % self.bank_size
[docs] def sample(self, features: Tensor, labels: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
"""
Args:
features: Features with the shape ``[batch_size, features_dim]``
labels: Labels with the size of ``batch_size``
Returns:
Batch of triplets in the following order: anchor, positive, negative
"""
if isinstance(labels, (list, tuple)):
labels = torch.tensor(labels, dtype=torch.long, device=features.device)
self.__allocate_if_needed(features=features, labels=labels)
if self.is_accumulated:
len_batch = len(labels)
features_miner = torch.cat([features, self.bank_features], dim=0)
labels_miner = torch.cat([labels, self.bank_labels], dim=0)
ignore_anchor_mask = torch.zeros(len(labels_miner), dtype=torch.bool, device=features.device)
ignore_anchor_mask[len_batch:] = True
else:
features_miner = features
labels_miner = labels
ignore_anchor_mask = torch.zeros(len(labels_miner), dtype=torch.bool, device=features.device)
ids_a, ids_p, ids_n = self.miner._sample(
features=features_miner, labels=labels_miner, ignore_anchor_mask=ignore_anchor_mask
)
if self.need_logs:
self._last_logs = self._prepare_logs(
ids_a=ids_a, ids_p=ids_p, ids_n=ids_n, ignore_anchor_mask=ignore_anchor_mask
)
self._update_bank(features=features, labels=labels)
return features_miner[ids_a], features_miner[ids_p], features_miner[ids_n]
@staticmethod
def _prepare_logs(
ids_a: List[int], ids_p: List[int], ids_n: List[int], ignore_anchor_mask: Tensor
) -> Dict[str, float]:
logs = OnlineAvgDict()
unq_triplets = set(zip(ids_a, ids_p, ids_n))
ids_anchor2positives = defaultdict(set)
ids_anchor2negatives = defaultdict(set)
for anch, pos, neg in unq_triplets:
ids_anchor2positives[anch].add(int(pos))
ids_anchor2negatives[anch].add(int(neg))
for anch in ids_anchor2positives.keys():
positives = ids_anchor2positives[anch]
positives_from_bank = ignore_anchor_mask[list(positives)].sum().item()
positives_from_batch = len(positives) - positives_from_bank
logs.update({"positives_from_bank": positives_from_bank, "positives_from_batch": positives_from_batch})
negatives = ids_anchor2negatives[anch]
negatives_from_bank = ignore_anchor_mask[list(negatives)].sum().item()
negatives_from_batch = len(negatives) - negatives_from_bank
logs.update({"negatives_from_bank": negatives_from_bank, "negatives_from_batch": negatives_from_batch})
return logs.get_dict_with_results()
@property
def last_logs(self) -> Dict[str, Any]:
"""
Returns:
Dictionary containing useful statistic calculated for the last batch.
"""
return self._last_logs
__all__ = ["MinerWithBank"]