Пример #1
0
    def test_batch_diag(self):
        left_interp_indices = torch.LongTensor([[2, 3], [3, 4],
                                                [4, 5]]).repeat(5, 1, 1)
        left_interp_values = torch.tensor([[1, 1], [1, 1], [1, 1]],
                                          dtype=torch.float).repeat(5, 1, 1)
        right_interp_indices = torch.LongTensor([[0, 1], [1, 2],
                                                 [2, 3]]).repeat(5, 1, 1)
        right_interp_values = torch.tensor([[1, 1], [1, 1], [1, 1]],
                                           dtype=torch.float).repeat(5, 1, 1)

        base_lazy_tensor_mat = torch.randn(5, 6, 6)
        base_lazy_tensor_mat = base_lazy_tensor_mat.transpose(
            1, 2).matmul(base_lazy_tensor_mat)
        base_lazy_tensor_mat.requires_grad = True

        base_lazy_tensor = NonLazyTensor(base_lazy_tensor_mat)
        interp_lazy_tensor = InterpolatedLazyTensor(base_lazy_tensor,
                                                    left_interp_indices,
                                                    left_interp_values,
                                                    right_interp_indices,
                                                    right_interp_values)

        actual = interp_lazy_tensor.evaluate()
        actual_diag = torch.stack([
            actual[0].diag(), actual[1].diag(), actual[2].diag(),
            actual[3].diag(), actual[4].diag()
        ])

        self.assertTrue(approx_equal(actual_diag, interp_lazy_tensor.diag()))
Пример #2
0
    def test_batch_sample(self):
        left_interp_indices = torch.LongTensor([[2, 3], [3, 4],
                                                [4, 5]]).repeat(5, 1, 1)
        left_interp_values = torch.tensor([[1, 1], [1, 1], [1, 1]],
                                          dtype=torch.float).repeat(5, 1, 1)

        base_lazy_tensor_mat = torch.randn(5, 6, 6)
        base_lazy_tensor_mat = base_lazy_tensor_mat.transpose(
            1, 2).matmul(base_lazy_tensor_mat)
        base_lazy_tensor_mat.requires_grad = True

        base_lazy_tensor = NonLazyTensor(base_lazy_tensor_mat)
        interp_lazy_tensor = InterpolatedLazyTensor(base_lazy_tensor,
                                                    left_interp_indices,
                                                    left_interp_values,
                                                    left_interp_indices,
                                                    left_interp_values)

        actual = interp_lazy_tensor.evaluate()

        samples = interp_lazy_tensor.zero_mean_mvn_samples(10000)
        sample_covar = samples.unsqueeze(-1).matmul(
            samples.unsqueeze(-2)).mean(0)
        self.assertLess(((sample_covar - actual).abs() /
                         actual.abs().clamp(1e-5, 1e5)).max().item(), 2e-1)
Пример #3
0
    def test_diag(self):
        left_interp_indices = torch.LongTensor([[2, 3], [3, 4], [4, 5]])
        left_interp_values = torch.tensor([[1, 1], [1, 1], [1, 1]],
                                          dtype=torch.float)
        right_interp_indices = torch.LongTensor([[0, 1], [1, 2], [2, 3]])
        right_interp_values = torch.tensor([[1, 1], [1, 1], [1, 1]],
                                           dtype=torch.float)

        base_lazy_tensor_mat = torch.randn(6, 6)
        base_lazy_tensor_mat = base_lazy_tensor_mat.t().matmul(
            base_lazy_tensor_mat)
        base_lazy_tensor_mat.requires_grad = True

        base_lazy_tensor = NonLazyTensor(base_lazy_tensor_mat)
        interp_lazy_tensor = InterpolatedLazyTensor(base_lazy_tensor,
                                                    left_interp_indices,
                                                    left_interp_values,
                                                    right_interp_indices,
                                                    right_interp_values)

        actual = interp_lazy_tensor.evaluate()
        self.assertTrue(approx_equal(actual.diag(), interp_lazy_tensor.diag()))