PyTorch Lightning
MetricValCallback
- class oml.lightning.callbacks.metric.MetricValCallback(metric: IBasicMetric, log_images: bool = False, loader_idx: int = 0, samples_in_getitem: int = 1)[source]
Bases:
Callback
This is a wrapper which allows to use IBasicMetric with PyTorch Lightning.
- __init__(metric: IBasicMetric, log_images: bool = False, loader_idx: int = 0, samples_in_getitem: int = 1)[source]
- Parameters
metric – Metric
log_images – Set
True
if you want to have visual loggingloader_idx – Idx of the loader to calculate metric for
samples_in_getitem – Some of the datasets return several samples when calling
__getitem__
, so we need to handle it for the proper calculation. For most of the cases this value equals to 1, but for the dataset which explicitly return triplets, this value must be equal to 3, for a dataset of pairs it must be equal to 2.
ExtractorModule
- class oml.lightning.modules.extractor.ExtractorModule(extractor: IExtractor, criterion: Optional[Module] = None, optimizer: Optional[Optimizer] = None, scheduler: Optional[_LRScheduler] = None, scheduler_interval: str = 'step', scheduler_frequency: int = 1, input_tensors_key: str = 'input_tensors', labels_key: str = 'labels', embeddings_key: str = 'embeddings', scheduler_monitor_metric: Optional[str] = None, freeze_n_epochs: int = 0)[source]
Bases:
LightningModule
This is a base module to train your model with Lightning.
- __init__(extractor: IExtractor, criterion: Optional[Module] = None, optimizer: Optional[Optimizer] = None, scheduler: Optional[_LRScheduler] = None, scheduler_interval: str = 'step', scheduler_frequency: int = 1, input_tensors_key: str = 'input_tensors', labels_key: str = 'labels', embeddings_key: str = 'embeddings', scheduler_monitor_metric: Optional[str] = None, freeze_n_epochs: int = 0)[source]
- Parameters
extractor – Extractor to train
criterion – Criterion to optimize
optimizer – Optimizer
scheduler – Learning rate scheduler
scheduler_interval – Interval of calling scheduler (must be
step
orepoch
)scheduler_frequency – Frequency of calling scheduler
input_tensors_key – Key to get tensors from the batches
labels_key – Key to get labels from the batches
embeddings_key – Key to get embeddings from the batches
scheduler_monitor_metric – Metric to monitor for the schedulers that depend on the metric value
freeze_n_epochs – number of epochs to freeze model (for n > 0 model has to be a successor of
IFreezable
interface). Whencurrent_epoch >= freeze_n_epochs
model is unfreezed. Note that epochs are starting with 0.
extractor_training_pipeline
- oml.lightning.pipelines.train.extractor_training_pipeline(cfg: Union[Dict[str, Any], DictConfig]) None [source]
This pipeline allows you to train and validate a feature extractor which represents images as feature vectors.
The config can be specified as a dictionary or with hydra: https://hydra.cc/. For more details look at
pipelines/features_extraction/README.md
extractor_validation_pipeline
- oml.lightning.pipelines.validate.extractor_validation_pipeline(cfg: Union[Dict[str, Any], DictConfig]) Tuple[Trainer, Dict[str, Any]] [source]
This pipeline allows you to validate a feature extractor which represents images as feature vectors.
The config can be specified as a dictionary or with hydra: https://hydra.cc/. For more details look at
pipelines/features_extraction/README.md
extractor_prediction_pipeline
- oml.lightning.pipelines.predict.extractor_prediction_pipeline(cfg: Union[Dict[str, Any], DictConfig]) None [source]
This pipeline allows you to save features extracted by a feature extractor.
The config can be specified as a dictionary or with hydra: https://hydra.cc/. For more details look at
pipelines/features_extraction/README.md
postprocessor_training_pipeline
- oml.lightning.pipelines.train_postprocessor.postprocessor_training_pipeline(cfg: DictConfig) None [source]
This pipeline allows you to train and validate a pairwise postprocessor which fixes mistakes of a feature extractor in retrieval setup.
The config can be specified as a dictionary or with hydra: https://hydra.cc/. For more details look at
pipelines/postprocessing/pairwise_postprocessing/README.md