Retrieval & Post-processing

RetrievalResults

class oml.retrieval.retrieval_results.RetrievalResults(distances: Sequence[FloatTensor], retrieved_ids: Sequence[LongTensor], gt_ids: Optional[Sequence[LongTensor]] = None)[source]

Bases: object

__init__(distances: Sequence[FloatTensor], retrieved_ids: Sequence[LongTensor], gt_ids: Optional[Sequence[LongTensor]] = None)[source]
Parameters
  • distances – Sorted distances from queries to the first gallery items with the size of n_query.

  • retrieved_ids – First gallery indices retrieved for every query with the size of n_query. Every index is within the range (0, n_gallery - 1).

  • gt_ids – Gallery indices relevant to every query with the size of n_query. Every element is within the range (0, n_gallery - 1)

classmethod from_embeddings(embeddings: FloatTensor, dataset: Union[IQueryGalleryDataset, IQueryGalleryLabeledDataset], n_items: int = 100, verbose: bool = False) RetrievalResults[source]
Parameters
  • embeddings – The result of inference with the shape of [dataset_len, emb_dim].

  • dataset – Dataset having query/gallery split.

  • n_items – Number of the closest gallery items to retrieve. It may be clipped by gallery size if needed. Note, some queries may get less than this number of retrieved items if they don’t have enough gallery items available.

  • verbose – Set True to see progress bar.

classmethod from_embeddings_qg(embeddings_query: FloatTensor, embeddings_gallery: FloatTensor, dataset_query: Union[IBaseDataset, ILabeledDataset], dataset_gallery: Union[IBaseDataset, ILabeledDataset], n_items: int = 100, verbose: bool = False) RetrievalResults[source]
Parameters
  • embeddings_query – The result of inference with the shape of [n_queries, emb_dim].

  • embeddings_gallery – The result of inference with the shape of [n_galleries, emb_dim].

  • dataset_query – Dataset of queries with the length of n_queries.

  • dataset_gallery – Dataset of galleries with the length of n_galleries.

  • n_items – Number of the closest gallery items to retrieve. It may be clipped by gallery size if needed. Note, some queries may get less than this number of retrieved items if they don’t have enough gallery items available.

  • verbose – Set True to see progress bar.

visualize(query_ids: List[int], dataset: IQueryGalleryDataset, n_galleries_to_show: int = 5, n_gt_to_show: int = 2, verbose: bool = False, show: bool = False) Figure[source]
Parameters
  • query_ids – Query indices within the range of (0, n_query - 1).

  • dataset – Dataset that provides query-gallery split and supports visualisation.

  • n_galleries_to_show – Number of closest gallery items to show.

  • n_gt_to_show – Number of ground truth gallery items to show for reference (if available).

  • verbose – Set True to allow prints.

  • show – Set True to instantly visualise the resulted figure.

visualize_qg(query_ids: List[int], dataset_query: IVisualizableDataset, dataset_gallery: IVisualizableDataset, n_galleries_to_show: int = 5, n_gt_to_show: int = 2, verbose: bool = False, show: bool = False) Figure[source]
Parameters
  • query_ids – Query indices within the range of (0, n_query - 1).

  • dataset_query – Dataset of queries supporting visualisation, with the length of n_query.

  • dataset_gallery – Dataset of queries supporting visualisation, with the length of n_gallery.

  • n_galleries_to_show – Number of closest gallery items to show.

  • n_gt_to_show – Number of ground truth gallery items to show for reference (if available).

  • verbose – Set True to allow prints.

  • show – Set True to instantly visualise the resulted figure.

visualize_with_functions(query_ids: List[int], visualize_query_fn: Callable[[int, Tuple[int, int, int]], ndarray], visualize_gallery_fn: Callable[[int, Tuple[int, int, int]], ndarray], n_galleries_to_show: int = 5, n_gt_to_show: int = 2, show: bool = False) Figure[source]
Parameters
  • query_ids – Query indices within the range of (0, n_query - 1).

  • visualize_query_fn – Function plotting i-th query with respect to the given color.

  • visualize_gallery_fn – Function plotting j-th gallery with respect to the given color.

  • n_galleries_to_show – Number of closest gallery items to show.

  • n_gt_to_show – Number of ground truth gallery items to show for reference (if available).

  • show – Set True to instantly visualize the resulted figure.

is_empty() bool[source]
deepcopy() RetrievalResults[source]
property n_retrieved_items: int

Returns: Number of items retrieved for each query. If queries have different number of retrieved items, returns the maximum of them.

property distances: Tuple[FloatTensor, ...]

Returns: Sorted distances from queries to the first gallery items with the size of n_query.

property retrieved_ids: Tuple[LongTensor, ...]

Returns: First gallery indices retrieved for every query with the size of n_query. Every index is within the range (0, n_gallery - 1).

property gt_ids: Optional[Tuple[LongTensor, ...]]

Returns: Gallery indices relevant to every query with the size of n_query. Every element is within the range (0, n_gallery - 1)

PairwiseReranker

class oml.retrieval.postprocessors.pairwise.PairwiseReranker(top_n: int, pairwise_model: IPairwiseModel, num_workers: int, batch_size: int, verbose: bool = False, use_fp16: bool = False)[source]

Bases: IRetrievalPostprocessor

__init__(top_n: int, pairwise_model: IPairwiseModel, num_workers: int, batch_size: int, verbose: bool = False, use_fp16: bool = False)[source]
Parameters
  • top_n – Model will be applied to the num_queries * top_n pairs formed by each query and top_n most relevant galleries.

  • pairwise_model – Model which is able to take two items as inputs and estimate the distance (not in a strictly mathematical sense) between them.

  • num_workers – Number of workers in DataLoader

  • batch_size – Batch size that will be used in DataLoader

  • verbose – Set True if you want to see progress bar for an inference

  • use_fp16 – Set True if you want to use half precision

process(rr: RetrievalResults, dataset: IQueryGalleryDataset) RetrievalResults[source]
Parameters
  • rrRetrievalResults object.

  • dataset – Dataset having query/gallery split.

Returns

After re-ranking is applied to the top_n retrieved items, the updated RetrievalResults are returned. In other words, we permute the first top_n items, but the rest remains untouched.

Example 1 (for one query):

rr.retrieved_ids = [[3,   2,   1,   0,   4  ]]
rr.distances     = [[0.1, 0.2, 0.5, 0.6, 0.7]]

# Let's say a postprocessor has been applied to the
# first 3 elements and the new distances are: [0.4, 0.2, 0.3]

# In this case, the updated values will be:
rr.retrieved_ids = [[2,   1,   3,   0,   4  ]]
rr.distances:    = [[0.2, 0.3, 0.4, 0.6, 0.7]]

Example 2 (for one query):

# Note, the new distances to the top_n items produced by the pairwise model
# may be rescaled to keep the distances order. Here is an example:
rr.distances = [[0.1, 0.2, 0.3, 0.5, 0.6]]
top_n = 3

# Imagine, the postprocessor didn't change the order of the first 3 items
# (it's just a convenient example, the general logic remains the same),
# however the new values have a bigger scale:
distances_new = [[1, 2, 5, 0.5, 0.6]]

# Thus, we need to downscale the first 3 distances, so they are lower than 0.5:
scale = 5 / 0.5 = 0.1
# Finally, let's apply the found scale to the top 3 distances:
rr_upd.distances = [[0.1, 0.2, 0.5, 0.5, 0.6]]

# Note, if new and old distances are already sorted, we don't apply any scaling.
property top_n: int

Returns: Number of gallery items closest to each query to process.

ConstantThresholding

class oml.retrieval.postprocessors.algo.ConstantThresholding(th: float)[source]

Bases: IRetrievalPostprocessor

__init__(th: float)[source]
Parameters

th – Distance threshold to limit the RetrievalResults.

process(rr: RetrievalResults) RetrievalResults[source]

AdaptiveThresholding

class oml.retrieval.postprocessors.algo.AdaptiveThresholding(n_std: float)[source]

Bases: IRetrievalPostprocessor

__init__(n_std: float)[source]

This postprocessor cuts RetrievalResults if a big gap in consecutive distances has been found. The big gap is defined as a gap greater than n_std * avg_gap.

Parameters

n_std – the smaller value, the less RetrievalResults will be remained.

process(rr: RetrievalResults) RetrievalResults[source]