from typing import Dict, List, Tuple, Union
from torch import Tensor
from oml.const import INDEX_KEY, INPUT_TENSORS_KEY_1, INPUT_TENSORS_KEY_2
from oml.interfaces.datasets import IBaseDataset, IPairDataset
[docs]class PairDataset(IPairDataset):
"""
Dataset to iterate over pairs of items of any modality.
"""
[docs] def __init__(
self,
base_dataset: IBaseDataset,
pair_ids: List[Tuple[int, int]],
input_tensors_key_1: str = INPUT_TENSORS_KEY_1,
input_tensors_key_2: str = INPUT_TENSORS_KEY_2,
index_key: str = INDEX_KEY,
):
self.base_dataset = base_dataset
self.pair_ids = pair_ids
self.input_tensors_key_1 = input_tensors_key_1
self.input_tensors_key_2 = input_tensors_key_2
self.index_key: str = index_key
[docs] def __getitem__(self, item: int) -> Dict[str, Union[Tensor, int]]:
i1, i2 = self.pair_ids[item]
key = self.base_dataset.input_tensors_key
return {
self.input_tensors_key_1: self.base_dataset[i1][key],
self.input_tensors_key_2: self.base_dataset[i2][key],
self.index_key: item,
}
def __len__(self) -> int:
return len(self.pair_ids)
__all__ = ["PairDataset"]