Source code for oml.models.vit_dino.extractor
from pathlib import Path
from typing import Optional, Union
import numpy as np
import PIL
import torch
from PIL.Image import Image as TPILImage
from torch import nn
from oml.const import MEAN, STD, STORAGE_CKPTS, TNormParam
from oml.interfaces.models import IExtractor
from oml.models.utils import (
remove_criterion_in_state_dict,
remove_prefix_from_state_dict,
)
from oml.models.vit_dino.external.hubconf import ( # type: ignore
dino_vitb8,
dino_vitb16,
dino_vits8,
dino_vits16,
)
from oml.models.vit_dino.external_v2.hubconf import ( # type: ignore
dinov2_vitb14,
dinov2_vitb14_reg,
dinov2_vitl14,
dinov2_vitl14_reg,
dinov2_vits14,
dinov2_vits14_reg,
)
from oml.transforms.images.albumentations import get_normalisation_albu
from oml.utils.io import download_checkpoint_one_of
from oml.utils.misc_torch import normalise, temporary_setting_model_mode
_FB_URL = "https://dl.fbaipublicfiles.com"
[docs]class ViTExtractor(IExtractor):
"""
The base class for the extractors that follow VisualTransformer architecture.
"""
constructors = {
"vits8": dino_vits8,
"vits16": dino_vits16,
"vitb8": dino_vitb8,
"vitb16": dino_vitb16,
"vits14": dinov2_vits14,
"vitb14": dinov2_vitb14,
"vitl14": dinov2_vitl14,
"vits14_reg": dinov2_vits14_reg,
"vitb14_reg": dinov2_vitb14_reg,
"vitl14_reg": dinov2_vitl14_reg,
}
pretrained_models = {
# checkpoints pretrained in DINO framework on ImageNet by MetaAI
"vits16_dino": {
"url": f"{_FB_URL}/dino/dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth",
"hash": "cf0f22",
"fname": "vits16_dino.ckpt",
"init_args": {"arch": "vits16", "normalise_features": False},
},
"vits8_dino": {
"url": f"{_FB_URL}/dino/dino_deitsmall8_pretrain/dino_deitsmall8_pretrain.pth",
"hash": "230cd5",
"fname": "vits8_dino.ckpt",
"init_args": {"arch": "vits8", "normalise_features": False},
},
"vitb16_dino": {
"url": f"{_FB_URL}/dino/dino_vitbase16_pretrain/dino_vitbase16_pretrain.pth",
"hash": "552daf",
"fname": "vitb16_dino.ckpt",
"init_args": {"arch": "vitb16", "normalise_features": False},
},
"vitb8_dino": {
"url": f"{_FB_URL}/dino/dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth",
"hash": "556550",
"fname": "vitb8_dino.ckpt",
"init_args": {"arch": "vitb8", "normalise_features": False},
},
"vits14_dinov2": {
"url": f"{_FB_URL}/dinov2/dinov2_vits14/dinov2_vits14_pretrain.pth",
"hash": "2e405c",
"fname": "dinov2_vits14.ckpt",
"init_args": {"arch": "vits14", "normalise_features": False},
},
"vits14_reg_dinov2": {
"url": f"{_FB_URL}/dinov2/dinov2_vits14/dinov2_vits14_reg4_pretrain.pth",
"hash": "2a50c5",
"fname": "dinov2_vits14_reg4.ckpt",
"init_args": {"arch": "vits14_reg", "normalise_features": False},
},
"vitb14_dinov2": {
"url": f"{_FB_URL}/dinov2/dinov2_vitb14/dinov2_vitb14_pretrain.pth",
"hash": "8635e7",
"fname": "dinov2_vitb14.ckpt",
"init_args": {"arch": "vitb14", "normalise_features": False},
},
"vitb14_reg_dinov2": {
"url": f"{_FB_URL}/dinov2/dinov2_vitb14/dinov2_vitb14_reg4_pretrain.pth",
"hash": "13d13c",
"fname": "dinov2_vitb14_reg4.ckpt",
"init_args": {"arch": "vitb14_reg", "normalise_features": False},
},
"vitl14_dinov2": {
"url": f"{_FB_URL}/dinov2/dinov2_vitl14/dinov2_vitl14_pretrain.pth",
"hash": "19a02c",
"fname": "dinov2_vitl14.ckpt",
"init_args": {"arch": "vitl14", "normalise_features": False},
},
"vitl14_reg_dinov2": {
"url": f"{_FB_URL}/dinov2/dinov2_vitl14/dinov2_vitl14_reg4_pretrain.pth",
"hash": "8b6364",
"fname": "dinov2_vitl14_reg4.ckpt",
"init_args": {"arch": "vitl14_reg", "normalise_features": False},
},
# our pretrained checkpoints
"vits16_inshop": {
"url": [
f"{STORAGE_CKPTS}/inshop/vits16_inshop_a76b85.ckpt",
"1niX-TC8cj6j369t7iU2baHQSVN3MVJbW",
],
"hash": "a76b85",
"fname": "vits16_inshop.ckpt",
"init_args": {"arch": "vits16", "normalise_features": False},
},
"vits16_sop": {
"url": [
f"{STORAGE_CKPTS}/sop/vits16_sop_21e743.ckpt",
"1zuGRHvF2KHd59aw7i7367OH_tQNOGz7A",
],
"hash": "21e743",
"fname": "vits16_sop.ckpt",
"init_args": {"arch": "vits16", "normalise_features": True},
},
"vits16_cub": {
"url": [
f"{STORAGE_CKPTS}/cub/vits16_cub.ckpt",
"1p2tUosFpGXh5sCCdzlXtjV87kCDfG34G",
],
"hash": "e82633",
"fname": "vits16_cub.ckpt",
"init_args": {"arch": "vits16", "normalise_features": False},
},
"vits16_cars": {
"url": [
f"{STORAGE_CKPTS}/cars/vits16_cars.ckpt",
"1hcOxDRRXrKr6ZTCyBauaY8Ue-pok4Icg",
],
"hash": "9f1e59",
"fname": "vits16_cars.ckpt",
"init_args": {"arch": "vits16", "normalise_features": False},
},
}
[docs] def __init__(
self,
weights: Optional[Union[Path, str]],
arch: str,
normalise_features: bool,
use_multi_scale: bool = False,
):
"""
Args:
weights: Path to weights or a special key to download pretrained checkpoint, use ``None`` to
randomly initialize model's weights. You can check the available pretrained checkpoints
in ``self.pretrained_models``.
arch: Might be one of ``vits8``, ``vits16``, ``vitb8``, ``vitb16``. You can check all the available options
in ``self.constructors``
normalise_features: Set ``True`` to normalise output features
use_multi_scale: Set ``True`` to use multiscale (the analogue of test time augmentations)
"""
assert arch in self.constructors
super(ViTExtractor, self).__init__()
self.normalise_features = normalise_features
self.mscale = use_multi_scale
self.arch = arch
factory_fun = self.constructors[self.arch]
self.model = factory_fun(pretrained=False)
if weights is None:
return
if weights in self.pretrained_models:
pretrained = self.pretrained_models[weights] # type: ignore
weights = download_checkpoint_one_of(
url_or_fid_list=pretrained["url"], # type: ignore
hash_md5=pretrained["hash"], # type: ignore
fname=pretrained["fname"], # type: ignore
)
ckpt = torch.load(weights, map_location="cpu", weights_only=False)
state_dict = ckpt["state_dict"] if "state_dict" in ckpt else ckpt
state_dict = remove_criterion_in_state_dict(state_dict)
ckpt = remove_prefix_from_state_dict(state_dict, trial_key="norm.bias")
self.model.load_state_dict(ckpt, strict=True)
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.mscale:
x = self._multi_scale(x)
else:
x = self.model(x)
if self.normalise_features:
x = normalise(x)
return x
@property
def feat_dim(self) -> int:
return len(self.model.norm.bias)
def _multi_scale(self, samples: torch.Tensor) -> torch.Tensor:
# code from the original DINO
# TODO: check grads later
v = torch.zeros((len(samples), self.feat_dim), device=samples.device)
scales = [1.0, 1 / 2 ** (1 / 2), 1 / 2] # we use 3 different scales
for s in scales:
if s == 1:
inp = samples.clone()
else:
inp = nn.functional.interpolate(samples, scale_factor=s, mode="bilinear", align_corners=False)
feats = self.model.forward(inp).clone()
v += feats
v /= len(scales)
# v /= v.norm(dim=1) # we don't want to shift the norms values
return v
[docs] def draw_attention(self, image: Union[TPILImage, np.ndarray]) -> np.ndarray:
"""
Args:
image: An image with pixel values in the range of ``[0..255]``.
Returns:
An image with drawn attention maps.
Visualization of the multi-head attention on a particular image.
"""
return vis_vit(vit=self, image=image)
def vis_vit(
vit: ViTExtractor,
image: Union[TPILImage, np.ndarray],
mean: TNormParam = MEAN,
std: TNormParam = STD,
) -> np.ndarray:
# this is the optional dependency
from pytorch_grad_cam.utils.image import show_cam_on_image
need_to_convert = not isinstance(image, np.ndarray)
if need_to_convert:
image = np.asarray(image)
patch_size = vit.model.patch_embed.proj.kernel_size[0]
img_tensor = get_normalisation_albu(mean=mean, std=std)(image=image)["image"]
w = img_tensor.shape[1] - img_tensor.shape[1] % patch_size
h = img_tensor.shape[2] - img_tensor.shape[2] % patch_size
img_tensor = img_tensor[:, :w, :h].unsqueeze(0)
w_feat_map = img_tensor.shape[-2] // patch_size
h_feat_map = img_tensor.shape[-1] // patch_size
with temporary_setting_model_mode(vit, set_train=False):
with torch.no_grad():
attentions = vit.model.get_last_selfattention(img_tensor)
nh = attentions.shape[1]
attentions = attentions[0, :, 0, 1:].reshape(nh, -1)
attentions = attentions.reshape(nh, w_feat_map, h_feat_map)
attentions = (
nn.functional.interpolate(
attentions.unsqueeze(0),
scale_factor=patch_size,
mode="nearest",
)[0]
.cpu()
.numpy()
)
arr = sum(attentions[i] * 1 / attentions.shape[0] for i in range(attentions.shape[0]))
arr = show_cam_on_image(image / image.max(), 0.6 * arr / arr.max()) # type: ignore
if need_to_convert:
arr = PIL.Image.fromarray(arr)
return arr
__all__ = ["ViTExtractor", "vis_vit"]