Source code for oml.metrics.embeddings

import warnings
from copy import deepcopy
from pprint import pprint
from typing import Any, Collection, Dict, Iterable, List, Optional, Tuple, Union

import matplotlib.pyplot as plt
import numpy as np
import torch
from torch import FloatTensor, LongTensor

from oml.const import (
    CATEGORIES_COLUMN,
    EMBEDDINGS_KEY,
    LOG_TOPK_IMAGES_PER_ROW,
    LOG_TOPK_ROWS_PER_METRIC,
    OVERALL_CATEGORIES_KEY,
)
from oml.ddp.utils import is_main_process
from oml.functional.metrics import (
    TMetricsDict,
    calc_fnmr_at_fmr_by_distances,
    calc_retrieval_metrics,
    calc_topological_metrics,
    reduce_metrics,
)
from oml.interfaces.datasets import IQueryGalleryLabeledDataset, IVisualizableDataset
from oml.interfaces.metrics import IMetricVisualisable, TIndices
from oml.interfaces.retrieval import IRetrievalPostprocessor
from oml.metrics.accumulation import Accumulator
from oml.retrieval.retrieval_results import RetrievalResults
from oml.utils.misc import flatten_dict, pad_array_right, remove_unused_kwargs


[docs]def calc_retrieval_metrics_rr( rr: RetrievalResults, query_categories: Optional[Union[LongTensor, np.ndarray]] = None, cmc_top_k: Tuple[int, ...] = (5,), precision_top_k: Tuple[int, ...] = (5,), map_top_k: Tuple[int, ...] = (5,), reduce: bool = True, verbose: bool = True, ) -> TMetricsDict: """ Function to compute different retrieval metrics. Args: rr: An instance of `RetrievalResults`. query_categories: Categories of queries with the size of ``n_query`` to compute metrics for each category. cmc_top_k: Values of ``k`` to calculate ``cmc@k`` (`Cumulative Matching Characteristic`) precision_top_k: Values of ``k`` to calculate ``precision@k`` map_top_k: Values of ``k`` to calculate ``map@k`` (`Mean Average Precision`) reduce: If ``False`` return metrics for each query without averaging verbose: Set ``True`` to make the function verbose. Returns: Metrics dictionary. """ return calc_retrieval_metrics( retrieved_ids=rr.retrieved_ids, gt_ids=rr.gt_ids, cmc_top_k=cmc_top_k, precision_top_k=precision_top_k, map_top_k=map_top_k, query_categories=query_categories, reduce=reduce, verbose=verbose, )
[docs]def calc_fnmr_at_fmr_rr( rr: RetrievalResults, fmr_vals: Tuple[float, ...] = (0.1,), ) -> TMetricsDict: """ For more details see `calc_fnmr_at_fmr` docs. Args: rr: An instance of `RetrievalResults`.: fmr_vals: Values of `FMR` to calculate `FNMR` at. Returns: Metrics dictionary. """ max_size = max(len(d) for d in rr.distances) dist = np.stack([pad_array_right(np.array(d), max_size, val=-1) for d in rr.distances]) mask_gt = np.zeros(dist.shape, dtype=bool) mask_not_padding = np.ones(dist.shape, dtype=bool) for i, (ri, gt_id) in enumerate(zip(rr.retrieved_ids, rr.gt_ids)): is_correct = torch.isin(ri, gt_id) mask_gt[i, : len(is_correct)] = is_correct mask_not_padding[i, len(ri) :] = False pos_dist = dist[mask_gt & mask_not_padding].flatten() neg_dist = dist[~mask_gt & mask_not_padding].flatten() return calc_fnmr_at_fmr_by_distances(pos_dist=pos_dist, neg_dist=neg_dist, fmr_vals=fmr_vals)
[docs]class EmbeddingMetrics(IMetricVisualisable): """ This class is designed to accumulate model outputs produced for every batch. Since retrieval metrics are not additive, we can compute them only after all data has been collected. """ metric_name = ""
[docs] def __init__( self, dataset: Optional[IQueryGalleryLabeledDataset], cmc_top_k: Tuple[int, ...] = (5,), precision_top_k: Tuple[int, ...] = (5,), map_top_k: Tuple[int, ...] = (5,), fmr_vals: Tuple[float, ...] = tuple(), pcf_variance: Tuple[float, ...] = (0.5,), postprocessor: Optional[IRetrievalPostprocessor] = None, metrics_to_exclude_from_visualization: Iterable[str] = (), return_only_overall_category: bool = False, visualize_only_overall_category: bool = True, verbose: bool = True, ): """ Args: dataset: Annotated dataset having query-gallery split. cmc_top_k: Values of ``k`` to calculate ``cmc@k`` (`Cumulative Matching Characteristic`) precision_top_k: Values of ``k`` to calculate ``precision@k`` map_top_k: Values of ``k`` to calculate ``map@k`` (`Mean Average Precision`) fmr_vals: Values of ``fmr`` (measured in quantiles) to calculate ``fnmr@fmr`` (`False Non Match Rate at the given False Match Rate`). For example, if ``fmr_values`` is (0.2, 0.4) we will calculate ``fnmr@fmr=0.2`` and ``fnmr@fmr=0.4``. Note, computing this metric requires additional memory overhead, that is why it's turned off by default. pcf_variance: Values in range [0, 1]. Find the number of components such that the amount of variance that needs to be explained is greater than the percentage specified by ``pcf_variance``. postprocessor: Postprocessor which applies some techniques like query reranking metrics_to_exclude_from_visualization: Names of the metrics to exclude from the visualization. It will not affect calculations. return_only_overall_category: Set ``True`` if you want to return only the aggregated metrics visualize_only_overall_category: Set ``False`` if you want to visualize each category separately verbose: Set ``True`` if you want to print metrics """ self.dataset = dataset self.cmc_top_k = cmc_top_k self.precision_top_k = precision_top_k self.map_top_k = map_top_k self.fmr_vals = fmr_vals self.pcf_variance = pcf_variance self.postprocessor = postprocessor self.retrieval_results: Optional[RetrievalResults] = None self.metrics: Optional[TMetricsDict] = None self.metrics_unreduced: Optional[TMetricsDict] = None self.visualize_only_overall_category = visualize_only_overall_category self.return_only_overall_category = return_only_overall_category self.metrics_to_exclude_from_visualization = ["fnmr@fmr", "pcf", *metrics_to_exclude_from_visualization] self.verbose = verbose self._acc_embeddings_key = "__embeddings" self.acc = Accumulator(keys_to_accumulate=(self._acc_embeddings_key,)) if fmr_vals: warnings.warn("Note, computing FNMR@FMR may significantly decrease computation time and memory consuming!")
[docs] def setup(self, num_samples: Optional[int] = None) -> None: # type: ignore self.retrieval_results = None self.metrics = None self.metrics_unreduced = None num_samples = num_samples if num_samples is not None else len(self.dataset) self.acc.refresh(num_samples)
[docs] def update(self, embeddings: FloatTensor, indices: TIndices) -> None: """ Args: embeddings: Representations of the dataset items containing in the current batch. indices: Global indices of the dataset items within the range of ``(0, dataset_size - 1)``. Indices are needed to make sure that we can align dataset items and collected information. """ indices = indices if isinstance(indices, List) else indices.tolist() self.acc.update_data(data_dict={self._acc_embeddings_key: embeddings}, indices=indices)
def update_data(self, data: Dict[str, Any], indices: TIndices) -> Any: self.update(embeddings=data[EMBEDDINGS_KEY], indices=indices) def _compute_retrieval_results(self) -> None: # note, fmr requires a lot of compute fmr_vals = len(self.dataset.get_gallery_ids()) if self.fmr_vals else 1 max_k = max([*self.cmc_top_k, *self.precision_top_k, *self.map_top_k, fmr_vals]) if self.postprocessor: # todo: refactor how we deal with postprocessors after we have more examples top_n = getattr(self.postprocessor, "top_n", len(self.dataset.get_gallery_ids())) max_k = max(max_k, top_n) self.retrieval_results = RetrievalResults.from_embeddings( # type: ignore embeddings=self.acc.storage[self._acc_embeddings_key], dataset=self.dataset, n_items=max_k, verbose=self.verbose, ) if self.postprocessor: args = {"rr": self.retrieval_results, "dataset": self.dataset} args = remove_unused_kwargs(args, self.postprocessor.process) self.retrieval_results = self.postprocessor.process(**args)
[docs] def compute_metrics(self) -> TMetricsDict: # type: ignore self.acc = self.acc.sync() # gathering data from devices happens here if DDP if not self.acc.is_storage_full(): raise ValueError( f"Metrics have to be calculated on fully collected data. " f"The size of the current storage is less than num samples: " f"we've collected {self.acc.collected_samples} out of {self.acc.num_samples}." ) self._compute_retrieval_results() args_r = { "cmc_top_k": self.cmc_top_k, "precision_top_k": self.precision_top_k, "map_top_k": self.map_top_k, "rr": self.retrieval_results, "reduce": False, "verbose": self.verbose, } args_t = {"embeddings": self.acc.storage[self._acc_embeddings_key], "pcf_variance": self.pcf_variance} if CATEGORIES_COLUMN in self.dataset.extra_data: categories = np.array(self.dataset.extra_data[CATEGORIES_COLUMN]) query_categories = categories[self.dataset.get_query_ids()] metrics_r = calc_retrieval_metrics_rr(query_categories=query_categories, **args_r) # type: ignore metrics_t = calc_topological_metrics(categories=categories, **args_t) # type: ignore self.metrics_unreduced = {cat: {**metrics_r[cat], **metrics_t[cat]} for cat in metrics_r.keys()} else: metrics_r = calc_retrieval_metrics_rr(**args_r) # type: ignore metrics_t = calc_topological_metrics(**args_t) # type: ignore self.metrics_unreduced = {OVERALL_CATEGORIES_KEY: {**metrics_r, **metrics_t}} self.metrics_unreduced[OVERALL_CATEGORIES_KEY].update( calc_fnmr_at_fmr_rr(self.retrieval_results, self.fmr_vals) ) self.metrics = reduce_metrics(deepcopy(self.metrics_unreduced)) if self.return_only_overall_category: metric_to_return = {OVERALL_CATEGORIES_KEY: deepcopy(self.metrics[OVERALL_CATEGORIES_KEY])} else: metric_to_return = deepcopy(self.metrics) if self.verbose and is_main_process(): print("\nMetrics:") pprint(metric_to_return) return metric_to_return
def ready_to_visualize(self) -> bool: return isinstance(self.dataset, IVisualizableDataset)
[docs] def visualize(self) -> Tuple[Collection[plt.Figure], Collection[str]]: """ Visualize worst queries by all the available metrics. """ metrics_flat = flatten_dict(self.metrics, ignored_keys=self.metrics_to_exclude_from_visualization) figures = [] titles = [] for metric_name in metrics_flat: if self.visualize_only_overall_category and not metric_name.startswith(OVERALL_CATEGORIES_KEY): continue fig = self.get_plot_for_worst_queries( metric_name=metric_name, n_queries=LOG_TOPK_ROWS_PER_METRIC, n_instances=LOG_TOPK_IMAGES_PER_ROW ) log_str = f"top {LOG_TOPK_ROWS_PER_METRIC} worst by {metric_name}".replace("/", "_") figures.append(fig) titles.append(log_str) return figures, titles
[docs] def get_worst_queries_ids(self, metric_name: str, n_queries: int) -> List[int]: metric_values = flatten_dict(self.metrics_unreduced)[metric_name] # type: ignore return torch.topk(metric_values, min(n_queries, len(metric_values)), largest=False)[1].tolist()
[docs] def get_plot_for_worst_queries( self, metric_name: str, n_queries: int, n_instances: int, verbose: bool = False ) -> plt.Figure: query_ids = self.get_worst_queries_ids(metric_name=metric_name, n_queries=n_queries) return self.get_plot_for_queries(query_ids=query_ids, n_instances=n_instances, verbose=verbose)
[docs] def get_plot_for_queries(self, query_ids: List[int], n_instances: int, verbose: bool = True) -> plt.Figure: """ Args: query_ids: Indices of the queries n_instances: Amount of the retrieved items to show verbose: Set ``True`` for additional information """ assert self.retrieval_results is not None, "We are not ready to plot, because there are no retrieval results." assert self.metrics_unreduced is not None, "We are not ready to plot, because metrics were not calculated yet." fig = self.retrieval_results.visualize( query_ids=query_ids, n_galleries_to_show=n_instances, verbose=verbose, dataset=self.dataset ) fig.tight_layout() return fig
__all__ = ["EmbeddingMetrics", "calc_retrieval_metrics_rr", "calc_fnmr_at_fmr_rr"]