def mrf_loss( input: torch.Tensor, target: torch.Tensor, eps: float = 1e-8, reduction: str = "mean", ) -> torch.Tensor: """Calculates the MRF loss. See :class:`pystiche.ops.MRFOperator` for details. Args: input: Input tensor. target: Target tensor. eps: Small value to avoid zero division. Defaults to ``1e-8``. reduction: Reduction method of the output passed to :func:`pystiche.misc.reduce`. Defaults to ``"mean"``. Examples: >>> input = torch.rand(256, 64, 3, 3) >>> target = torch.rand(256, 64, 3, 3) >>> score = F.mrf_loss(input, target) """ with torch.no_grad(): similarity = pystiche.cosine_similarity(input, target, eps=eps) idcs = torch.argmax(similarity, dim=1) target = torch.index_select(target, dim=0, index=idcs) return mse_loss(input, target, reduction=reduction)
def test_cosine_similarity_batched_input(): torch.manual_seed(0) x1 = torch.rand(2, 1, 256) x2 = torch.rand(2, 1, 256) eps = 1e-6 actual = pystiche.cosine_similarity(x1, x2, eps=eps, batched_input=True) expected = F.cosine_similarity(x1, x2, dim=2, eps=eps).unsqueeze(2) ptu.assert_allclose(actual, expected, rtol=1e-6)
def test_cosine_similarity(): torch.manual_seed(0) input = torch.rand(1, 256) target = torch.rand(1, 256) eps = 1e-6 actual = pystiche.cosine_similarity(input, target, eps=eps) expected = F.cosine_similarity(input, target, dim=1, eps=eps).unsqueeze(1) ptu.assert_allclose(actual, expected, rtol=1e-6)
def test_shape(self): torch.manual_seed(0) input = torch.rand(2, 3, 4, 5) target = torch.rand(2, 3, 4, 5) actual = pystiche.cosine_similarity(input, target, batched_input=True).size() expected = (2, 3, 3) assert actual == expected
def mrf_loss( input: torch.Tensor, target: torch.Tensor, eps: float = 1e-8, reduction: str = "mean", batched_input: Optional[bool] = None, ) -> torch.Tensor: r"""Calculates the MRF loss. See :class:`pystiche.loss.MRFLoss` for details. Args: input: Input of shape :math:`B \times S_1 \times N_1 \times \dots \times N_D`. target: Target of shape :math:`B \times S_2 \times N_1 \times \dots \times N_D`. eps: Small value to avoid zero division. Defaults to ``1e-8``. reduction: Reduction method of the output passed to :func:`pystiche.misc.reduce`. Defaults to ``"mean"``. batched_input: If ``False``, treat the first dimension of the inputs as sample dimension, i.e. :math:`S \times N_1 \times \dots \times N_D`. Defaults to ``True``. See :func:`pystiche.cosine_similarity` for details. Examples: >>> import pystiche.loss.functional as F >>> input = torch.rand(1, 256, 64, 3, 3) >>> target = torch.rand(1, 128, 64, 3, 3) >>> score = F.mrf_loss(input, target, batched_input=True) """ if batched_input is None: msg = ( "The default value of batched_input has changed " "from False to True in version 1.0.0. " "To suppress this warning, pass the wanted behavior explicitly.") warnings.warn(msg, UserWarning) batched_input = True with torch.no_grad(): similarity = pystiche.cosine_similarity(input, target, eps=eps, batched_input=batched_input) index = torch.argmax(similarity, dim=-1) index = index.view(*index.shape, *[1] * (target.ndim - index.ndim)).expand( *[-1] * index.ndim, *target.shape[index.ndim:]) target = torch.gather(target, 1 if batched_input else 0, index) return mse_loss(input, target, reduction=reduction)
def mrf_loss( input: torch.Tensor, target: torch.Tensor, eps: float = 1e-8, reduction: str = "mean", batched_input: Optional[bool] = None, ) -> torch.Tensor: r"""Calculates the MRF loss. See :class:`pystiche.ops.MRFOperator` for details. Args: input: Input of shape :math:`S_1 \times N_1 \times \dots \times N_D`. target: Target of shape :math:`S_2 \times N_1 \times \dots \times N_D`. eps: Small value to avoid zero division. Defaults to ``1e-8``. reduction: Reduction method of the output passed to :func:`pystiche.misc.reduce`. Defaults to ``"mean"``. batched_input: If ``True``, treat the first dimension of the inputs as batch dimension, i.e. :math:`B \times S \times N_1 \times \dots \times N_D`. Defaults to ``False``. See :func:`pystiche.cosine_similarity` for details. Note: The default value of ``batched_input`` will change from ``False`` to ``True`` in the future. Examples: >>> input = torch.rand(256, 64, 3, 3) >>> target = torch.rand(256, 64, 3, 3) >>> score = F.mrf_loss(input, target) """ with torch.no_grad(): similarity = pystiche.cosine_similarity(input, target, eps=eps, batched_input=batched_input) idcs = torch.argmax(similarity, dim=1) target = torch.index_select(target, dim=0, index=idcs) return mse_loss(input, target, reduction=reduction)
def test_cosine_similarity_shape(): torch.manual_seed(0) input = torch.rand(2, 3, 4, 5) target = torch.rand(2, 3, 4, 5) assert pystiche.cosine_similarity(input, target).size() == (2, 2)
def test_cosine_similarity_future_warning(): x1 = torch.empty(1, 2) x2 = torch.empty(1, 2) with pytest.warns(FutureWarning): pystiche.cosine_similarity(x1, x2)
def test_batched_input_not_specified(self): input = torch.rand(2, 1, 256) target = torch.rand(2, 1, 256) with pytest.warns(UserWarning): pystiche.cosine_similarity(input, target)