コード例 #1
0
    def test_normal_trace_log_det_quad_form_forward(self):
        covar = torch.Tensor([
            [3, -1, 0],
            [-1, 3, 0],
            [0, 0, 3],
        ])
        mu_diffs = torch.Tensor([0, -1, 1])
        chol_covar = torch.Tensor([
            [1, -2, 0],
            [0, 1, -2],
            [0, 0, 1],
        ])

        actual = mu_diffs.dot(covar.inverse().matmul(mu_diffs))
        actual += math.log(np.linalg.det(covar.numpy()))
        actual += (covar.inverse().matmul(
            chol_covar.t().matmul(chol_covar))).trace()

        covarvar = Variable(covar)
        chol_covarvar = Variable(chol_covar)
        mu_diffsvar = Variable(mu_diffs)

        res = gpytorch.trace_logdet_quad_form(mu_diffsvar, chol_covarvar,
                                              covarvar)
        self.assertTrue(all(torch.abs(actual - res.data).div(res.data) < 0.1))
コード例 #2
0
    def mvn_kl_divergence(self):
        mean_diffs = self.inducing_output.mean() - self.variational_mean
        chol_variational_covar = self.chol_variational_covar

        if chol_variational_covar.ndimension() == 2:
            matrix_diag = chol_variational_covar.diag()
        elif chol_variational_covar.ndimension() == 3:
            batch_size, diag_size, _ = chol_variational_covar.size()
            batch_index = chol_variational_covar.data.new(batch_size).long()
            torch.arange(0, batch_size, out=batch_index)
            batch_index = batch_index.unsqueeze(1).repeat(1,
                                                          diag_size).view(-1)
            diag_index = chol_variational_covar.data.new(diag_size).long()
            torch.arange(0, diag_size, out=diag_index)
            diag_index = diag_index.unsqueeze(1).repeat(batch_size, 1).view(-1)
            matrix_diag = chol_variational_covar[batch_index, diag_index,
                                                 diag_index].view(
                                                     batch_size, diag_size)
        else:
            raise RuntimeError(
                'Invalid number of variational covar dimensions')

        logdet_variational_covar = matrix_diag.log().sum() * 2
        trace_logdet_quad_form = gpytorch.trace_logdet_quad_form(
            mean_diffs, self.chol_variational_covar,
            gpytorch.add_jitter(self.inducing_output.covar()))

        # Compute the KL Divergence.
        res = 0.5 * (trace_logdet_quad_form - logdet_variational_covar -
                     len(mean_diffs))
        return res
コード例 #3
0
    def kl_divergence(self):
        prior_mean = self.prior_dist.mean()
        prior_covar = self.prior_dist.covar()
        variational_mean = self.variational_dist.mean()
        variational_covar = self.variational_dist.covar()
        if not isinstance(variational_covar, CholLazyVariable):
            raise RuntimeError('The variational covar for an MVN distribution should be a CholLazyVariable')
        chol_variational_covar = variational_covar.lhs

        mean_diffs = prior_mean - variational_mean
        chol_variational_covar = chol_variational_covar

        if chol_variational_covar.ndimension() == 2:
            matrix_diag = chol_variational_covar.diag()
        elif chol_variational_covar.ndimension() == 3:
            batch_size, diag_size, _ = chol_variational_covar.size()
            batch_index = chol_variational_covar.data.new(batch_size).long()
            torch.arange(0, batch_size, out=batch_index)
            batch_index = batch_index.unsqueeze(1).repeat(1, diag_size).view(-1)
            diag_index = chol_variational_covar.data.new(diag_size).long()
            torch.arange(0, diag_size, out=diag_index)
            diag_index = diag_index.unsqueeze(1).repeat(batch_size, 1).view(-1)
            matrix_diag = chol_variational_covar[batch_index, diag_index, diag_index].view(batch_size, diag_size)
        else:
            raise RuntimeError('Invalid number of variational covar dimensions')

        logdet_variational_covar = matrix_diag.log().sum() * 2
        trace_logdet_quad_form = gpytorch.trace_logdet_quad_form(mean_diffs, chol_variational_covar,
                                                                 gpytorch.add_jitter(prior_covar))

        # Compute the KL Divergence.
        res = 0.5 * (trace_logdet_quad_form - logdet_variational_covar - len(mean_diffs))
        return res
コード例 #4
0
    def test_normal_trace_log_det_quad_form_backward(self):
        covar = Variable(torch.Tensor([
            [3, -1, 0],
            [-1, 3, 0],
            [0, 0, 3],
        ]),
                         requires_grad=True)
        mu_diffs = Variable(torch.Tensor([0, -1, 1]), requires_grad=True)
        chol_covar = Variable(torch.Tensor([
            [1, -2, 0],
            [0, 1, -2],
            [0, 0, 1],
        ]),
                              requires_grad=True)

        actual = mu_diffs.dot(covar.inverse().matmul(mu_diffs))
        actual += (covar.inverse().matmul(
            chol_covar.t().matmul(chol_covar))).trace()
        actual.backward()

        actual_covar_grad = covar.grad.data.clone() + covar.data.inverse()
        actual_mu_diffs_grad = mu_diffs.grad.data.clone()
        actual_chol_covar_grad = chol_covar.grad.data.clone()

        covar = Variable(torch.Tensor([
            [3, -1, 0],
            [-1, 3, 0],
            [0, 0, 3],
        ]),
                         requires_grad=True)
        mu_diffs = Variable(torch.Tensor([0, -1, 1]), requires_grad=True)
        chol_covar = Variable(torch.Tensor([
            [1, -2, 0],
            [0, 1, -2],
            [0, 0, 1],
        ]),
                              requires_grad=True)

        with gpytorch.settings.num_trace_samples(1000):
            res = gpytorch.trace_logdet_quad_form(mu_diffs, chol_covar, covar)
            res.backward()

        res_covar_grad = covar.grad.data
        res_mu_diffs_grad = mu_diffs.grad.data
        res_chol_covar_grad = chol_covar.grad.data

        self.assertLess(
            torch.norm(actual_covar_grad - res_covar_grad),
            1e-1,
        )
        self.assertLess(
            torch.norm(actual_mu_diffs_grad - res_mu_diffs_grad),
            1e-1,
        )
        self.assertLess(
            torch.norm(actual_chol_covar_grad - res_chol_covar_grad),
            1e-1,
        )
コード例 #5
0
    def test_batch_trace_log_det_quad_form_backward(self):
        covar = Variable(torch.Tensor([[
            [3, -1, 0],
            [-1, 3, 0],
            [0, 0, 3],
        ], [
            [10, -2, 1],
            [-2, 10, 0],
            [1, 0, 10],
        ]]),
                         requires_grad=True)
        mu_diffs = Variable(torch.Tensor([[0, -1, 1], [1, 2, 3]]),
                            requires_grad=True)
        chol_covar = Variable(torch.Tensor([[
            [1, -2, 0],
            [0, 1, -2],
            [0, 0, 1],
        ], [
            [2, -4, 0],
            [0, 2, -4],
            [0, 0, 2],
        ]]),
                              requires_grad=True)

        actual = mu_diffs[0].dot(covar[0].inverse().matmul(mu_diffs[0]))
        actual += (covar[0].inverse().matmul(chol_covar[0].t().matmul(
            chol_covar[0]))).trace()
        actual += mu_diffs[1].dot(covar[1].inverse().matmul(mu_diffs[1]))
        actual += (covar[1].inverse().matmul(chol_covar[1].t().matmul(
            chol_covar[1]))).trace()
        actual.backward()

        actual_covar_grad = (covar.grad.data.clone() + torch.cat([
            covar[0].data.inverse().unsqueeze(0),
            covar[1].data.inverse().unsqueeze(0)
        ]))
        actual_mu_diffs_grad = mu_diffs.grad.data.clone()
        actual_chol_covar_grad = chol_covar.grad.data.clone()

        covar.grad.data.fill_(0)
        mu_diffs.grad.data.fill_(0)
        chol_covar.grad.data.fill_(0)
        with gpytorch.settings.num_trace_samples(1000):
            res = gpytorch.trace_logdet_quad_form(mu_diffs, chol_covar, covar)
            res.backward()

        res_covar_grad = covar.grad.data
        res_mu_diffs_grad = mu_diffs.grad.data
        res_chol_covar_grad = chol_covar.grad.data

        self.assertLess(torch.norm(actual_covar_grad - res_covar_grad), 1e-1)
        self.assertLess(torch.norm(actual_mu_diffs_grad - res_mu_diffs_grad),
                        1e-1)
        self.assertLess(
            torch.norm(actual_chol_covar_grad - res_chol_covar_grad), 1e-1)
コード例 #6
0
def test_batch_trace_log_det_quad_form_backward():
    covar = Variable(torch.Tensor([[
        [5, -3, 0],
        [-3, 5, 0],
        [0, 0, 2],
    ], [
        [10, -2, 1],
        [-2, 10, 0],
        [1, 0, 10],
    ]]),
                     requires_grad=True)
    mu_diffs = Variable(torch.Tensor([[0, -1, 1], [1, 2, 3]]),
                        requires_grad=True)
    chol_covar = Variable(torch.Tensor([[
        [1, -2, 0],
        [0, 1, -2],
        [0, 0, 1],
    ], [
        [2, -4, 0],
        [0, 2, -4],
        [0, 0, 2],
    ]]),
                          requires_grad=True)

    actual = mu_diffs[0].dot(covar[0].inverse().matmul(mu_diffs[0]))
    actual += (covar[0].inverse().matmul(chol_covar[0].t().matmul(
        chol_covar[0]))).trace()
    actual += mu_diffs[1].dot(covar[1].inverse().matmul(mu_diffs[1]))
    actual += (covar[1].inverse().matmul(chol_covar[1].t().matmul(
        chol_covar[1]))).trace()
    actual.backward()

    actual_covar_grad = covar.grad.data.clone() + torch.cat([
        covar[0].data.inverse().unsqueeze(0),
        covar[1].data.inverse().unsqueeze(0)
    ])
    actual_mu_diffs_grad = mu_diffs.grad.data.clone()
    actual_chol_covar_grad = chol_covar.grad.data.clone()

    covar.grad.data.fill_(0)
    mu_diffs.grad.data.fill_(0)
    chol_covar.grad.data.fill_(0)

    res = gpytorch.trace_logdet_quad_form(mu_diffs, chol_covar, covar)
    res.backward()

    res_covar_grad = covar.grad.data
    res_mu_diffs_grad = mu_diffs.grad.data
    res_chol_covar_grad = chol_covar.grad.data

    assert approx_equal(actual_covar_grad, res_covar_grad)
    assert approx_equal(actual_mu_diffs_grad, res_mu_diffs_grad)
    assert approx_equal(actual_chol_covar_grad, res_chol_covar_grad)
コード例 #7
0
def test_normal_trace_log_det_quad_form_backward():
    covar = Variable(torch.Tensor([
        [5, -3, 0],
        [-3, 5, 0],
        [0, 0, 2],
    ]),
                     requires_grad=True)
    mu_diffs = Variable(torch.Tensor([0, -1, 1]), requires_grad=True)
    chol_covar = Variable(torch.Tensor([
        [1, -2, 0],
        [0, 1, -2],
        [0, 0, 1],
    ]),
                          requires_grad=True)

    actual = mu_diffs.dot(covar.inverse().matmul(mu_diffs))
    actual += (covar.inverse().matmul(
        chol_covar.t().matmul(chol_covar))).trace()
    actual.backward()

    actual_covar_grad = covar.grad.data.clone() + covar.data.inverse()
    actual_mu_diffs_grad = mu_diffs.grad.data.clone()
    actual_chol_covar_grad = chol_covar.grad.data.clone()

    covar = Variable(torch.Tensor([
        [5, -3, 0],
        [-3, 5, 0],
        [0, 0, 2],
    ]),
                     requires_grad=True)
    mu_diffs = Variable(torch.Tensor([0, -1, 1]), requires_grad=True)
    chol_covar = Variable(torch.Tensor([
        [1, -2, 0],
        [0, 1, -2],
        [0, 0, 1],
    ]),
                          requires_grad=True)

    res = gpytorch.trace_logdet_quad_form(mu_diffs, chol_covar, covar)
    res.backward()

    res_covar_grad = covar.grad.data
    res_mu_diffs_grad = mu_diffs.grad.data
    res_chol_covar_grad = chol_covar.grad.data

    assert approx_equal(actual_covar_grad, res_covar_grad)
    assert approx_equal(actual_mu_diffs_grad, res_mu_diffs_grad)
    assert approx_equal(actual_chol_covar_grad, res_chol_covar_grad)
コード例 #8
0
    def test_batch_trace_log_det_quad_form_forward(self):
        covar = torch.Tensor([
            [
                [3, -1, 0],
                [-1, 3, 0],
                [0, 0, 3],
            ], [
                [10, -2, 1],
                [-2, 10, 0],
                [1, 0, 10],
            ]
        ])
        mu_diffs = torch.Tensor([
            [0, -1, 1],
            [1, 2, 3]
        ])
        chol_covar = torch.Tensor([
            [
                [1, -2, 0],
                [0, 1, -2],
                [0, 0, 1],
            ], [
                [2, -4, 0],
                [0, 2, -4],
                [0, 0, 2],
            ]
        ])

        actual = mu_diffs[0].dot(covar[0].inverse().matmul(mu_diffs[0]))
        actual += math.log(np.linalg.det(covar[0].numpy()))
        actual += (
            covar[0].inverse().matmul(chol_covar[0].t().matmul(chol_covar[0]))
        ).trace()
        actual += mu_diffs[1].dot(covar[1].inverse().matmul(mu_diffs[1]))
        actual += math.log(np.linalg.det(covar[1].numpy()))
        actual += (
            covar[1].inverse().matmul(chol_covar[1].t().matmul(chol_covar[1]))
        ).trace()

        covarvar = Variable(covar)
        chol_covarvar = Variable(chol_covar)
        mu_diffsvar = Variable(mu_diffs)

        res = gpytorch.trace_logdet_quad_form(mu_diffsvar, chol_covarvar, covarvar)
        self.assertTrue((torch.abs(actual - res.data).div(res.data) < 0.1).all())