def test_batch_separate(self):
        a = torch.tensor([[4, 2, 8], [1, 2, 3]], dtype=torch.float).view(2, 3, 1)
        b = torch.tensor([[0, 2, 1], [-1, 2, 0]], dtype=torch.float).view(2, 3, 1)
        means = torch.tensor([[1, 2], [2, 3]], dtype=torch.float).view(2, 2, 1, 1)
        scales = torch.tensor([[0.5, 0.25], [0.25, 1]], dtype=torch.float).view(2, 2, 1, 1)
        weights = torch.tensor([[4, 2], [1, 2]], dtype=torch.float).view(2, 2)
        kernel = SpectralMixtureKernel(batch_shape=torch.Size([2]), num_mixtures=2)
        kernel.initialize(mixture_weights=weights, mixture_means=means, mixture_scales=scales)
        kernel.eval()

        actual = torch.zeros(2, 3, 3)
        for l in range(2):
            for k in range(2):
                for i in range(3):
                    for j in range(3):
                        new_term = torch.cos(2 * math.pi * (a[l, i] - b[l, j]) * means[l, k])
                        new_term *= torch.exp(-2 * (math.pi * (a[l, i] - b[l, j])) ** 2 * scales[l, k] ** 2)
                        new_term *= weights[l, k]
                        actual[l, i, j] += new_term.item()

        res = kernel(a, b).evaluate()
        self.assertLess(torch.norm(res - actual), 1e-5)

        # diag
        res = kernel(a, b).diag()
        actual = torch.cat([actual[i].diag().unsqueeze(0) for i in range(actual.size(0))])
        self.assertLess(torch.norm(res - actual), 1e-5)
    def test_standard(self):
        a = torch.tensor([4, 2, 8], dtype=torch.float).view(3, 1)
        b = torch.tensor([0, 2], dtype=torch.float).view(2, 1)
        means = [1, 2]
        scales = [0.5, 0.25]
        weights = [4, 2]
        kernel = SpectralMixtureKernel(num_mixtures=2)
        kernel.initialize(
            log_mixture_weights=torch.tensor([[4, 2]],
                                             dtype=torch.float).log(),
            log_mixture_means=torch.tensor([[[[1]], [[2]]]],
                                           dtype=torch.float).log(),
            log_mixture_scales=torch.tensor([[[[0.5]], [[0.25]]]],
                                            dtype=torch.float).log(),
        )
        kernel.eval()

        actual = torch.zeros(3, 2)
        for i in range(3):
            for j in range(2):
                for k in range(2):
                    new_term = torch.cos(2 * math.pi * (a[i] - b[j]) *
                                         means[k])
                    new_term *= torch.exp(-2 * (math.pi * (a[i] - b[j]))**2 *
                                          scales[k]**2)
                    new_term *= weights[k]
                    actual[i, j] += new_term.item()

        res = kernel(a, b).evaluate()
        self.assertLess(torch.norm(res - actual), 1e-5)
Example #3
0
    def test_standard(self):
        a = torch.tensor([[0, 1], [2, 2], [2, 0]], dtype=torch.float)
        means = torch.tensor([[1, 2], [2, 1]], dtype=torch.float)
        scales = torch.tensor([[0.5, 0.25], [0.25, 0.5]], dtype=torch.float)
        scales = scales.unsqueeze(1)
        means = means.unsqueeze(1)
        weights = torch.tensor([4, 2], dtype=torch.float)
        kernel = SpectralMixtureKernel(num_mixtures=2, ard_num_dims=2)
        kernel.initialize(
            mixture_weights=weights,
            mixture_means=means,
            mixture_scales=scales,
        )
        kernel.eval()

        actual = torch.zeros(2, 3, 3, 2)
        for i in range(3):
            for j in range(3):
                for k in range(2):
                    new_term = torch.cos(2 * math.pi * (a[i] - a[j]) *
                                         means[k])
                    new_term *= torch.exp(-2 * (math.pi * (a[i] - a[j]))**2 *
                                          scales[k]**2)
                    actual[k, i, j] = new_term
        actual = actual.prod(-1)
        actual[0].mul_(weights[0])
        actual[1].mul_(weights[1])
        actual = actual.sum(0)

        res = kernel(a, a).evaluate()
        self.assertLess(torch.norm(res - actual), 1e-5)

        # diag
        res = kernel(a, a).diag()
        actual = actual.diag()
        self.assertLess(torch.norm(res - actual), 1e-5)

        # batch_dims
        actual = torch.zeros(2, 3, 3, 2)
        for i in range(3):
            for j in range(3):
                for k in range(2):
                    new_term = torch.cos(2 * math.pi * (a[i] - a[j]) *
                                         means[k])
                    new_term *= torch.exp(-2 * (math.pi * (a[i] - a[j]))**2 *
                                          scales[k]**2)
                    actual[k, i, j] = new_term
        actual[0].mul_(weights[0])
        actual[1].mul_(weights[1])
        actual = actual.sum(0)
        actual = actual.permute(2, 0, 1)
        res = kernel(a, a, batch_dims=(0, 2)).evaluate()
        self.assertLess(torch.norm(res - actual), 1e-5)

        # batch_dims + diag
        res = kernel(a, a, batch_dims=(0, 2)).diag()
        actual = torch.cat(
            [actual[i].diag().unsqueeze(0) for i in range(actual.size(0))])
        self.assertLess(torch.norm(res - actual), 1e-5)