Source code for oml.interfaces.criterions
from typing import List, Union
from torch import Tensor
from torch.nn import Module
[docs]class ITripletLossWithMiner(Module):
"""
Base class for TripletLoss combined with Miner.
"""
[docs] def forward(self, features: Tensor, labels: Union[Tensor, List[int]]) -> Tensor:
"""
Args:
features: Features with the shape ``[batch_size, features_dim]``
labels: Labels with the size of ``batch_size``
Returns:
Loss value
"""
raise NotImplementedError()
__all__ = ["ITripletLossWithMiner"]