Examples

Here is an example of how to train, validate and post-process the model on a tiny dataset of images, texts, or audios. See more details on dataset format.

SCROLL RIGHT FOR IMAGES > TEXTS > AUDIOS

IMAGES TEXTS AUDIOS
from torch.optim import Adam
from torch.utils.data import DataLoader

from oml import datasets as d
from oml.inference import inference
from oml.losses import TripletLossWithMiner
from oml.metrics import calc_retrieval_metrics_rr
from oml.miners import AllTripletsMiner
from oml.models import ViTExtractor
from oml.registry import get_transforms_for_pretrained
from oml.retrieval import RetrievalResults, AdaptiveThresholding
from oml.samplers import BalanceSampler
from oml.utils import get_mock_images_dataset

model = ViTExtractor.from_pretrained("vits16_dino").to("cpu").train()
transform, _ = get_transforms_for_pretrained("vits16_dino")

df_train, df_val = get_mock_images_dataset(global_paths=True)
train = d.ImageLabeledDataset(df_train, transform=transform)
val = d.ImageQueryGalleryLabeledDataset(df_val, transform=transform)

optimizer = Adam(model.parameters(), lr=1e-4)
criterion = TripletLossWithMiner(0.1, AllTripletsMiner(), need_logs=True)
sampler = BalanceSampler(train.get_labels(), n_labels=2, n_instances=2)


# training 1 epoch
for batch in DataLoader(train, batch_sampler=sampler):
    embeddings = model(batch["input_tensors"])
    loss = criterion(embeddings, batch["labels"])
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()
    print(criterion.last_logs)


# validation by retrieving relevant items
embeddings = inference(model, val, batch_size=4, num_workers=0)
rr = RetrievalResults.from_embeddings(embeddings, val, n_items=3)
rr = AdaptiveThresholding(n_std=2).process(rr)
rr.visualize(query_ids=[2, 1], dataset=val, show=True)
print(calc_retrieval_metrics_rr(rr, map_top_k=(3,), cmc_top_k=(1,)))
from torch.optim import Adam
from torch.utils.data import DataLoader
from transformers import AutoModel, AutoTokenizer

from oml import datasets as d
from oml.inference import inference
from oml.losses import TripletLossWithMiner
from oml.metrics import calc_retrieval_metrics_rr
from oml.miners import AllTripletsMiner
from oml.models import HFWrapper
from oml.retrieval import RetrievalResults, AdaptiveThresholding
from oml.samplers import BalanceSampler
from oml.utils import get_mock_texts_dataset

model = HFWrapper(AutoModel.from_pretrained("bert-base-uncased"), 768).to("cpu").train()
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

df_train, df_val = get_mock_texts_dataset()
train = d.TextLabeledDataset(df_train, tokenizer=tokenizer)
val = d.TextQueryGalleryLabeledDataset(df_val, tokenizer=tokenizer)

optimizer = Adam(model.parameters(), lr=1e-4)
criterion = TripletLossWithMiner(0.1, AllTripletsMiner(), need_logs=True)
sampler = BalanceSampler(train.get_labels(), n_labels=2, n_instances=2)


# training 1 epoch
for batch in DataLoader(train, batch_sampler=sampler):
    embeddings = model(batch["input_tensors"])
    loss = criterion(embeddings, batch["labels"])
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()
    print(criterion.last_logs)


# validation by retrieving relevant items
embeddings = inference(model, val, batch_size=4, num_workers=0)
rr = RetrievalResults.from_embeddings(embeddings, val, n_items=3)
rr = AdaptiveThresholding(n_std=2).process(rr)
rr.visualize(query_ids=[2, 1], dataset=val, show=True)
print(calc_retrieval_metrics_rr(rr, map_top_k=(3,), cmc_top_k=(1,)))
from torch.optim import Adam
from torch.utils.data import DataLoader

from oml import datasets as d
from oml.inference import inference
from oml.losses import TripletLossWithMiner
from oml.metrics import calc_retrieval_metrics_rr
from oml.miners import AllTripletsMiner
from oml.models import ECAPATDNNExtractor
from oml.retrieval import AdaptiveThresholding, RetrievalResults
from oml.samplers import BalanceSampler
from oml.utils import get_mock_audios_dataset

model = ECAPATDNNExtractor.from_pretrained("ecapa_tdnn_taoruijie").to("cpu").train()

df_train, df_val = get_mock_audios_dataset(global_paths=True)
train = d.AudioLabeledDataset(df_train)
val = d.AudioQueryGalleryLabeledDataset(df_val)

optimizer = Adam(model.parameters(), lr=1e-4)
criterion = TripletLossWithMiner(0.1, AllTripletsMiner(), need_logs=True)
sampler = BalanceSampler(train.get_labels(), n_labels=2, n_instances=2)


# training 1 epoch
for batch in DataLoader(train, batch_sampler=sampler):
    embeddings = model(batch["input_tensors"])
    loss = criterion(embeddings, batch["labels"])
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()
    print(criterion.last_logs)


# validation by retrieving relevant items
embeddings = inference(model, val, batch_size=4, num_workers=0)
rr = RetrievalResults.from_embeddings(embeddings, val, n_items=3)
rr = AdaptiveThresholding(n_std=2).process(rr)
rr.visualize_as_html(query_ids=[2, 1], dataset=val, show=True)
print(calc_retrieval_metrics_rr(rr, map_top_k=(3,), cmc_top_k=(1,)))
Output
{'active_tri': 0.125, 'pos_dist': 82.5, 'neg_dist': 100.5}  # batch 1
{'active_tri': 0.0, 'pos_dist': 36.3, 'neg_dist': 56.9}     # batch 2

{'cmc': {1: 0.75}, 'precision': {5: 0.75}, 'map': {3: 0.8}}
Open In Colab
Output
{'active_tri': 0.0, 'pos_dist': 8.5, 'neg_dist': 11.0}  # batch 1
{'active_tri': 0.25, 'pos_dist': 8.9, 'neg_dist': 9.8}  # batch 2

{'cmc': {1: 0.8}, 'precision': {5: 0.7}, 'map': {3: 0.9}}
Open In Colab
Output
{'active_tri': 0.25, 'pos_dist': 17.3, 'neg_dist': 18.4}  # batch 1
{'active_tri': 0.0, 'pos_dist': 17.1, 'neg_dist': 18.5}   # batch 2

{'cmc': {1: 1.0}, 'precision': {5: 1.0}, 'map': {3: 1.0}}
Open In Colab

Extra illustrations, explanations and tips for the code above.

Retrieval by trained model

Here is an inference time example (in other words, retrieval on test set). The code below works for both texts and images.

See example

from oml.datasets import ImageQueryGalleryDataset
from oml.inference import inference
from oml.models import ViTExtractor
from oml.registry import get_transforms_for_pretrained
from oml.utils import get_mock_images_dataset
from oml.retrieval import RetrievalResults, AdaptiveThresholding

_, df_test = get_mock_images_dataset(global_paths=True)
del df_test["label"]  # we don't need gt labels for doing predictions

extractor = ViTExtractor.from_pretrained("vits16_dino").to("cpu")
transform, _ = get_transforms_for_pretrained("vits16_dino")

dataset = ImageQueryGalleryDataset(df_test, transform=transform)
embeddings = inference(extractor, dataset, batch_size=4, num_workers=0)

rr = RetrievalResults.from_embeddings(embeddings, dataset, n_items=5)
rr = AdaptiveThresholding(n_std=3.5).process(rr)
rr.visualize(query_ids=[0, 1], dataset=dataset, show=True)

# you get the ids of retrieved items and the corresponding distances
print(rr)

Retrieval by trained model: streaming & txt2im

Here is an example where queries and galleries processed separately.

  • First, it may be useful for streaming retrieval, when a gallery (index) set is huge and fixed, but queries are coming in batches.

  • Second, queries and galleries have different natures, for examples, queries are texts, but galleries are images.

See example

import pandas as pd

from oml.datasets import ImageBaseDataset
from oml.inference import inference
from oml.models import ViTExtractor
from oml.registry import get_transforms_for_pretrained
from oml.retrieval import RetrievalResults, ConstantThresholding
from oml.utils import get_mock_images_dataset

extractor = ViTExtractor.from_pretrained("vits16_dino").to("cpu")
transform, _ = get_transforms_for_pretrained("vits16_dino")

paths = pd.concat(get_mock_images_dataset(global_paths=True))["path"]
galleries, queries1, queries2 = paths[:20], paths[20:22], paths[22:24]

# gallery is huge and fixed, so we only process it once
dataset_gallery = ImageBaseDataset(galleries, transform=transform)
embeddings_gallery = inference(extractor, dataset_gallery, batch_size=4, num_workers=0)

# queries come "online" in stream
for queries in [queries1, queries2]:
    dataset_query = ImageBaseDataset(queries, transform=transform)
    embeddings_query = inference(extractor, dataset_query, batch_size=4, num_workers=0)

    # for the operation below we are going to provide integrations with vector search DB like QDrant or Faiss
    rr = RetrievalResults.from_embeddings_qg(
        embeddings_query=embeddings_query, embeddings_gallery=embeddings_gallery,
        dataset_query=dataset_query, dataset_gallery=dataset_gallery
    )
    rr = ConstantThresholding(th=80).process(rr)
    rr.visualize_qg([0, 1], dataset_query=dataset_query, dataset_gallery=dataset_gallery, show=True)
    print(rr)

Usage with PyTorch Lightning

PyTorch Lightning

import pytorch_lightning as pl
from torch.utils.data import DataLoader
from torch.optim import Adam

from oml.datasets import ImageLabeledDataset, ImageQueryGalleryLabeledDataset
from oml.lightning import ExtractorModule
from oml.lightning import MetricValCallback
from oml.losses import ArcFaceLoss
from oml.metrics import EmbeddingMetrics
from oml.models import ViTExtractor
from oml.samplers import BalanceSampler
from oml.utils import get_mock_images_dataset
from oml.lightning import logging
from oml.retrieval import ConstantThresholding

df_train, df_val = get_mock_images_dataset(global_paths=True, df_name="df_with_category.csv")

# model
extractor = ViTExtractor("vits16_dino", arch="vits16", normalise_features=True)

# train
optimizer = Adam(extractor.parameters(), lr=1e-6)
train_dataset = ImageLabeledDataset(df_train)
criterion = ArcFaceLoss(in_features=extractor.feat_dim, num_classes=df_train["label"].nunique())
batch_sampler = BalanceSampler(train_dataset.get_labels(), n_labels=2, n_instances=3)
train_loader = DataLoader(train_dataset, batch_sampler=batch_sampler)

# val
val_dataset = ImageQueryGalleryLabeledDataset(df_val)
val_loader = DataLoader(val_dataset, batch_size=4)
metric_callback = MetricValCallback(
    metric=EmbeddingMetrics(dataset=val_dataset, postprocessor=ConstantThresholding(0.8)),
    log_images=True
)

# 1) Logging with Tensorboard
logger = logging.TensorBoardPipelineLogger(".")

# 2) Logging with Neptune
# logger = logging.NeptunePipelineLogger(api_key="", project="", log_model_checkpoints=False)

# 3) Logging with Weights and Biases
# import os
# os.environ["WANDB_API_KEY"] = ""
# logger = logging.WandBPipelineLogger(project="test_project", log_model=False)

# 4) Logging with MLFlow locally
# logger = logging.MLFlowPipelineLogger(experiment_name="exp", tracking_uri="file:./ml-runs")

# 5) Logging with ClearML
# logger = logging.ClearMLPipelineLogger(project_name="exp", task_name="test")

# run
pl_model = ExtractorModule(extractor, criterion, optimizer)
trainer = pl.Trainer(max_epochs=3, callbacks=[metric_callback], num_sanity_val_steps=0, logger=logger)
trainer.fit(pl_model, train_dataloaders=train_loader, val_dataloaders=val_loader)

PyTorch Lightning: DDP

pip install open-metric-learning[nlp]
import pytorch_lightning as pl
from torch.utils.data import DataLoader
from torch.optim import Adam

from oml import datasets as d
from oml.lightning import ExtractorModuleDDP
from oml.lightning import MetricValCallback
from oml.losses import TripletLossWithMiner
from oml.metrics import EmbeddingMetrics
from oml.miners import AllTripletsMiner
from oml.models import HFWrapper
from oml.samplers import BalanceSampler
from oml.utils import get_mock_texts_dataset
from oml.retrieval import AdaptiveThresholding
from pytorch_lightning.strategies import DDPStrategy

from transformers import AutoModel, AutoTokenizer

df_train, df_val = get_mock_texts_dataset()

# model
extractor = HFWrapper(AutoModel.from_pretrained("bert-base-uncased"), 768)
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

# train
optimizer = Adam(extractor.parameters(), lr=1e-6)
train_dataset = d.TextLabeledDataset(df_train, tokenizer=tokenizer)
criterion = TripletLossWithMiner(margin=0.1, miner=AllTripletsMiner())
batch_sampler = BalanceSampler(train_dataset.get_labels(), n_labels=2, n_instances=3)
train_loader = DataLoader(train_dataset, batch_sampler=batch_sampler)

# val
val_dataset = d.TextQueryGalleryLabeledDataset(df_val, tokenizer=tokenizer)
val_loader = DataLoader(val_dataset, batch_size=4)
metric_callback = MetricValCallback(metric=EmbeddingMetrics(dataset=val_dataset, postprocessor=AdaptiveThresholding(n_std=3)))

# run
pl_model = ExtractorModuleDDP(extractor=extractor, criterion=criterion, optimizer=optimizer,
                              loaders_train=train_loader, loaders_val=val_loader  # DDP specific
                              )

ddp_args = {"accelerator": "cpu", "devices": 2, "strategy": DDPStrategy(), "use_distributed_sampler": False}  # DDP specific
trainer = pl.Trainer(max_epochs=1, callbacks=[metric_callback], num_sanity_val_steps=0, **ddp_args)
trainer.fit(pl_model)  # we don't pass loaders to .fit() in DDP

PyTorch Lightning: Deal with 2 validation loaders

import pytorch_lightning as pl

from torch.utils.data import DataLoader

from oml.datasets import ImageQueryGalleryLabeledDataset
from oml.lightning.callbacks.metric import MetricValCallback
from oml.lightning.modules.extractor import ExtractorModule
from oml.metrics import EmbeddingMetrics
from oml.models import ViTExtractor
from oml.transforms.images.torchvision import get_normalisation_resize_torch
from oml.utils import get_mock_images_dataset

_, df_val = get_mock_images_dataset(global_paths=True)

extractor = ViTExtractor("vits16_dino", arch="vits16", normalise_features=False)

# 1st validation dataset (big images)
val_dataset_1 = ImageQueryGalleryLabeledDataset(df_val, transform=get_normalisation_resize_torch(im_size=224))
val_loader_1 = DataLoader(val_dataset_1, batch_size=4)
metric_callback_1 = MetricValCallback(
    metric=EmbeddingMetrics(dataset=val_dataset_1), log_images=True, loader_idx=0
)

# 2nd validation dataset (small images)
val_dataset_2 = ImageQueryGalleryLabeledDataset(df_val, transform=get_normalisation_resize_torch(im_size=48))
val_loader_2 = DataLoader(val_dataset_2, batch_size=4)
metric_callback_2 = MetricValCallback(
    metric=EmbeddingMetrics(dataset=val_dataset_2), log_images=True, loader_idx=1
)

# run validation
pl_model = ExtractorModule(extractor, None, None)
trainer = pl.Trainer(max_epochs=3, callbacks=[metric_callback_1, metric_callback_2], num_sanity_val_steps=0)
trainer.validate(pl_model, dataloaders=(val_loader_1, val_loader_2))

print(metric_callback_1.metric.retrieval_results)
print(metric_callback_2.metric.retrieval_results)


Usage with PyTorch Metric Learning

You can easily access a lot of content from PyTorch Metric Learning. The examples below are different from the basic ones only in a few lines of code:

Losses from PyTorch Metric Learning

pip install pytorch-metric-learning
from torch.optim import Adam
from torch.utils.data import DataLoader

from oml.datasets import ImageLabeledDataset
from oml.models import ViTExtractor
from oml.samplers import BalanceSampler
from oml.utils import get_mock_images_dataset

from pytorch_metric_learning import losses

df_train, _ = get_mock_images_dataset(global_paths=True)

extractor = ViTExtractor("vits16_dino", arch="vits16", normalise_features=False).train()
optimizer = Adam(extractor.parameters(), lr=1e-4)

train_dataset = ImageLabeledDataset(df_train)

# PML specific
# criterion = losses.TripletMarginLoss(margin=0.2, triplets_per_anchor="all")
criterion = losses.ArcFaceLoss(num_classes=df_train["label"].nunique(), embedding_size=extractor.feat_dim)  # for classification-like losses

sampler = BalanceSampler(train_dataset.get_labels(), n_labels=2, n_instances=2)
train_loader = DataLoader(train_dataset, batch_sampler=sampler)

for batch in train_loader:
    embeddings = extractor(batch["input_tensors"])
    loss = criterion(embeddings, batch["labels"])
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

Losses from PyTorch Metric Learning: advanced

pip install pytorch-metric-learning
from torch.utils.data import DataLoader
from torch.optim import Adam

from oml.datasets import ImageLabeledDataset
from oml.models import ViTExtractor
from oml.samplers import BalanceSampler
from oml.utils import get_mock_images_dataset
from oml.transforms.images.torchvision import get_augs_torch

from pytorch_metric_learning import losses, distances, reducers, miners

df_train, _ = get_mock_images_dataset(global_paths=True)

extractor = ViTExtractor("vits16_dino", arch="vits16", normalise_features=False).train()
optimizer = Adam(extractor.parameters(), lr=1e-4)
train_dataset = ImageLabeledDataset(df_train, transform=get_augs_torch(im_size=224))

# PML specific
distance = distances.LpDistance(p=2)
reducer = reducers.ThresholdReducer(low=0)
criterion = losses.TripletMarginLoss()
miner = miners.TripletMarginMiner(margin=0.2, distance=distance, type_of_triplets="all")

sampler = BalanceSampler(train_dataset.get_labels(), n_labels=2, n_instances=2)
train_loader = DataLoader(train_dataset, batch_sampler=sampler)

for batch in train_loader:
    embeddings = extractor(batch["input_tensors"])
    loss = criterion(embeddings, batch["labels"], miner(embeddings, batch["labels"]))  # PML specific
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()


To use content from PyTorch Metric Learning (PML) with our Pipelines just follow the standard tutorial of adding custom loss.

Note! During the validation process OpenMetricLearning computes L2 distances. Thus, when choosing a distance from PML, we recommend you to pick distances.LpDistance(p=2).

Handling categories

Category is something that hierarchically unites a group of labels. For example, we have 3 different catalog items of tables with the label``s like ``table1, table2, table3 and their category is tables.

Categories in training:

  • Category balanced sampling may help to deal with category imbalance.

  • For contrastive losses, limiting the number of categories in batches may help to mine harder negative samples (another table is harder positive example than another sofa). Without such samples there is no guarantee that we get enough tables in the batch.

Categories in validation:

  • Having categories allows to obtain fine-grained metrics and recognize over- and under- performing subsets of the dataset.

See example
pip install open-metric-learning[nlp]
from pprint import pprint

import numpy as np
from torch.optim import Adam
from torch.utils.data import DataLoader

from oml import datasets as d
from oml.inference import inference
from oml.losses import TripletLossWithMiner
from oml.metrics import calc_retrieval_metrics_rr
from oml.miners import AllTripletsMiner
from oml.models import ViTExtractor
from oml.retrieval import RetrievalResults
from oml.samplers import DistinctCategoryBalanceSampler, CategoryBalanceSampler
from oml.utils import get_mock_images_dataset
from oml.registry import get_transforms_for_pretrained

model = ViTExtractor.from_pretrained("vits16_dino").to("cpu")
transforms, _ = get_transforms_for_pretrained("vits16_dino")

df_train, df_val = get_mock_images_dataset(df_name="df_with_category.csv", global_paths=True)
train = d.ImageLabeledDataset(df_train, transform=transforms)
val = d.ImageQueryGalleryLabeledDataset(df_val, transform=transforms)

optimizer = Adam(model.parameters(), lr=1e-4)
criterion = TripletLossWithMiner(0.1, AllTripletsMiner(), need_logs=True)

# >>>>> You can use one of category aware samplers
args = {"n_categories": 2, "n_labels": 2, "n_instances": 2, "label2category": train.get_label2category(), "labels": train.get_labels()}
sampler = DistinctCategoryBalanceSampler(epoch_size=5, **args)
# sampler = CategoryBalanceSampler(resample_labels=False, weight_categories=True, **args)  # a bit different sampling


def training():
    for batch in DataLoader(train, batch_sampler=sampler):
        embeddings = model(batch["input_tensors"])
        loss = criterion(embeddings, batch["labels"])
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        pprint(criterion.last_logs)


def validation():
    embeddings = inference(model, val, batch_size=4, num_workers=0)
    rr = RetrievalResults.from_embeddings(embeddings, val, n_items=3)
    rr.visualize(query_ids=[2, 1], dataset=val, show=True)

    # >>>> When query categories are known we may get fine-grained metrics
    query_categories = np.array(val.extra_data["category"])[val.get_query_ids()]
    pprint(calc_retrieval_metrics_rr(rr, query_categories=query_categories, map_top_k=(3,), cmc_top_k=(1,)))


training()
validation()

Handling sequences of photos

The below is mostly related to animal or person re-identification tasks, where observations are often done in the form of sequences of frames. The problem appears when calculating retrieval metrics, because the closest retrieved images most likely will be the neighbor frames from the same sequence as a query. Thus, we get good values of metrics. but don’t really understand what is going on. So, it’s better to ignore photos taken from the same sequence as a given query.

If you take a look at standard Re-id benchmarks as MARS dataset, you may see that ignoring frames from the same camera is a part of the actual protocol. Following the same logic, we introduced sequence field in our dataset format.

If sequence ids are provided, retrieved items having the same sequence id as a given query will be ignored.

Below is an example of how to label consecutive shoots of the tiger with the same sequence:

On the figure below we show how provided sequence labels affect metrics calculation:

metric

consider sequence?

value

CMC@1

no (top figure)

1.0

CMC@1

yes (bottom figure)

0.0

Precision@2

no (top figure)

0.5

Precision@2

yes (bottom figure)

0.5

To use this functionality you only need to provide sequence column in your dataframe (containing strings or integers) and pass sequence_key to EmbeddingMetrics():

See example
from oml.inference import inference
from oml.datasets import ImageQueryGalleryLabeledDataset
from oml.models import ViTExtractor
from oml.retrieval import RetrievalResults
from oml.utils import get_mock_images_dataset
from oml.metrics import calc_retrieval_metrics_rr

extractor = ViTExtractor("vits16_dino", arch="vits16", normalise_features=False).to("cpu")

_, df_val = get_mock_images_dataset(global_paths=True, df_name="df_with_sequence.csv")  # <- sequence info is in the file
dataset = ImageQueryGalleryLabeledDataset(df_val)
embeddings = inference(extractor, dataset, batch_size=4, num_workers=0)

rr = RetrievalResults.from_embeddings(embeddings, dataset, n_items=5)
rr.visualize(query_ids=[2, 1], dataset=dataset, show=True)

metrics = calc_retrieval_metrics_rr(rr, map_top_k=(3, 5), precision_top_k=(5,), cmc_top_k=(3,))
print(rr, "\n", metrics)