Source code for oml.models.meta.projection

from pathlib import Path
from typing import List, Optional, Union

import torch
from torchvision.ops import MLP

from oml.const import STORAGE_CKPTS
from oml.interfaces.models import IExtractor, IFreezable
from oml.models.utils import (
    remove_criterion_in_state_dict,
    remove_prefix_from_state_dict,
)
from oml.models.vit_dino.extractor import ViTExtractor
from oml.utils.io import download_checkpoint


[docs]class ExtractorWithMLP(IExtractor, IFreezable): """ Class-wrapper for extractors which an additional MLP. """ pretrained_models = { "vits16_224_mlp_384_inshop": { # type: ignore "url": f"{STORAGE_CKPTS}/inshop/vits16_224_mlp_384_inshop.ckpt", "hash": "35244966", "fname": "vits16_224_mlp_384_inshop.ckpt", "init_args": { "extractor_creator": lambda: ViTExtractor(None, "vits16", False, use_multi_scale=False), "mlp_features": [384], "train_backbone": True, }, } }
[docs] def __init__( self, extractor: IExtractor, mlp_features: List[int], weights: Optional[Union[str, Path]] = None, train_backbone: bool = False, ): """ Args: extractor: Instance of ``IExtractor`` (e.g. ``ViTExtractor``) mlp_features: Sizes of projection layers weights: Path to weights file or ``None`` for random initialization train_backbone: set ``False`` if you want to train only MLP head """ IExtractor.__init__(self) self.extractor = extractor self.mlp_features = mlp_features self.train_backbone = train_backbone self.projection = MLP(self.extractor.feat_dim, self.mlp_features) if weights: if weights in self.pretrained_models: pretrained = self.pretrained_models[weights] # type: ignore weights = download_checkpoint( url_or_fid=pretrained["url"], # type: ignore hash_md5=pretrained["hash"], # type: ignore fname=pretrained["fname"], # type: ignore ) loaded = torch.load(weights, map_location="cpu", weights_only=False) loaded = loaded.get("state_dict", loaded) loaded = remove_criterion_in_state_dict(loaded) loaded = remove_prefix_from_state_dict(loaded, trial_key="extractor.") self.load_state_dict(loaded, strict=True)
def forward(self, x: torch.Tensor) -> torch.Tensor: with torch.set_grad_enabled(self.train_backbone): features = self.extractor(x) return self.projection(features) @property def feat_dim(self) -> int: return self.mlp_features[-1] def freeze(self) -> None: self.train_backbone = False def unfreeze(self) -> None: self.train_backbone = True @classmethod def from_pretrained(cls, weights: str, **kwargs) -> "IExtractor": # type: ignore # The current class takes another model as a constructor's argument, so, they need to be # in the `self.pretrained_models`. The problem is these models will be instantiated even if we simply # import something from the current module. To avoid it we added the logic of wrapping/unwrapping # constructors into lambda functions. ini = cls.pretrained_models[weights]["init_args"] ini["extractor"] = ini.pop("extractor_creator")() # type: ignore return super().from_pretrained(weights, **kwargs)
__all__ = ["ExtractorWithMLP"]