Source code for oml.models.vit_clip.extractor

from pathlib import Path
from typing import Any, Dict, Iterable, Optional

import torch

from oml.interfaces.models import IExtractor
from oml.models.utils import (
    TStateDict,
    filter_state_dict,
    patch_device_and_float,
    remove_criterion_in_state_dict,
    remove_prefix_from_state_dict,
)
from oml.models.vit_clip.external.model import VisionTransformer
from oml.utils.io import download_checkpoint

_OPENAI_URL = "https://openaipublic.azureedge.net/clip/models"
_SBER_URL = "https://huggingface.co/sberbank-ai"


def vitb16_224() -> VisionTransformer:
    return VisionTransformer(
        output_dim=512,
        input_resolution=224,
        layers=12,
        width=768,
        patch_size=16,
        heads=12,
    )


def vitb32_224() -> VisionTransformer:
    return VisionTransformer(
        output_dim=512,
        input_resolution=224,
        layers=12,
        width=768,
        patch_size=32,
        heads=12,
    )


def vitl14_224() -> VisionTransformer:
    return VisionTransformer(
        output_dim=768,
        input_resolution=224,
        layers=24,
        width=1024,
        patch_size=14,
        heads=16,
    )


def vitl14_336() -> VisionTransformer:
    return VisionTransformer(
        output_dim=768,
        input_resolution=336,
        layers=24,
        width=1024,
        patch_size=14,
        heads=16,
    )


[docs]class ViTCLIPExtractor(IExtractor): constructors = { "vitb16_224": vitb16_224, "vitb32_224": vitb32_224, "vitl14_224": vitl14_224, "vitl14_336": vitl14_336, } pretrained_models: Dict[str, Any] = { # checkpoints pretrained by OpenAI "openai_vitb16_224": { "url": f"{_OPENAI_URL}/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt", "hash": "44c3d804ecac03d9545ac1a3adbca3a6", "is_jitted": True, "fname": "openai_vitb16_224.ckpt", "init_args": {"arch": "vitb16_224", "normalise_features": False}, }, "openai_vitb32_224": { "url": f"{_OPENAI_URL}/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", "hash": "3ba34e387b24dfe590eeb1ae6a8a122b", "is_jitted": True, "fname": "openai_vitb32_224.ckpt", "init_args": {"arch": "vitb32_224", "normalise_features": False}, }, "openai_vitl14_224": { "url": f"{_OPENAI_URL}/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt", "hash": "096db1af569b284eb76b3881534822d9", "is_jitted": True, "fname": "openai_vitl14_224.ckpt", "init_args": {"arch": "vitl14_224", "normalise_features": False}, }, # checkpoints pretrained by SberbankAI "sber_vitb16_224": { "url": f"{_SBER_URL}/ruclip-vit-base-patch16-224/resolve/main/pytorch_model.bin", "hash": "7882e07674d78c674e33cb892a68bbfc", "is_jitted": False, "fname": "sber_vitb16_224.ckpt", "init_args": {"arch": "vitb16_224", "normalise_features": False}, }, "sber_vitb32_224": { "url": f"{_SBER_URL}/ruclip-vit-base-patch32-224/resolve/main/pytorch_model.bin", "hash": "e2c4dab46a3cfa608bdd762973e90d32", "is_jitted": False, "fname": "sber_vitb32_224.ckpt", "init_args": {"arch": "vitb32_224", "normalise_features": False}, }, "sber_vitl14_224": { "url": f"{_SBER_URL}/ruclip-vit-large-patch14-224/resolve/main/pytorch_model.bin", "hash": "9b4a1cd25d15bad4ffd2ba6e34b8a67c", "is_jitted": False, "fname": "sber_vitl14_224.ckpt", "init_args": {"arch": "vitl14_224", "normalise_features": False}, }, }
[docs] def __init__( self, weights: Optional[str], arch: str, normalise_features: bool = True, ): """ Args: weights: Path to weights or special key for pretrained ones or ``None`` for random initialization. You can check available pretrained checkpoints in ``ViTCLIPExtractor.pretrained_models``. arch: Might be one of ``vitb16_224``, ``vitb32_224``, ``vitl14_224``, ``vitl14_336``. normalise_features: Set ``True`` to normalise output features """ super().__init__() self.normalize = normalise_features self.visual = self.constructors[arch]() self.input_size = int(arch.split("_")[-1]) if weights is None: return if weights in self.pretrained_models: pretrained = self.pretrained_models[weights] jitted_weights = pretrained["is_jitted"] weights = download_checkpoint(pretrained["url"], pretrained["hash"], fname=pretrained["fname"]) else: jitted_weights = False if jitted_weights: # check if weights are jitted visual = torch.jit.load(Path(weights), map_location="cpu").visual patch_device_and_float(visual, device="cpu") state_dict = visual.state_dict() else: state_dict = torch.load(Path(weights), map_location="cpu", weights_only=False) state_dict = state_dict.get("state_dict", state_dict) state_dict = remove_criterion_in_state_dict(state_dict) state_dict = take_visual_part_of_vit_clip(state_dict, needed_keys=self.visual.state_dict().keys()) state_dict = remove_prefix_from_state_dict(state_dict, trial_key="conv1.weight") self.visual.load_state_dict(state_dict=state_dict, strict=True)
def forward(self, x: torch.Tensor) -> torch.Tensor: assert (x.shape[-2] == self.input_size) and ( x.shape[-1] == self.input_size ), f"The model expects input images to be resized to {self.input_size}x{self.input_size}" res = self.visual.forward(x) if self.normalize: res = res / torch.linalg.norm(res, 2, dim=1, keepdim=True).detach() return res @property def feat_dim(self) -> int: return self.visual.state_dict()["proj"].shape[-1]
def take_visual_part_of_vit_clip(state_dict: TStateDict, needed_keys: Iterable[str]) -> TStateDict: for k in list(state_dict): if "visual" in k: new_key = k[k.find("visual") + len("visual") + 1 :] state_dict[new_key] = state_dict.pop(k) state_dict = filter_state_dict(state_dict, needed_keys=needed_keys) return state_dict __all__ = ["ViTCLIPExtractor"]