from pathlib import Path
from typing import Any, Dict, List, Optional, Union
import torch
from torch import Tensor, nn
from torch.nn.modules.activation import Sigmoid
from torchvision.ops import MLP
from oml.interfaces.models import IExtractor, IFreezable, IPairwiseModel
from oml.models.utils import remove_prefix_from_state_dict
from oml.utils.io import download_checkpoint
from oml.utils.misc_torch import elementwise_dist
[docs]class LinearTrivialDistanceSiamese(IPairwiseModel):
"""
This model is a useful tool mostly for development.
"""
[docs] def __init__(self, feat_dim: int, identity_init: bool, output_bias: float = 0):
"""
Args:
feat_dim: Expected size of each input.
identity_init: If ``True``, models' weights initialised in a way when
the model simply estimates L2 distance between the original embeddings.
output_bias: Value to add to the output.
"""
super(LinearTrivialDistanceSiamese, self).__init__()
self.feat_dim = feat_dim
self.output_bias = output_bias
self.proj = torch.nn.Linear(in_features=feat_dim, out_features=feat_dim, bias=False)
if identity_init:
self.proj.load_state_dict({"weight": torch.eye(feat_dim)})
[docs] def forward(self, x1: Tensor, x2: Tensor) -> Tensor:
"""
Args:
x1: Embedding with the shape of ``[batch_size, feat_dim]``
x2: Embedding with the shape of ``[batch_size, feat_dim]``
Returns:
Distance between transformed inputs.
"""
x1 = self.proj(x1)
x2 = self.proj(x2)
y = elementwise_dist(x1, x2, p=2) + self.output_bias
return y
def predict(self, x1: Tensor, x2: Tensor) -> Tensor:
return self.forward(x1=x1, x2=x2)
[docs]class ConcatSiamese(IPairwiseModel, IFreezable):
"""
This model concatenates two inputs and passes them through
a given backbone and applyies a head after that.
"""
pretrained_models: Dict[str, Any] = {}
[docs] def __init__(
self,
extractor: IExtractor,
mlp_hidden_dims: List[int],
use_tta: bool = False,
weights: Optional[Union[str, Path]] = None,
) -> None:
"""
Args:
extractor: Instance of ``IExtractor`` (e.g. ``ViTExtractor``)
mlp_hidden_dims: Hidden dimensions of the head
use_tta: Set ``True`` if you want to average the results obtained by two different orders of concatenating
input images. Affects only ``self.predict()`` method.
weights: Path to weights file or ``None`` for random initialization
"""
super(ConcatSiamese, self).__init__()
self.extractor = extractor
self.use_tta = use_tta
self.head = MLP(
in_channels=self.extractor.feat_dim,
hidden_channels=[*mlp_hidden_dims, 1],
activation_layer=Sigmoid,
dropout=0.5,
inplace=None,
)
# turn off the last bias
self.head[-2] = nn.Linear(self.head[-2].in_features, self.head[-2].out_features, bias=False)
# turn off the last dropout
self.head[-1] = nn.Identity()
self.train_backbone = True
if weights:
if weights in self.pretrained_models:
url_or_fid, hash_md5, fname = self.pretrained_models[weights] # type: ignore
weights = download_checkpoint(url_or_fid=url_or_fid, hash_md5=hash_md5, fname=fname)
loaded = torch.load(weights, map_location="cpu", weights_only=False)
loaded = loaded.get("state_dict", loaded)
loaded = remove_prefix_from_state_dict(loaded, trial_key="extractor.")
self.load_state_dict(loaded, strict=True)
[docs] def forward(self, x1: Tensor, x2: Tensor) -> Tensor:
x = torch.concat([x1, x2], dim=2)
with torch.set_grad_enabled(self.train_backbone):
x = self.extractor(x)
x = self.head(x)
x = x.view(len(x))
return x
[docs] def predict(self, x1: Tensor, x2: Tensor) -> Tensor:
x = self.forward(x1=x1, x2=x2)
x = torch.sigmoid(x)
if self.use_tta:
y = self.forward(x1=x2, x2=x1)
y = torch.sigmoid(y)
return (x + y) / 2
else:
return x
[docs] def freeze(self) -> None:
self.train_backbone = False
[docs] def unfreeze(self) -> None:
self.train_backbone = True
[docs]class TrivialDistanceSiamese(IPairwiseModel):
"""
This model is a useful tool mostly for development.
"""
pretrained_models: Dict[str, Any] = {}
[docs] def __init__(self, extractor: IExtractor, output_bias: float = 0) -> None:
"""
Args:
extractor: Instance of ``IExtractor`` (e.g. ``ViTExtractor``)
output_bias: Value to add to the outputs.
"""
super(TrivialDistanceSiamese, self).__init__()
self.extractor = extractor
self.output_bias = output_bias
[docs] def forward(self, x1: Tensor, x2: Tensor) -> Tensor:
"""
Args:
x1: The first input.
x2: The second input.
Returns:
Distance between inputs.
"""
x1 = self.extractor(x1)
x2 = self.extractor(x2)
return elementwise_dist(x1, x2, p=2) + self.output_bias
def predict(self, x1: Tensor, x2: Tensor) -> Tensor:
return self.forward(x1=x1, x2=x2)
__all__ = ["LinearTrivialDistanceSiamese", "ConcatSiamese", "TrivialDistanceSiamese"]