Source code for oml.models.resnet.extractor
from pathlib import Path
from typing import Any, Dict, Optional, Union
import numpy as np
import PIL.Image
import torch
from PIL.Image import Image as TPILImage
from torch import nn
from torchvision.models import resnet18, resnet34, resnet50, resnet101, resnet152
from oml.interfaces.models import IExtractor
from oml.models.resnet.pooling import GEM
from oml.models.utils import (
remove_criterion_in_state_dict,
remove_prefix_from_state_dict,
)
from oml.transforms.images.albumentations import get_normalisation_albu
from oml.utils.io import download_checkpoint
from oml.utils.misc_torch import get_device, normalise
def resnet50_projector() -> nn.Module:
model = resnet50(weights=None, num_classes=128)
model.fc = nn.Sequential(
nn.Linear(model.fc.weight.shape[1], model.fc.weight.shape[1]),
nn.ReLU(),
model.fc,
)
return model
[docs]class ResnetExtractor(IExtractor):
"""
The base class for the extractors that follow ResNet architecture.
"""
constructors = {
"resnet18": resnet18,
"resnet34": resnet34,
"resnet50": resnet50,
"resnet50_projector": resnet50_projector,
"resnet101": resnet101,
"resnet152": resnet152,
}
pretrained_models = {
"resnet50_moco_v2": {
"url": "https://dl.fbaipublicfiles.com/moco/moco_checkpoints/moco_v2_800ep/moco_v2_800ep_pretrain.pth.tar",
"hash": "a04e12f8",
"fname": "moco_v2_800ep_pretrain.pth.tar",
"init_args": {"arch": "resnet50_projector", "remove_fc": True, "normalise_features": False, "gem_p": 5.0},
},
"resnet18_imagenet1k_v1": {
"init_args": {"arch": "resnet18", "remove_fc": True, "normalise_features": False, "gem_p": None}
},
"resnet34_imagenet1k_v1": {
"init_args": {"arch": "resnet34", "remove_fc": True, "normalise_features": False, "gem_p": None}
},
"resnet50_imagenet1k_v1": {
"init_args": {"arch": "resnet50", "remove_fc": True, "normalise_features": False, "gem_p": None},
},
"resnet101_imagenet1k_v1": {
"init_args": {"arch": "resnet101", "remove_fc": True, "normalise_features": False, "gem_p": None},
},
"resnet152_imagenet1k_v1": {
"init_args": {"arch": "resnet152", "remove_fc": True, "normalise_features": False, "gem_p": None},
},
}
[docs] def __init__(
self,
weights: Optional[Union[Path, str]],
arch: str,
gem_p: Optional[float],
remove_fc: bool,
normalise_features: bool,
):
"""
Args:
weights: Path to weights or a special key to download pretrained checkpoint, use ``None`` to randomly
initialize model's weights. You can check the available pretrained checkpoints
in ``self.pretrained_models``.
arch: Different types of ResNet, please, check ``self.constructors``
gem_p: Value of power in `Generalized Mean Pooling` that we use as the replacement for the default one
(if ``gem_p == 1`` or ``None`` it's just a normal average pooling and if ``gem_p -> inf`` it's max-pooling)
remove_fc: Set ``True`` if you want to remove the last fully connected layer. Note, that having this layer
is obligatory for calling ``draw_gradcam()`` method
normalise_features: Set ``True`` to normalise output features
"""
assert arch in self.constructors.keys()
super(ResnetExtractor, self).__init__()
self.arch = arch
self.gem_p = gem_p
self.remove_fc = remove_fc
self.normalise_features = normalise_features
constructor = self.constructors[self.arch]
self.model = constructor()
if gem_p is not None:
self.model.avgpool = GEM(p=gem_p)
if weights is None:
if self.remove_fc:
self.model.fc = nn.Identity()
return
elif weights == "resnet50_moco_v2":
pretrained = self.pretrained_models[weights] # type: ignore
moco_path = download_checkpoint(
url_or_fid=pretrained["url"], # type: ignore
hash_md5=pretrained["hash"], # type: ignore
fname=pretrained["fname"], # type: ignore
)
state_dict = load_moco_weights(moco_path)
elif str(weights).endswith("_imagenet1k_v1"):
state_dict = constructor(weights="IMAGENET1K_V1").state_dict()
else:
state_dict = torch.load(weights, map_location="cpu", weights_only=False)
state_dict = state_dict["state_dict"] if "state_dict" in state_dict.keys() else state_dict
state_dict = remove_criterion_in_state_dict(state_dict) # type: ignore
state_dict = remove_prefix_from_state_dict(state_dict, "layer4.") # type: ignore
if self.remove_fc:
state_dict.pop("fc.weight", None)
state_dict.pop("fc.bias", None)
if arch != "resnet50_projector":
self.model.fc = nn.Identity()
self.model.load_state_dict(state_dict, strict=True)
if self.remove_fc:
self.model.fc = nn.Identity()
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.model(x)
if self.normalise_features:
x = normalise(x)
return x
def get_last_conv_channels(self) -> int:
last_block = self.model.layer4[-1]
if self.arch in ("resnet18", "resnet34"):
n_out_channels = last_block.conv2.out_channels
else:
# resnet50, resnet101, resnet152
n_out_channels = last_block.conv3.out_channels
return n_out_channels
@property
def feat_dim(self) -> int:
if isinstance(self.model.fc, torch.nn.Identity):
return self.get_last_conv_channels()
elif isinstance(self.model.fc, torch.nn.Linear):
return self.model.fc.out_features
else:
# 2-layer mlp case
return self.model.fc[-1].out_features
[docs] def draw_gradcam(self, image: Union[np.ndarray, TPILImage]) -> Union[np.ndarray, TPILImage]:
"""
Args:
image: An image with pixel values in the range of ``[0..255]``.
Returns:
An image with drawn gradients.
Visualization of the gradients on a particular image using `GradCam`_.
.. _GradCam: https://arxiv.org/abs/1610.02391
"""
# this is the optional dependency
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.image import show_cam_on_image
if self.remove_fc:
raise ValueError("This method does not work if there is no FC layer in the model.")
need_to_convert = not isinstance(image, np.ndarray)
if need_to_convert:
image = np.asarray(image)
device = get_device(self.model)
image_tensor = get_normalisation_albu()(image=image)["image"].to(device)
cam = GradCAM(model=self.model, target_layers=[self.model.layer4[-1]], use_cuda=device != "cpu")
gray_image = cam(image_tensor.unsqueeze(0), None)[0]
img_with_grads = show_cam_on_image(image / 255, gray_image)
if need_to_convert:
img_with_grads = PIL.Image.fromarray(img_with_grads)
return img_with_grads
def load_moco_weights(path_to_model: Union[str, Path]) -> Dict[str, Any]:
"""
Args:
path_to_model: Path to model trained using original
code from MoCo repository:
https://github.com/facebookresearch/moco
Returns:
State dict without weights of student
"""
checkpoint = torch.load(path_to_model, map_location="cpu", weights_only=False)
state_dict = checkpoint["state_dict"]
for key in list(state_dict.keys()):
# retain only encoder_q up to before the embedding layer
if key.startswith("module.encoder_q"):
new_key = key[len("module.encoder_q.") :]
state_dict[new_key] = state_dict[key]
del state_dict[key]
return state_dict
__all__ = ["ResnetExtractor"]