OML logo

OML is a PyTorch-based framework to train and validate the models producing high-quality embeddings.

OML features

Losses | Miners
miner = AllTripletsMiner()
miner = NHardTripletsMiner()
miner = MinerWithBank()
...
criterion = TripletLossWithMiner(0.1, miner)
criterion = ArcFaceLoss()
criterion = SurrogatePrecision()
Samplers
labels = train.get_labels()
l2c = train.get_label2category()


sampler = BalanceSampler(labels)
sampler = CategoryBalanceSampler(labels, l2c)
sampler = DistinctCategoryBalanceSampler(labels, l2c)
Configs support
max_epochs: 10
sampler:
  name: balance
  args:
    n_labels: 2
    n_instances: 2
Pre-trained models of different modalities
model_hf = AutoModel.from_pretrained("roberta-base")
tokenizer = AutoTokenizer.from_pretrained("roberta-base")
extractor_txt = HFWrapper(model_hf)

extractor_img = ViTExtractor.from_pretrained("vits16_dino")
transforms, _ = get_transforms_for_pretrained("vits16_dino")

extractor_audio = ECAPATDNNExtractor.from_pretrained()
Post-processing
emb = inference(extractor, dataset)
rr = RetrievalResults.from_embeddings(emb, dataset)

postprocessor = AdaptiveThresholding()
rr_upd = postprocessor.process(rr, dataset)
Post-processing by NN | Paper
embeddings = inference(extractor, dataset)
rr = RetrievalResults.from_embeddings(embeddings, dataset)

postprocessor = PairwiseReranker(ConcatSiamese(), top_n=3)
rr_upd = postprocessor.process(rr, dataset)
Logging
logger = TensorBoardPipelineLogger()
logger = NeptunePipelineLogger()
logger = WandBPipelineLogger()
logger = MLFlowPipelineLogger()
logger = ClearMLPipelineLogger()
PML
from pytorch_metric_learning import losses

criterion = losses.TripletMarginLoss(0.2, "all")
pred = ViTExtractor()(data)
criterion(pred, gts)
Categories support
# train
loader = DataLoader(CategoryBalanceSampler())

# validation
rr = RetrievalResults.from_embeddings()
m.calc_retrieval_metrics_rr(rr, query_categories)
Misc metrics
embeddigs = inference(model, dataset)
rr = RetrievalResults.from_embeddings(embeddings, dataset)

m.calc_retrieval_metrics_rr(rr, precision_top_k=(5,))
m.calc_fnmr_at_fmr_rr(rr, fmr_vals=(0.1,))
m.calc_topological_metrics(embeddings, pcf_variance=(0.5,))
Lightning
import pytorch_lightning as pl

model = ViTExtractor.from_pretrained("vits16_dino")
clb = MetricValCallback(EmbeddingMetrics(dataset))
module = ExtractorModule(model, criterion, optimizer)

trainer = pl.Trainer(max_epochs=3, callbacks=[clb])
trainer.fit(module, train_loader, val_loader)
Lightning DDP
clb = MetricValCallback(EmbeddingMetrics(val))
module = ExtractorModuleDDP(
    model, criterion, optimizer, train, val
)

ddp = {"devices": 2, "strategy": DDPStrategy()}
trainer = pl.Trainer(max_epochs=3, callbacks=[clb], **ddp)
trainer.fit(module)

Contents