Beispiel #1
0
def total_variation_loss(input: torch.Tensor,
                         exponent: float = 2.0,
                         reduction: str = "mean") -> torch.Tensor:
    r"""Calculates the total variation loss. See
    :class:`pystiche.ops.TotalVariationOperator` for details.

    Args:
        input: Input image
        exponent: Parameter :math:`\beta` . A higher value leads to more smoothed
            results. Defaults to ``2.0``.
        reduction: Reduction method of the output passed to
            :func:`pystiche.misc.reduce`. Defaults to ``"mean"``.

    Examples:

        >>> import pystiche.loss.functional as F
        >>> input = torch.rand(2, 3, 256, 256)
        >>> score = F.total_variation_loss(input)
    """
    # this ignores the last row and column of the image
    grad_vert = input[:, :, :-1, :-1] - input[:, :, 1:, :-1]
    grad_horz = input[:, :, :-1, :-1] - input[:, :, :-1, 1:]
    grad = pystiche.nonnegsqrt(grad_vert**2.0 + grad_horz**2.0)
    loss = grad**exponent
    return reduce(loss, reduction)
Beispiel #2
0
def test_nonnegsqrt():
    vals = (-1.0, 0.0, 1.0, 2.0)
    desireds = (0.0, 0.0, 1.0, sqrt(2.0))

    for val, desired in zip(vals, desireds):
        x = torch.tensor(val, requires_grad=True)
        y = pystiche.nonnegsqrt(x)

        assert y == ptu.approx(desired)
Beispiel #3
0
    def test_main(self):
        vals = (-1.0, 0.0, 1.0, 2.0)
        desireds = (0.0, 0.0, 1.0, sqrt(2.0))

        for val, desired in zip(vals, desireds):
            x = torch.tensor(val)
            y = pystiche.nonnegsqrt(x)

            assert y == ptu.approx(desired)
Beispiel #4
0
def test_nonnegsqrt_grad():
    vals = (-1.0, 0.0, 1.0, 2.0)
    desireds = (0.0, 0.0, 1.0 / 2.0, 1.0 / (2.0 * sqrt(2.0)))

    for val, desired in zip(vals, desireds):
        x = torch.tensor(val, requires_grad=True)
        y = pystiche.nonnegsqrt(x)
        y.backward()

        assert x.grad == ptu.approx(desired)
Beispiel #5
0
    def test_nonnegsqrt(self):
        vals = (-1.0, 0.0, 1.0, 2.0)
        desireds = (0.0, 0.0, 1.0, sqrt(2.0))

        for val, desired in zip(vals, desireds):
            x = torch.tensor(val, requires_grad=True)
            y = pystiche.nonnegsqrt(x)

            actual = y.item()
            self.assertAlmostEqual(actual, desired)