Beispiel #1
0
def gs_div(
    input: torch.Tensor,
    target: torch.Tensor,
    alpha: float = -1,
    lmd: float = 0.5,
    reduction: Optional[str] = 'sum',
) -> torch.Tensor:
    r"""The $\alpha$-geodesical skew divergence.

    Args:
        input: Tensor of arbitrary shape
        target: Tensor of the same shape as input
        alpha: Specifies the coordinate systems which equiped the geodesical skew divergence
        lmd: Specifies the position on the geodesic
        reduction (string, optional): Specifies the reduction to apply to the output:
            ``'none'`` | ``'batchmean'`` | ``'sum'`` | ``'mean'``.
            ``'none'``: no reduction will be applied
            ``'batchmean``': the sum of the output will be divided by the batchsize
            ``'sum'``: the output will be summed
            ``'mean'``: the output will be divided by the number of elements in the output
            Default: ``'sum'``
    """

    assert lmd >= 0 and lmd <= 1

    skew_target = alpha_geodesic(input, target, alpha=alpha, lmd=lmd)
    div = input * torch.log(input / skew_target + 1e-12)
    if reduction == 'batchmean':
        div = div.sum() / input.size()[0]
    elif reduction == 'sum':
        div = div.sum()
    elif reduction == 'mean':
        div = div.mean()

    return div
    def test_value_0_2d(self):
        a = torch.Tensor([[0.1, 0.2, 0.7], [0.5, 0.5, 0.0]])
        b = torch.Tensor([[0.4, 0.4, 0.2], [0.2, 0.1, 0.7]])

        g = alpha_geodesic(a, b, alpha=1, lmd=0.5)

        self.assertTrue(torch.isinf(g).sum() == 0)
    def test_grad(self):
        a = torch.tensor([[0.1, 0.2, 0.7], [0.5, 0.5, 0.0]], requires_grad=True)
        b = torch.tensor([[0.4, 0.4, 0.2], [0.2, 0.1, 0.7]])

        g = alpha_geodesic(a, b, alpha=1, lmd=0.5)

        self.assertIsNotNone(g.grad_fn)
    def test_alpha_minus_inf(self):
        a = torch.Tensor([1, 2, 3])
        b = torch.Tensor([4, 5, 6])
        g = alpha_geodesic(a, b, alpha=-float('inf'), lmd=0.5)
        res = torch.max(a, b)

        self.assertTrue(torch.equal(g, res))
    def test_alpha_3(self):
        a = torch.Tensor([1, 2, 3])
        b = torch.Tensor([4, 5, 6])
        g = alpha_geodesic(a, b, alpha=3, lmd=0.5)
        res = 1 / (0.5 * 1/a + 0.5 * 1/b)

        self.assertTrue(torch.equal(g, res))
    def test_alpha_0(self):
        a = torch.Tensor([1, 2, 3])
        b = torch.Tensor([4, 5, 6])
        g = alpha_geodesic(a, b, alpha=0, lmd=0.5)
        res = (0.5 * torch.sqrt(a) + 0.5 * torch.sqrt(b))**2

        self.assertTrue(torch.equal(g, res))
    def test_alpha_1(self):
        a = torch.Tensor([1, 2, 3])
        b = torch.Tensor([4, 5, 6])
        g = alpha_geodesic(a, b, alpha=1, lmd=0.5)
        res = torch.exp(0.5 * torch.log(a) + 0.5 * torch.log(b))

        self.assertTrue(torch.equal(g, res))
    def test_value_inf(self):
        a = torch.Tensor([[0.1, 0.2, 0.7], [0.5, 0.5, 0.0]])
        b = torch.Tensor([[0.4, 0.4, 0.2], [0.2, 0.1, 0.7]])

        g = alpha_geodesic(a, b, alpha=100, lmd=0.5)
        res = torch.min(a, b)

        self.assertTrue(torch.all(torch.isclose(g, res)))
    def test_alpha_minus_1(self):
        a = torch.Tensor([1, 2, 3])
        b = torch.Tensor([4, 5, 6])
        g = alpha_geodesic(a, b, alpha=-1, lmd=0.5)

        self.assertTrue(torch.equal(g, ((a+b) / 2)))
    def test_value_0(self):
        a = torch.Tensor([0, 1, 2])
        b = torch.Tensor([1, 2, 3])
        g = alpha_geodesic(a, b, alpha=-1, lmd=0.5)

        self.assertTrue(torch.isinf(g).sum() == 0)