Source code for oml.interfaces.models

from abc import ABC
from typing import Any, Dict

from torch import Tensor, nn


[docs]class IExtractor(nn.Module, ABC): """ Models have to inherit this interface to be comparable with the rest of the library. """ pretrained_models: Dict[str, Any] = {}
[docs] def extract(self, x: Tensor) -> Tensor: return self.forward(x)
@property def feat_dim(self) -> int: """ The only method that obligatory to implemented. """ raise NotImplementedError()
[docs] @classmethod def from_pretrained(cls, weights: str, **kwargs) -> "IExtractor": # type: ignore """ This method allows to download a pretrained checkpoint. The class field ``self.pretrained_models`` is the dictionary which keeps records of all the available checkpoints in the format, depending on implementation of a particular child of ``IExtractor``. As a user, you don't need to worry about implementing this method. Args: weights: A unique identifier (key) of a pretrained model information stored in a class field ``self.pretrained_models``. Returns: An instance of ``IExtractor`` """ if weights not in cls.pretrained_models: raise KeyError( f"There is no pretrained model {weights}. The existing ones are {list(cls.pretrained_models.keys())}." ) extractor = cls(weights=weights, **cls.pretrained_models[weights]["init_args"], **kwargs) # type: ignore return extractor
[docs]class IFreezable(ABC): """ Models which can freeze and unfreeze their parts. """
[docs] def freeze(self) -> None: """ Function for freezing. You can use it to partially freeze a model. """ raise NotImplementedError()
[docs] def unfreeze(self) -> None: """ Function for unfreezing. You can use it to unfreeze a model. """ raise NotImplementedError()
[docs]class IPairwiseModel(nn.Module): """ A model of this type takes two inputs, for example, two embeddings or two images. """
[docs] def forward(self, x1: Any, x2: Any) -> Tensor: """ Args: x1: The first input. x2: The second input. """ raise NotImplementedError()
[docs] def predict(self, x1: Any, x2: Any) -> Tensor: """ While ``self.forward()`` is called during training, this method is called during inference or validation time. For example, it allows application of some activation, which was a part of a loss function during the training. Args: x1: The first input. x2: The second input. """ raise NotImplementedError()
__all__ = ["IExtractor", "IFreezable", "IPairwiseModel"]