Examples
Using Python-API is the most flexible approach: you are not limited by our project & config structures and you can use only the needed part of OML’s functionality. You will find code snippets below to train, validate and inference the model on a tiny dataset of figures. Here are more details regarding dataset format.
Schemas, explanations and tips illustrating the code below.
Training
import torch
from tqdm import tqdm
from oml.datasets.base import DatasetWithLabels
from oml.losses.triplet import TripletLossWithMiner
from oml.miners.inbatch_all_tri import AllTripletsMiner
from oml.models import ViTExtractor
from oml.samplers.balance import BalanceSampler
from oml.utils.download_mock_dataset import download_mock_dataset
dataset_root = "mock_dataset/"
df_train, _ = download_mock_dataset(dataset_root)
extractor = ViTExtractor("vits16_dino", arch="vits16", normalise_features=False).train()
optimizer = torch.optim.SGD(extractor.parameters(), lr=1e-6)
train_dataset = DatasetWithLabels(df_train, dataset_root=dataset_root)
criterion = TripletLossWithMiner(margin=0.1, miner=AllTripletsMiner(), need_logs=True)
sampler = BalanceSampler(train_dataset.get_labels(), n_labels=2, n_instances=2)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_sampler=sampler)
for batch in tqdm(train_loader):
embeddings = extractor(batch["input_tensors"])
loss = criterion(embeddings, batch["labels"])
loss.backward()
optimizer.step()
optimizer.zero_grad()
# info for logging: positive/negative distances, number of active triplets
print(criterion.last_logs)
Validation
import torch
from tqdm import tqdm
from oml.datasets.base import DatasetQueryGallery
from oml.metrics.embeddings import EmbeddingMetrics
from oml.models import ViTExtractor
from oml.utils.download_mock_dataset import download_mock_dataset
dataset_root = "mock_dataset/"
_, df_val = download_mock_dataset(dataset_root)
extractor = ViTExtractor("vits16_dino", arch="vits16", normalise_features=False).eval()
val_dataset = DatasetQueryGallery(df_val, dataset_root=dataset_root)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=4)
calculator = EmbeddingMetrics(extra_keys=("paths",))
calculator.setup(num_samples=len(val_dataset))
with torch.no_grad():
for batch in tqdm(val_loader):
batch["embeddings"] = extractor(batch["input_tensors"])
calculator.update_data(batch)
metrics = calculator.compute_metrics()
# Logging
print(calculator.metrics) # metrics
print(calculator.metrics_unreduced) # metrics without averaging over queries
# Visualisation
calculator.get_plot_for_queries(query_ids=[0, 2], n_instances=5) # draw predictions on predefined queries
calculator.get_plot_for_worst_queries(metric_name="OVERALL/map/5", n_queries=2, n_instances=5) # draw mistakes
calculator.visualize() # draw mistakes for all the available metrics
Using a trained model for retrieval
import torch
from oml.const import MOCK_DATASET_PATH
from oml.inference.flat import inference_on_images
from oml.models import ViTExtractor
from oml.registry.transforms import get_transforms_for_pretrained
from oml.utils.download_mock_dataset import download_mock_dataset
from oml.utils.misc_torch import pairwise_dist
_, df_val = download_mock_dataset(MOCK_DATASET_PATH)
df_val["path"] = df_val["path"].apply(lambda x: MOCK_DATASET_PATH / x)
queries = df_val[df_val["is_query"]]["path"].tolist()
galleries = df_val[df_val["is_gallery"]]["path"].tolist()
extractor = ViTExtractor.from_pretrained("vits16_dino")
transform, _ = get_transforms_for_pretrained("vits16_dino")
args = {"num_workers": 0, "batch_size": 8}
features_queries = inference_on_images(extractor, paths=queries, transform=transform, **args)
features_galleries = inference_on_images(extractor, paths=galleries, transform=transform, **args)
# Now we can explicitly build pairwise matrix of distances or save you RAM via using kNN
use_knn = False
top_k = 3
if use_knn:
from sklearn.neighbors import NearestNeighbors
knn = NearestNeighbors(algorithm="auto", p=2)
knn.fit(features_galleries)
dists, ii_closest = knn.kneighbors(features_queries, n_neighbors=top_k, return_distance=True)
else:
dist_mat = pairwise_dist(x1=features_queries, x2=features_galleries)
dists, ii_closest = torch.topk(dist_mat, dim=1, k=top_k, largest=False)
print(f"Top {top_k} items closest to queries are:\n {ii_closest}")
Training + Validation [Lightning and logging]
import pytorch_lightning as pl
import torch
from oml.datasets.base import DatasetQueryGallery, DatasetWithLabels
from oml.lightning.modules.extractor import ExtractorModule
from oml.lightning.callbacks.metric import MetricValCallback
from oml.losses.triplet import TripletLossWithMiner
from oml.metrics.embeddings import EmbeddingMetrics
from oml.miners.inbatch_all_tri import AllTripletsMiner
from oml.models import ViTExtractor
from oml.samplers.balance import BalanceSampler
from oml.utils.download_mock_dataset import download_mock_dataset
from oml.lightning.pipelines.logging import (
ClearMLPipelineLogger,
MLFlowPipelineLogger,
NeptunePipelineLogger,
TensorBoardPipelineLogger,
WandBPipelineLogger,
)
dataset_root = "mock_dataset/"
df_train, df_val = download_mock_dataset(dataset_root)
# model
extractor = ViTExtractor("vits16_dino", arch="vits16", normalise_features=False)
# train
optimizer = torch.optim.SGD(extractor.parameters(), lr=1e-6)
train_dataset = DatasetWithLabels(df_train, dataset_root=dataset_root)
criterion = TripletLossWithMiner(margin=0.1, miner=AllTripletsMiner())
batch_sampler = BalanceSampler(train_dataset.get_labels(), n_labels=2, n_instances=3)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_sampler=batch_sampler)
# val
val_dataset = DatasetQueryGallery(df_val, dataset_root=dataset_root)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=4)
metric_callback = MetricValCallback(metric=EmbeddingMetrics(extra_keys=[train_dataset.paths_key,]), log_images=True)
# 1) Logging with Tensorboard
logger = TensorBoardPipelineLogger(".")
# 2) Logging with Neptune
# logger = NeptunePipelineLogger(api_key="", project="", log_model_checkpoints=False)
# 3) Logging with Weights and Biases
# import os
# os.environ["WANDB_API_KEY"] = ""
# logger = WandBPipelineLogger(project="test_project", log_model=False)
# 4) Logging with MLFlow locally
# logger = MLFlowPipelineLogger(experiment_name="exp", tracking_uri="file:./ml-runs")
# 5) Logging with ClearML
# logger = ClearMLPipelineLogger(project_name="exp", task_name="test")
# run
pl_model = ExtractorModule(extractor, criterion, optimizer)
trainer = pl.Trainer(max_epochs=3, callbacks=[metric_callback], num_sanity_val_steps=0, logger=logger)
trainer.fit(pl_model, train_dataloaders=train_loader, val_dataloaders=val_loader)
Training + Validation [Lightning Distributed]
import pytorch_lightning as pl
import torch
from oml.datasets.base import DatasetQueryGallery, DatasetWithLabels
from oml.lightning.modules.extractor import ExtractorModuleDDP
from oml.lightning.callbacks.metric import MetricValCallbackDDP
from oml.losses.triplet import TripletLossWithMiner
from oml.metrics.embeddings import EmbeddingMetricsDDP
from oml.miners.inbatch_all_tri import AllTripletsMiner
from oml.models import ViTExtractor
from oml.samplers.balance import BalanceSampler
from oml.utils.download_mock_dataset import download_mock_dataset
from pytorch_lightning.strategies import DDPStrategy
dataset_root = "mock_dataset/"
df_train, df_val = download_mock_dataset(dataset_root)
# model
extractor = ViTExtractor("vits16_dino", arch="vits16", normalise_features=False)
# train
optimizer = torch.optim.SGD(extractor.parameters(), lr=1e-6)
train_dataset = DatasetWithLabels(df_train, dataset_root=dataset_root)
criterion = TripletLossWithMiner(margin=0.1, miner=AllTripletsMiner())
batch_sampler = BalanceSampler(train_dataset.get_labels(), n_labels=2, n_instances=3)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_sampler=batch_sampler)
# val
val_dataset = DatasetQueryGallery(df_val, dataset_root=dataset_root)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=4)
metric_callback = MetricValCallbackDDP(metric=EmbeddingMetricsDDP()) # DDP specific
# run
pl_model = ExtractorModuleDDP(extractor=extractor, criterion=criterion, optimizer=optimizer,
loaders_train=train_loader, loaders_val=val_loader # DDP specific
)
ddp_args = {"accelerator": "cpu", "devices": 2, "strategy": DDPStrategy(), "use_distributed_sampler": False} # DDP specific
trainer = pl.Trainer(max_epochs=1, callbacks=[metric_callback], num_sanity_val_steps=0, **ddp_args)
trainer.fit(pl_model) # we don't pass loaders to .fit() in DDP
Colab: there is no Colab link since it provides only single-GPU machines.
Validation with 2 loaders
import pytorch_lightning as pl
import torch
from oml.datasets.base import DatasetQueryGallery
from oml.lightning.callbacks.metric import MetricValCallback
from oml.lightning.modules.extractor import ExtractorModule
from oml.metrics.embeddings import EmbeddingMetrics
from oml.models import ViTExtractor
from oml.transforms.images.torchvision import get_normalisation_resize_torch
from oml.utils.download_mock_dataset import download_mock_dataset
dataset_root = "mock_dataset/"
_, df_val = download_mock_dataset(dataset_root)
extractor = ViTExtractor("vits16_dino", arch="vits16", normalise_features=False)
# 1st validation dataset (big images)
val_dataset_1 = DatasetQueryGallery(df_val, dataset_root=dataset_root,
transform=get_normalisation_resize_torch(im_size=224))
val_loader_1 = torch.utils.data.DataLoader(val_dataset_1, batch_size=4)
metric_callback_1 = MetricValCallback(metric=EmbeddingMetrics(extra_keys=[val_dataset_1.paths_key,]),
log_images=True, loader_idx=0)
# 2nd validation dataset (small images)
val_dataset_2 = DatasetQueryGallery(df_val, dataset_root=dataset_root,
transform=get_normalisation_resize_torch(im_size=48))
val_loader_2 = torch.utils.data.DataLoader(val_dataset_2, batch_size=4)
metric_callback_2 = MetricValCallback(metric=EmbeddingMetrics(extra_keys=[val_dataset_2.paths_key,]),
log_images=True, loader_idx=1)
# run validation
pl_model = ExtractorModule(extractor, None, None)
trainer = pl.Trainer(max_epochs=3, callbacks=[metric_callback_1, metric_callback_2], num_sanity_val_steps=0)
trainer.validate(pl_model, dataloaders=(val_loader_1, val_loader_2))
print(metric_callback_1.metric.metrics)
print(metric_callback_2.metric.metrics)
Usage with PyTorch Metric Learning
You can easily access a lot of content from PyTorch Metric Learning. The examples below are different from the basic ones only in a few lines of code:
Training with loss from PML
import torch
from tqdm import tqdm
from oml.datasets.base import DatasetWithLabels
from oml.models import ViTExtractor
from oml.samplers.balance import BalanceSampler
from oml.utils.download_mock_dataset import download_mock_dataset
from pytorch_metric_learning import losses, distances, reducers, miners
dataset_root = "mock_dataset/"
df_train, _ = download_mock_dataset(dataset_root)
extractor = ViTExtractor("vits16_dino", arch="vits16", normalise_features=False).train()
optimizer = torch.optim.SGD(extractor.parameters(), lr=1e-6)
train_dataset = DatasetWithLabels(df_train, dataset_root=dataset_root)
# PML specific
# criterion = losses.TripletMarginLoss(margin=0.2, triplets_per_anchor="all")
criterion = losses.ArcFaceLoss(num_classes=df_train["label"].nunique(), embedding_size=extractor.feat_dim) # for classification-like losses
sampler = BalanceSampler(train_dataset.get_labels(), n_labels=2, n_instances=2)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_sampler=sampler)
for batch in tqdm(train_loader):
embeddings = extractor(batch["input_tensors"])
loss = criterion(embeddings, batch["labels"])
loss.backward()
optimizer.step()
optimizer.zero_grad()
Training with distance, reducer, miner and loss from PML
import torch
from tqdm import tqdm
from oml.datasets.base import DatasetWithLabels
from oml.models import ViTExtractor
from oml.samplers.balance import BalanceSampler
from oml.utils.download_mock_dataset import download_mock_dataset
from pytorch_metric_learning import losses, distances, reducers, miners
dataset_root = "mock_dataset/"
df_train, _ = download_mock_dataset(dataset_root)
extractor = ViTExtractor("vits16_dino", arch="vits16", normalise_features=False).train()
optimizer = torch.optim.SGD(extractor.parameters(), lr=1e-6)
train_dataset = DatasetWithLabels(df_train, dataset_root=dataset_root)
# PML specific
distance = distances.LpDistance(p=2)
reducer = reducers.ThresholdReducer(low=0)
criterion = losses.TripletMarginLoss()
miner = miners.TripletMarginMiner(margin=0.2, distance=distance, type_of_triplets="all")
sampler = BalanceSampler(train_dataset.get_labels(), n_labels=2, n_instances=2)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_sampler=sampler)
for batch in tqdm(train_loader):
embeddings = extractor(batch["input_tensors"])
loss = criterion(embeddings, batch["labels"], miner(embeddings, batch["labels"])) # PML specific
loss.backward()
optimizer.step()
optimizer.zero_grad()
To use content from PyTorch Metric Learning with our Pipelines just follow the standard tutorial of adding custom loss.
Note, during the validation process OpenMetricLearning computes L2 distances. Thus, when choosing a distance from PML, we recommend you to pick distances.LpDistance(p=2).
Handling sequences of photos
The below is mostly related to animal or person re-identification tasks, where observations are often done in the form of sequences of frames. The problem appears when calculating retrieval metrics, because the closest retrieved images most likely will be the neighbor frames from the same sequence as a query. Thus, we get good values of metrics. but don’t really understand what is going on. So, it’s better to ignore photos taken from the same sequence as a given query.
If you take a look at standard Re-id benchmarks as MARS
dataset, you may see that ignoring frames from the same camera is a part of the actual protocol.
Following the same logic, we introduced sequence
field in our dataset format.
If sequence ids are provided, retrieved items having the same sequence id as a given query will be ignored.
Below is an example of how to label consecutive shoots of the tiger with the same sequence
:
On the figure below we show how provided sequence labels affect metrics calculation:
metric |
consider sequence? |
value |
---|---|---|
CMC@1 |
no (top figure) |
1.0 |
CMC@1 |
yes (bottom figure) |
0.0 |
Precision@2 |
no (top figure) |
0.5 |
Precision@2 |
yes (bottom figure) |
0.5 |
To use this functionality you only need to provide sequence
column in your dataframe
(containing strings or integers) and pass sequence_key
to EmbeddingMetrics()
:
Validation + handling sequences
import torch
from tqdm import tqdm
from oml.datasets.base import DatasetQueryGallery
from oml.metrics.embeddings import EmbeddingMetrics
from oml.models import ViTExtractor
from oml.utils.download_mock_dataset import download_mock_dataset
dataset_root = "mock_dataset/"
_, df_val = download_mock_dataset(dataset_root, df_name="df_with_sequence.csv") # <- sequence info is in the file
extractor = ViTExtractor("vits16_dino", arch="vits16", normalise_features=False).eval()
val_dataset = DatasetQueryGallery(df_val, dataset_root=dataset_root)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=4)
calculator = EmbeddingMetrics(extra_keys=("paths",), sequence_key=val_dataset.sequence_key)
calculator.setup(num_samples=len(val_dataset))
with torch.no_grad():
for batch in tqdm(val_loader):
batch["embeddings"] = extractor(batch["input_tensors"])
calculator.update_data(batch)
metrics = calculator.compute_metrics()
ㅤ ㅤ