Source code for oml.models.vit_dino.extractor

from pathlib import Path
from typing import Optional, Union

import numpy as np
import PIL
import torch
from PIL.Image import Image as TPILImage
from torch import nn

from oml.const import MEAN, STD, STORAGE_CKPTS, TNormParam
from oml.interfaces.models import IExtractor
from oml.models.utils import (
    remove_criterion_in_state_dict,
    remove_prefix_from_state_dict,
)
from oml.models.vit_dino.external.hubconf import (  # type: ignore
    dino_vitb8,
    dino_vitb16,
    dino_vits8,
    dino_vits16,
)
from oml.models.vit_dino.external_v2.hubconf import (  # type: ignore
    dinov2_vitb14,
    dinov2_vitb14_reg,
    dinov2_vitl14,
    dinov2_vitl14_reg,
    dinov2_vits14,
    dinov2_vits14_reg,
)
from oml.transforms.images.albumentations import get_normalisation_albu
from oml.utils.io import download_checkpoint_one_of
from oml.utils.misc_torch import normalise, temporary_setting_model_mode

_FB_URL = "https://dl.fbaipublicfiles.com"


[docs]class ViTExtractor(IExtractor): """ The base class for the extractors that follow VisualTransformer architecture. """ constructors = { "vits8": dino_vits8, "vits16": dino_vits16, "vitb8": dino_vitb8, "vitb16": dino_vitb16, "vits14": dinov2_vits14, "vitb14": dinov2_vitb14, "vitl14": dinov2_vitl14, "vits14_reg": dinov2_vits14_reg, "vitb14_reg": dinov2_vitb14_reg, "vitl14_reg": dinov2_vitl14_reg, } pretrained_models = { # checkpoints pretrained in DINO framework on ImageNet by MetaAI "vits16_dino": { "url": f"{_FB_URL}/dino/dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth", "hash": "cf0f22", "fname": "vits16_dino.ckpt", "init_args": {"arch": "vits16", "normalise_features": False}, }, "vits8_dino": { "url": f"{_FB_URL}/dino/dino_deitsmall8_pretrain/dino_deitsmall8_pretrain.pth", "hash": "230cd5", "fname": "vits8_dino.ckpt", "init_args": {"arch": "vits8", "normalise_features": False}, }, "vitb16_dino": { "url": f"{_FB_URL}/dino/dino_vitbase16_pretrain/dino_vitbase16_pretrain.pth", "hash": "552daf", "fname": "vitb16_dino.ckpt", "init_args": {"arch": "vitb16", "normalise_features": False}, }, "vitb8_dino": { "url": f"{_FB_URL}/dino/dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth", "hash": "556550", "fname": "vitb8_dino.ckpt", "init_args": {"arch": "vitb8", "normalise_features": False}, }, "vits14_dinov2": { "url": f"{_FB_URL}/dinov2/dinov2_vits14/dinov2_vits14_pretrain.pth", "hash": "2e405c", "fname": "dinov2_vits14.ckpt", "init_args": {"arch": "vits14", "normalise_features": False}, }, "vits14_reg_dinov2": { "url": f"{_FB_URL}/dinov2/dinov2_vits14/dinov2_vits14_reg4_pretrain.pth", "hash": "2a50c5", "fname": "dinov2_vits14_reg4.ckpt", "init_args": {"arch": "vits14_reg", "normalise_features": False}, }, "vitb14_dinov2": { "url": f"{_FB_URL}/dinov2/dinov2_vitb14/dinov2_vitb14_pretrain.pth", "hash": "8635e7", "fname": "dinov2_vitb14.ckpt", "init_args": {"arch": "vitb14", "normalise_features": False}, }, "vitb14_reg_dinov2": { "url": f"{_FB_URL}/dinov2/dinov2_vitb14/dinov2_vitb14_reg4_pretrain.pth", "hash": "13d13c", "fname": "dinov2_vitb14_reg4.ckpt", "init_args": {"arch": "vitb14_reg", "normalise_features": False}, }, "vitl14_dinov2": { "url": f"{_FB_URL}/dinov2/dinov2_vitl14/dinov2_vitl14_pretrain.pth", "hash": "19a02c", "fname": "dinov2_vitl14.ckpt", "init_args": {"arch": "vitl14", "normalise_features": False}, }, "vitl14_reg_dinov2": { "url": f"{_FB_URL}/dinov2/dinov2_vitl14/dinov2_vitl14_reg4_pretrain.pth", "hash": "8b6364", "fname": "dinov2_vitl14_reg4.ckpt", "init_args": {"arch": "vitl14_reg", "normalise_features": False}, }, # our pretrained checkpoints "vits16_inshop": { "url": [ f"{STORAGE_CKPTS}/inshop/vits16_inshop_a76b85.ckpt", "1niX-TC8cj6j369t7iU2baHQSVN3MVJbW", ], "hash": "a76b85", "fname": "vits16_inshop.ckpt", "init_args": {"arch": "vits16", "normalise_features": False}, }, "vits16_sop": { "url": [ f"{STORAGE_CKPTS}/sop/vits16_sop_21e743.ckpt", "1zuGRHvF2KHd59aw7i7367OH_tQNOGz7A", ], "hash": "21e743", "fname": "vits16_sop.ckpt", "init_args": {"arch": "vits16", "normalise_features": True}, }, "vits16_cub": { "url": [ f"{STORAGE_CKPTS}/cub/vits16_cub.ckpt", "1p2tUosFpGXh5sCCdzlXtjV87kCDfG34G", ], "hash": "e82633", "fname": "vits16_cub.ckpt", "init_args": {"arch": "vits16", "normalise_features": False}, }, "vits16_cars": { "url": [ f"{STORAGE_CKPTS}/cars/vits16_cars.ckpt", "1hcOxDRRXrKr6ZTCyBauaY8Ue-pok4Icg", ], "hash": "9f1e59", "fname": "vits16_cars.ckpt", "init_args": {"arch": "vits16", "normalise_features": False}, }, }
[docs] def __init__( self, weights: Optional[Union[Path, str]], arch: str, normalise_features: bool, use_multi_scale: bool = False, ): """ Args: weights: Path to weights or a special key to download pretrained checkpoint, use ``None`` to randomly initialize model's weights. You can check the available pretrained checkpoints in ``self.pretrained_models``. arch: Might be one of ``vits8``, ``vits16``, ``vitb8``, ``vitb16``. You can check all the available options in ``self.constructors`` normalise_features: Set ``True`` to normalise output features use_multi_scale: Set ``True`` to use multiscale (the analogue of test time augmentations) """ assert arch in self.constructors super(ViTExtractor, self).__init__() self.normalise_features = normalise_features self.mscale = use_multi_scale self.arch = arch factory_fun = self.constructors[self.arch] self.model = factory_fun(pretrained=False) if weights is None: return if weights in self.pretrained_models: pretrained = self.pretrained_models[weights] # type: ignore weights = download_checkpoint_one_of( url_or_fid_list=pretrained["url"], # type: ignore hash_md5=pretrained["hash"], # type: ignore fname=pretrained["fname"], # type: ignore ) ckpt = torch.load(weights, map_location="cpu", weights_only=False) state_dict = ckpt["state_dict"] if "state_dict" in ckpt else ckpt state_dict = remove_criterion_in_state_dict(state_dict) ckpt = remove_prefix_from_state_dict(state_dict, trial_key="norm.bias") self.model.load_state_dict(ckpt, strict=True)
def forward(self, x: torch.Tensor) -> torch.Tensor: if self.mscale: x = self._multi_scale(x) else: x = self.model(x) if self.normalise_features: x = normalise(x) return x @property def feat_dim(self) -> int: return len(self.model.norm.bias) def _multi_scale(self, samples: torch.Tensor) -> torch.Tensor: # code from the original DINO # TODO: check grads later v = torch.zeros((len(samples), self.feat_dim), device=samples.device) scales = [1.0, 1 / 2 ** (1 / 2), 1 / 2] # we use 3 different scales for s in scales: if s == 1: inp = samples.clone() else: inp = nn.functional.interpolate(samples, scale_factor=s, mode="bilinear", align_corners=False) feats = self.model.forward(inp).clone() v += feats v /= len(scales) # v /= v.norm(dim=1) # we don't want to shift the norms values return v
[docs] def draw_attention(self, image: Union[TPILImage, np.ndarray]) -> np.ndarray: """ Args: image: An image with pixel values in the range of ``[0..255]``. Returns: An image with drawn attention maps. Visualization of the multi-head attention on a particular image. """ return vis_vit(vit=self, image=image)
def vis_vit( vit: ViTExtractor, image: Union[TPILImage, np.ndarray], mean: TNormParam = MEAN, std: TNormParam = STD, ) -> np.ndarray: # this is the optional dependency from pytorch_grad_cam.utils.image import show_cam_on_image need_to_convert = not isinstance(image, np.ndarray) if need_to_convert: image = np.asarray(image) patch_size = vit.model.patch_embed.proj.kernel_size[0] img_tensor = get_normalisation_albu(mean=mean, std=std)(image=image)["image"] w = img_tensor.shape[1] - img_tensor.shape[1] % patch_size h = img_tensor.shape[2] - img_tensor.shape[2] % patch_size img_tensor = img_tensor[:, :w, :h].unsqueeze(0) w_feat_map = img_tensor.shape[-2] // patch_size h_feat_map = img_tensor.shape[-1] // patch_size with temporary_setting_model_mode(vit, set_train=False): with torch.no_grad(): attentions = vit.model.get_last_selfattention(img_tensor) nh = attentions.shape[1] attentions = attentions[0, :, 0, 1:].reshape(nh, -1) attentions = attentions.reshape(nh, w_feat_map, h_feat_map) attentions = ( nn.functional.interpolate( attentions.unsqueeze(0), scale_factor=patch_size, mode="nearest", )[0] .cpu() .numpy() ) arr = sum(attentions[i] * 1 / attentions.shape[0] for i in range(attentions.shape[0])) arr = show_cam_on_image(image / image.max(), 0.6 * arr / arr.max()) # type: ignore if need_to_convert: arr = PIL.Image.fromarray(arr) return arr __all__ = ["ViTExtractor", "vis_vit"]