예제 #1
0
def mrf_loss(
    input: torch.Tensor,
    target: torch.Tensor,
    eps: float = 1e-8,
    reduction: str = "mean",
) -> torch.Tensor:
    """Calculates the MRF loss. See :class:`pystiche.ops.MRFOperator` for details.

    Args:
        input: Input tensor.
        target: Target tensor.
        eps: Small value to avoid zero division. Defaults to ``1e-8``.
        reduction: Reduction method of the output passed to
            :func:`pystiche.misc.reduce`. Defaults to ``"mean"``.

    Examples:

        >>> input = torch.rand(256, 64, 3, 3)
        >>> target = torch.rand(256, 64, 3, 3)
        >>> score = F.mrf_loss(input, target)

    """
    with torch.no_grad():
        similarity = pystiche.cosine_similarity(input, target, eps=eps)
        idcs = torch.argmax(similarity, dim=1)
        target = torch.index_select(target, dim=0, index=idcs)
    return mse_loss(input, target, reduction=reduction)
예제 #2
0
def test_cosine_similarity_batched_input():
    torch.manual_seed(0)
    x1 = torch.rand(2, 1, 256)
    x2 = torch.rand(2, 1, 256)
    eps = 1e-6

    actual = pystiche.cosine_similarity(x1, x2, eps=eps, batched_input=True)
    expected = F.cosine_similarity(x1, x2, dim=2, eps=eps).unsqueeze(2)
    ptu.assert_allclose(actual, expected, rtol=1e-6)
예제 #3
0
def test_cosine_similarity():
    torch.manual_seed(0)
    input = torch.rand(1, 256)
    target = torch.rand(1, 256)
    eps = 1e-6

    actual = pystiche.cosine_similarity(input, target, eps=eps)
    expected = F.cosine_similarity(input, target, dim=1, eps=eps).unsqueeze(1)
    ptu.assert_allclose(actual, expected, rtol=1e-6)
예제 #4
0
    def test_shape(self):
        torch.manual_seed(0)
        input = torch.rand(2, 3, 4, 5)
        target = torch.rand(2, 3, 4, 5)

        actual = pystiche.cosine_similarity(input, target,
                                            batched_input=True).size()
        expected = (2, 3, 3)

        assert actual == expected
예제 #5
0
def mrf_loss(
    input: torch.Tensor,
    target: torch.Tensor,
    eps: float = 1e-8,
    reduction: str = "mean",
    batched_input: Optional[bool] = None,
) -> torch.Tensor:
    r"""Calculates the MRF loss. See :class:`pystiche.loss.MRFLoss` for details.

    Args:
        input: Input of shape :math:`B \times S_1 \times N_1 \times \dots \times N_D`.
        target: Target of shape :math:`B \times S_2 \times N_1 \times \dots \times N_D`.
        eps: Small value to avoid zero division. Defaults to ``1e-8``.
        reduction: Reduction method of the output passed to
            :func:`pystiche.misc.reduce`. Defaults to ``"mean"``.
        batched_input: If ``False``, treat the first dimension of the inputs as sample
            dimension, i.e. :math:`S \times N_1 \times \dots \times N_D`. Defaults to
            ``True``. See :func:`pystiche.cosine_similarity` for details.

    Examples:

        >>> import pystiche.loss.functional as F
        >>> input = torch.rand(1, 256, 64, 3, 3)
        >>> target = torch.rand(1, 128, 64, 3, 3)
        >>> score = F.mrf_loss(input, target, batched_input=True)

    """
    if batched_input is None:
        msg = (
            "The default value of batched_input has changed "
            "from False to True in version 1.0.0. "
            "To suppress this warning, pass the wanted behavior explicitly.")
        warnings.warn(msg, UserWarning)
        batched_input = True

    with torch.no_grad():
        similarity = pystiche.cosine_similarity(input,
                                                target,
                                                eps=eps,
                                                batched_input=batched_input)

        index = torch.argmax(similarity, dim=-1)
        index = index.view(*index.shape,
                           *[1] * (target.ndim - index.ndim)).expand(
                               *[-1] * index.ndim, *target.shape[index.ndim:])

        target = torch.gather(target, 1 if batched_input else 0, index)
    return mse_loss(input, target, reduction=reduction)
예제 #6
0
def mrf_loss(
    input: torch.Tensor,
    target: torch.Tensor,
    eps: float = 1e-8,
    reduction: str = "mean",
    batched_input: Optional[bool] = None,
) -> torch.Tensor:
    r"""Calculates the MRF loss. See :class:`pystiche.ops.MRFOperator` for details.

    Args:
        input: Input of shape :math:`S_1 \times N_1 \times \dots \times N_D`.
        target: Target of shape :math:`S_2 \times N_1 \times \dots \times N_D`.
        eps: Small value to avoid zero division. Defaults to ``1e-8``.
        reduction: Reduction method of the output passed to
            :func:`pystiche.misc.reduce`. Defaults to ``"mean"``.
        batched_input: If ``True``, treat the first dimension of the inputs as batch
            dimension, i.e. :math:`B \times S \times N_1 \times \dots \times N_D`.
            Defaults to ``False``. See :func:`pystiche.cosine_similarity` for details.

    Note:
        The default value of ``batched_input`` will change from ``False`` to ``True``
        in the future.

    Examples:

        >>> input = torch.rand(256, 64, 3, 3)
        >>> target = torch.rand(256, 64, 3, 3)
        >>> score = F.mrf_loss(input, target)

    """
    with torch.no_grad():
        similarity = pystiche.cosine_similarity(input,
                                                target,
                                                eps=eps,
                                                batched_input=batched_input)
        idcs = torch.argmax(similarity, dim=1)
        target = torch.index_select(target, dim=0, index=idcs)
    return mse_loss(input, target, reduction=reduction)
예제 #7
0
def test_cosine_similarity_shape():
    torch.manual_seed(0)
    input = torch.rand(2, 3, 4, 5)
    target = torch.rand(2, 3, 4, 5)

    assert pystiche.cosine_similarity(input, target).size() == (2, 2)
예제 #8
0
def test_cosine_similarity_future_warning():
    x1 = torch.empty(1, 2)
    x2 = torch.empty(1, 2)

    with pytest.warns(FutureWarning):
        pystiche.cosine_similarity(x1, x2)
예제 #9
0
    def test_batched_input_not_specified(self):
        input = torch.rand(2, 1, 256)
        target = torch.rand(2, 1, 256)

        with pytest.warns(UserWarning):
            pystiche.cosine_similarity(input, target)