
ㅤ
OML is a PyTorch-based framework to train and validate the models producing high-quality embeddings.
OML features
Losses |
Minersminer = AllTripletsMiner()
miner = NHardTripletsMiner()
miner = MinerWithBank()
...
criterion = TripletLossWithMiner(0.1, miner)
criterion = ArcFaceLoss()
criterion = SurrogatePrecision()
|
Samplerslabels = train.get_labels()
l2c = train.get_label2category()
sampler = BalanceSampler(labels)
sampler = CategoryBalanceSampler(labels, l2c)
sampler = DistinctCategoryBalanceSampler(labels, l2c)
|
Configs supportmax_epochs: 10
sampler:
name: balance
args:
n_labels: 2
n_instances: 2
|
Pre-trained models of different modalitiesmodel_hf = AutoModel.from_pretrained("roberta-base")
tokenizer = AutoTokenizer.from_pretrained("roberta-base")
extractor_txt = HFWrapper(model_hf)
extractor_img = ViTExtractor.from_pretrained("vits16_dino")
transforms, _ = get_transforms_for_pretrained("vits16_dino")
extractor_audio = ECAPATDNNExtractor.from_pretrained()
|
Post-processingemb = inference(extractor, dataset)
rr = RetrievalResults.from_embeddings(emb, dataset)
postprocessor = AdaptiveThresholding()
rr_upd = postprocessor.process(rr, dataset)
|
Post-processing by NN |
Paperembeddings = inference(extractor, dataset)
rr = RetrievalResults.from_embeddings(embeddings, dataset)
postprocessor = PairwiseReranker(ConcatSiamese(), top_n=3)
rr_upd = postprocessor.process(rr, dataset)
|
Logginglogger = TensorBoardPipelineLogger()
logger = NeptunePipelineLogger()
logger = WandBPipelineLogger()
logger = MLFlowPipelineLogger()
logger = ClearMLPipelineLogger()
|
PMLfrom pytorch_metric_learning import losses
criterion = losses.TripletMarginLoss(0.2, "all")
pred = ViTExtractor()(data)
criterion(pred, gts)
|
Categories support# train
loader = DataLoader(CategoryBalanceSampler())
# validation
rr = RetrievalResults.from_embeddings()
m.calc_retrieval_metrics_rr(rr, query_categories)
|
Misc metricsembeddigs = inference(model, dataset)
rr = RetrievalResults.from_embeddings(embeddings, dataset)
m.calc_retrieval_metrics_rr(rr, precision_top_k=(5,))
m.calc_fnmr_at_fmr_rr(rr, fmr_vals=(0.1,))
m.calc_topological_metrics(embeddings, pcf_variance=(0.5,))
|
Lightningimport pytorch_lightning as pl
model = ViTExtractor.from_pretrained("vits16_dino")
clb = MetricValCallback(EmbeddingMetrics(dataset))
module = ExtractorModule(model, criterion, optimizer)
trainer = pl.Trainer(max_epochs=3, callbacks=[clb])
trainer.fit(module, train_loader, val_loader)
|
Lightning DDPclb = MetricValCallback(EmbeddingMetrics(val))
module = ExtractorModuleDDP(
model, criterion, optimizer, train, val
)
ddp = {"devices": 2, "strategy": DDPStrategy()}
trainer = pl.Trainer(max_epochs=3, callbacks=[clb], **ddp)
trainer.fit(module)
|
Features extraction
Postprocessing
Contents
- Base Interfaces
- IExtractor
- IPairwiseModel
- IFreezable
- IBatchSampler
- ITripletLossWithMiner
- IIndexedDataset
- IBaseDataset
- ILabeledDataset
- IQueryGalleryDataset
- IQueryGalleryLabeledDataset
- IPairDataset
- IVisualizableDataset
- IHTMLVisualizableDataset
- IBasicMetric
- ITripletsMiner
- ITripletsMinerInBatch
- IPipelineLogger
- IRetrievalPostprocessor
- Datasets
- ImageBaseDataset
- ImageLabeledDataset
- ImageQueryGalleryLabeledDataset
- ImageQueryGalleryDataset
- TextBaseDataset
- TextLabeledDataset
- TextQueryGalleryLabeledDataset
- TextQueryGalleryDataset
- AudioBaseDataset
- AudioLabeledDataset
- AudioQueryGalleryDataset
- AudioQueryGalleryLabeledDataset
- PairDataset
- get_mock_images_dataset
- get_mock_texts_dataset
- get_mock_audios_dataset
- Samplers
- Miners
- Losses
- Models
- Metrics
- PyTorch Lightning
- Utils
- DDP
- Retrieval & Post-processing