Source code for oml.interfaces.datasets

from abc import ABC, abstractmethod
from typing import Any, Dict, Optional, Union

import numpy as np
from torch import LongTensor
from torch.utils.data import Dataset

from oml.const import INPUT_TENSORS_KEY_1, INPUT_TENSORS_KEY_2, LABELS_KEY, TColor


[docs]class IIndexedDataset(Dataset, ABC): index_key: str
[docs] def __getitem__(self, item: int) -> Dict[str, Any]: """ Args: item: Idx of the sample Returns: Dictionary having the following key: ``self.index_key: int = item`` """ raise NotImplementedError()
def __len__(self) -> int: raise NotImplementedError()
[docs]class IBaseDataset(IIndexedDataset, ABC): input_tensors_key: str extra_data: Dict[str, Any] # container for storing extra records having the same size as the dataset
[docs] def __getitem__(self, item: int) -> Dict[str, Any]: """ Args: item: Idx of the sample Returns: Dictionary including the following keys: ``self.input_tensors_key`` ``self.index_key: int = item`` """ raise NotImplementedError()
[docs]class ILabeledDataset(IBaseDataset, ABC): """ This is an interface for the datasets which provide labels of containing items. """ labels_key: str = LABELS_KEY
[docs] def __getitem__(self, item: int) -> Dict[str, Any]: """ Args: item: Idx of the sample Returns: Dictionary including the following keys: ``self.input_tensors_key`` ``self.index_key: int = item`` ``self.labels_key`` """ raise NotImplementedError()
[docs] @abstractmethod def get_labels(self) -> np.ndarray: raise NotImplementedError()
[docs] def get_label2category(self) -> Optional[Dict[int, Union[str, int]]]: """ Returns: Mapping from label to category if known. """ raise NotImplementedError()
[docs]class IQueryGalleryDataset(IBaseDataset, ABC): """ This is an interface for the datasets which hold the information on how to split the data into the query and gallery. The query and gallery ids may overlap. It doesn't need the ground truth labels, so it can be used for prediction on not annotated data. """
[docs] @abstractmethod def get_query_ids(self) -> LongTensor: raise NotImplementedError()
[docs]class IQueryGalleryLabeledDataset(IQueryGalleryDataset, ILabeledDataset, ABC): """ This interface is similar to `IQueryGalleryDataset`, but there are ground truth labels. """
[docs]class IPairDataset(IIndexedDataset): """ This is an interface for the datasets which return pair of something. """ input_tensors_key_1: str = INPUT_TENSORS_KEY_1 input_tensors_key_2: str = INPUT_TENSORS_KEY_2
[docs] @abstractmethod def __getitem__(self, item: int) -> Dict[str, Any]: """ Args: item: Idx of the sample Returns: Dictionary with the following keys: ``self.input_tensors_key_1`` ``self.input_tensors_key_2`` ``self.index_key`` """ raise NotImplementedError()
[docs]class IVisualizableDataset(Dataset, ABC): """ Base class for the datasets which know how to visualise their items. """
[docs] @abstractmethod def visualize(self, item: int, color: TColor) -> np.ndarray: raise NotImplementedError()
[docs]class IHTMLVisualizableDataset(Dataset, ABC): """ Base class for the datasets which know how to visualise their items as HTML. """
[docs] @abstractmethod def visualize_as_html(self, item: int, title: str, color: TColor) -> np.ndarray: raise NotImplementedError()
__all__ = [ "IIndexedDataset", "IBaseDataset", "ILabeledDataset", "IQueryGalleryLabeledDataset", "IQueryGalleryDataset", "IPairDataset", "IVisualizableDataset", "IHTMLVisualizableDataset", ]