Examples
Here is an example of how to train, validate and post-process the model on a tiny dataset of images, texts, or audios. See more details on dataset format.
SCROLL RIGHT FOR IMAGES > TEXTS > AUDIOS
Extra illustrations, explanations and tips for the code above.
Retrieval by trained model
Here is an inference time example (in other words, retrieval on test set). The code below works for both texts and images.
See example
from oml.datasets import ImageQueryGalleryDataset
from oml.inference import inference
from oml.models import ViTExtractor
from oml.registry import get_transforms_for_pretrained
from oml.utils import get_mock_images_dataset
from oml.retrieval import RetrievalResults, AdaptiveThresholding
_, df_test = get_mock_images_dataset(global_paths=True)
del df_test["label"] # we don't need gt labels for doing predictions
extractor = ViTExtractor.from_pretrained("vits16_dino").to("cpu")
transform, _ = get_transforms_for_pretrained("vits16_dino")
dataset = ImageQueryGalleryDataset(df_test, transform=transform)
embeddings = inference(extractor, dataset, batch_size=4, num_workers=0)
rr = RetrievalResults.from_embeddings(embeddings, dataset, n_items=5)
rr = AdaptiveThresholding(n_std=3.5).process(rr)
rr.visualize(query_ids=[0, 1], dataset=dataset, show=True)
# you get the ids of retrieved items and the corresponding distances
print(rr)
Retrieval by trained model: streaming & txt2im
Here is an example where queries and galleries processed separately.
First, it may be useful for streaming retrieval, when a gallery (index) set is huge and fixed, but queries are coming in batches.
Second, queries and galleries have different natures, for examples, queries are texts, but galleries are images.
See example
import pandas as pd
from oml.datasets import ImageBaseDataset
from oml.inference import inference
from oml.models import ViTExtractor
from oml.registry import get_transforms_for_pretrained
from oml.retrieval import RetrievalResults, ConstantThresholding
from oml.utils import get_mock_images_dataset
extractor = ViTExtractor.from_pretrained("vits16_dino").to("cpu")
transform, _ = get_transforms_for_pretrained("vits16_dino")
paths = pd.concat(get_mock_images_dataset(global_paths=True))["path"]
galleries, queries1, queries2 = paths[:20], paths[20:22], paths[22:24]
# gallery is huge and fixed, so we only process it once
dataset_gallery = ImageBaseDataset(galleries, transform=transform)
embeddings_gallery = inference(extractor, dataset_gallery, batch_size=4, num_workers=0)
# queries come "online" in stream
for queries in [queries1, queries2]:
dataset_query = ImageBaseDataset(queries, transform=transform)
embeddings_query = inference(extractor, dataset_query, batch_size=4, num_workers=0)
# for the operation below we are going to provide integrations with vector search DB like QDrant or Faiss
rr = RetrievalResults.from_embeddings_qg(
embeddings_query=embeddings_query, embeddings_gallery=embeddings_gallery,
dataset_query=dataset_query, dataset_gallery=dataset_gallery
)
rr = ConstantThresholding(th=80).process(rr)
rr.visualize_qg([0, 1], dataset_query=dataset_query, dataset_gallery=dataset_gallery, show=True)
print(rr)
Usage with PyTorch Lightning
PyTorch Lightning
import pytorch_lightning as pl
from torch.utils.data import DataLoader
from torch.optim import Adam
from oml.datasets import ImageLabeledDataset, ImageQueryGalleryLabeledDataset
from oml.lightning import ExtractorModule
from oml.lightning import MetricValCallback
from oml.losses import ArcFaceLoss
from oml.metrics import EmbeddingMetrics
from oml.models import ViTExtractor
from oml.samplers import BalanceSampler
from oml.utils import get_mock_images_dataset
from oml.lightning import logging
from oml.retrieval import ConstantThresholding
df_train, df_val = get_mock_images_dataset(global_paths=True, df_name="df_with_category.csv")
# model
extractor = ViTExtractor("vits16_dino", arch="vits16", normalise_features=True)
# train
optimizer = Adam(extractor.parameters(), lr=1e-6)
train_dataset = ImageLabeledDataset(df_train)
criterion = ArcFaceLoss(in_features=extractor.feat_dim, num_classes=df_train["label"].nunique())
batch_sampler = BalanceSampler(train_dataset.get_labels(), n_labels=2, n_instances=3)
train_loader = DataLoader(train_dataset, batch_sampler=batch_sampler)
# val
val_dataset = ImageQueryGalleryLabeledDataset(df_val)
val_loader = DataLoader(val_dataset, batch_size=4)
metric_callback = MetricValCallback(
metric=EmbeddingMetrics(dataset=val_dataset, postprocessor=ConstantThresholding(0.8)),
log_images=True
)
# 1) Logging with Tensorboard
logger = logging.TensorBoardPipelineLogger(".")
# 2) Logging with Neptune
# logger = logging.NeptunePipelineLogger(api_key="", project="", log_model_checkpoints=False)
# 3) Logging with Weights and Biases
# import os
# os.environ["WANDB_API_KEY"] = ""
# logger = logging.WandBPipelineLogger(project="test_project", log_model=False)
# 4) Logging with MLFlow locally
# logger = logging.MLFlowPipelineLogger(experiment_name="exp", tracking_uri="file:./ml-runs")
# 5) Logging with ClearML
# logger = logging.ClearMLPipelineLogger(project_name="exp", task_name="test")
# run
pl_model = ExtractorModule(extractor, criterion, optimizer)
trainer = pl.Trainer(max_epochs=3, callbacks=[metric_callback], num_sanity_val_steps=0, logger=logger)
trainer.fit(pl_model, train_dataloaders=train_loader, val_dataloaders=val_loader)
PyTorch Lightning: DDP
pip install open-metric-learning[nlp]
import pytorch_lightning as pl
from torch.utils.data import DataLoader
from torch.optim import Adam
from oml import datasets as d
from oml.lightning import ExtractorModuleDDP
from oml.lightning import MetricValCallback
from oml.losses import TripletLossWithMiner
from oml.metrics import EmbeddingMetrics
from oml.miners import AllTripletsMiner
from oml.models import HFWrapper
from oml.samplers import BalanceSampler
from oml.utils import get_mock_texts_dataset
from oml.retrieval import AdaptiveThresholding
from pytorch_lightning.strategies import DDPStrategy
from transformers import AutoModel, AutoTokenizer
df_train, df_val = get_mock_texts_dataset()
# model
extractor = HFWrapper(AutoModel.from_pretrained("bert-base-uncased"), 768)
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
# train
optimizer = Adam(extractor.parameters(), lr=1e-6)
train_dataset = d.TextLabeledDataset(df_train, tokenizer=tokenizer)
criterion = TripletLossWithMiner(margin=0.1, miner=AllTripletsMiner())
batch_sampler = BalanceSampler(train_dataset.get_labels(), n_labels=2, n_instances=3)
train_loader = DataLoader(train_dataset, batch_sampler=batch_sampler)
# val
val_dataset = d.TextQueryGalleryLabeledDataset(df_val, tokenizer=tokenizer)
val_loader = DataLoader(val_dataset, batch_size=4)
metric_callback = MetricValCallback(metric=EmbeddingMetrics(dataset=val_dataset, postprocessor=AdaptiveThresholding(n_std=3)))
# run
pl_model = ExtractorModuleDDP(extractor=extractor, criterion=criterion, optimizer=optimizer,
loaders_train=train_loader, loaders_val=val_loader # DDP specific
)
ddp_args = {"accelerator": "cpu", "devices": 2, "strategy": DDPStrategy(), "use_distributed_sampler": False} # DDP specific
trainer = pl.Trainer(max_epochs=1, callbacks=[metric_callback], num_sanity_val_steps=0, **ddp_args)
trainer.fit(pl_model) # we don't pass loaders to .fit() in DDP
PyTorch Lightning: Deal with 2 validation loaders
import pytorch_lightning as pl
from torch.utils.data import DataLoader
from oml.datasets import ImageQueryGalleryLabeledDataset
from oml.lightning.callbacks.metric import MetricValCallback
from oml.lightning.modules.extractor import ExtractorModule
from oml.metrics import EmbeddingMetrics
from oml.models import ViTExtractor
from oml.transforms.images.torchvision import get_normalisation_resize_torch
from oml.utils import get_mock_images_dataset
_, df_val = get_mock_images_dataset(global_paths=True)
extractor = ViTExtractor("vits16_dino", arch="vits16", normalise_features=False)
# 1st validation dataset (big images)
val_dataset_1 = ImageQueryGalleryLabeledDataset(df_val, transform=get_normalisation_resize_torch(im_size=224))
val_loader_1 = DataLoader(val_dataset_1, batch_size=4)
metric_callback_1 = MetricValCallback(
metric=EmbeddingMetrics(dataset=val_dataset_1), log_images=True, loader_idx=0
)
# 2nd validation dataset (small images)
val_dataset_2 = ImageQueryGalleryLabeledDataset(df_val, transform=get_normalisation_resize_torch(im_size=48))
val_loader_2 = DataLoader(val_dataset_2, batch_size=4)
metric_callback_2 = MetricValCallback(
metric=EmbeddingMetrics(dataset=val_dataset_2), log_images=True, loader_idx=1
)
# run validation
pl_model = ExtractorModule(extractor, None, None)
trainer = pl.Trainer(max_epochs=3, callbacks=[metric_callback_1, metric_callback_2], num_sanity_val_steps=0)
trainer.validate(pl_model, dataloaders=(val_loader_1, val_loader_2))
print(metric_callback_1.metric.retrieval_results)
print(metric_callback_2.metric.retrieval_results)
Usage with PyTorch Metric Learning
You can easily access a lot of content from PyTorch Metric Learning. The examples below are different from the basic ones only in a few lines of code:
Losses from PyTorch Metric Learning
pip install pytorch-metric-learning
from torch.optim import Adam
from torch.utils.data import DataLoader
from oml.datasets import ImageLabeledDataset
from oml.models import ViTExtractor
from oml.samplers import BalanceSampler
from oml.utils import get_mock_images_dataset
from pytorch_metric_learning import losses
df_train, _ = get_mock_images_dataset(global_paths=True)
extractor = ViTExtractor("vits16_dino", arch="vits16", normalise_features=False).train()
optimizer = Adam(extractor.parameters(), lr=1e-4)
train_dataset = ImageLabeledDataset(df_train)
# PML specific
# criterion = losses.TripletMarginLoss(margin=0.2, triplets_per_anchor="all")
criterion = losses.ArcFaceLoss(num_classes=df_train["label"].nunique(), embedding_size=extractor.feat_dim) # for classification-like losses
sampler = BalanceSampler(train_dataset.get_labels(), n_labels=2, n_instances=2)
train_loader = DataLoader(train_dataset, batch_sampler=sampler)
for batch in train_loader:
embeddings = extractor(batch["input_tensors"])
loss = criterion(embeddings, batch["labels"])
loss.backward()
optimizer.step()
optimizer.zero_grad()
Losses from PyTorch Metric Learning: advanced
pip install pytorch-metric-learning
from torch.utils.data import DataLoader
from torch.optim import Adam
from oml.datasets import ImageLabeledDataset
from oml.models import ViTExtractor
from oml.samplers import BalanceSampler
from oml.utils import get_mock_images_dataset
from oml.transforms.images.torchvision import get_augs_torch
from pytorch_metric_learning import losses, distances, reducers, miners
df_train, _ = get_mock_images_dataset(global_paths=True)
extractor = ViTExtractor("vits16_dino", arch="vits16", normalise_features=False).train()
optimizer = Adam(extractor.parameters(), lr=1e-4)
train_dataset = ImageLabeledDataset(df_train, transform=get_augs_torch(im_size=224))
# PML specific
distance = distances.LpDistance(p=2)
reducer = reducers.ThresholdReducer(low=0)
criterion = losses.TripletMarginLoss()
miner = miners.TripletMarginMiner(margin=0.2, distance=distance, type_of_triplets="all")
sampler = BalanceSampler(train_dataset.get_labels(), n_labels=2, n_instances=2)
train_loader = DataLoader(train_dataset, batch_sampler=sampler)
for batch in train_loader:
embeddings = extractor(batch["input_tensors"])
loss = criterion(embeddings, batch["labels"], miner(embeddings, batch["labels"])) # PML specific
loss.backward()
optimizer.step()
optimizer.zero_grad()
To use content from PyTorch Metric Learning (PML) with our Pipelines just follow the standard tutorial of adding custom loss.
Note! During the validation process OpenMetricLearning computes L2 distances. Thus, when choosing a distance from PML, we recommend you to pick distances.LpDistance(p=2).
Handling categories
Category
is something that hierarchically unites a group of labels.
For example, we have 3 different catalog items of tables with the label``s like ``table1
, table2
, table3
and their category
is tables
.
Categories in training:
Category balanced sampling may help to deal with category imbalance.
For contrastive losses, limiting the number of categories in batches may help to mine harder negative samples (another table is harder positive example than another sofa). Without such samples there is no guarantee that we get enough tables in the batch.
Categories in validation:
Having categories allows to obtain fine-grained metrics and recognize over- and under- performing subsets of the dataset.
See example
pip install open-metric-learning[nlp]
from pprint import pprint
import numpy as np
from torch.optim import Adam
from torch.utils.data import DataLoader
from oml import datasets as d
from oml.inference import inference
from oml.losses import TripletLossWithMiner
from oml.metrics import calc_retrieval_metrics_rr
from oml.miners import AllTripletsMiner
from oml.models import ViTExtractor
from oml.retrieval import RetrievalResults
from oml.samplers import DistinctCategoryBalanceSampler, CategoryBalanceSampler
from oml.utils import get_mock_images_dataset
from oml.registry import get_transforms_for_pretrained
model = ViTExtractor.from_pretrained("vits16_dino").to("cpu")
transforms, _ = get_transforms_for_pretrained("vits16_dino")
df_train, df_val = get_mock_images_dataset(df_name="df_with_category.csv", global_paths=True)
train = d.ImageLabeledDataset(df_train, transform=transforms)
val = d.ImageQueryGalleryLabeledDataset(df_val, transform=transforms)
optimizer = Adam(model.parameters(), lr=1e-4)
criterion = TripletLossWithMiner(0.1, AllTripletsMiner(), need_logs=True)
# >>>>> You can use one of category aware samplers
args = {"n_categories": 2, "n_labels": 2, "n_instances": 2, "label2category": train.get_label2category(), "labels": train.get_labels()}
sampler = DistinctCategoryBalanceSampler(epoch_size=5, **args)
# sampler = CategoryBalanceSampler(resample_labels=False, weight_categories=True, **args) # a bit different sampling
def training():
for batch in DataLoader(train, batch_sampler=sampler):
embeddings = model(batch["input_tensors"])
loss = criterion(embeddings, batch["labels"])
loss.backward()
optimizer.step()
optimizer.zero_grad()
pprint(criterion.last_logs)
def validation():
embeddings = inference(model, val, batch_size=4, num_workers=0)
rr = RetrievalResults.from_embeddings(embeddings, val, n_items=3)
rr.visualize(query_ids=[2, 1], dataset=val, show=True)
# >>>> When query categories are known we may get fine-grained metrics
query_categories = np.array(val.extra_data["category"])[val.get_query_ids()]
pprint(calc_retrieval_metrics_rr(rr, query_categories=query_categories, map_top_k=(3,), cmc_top_k=(1,)))
training()
validation()
Handling sequences of photos
The below is mostly related to animal or person re-identification tasks, where observations are often done in the form of sequences of frames. The problem appears when calculating retrieval metrics, because the closest retrieved images most likely will be the neighbor frames from the same sequence as a query. Thus, we get good values of metrics. but don’t really understand what is going on. So, it’s better to ignore photos taken from the same sequence as a given query.
If you take a look at standard Re-id benchmarks as MARS
dataset, you may see that ignoring frames from the same camera is a part of the actual protocol.
Following the same logic, we introduced sequence
field in our dataset format.
If sequence ids are provided, retrieved items having the same sequence id as a given query will be ignored.
Below is an example of how to label consecutive shoots of the tiger with the same sequence
:

On the figure below we show how provided sequence labels affect metrics calculation:

metric |
consider sequence? |
value |
---|---|---|
CMC@1 |
no (top figure) |
1.0 |
CMC@1 |
yes (bottom figure) |
0.0 |
Precision@2 |
no (top figure) |
0.5 |
Precision@2 |
yes (bottom figure) |
0.5 |
To use this functionality you only need to provide sequence
column in your dataframe
(containing strings or integers) and pass sequence_key
to EmbeddingMetrics()
:
See example
from oml.inference import inference
from oml.datasets import ImageQueryGalleryLabeledDataset
from oml.models import ViTExtractor
from oml.retrieval import RetrievalResults
from oml.utils import get_mock_images_dataset
from oml.metrics import calc_retrieval_metrics_rr
extractor = ViTExtractor("vits16_dino", arch="vits16", normalise_features=False).to("cpu")
_, df_val = get_mock_images_dataset(global_paths=True, df_name="df_with_sequence.csv") # <- sequence info is in the file
dataset = ImageQueryGalleryLabeledDataset(df_val)
embeddings = inference(extractor, dataset, batch_size=4, num_workers=0)
rr = RetrievalResults.from_embeddings(embeddings, dataset, n_items=5)
rr.visualize(query_ids=[2, 1], dataset=dataset, show=True)
metrics = calc_retrieval_metrics_rr(rr, map_top_k=(3, 5), precision_top_k=(5,), cmc_top_k=(3,))
print(rr, "\n", metrics)