def gs_div( input: torch.Tensor, target: torch.Tensor, alpha: float = -1, lmd: float = 0.5, reduction: Optional[str] = 'sum', ) -> torch.Tensor: r"""The $\alpha$-geodesical skew divergence. Args: input: Tensor of arbitrary shape target: Tensor of the same shape as input alpha: Specifies the coordinate systems which equiped the geodesical skew divergence lmd: Specifies the position on the geodesic reduction (string, optional): Specifies the reduction to apply to the output: ``'none'`` | ``'batchmean'`` | ``'sum'`` | ``'mean'``. ``'none'``: no reduction will be applied ``'batchmean``': the sum of the output will be divided by the batchsize ``'sum'``: the output will be summed ``'mean'``: the output will be divided by the number of elements in the output Default: ``'sum'`` """ assert lmd >= 0 and lmd <= 1 skew_target = alpha_geodesic(input, target, alpha=alpha, lmd=lmd) div = input * torch.log(input / skew_target + 1e-12) if reduction == 'batchmean': div = div.sum() / input.size()[0] elif reduction == 'sum': div = div.sum() elif reduction == 'mean': div = div.mean() return div
def test_value_0_2d(self): a = torch.Tensor([[0.1, 0.2, 0.7], [0.5, 0.5, 0.0]]) b = torch.Tensor([[0.4, 0.4, 0.2], [0.2, 0.1, 0.7]]) g = alpha_geodesic(a, b, alpha=1, lmd=0.5) self.assertTrue(torch.isinf(g).sum() == 0)
def test_grad(self): a = torch.tensor([[0.1, 0.2, 0.7], [0.5, 0.5, 0.0]], requires_grad=True) b = torch.tensor([[0.4, 0.4, 0.2], [0.2, 0.1, 0.7]]) g = alpha_geodesic(a, b, alpha=1, lmd=0.5) self.assertIsNotNone(g.grad_fn)
def test_alpha_minus_inf(self): a = torch.Tensor([1, 2, 3]) b = torch.Tensor([4, 5, 6]) g = alpha_geodesic(a, b, alpha=-float('inf'), lmd=0.5) res = torch.max(a, b) self.assertTrue(torch.equal(g, res))
def test_alpha_3(self): a = torch.Tensor([1, 2, 3]) b = torch.Tensor([4, 5, 6]) g = alpha_geodesic(a, b, alpha=3, lmd=0.5) res = 1 / (0.5 * 1/a + 0.5 * 1/b) self.assertTrue(torch.equal(g, res))
def test_alpha_0(self): a = torch.Tensor([1, 2, 3]) b = torch.Tensor([4, 5, 6]) g = alpha_geodesic(a, b, alpha=0, lmd=0.5) res = (0.5 * torch.sqrt(a) + 0.5 * torch.sqrt(b))**2 self.assertTrue(torch.equal(g, res))
def test_alpha_1(self): a = torch.Tensor([1, 2, 3]) b = torch.Tensor([4, 5, 6]) g = alpha_geodesic(a, b, alpha=1, lmd=0.5) res = torch.exp(0.5 * torch.log(a) + 0.5 * torch.log(b)) self.assertTrue(torch.equal(g, res))
def test_value_inf(self): a = torch.Tensor([[0.1, 0.2, 0.7], [0.5, 0.5, 0.0]]) b = torch.Tensor([[0.4, 0.4, 0.2], [0.2, 0.1, 0.7]]) g = alpha_geodesic(a, b, alpha=100, lmd=0.5) res = torch.min(a, b) self.assertTrue(torch.all(torch.isclose(g, res)))
def test_alpha_minus_1(self): a = torch.Tensor([1, 2, 3]) b = torch.Tensor([4, 5, 6]) g = alpha_geodesic(a, b, alpha=-1, lmd=0.5) self.assertTrue(torch.equal(g, ((a+b) / 2)))
def test_value_0(self): a = torch.Tensor([0, 1, 2]) b = torch.Tensor([1, 2, 3]) g = alpha_geodesic(a, b, alpha=-1, lmd=0.5) self.assertTrue(torch.isinf(g).sum() == 0)