def forward(self, input):
        """
        Adds the log task noises to the diagonal of the covariance matrix of the supplied
        :obj:`gpytorch.random_variables.GaussianRandomVariable` or
        :obj:`gpytorch.random_variables.MultitaskGaussianRandomVariable`.

        To accomplish this, we form a new :obj:`gpytorch.lazy.KroneckerProductLazyVariable` between :math:`I_{n}`,
        an identity matrix with size equal to the data and a diagonal matrix containing the task noises :math:`D_{t}`.

        We also incorporate a shared `log_noise` parameter from the base
        :class:`gpytorch.likelihoods.GaussianLikelihood` that we extend.

        The final covariance matrix after this method is then :math:`K + D_{t} \otimes I_{n} + \sigma^{2}I_{nt}`.

        Args:
            input (:obj:`gpytorch.random_variables.MultitaskGaussianRandomVariable`): Random variable whose covariance
                matrix is a :obj:`gpytorch.lazy.LazyVariable` we intend to augment.
        Returns:
            :obj:`gpytorch.random_variables.MultitaskGaussianRandomVariable`: A new random variable whose covariance
            matrix is a :obj:`gpytorch.lazy.LazyVariable` with :math:`D_{t} \otimes I_{n}` and :math:`\sigma^{2}I_{nt}`
            added.
        """
        mean, covar = input.representation()
        eye_lv = DiagLazyVariable(
            torch.ones(covar.size(-1) // self.n_tasks,
                       device=self.log_noise.device))
        task_var_lv = DiagLazyVariable(self.log_task_noises.exp())
        diag_kron_lv = KroneckerProductLazyVariable(task_var_lv, eye_lv)
        noise = covar + diag_kron_lv
        noise = add_diag(noise, self.log_noise.exp())
        return input.__class__(mean, noise)
Beispiel #2
0
 def test_get_indices(self):
     diag_lv = DiagLazyVariable(Variable(diag))
     res = diag_lv._get_indices(
         Variable(torch.LongTensor([1, 2, 0])),
         Variable(torch.LongTensor([0, 2, 0])),
     )
     self.assertTrue(torch.equal(res.data, torch.Tensor([0, 3, 1])))
 def forward(self, x1, x2):
     if self.training:
         return DiagLazyVariable(self.variances.unsqueeze(0))
     elif x1.size(-2) == x2.size(-2) and x1.size(-2) == self.variances.size(
             -1) and torch.equal(x1, x2):
         return DiagLazyVariable(self.variances.unsqueeze(0))
     else:
         return ZeroLazyVariable(x1.size(-3), x1.size(-2), x2.size(-2))
Beispiel #4
0
 def forward(self, x1, x2):
     if self.training and torch.equal(x1, x2):
         # Reshape into a batch of batch_size diagonal matrices, each of which is
         # (data_size * task_size) x (data_size * task_size)
         return DiagLazyVariable(
             self.variances.view(self.variances.size(0), -1))
     elif x1.size(-2) == x2.size(-2) and x1.size(-2) == self.variances.size(
             1) and torch.equal(x1, x2):
         return DiagLazyVariable(
             self.variances.view(self.variances.size(0), -1))
     else:
         return ZeroLazyVariable(x1.size(-3), x1.size(-2), x2.size(-2))
def test_batch_function_factory():
    # 2d
    diag_var1 = Variable(diag.repeat(5, 1), requires_grad=True)
    diag_var2 = Variable(diag.repeat(5, 1), requires_grad=True)
    test_mat = torch.eye(3).repeat(5, 1, 1)

    diag_lv = DiagLazyVariable(diag_var1)
    diag_ev = DiagLazyVariable(diag_var2).evaluate()

    # Forward
    res = diag_lv.matmul(Variable(test_mat))
    actual = torch.matmul(diag_ev, Variable(test_mat))
    assert torch.norm(res.data - actual.data) < 1e-4

    # Backward
    res.sum().backward()
    actual.sum().backward()
    assert torch.norm(diag_var1.grad.data - diag_var2.grad.data) < 1e-3
    def _covar_diag(self, inputs):
        if inputs.ndimension() == 1:
            inputs = inputs.unsqueeze(1)
        orig_size = list(inputs.size())

        # Resize inputs so that everything is batch
        inputs = inputs.unsqueeze(-2).view(-1, 1, inputs.size(-1))

        # Get diagonal of covar
        covar_diag = self.base_kernel_module(inputs)
        if isinstance(covar_diag, LazyVariable):
            covar_diag = covar_diag.evaluate()
        covar_diag = covar_diag.view(orig_size[:-1])
        return DiagLazyVariable(covar_diag)
Beispiel #7
0
    def test_function_factory(self):
        # 1d
        diag_var1 = Variable(diag, requires_grad=True)
        diag_var2 = Variable(diag, requires_grad=True)
        test_mat = torch.Tensor([3, 4, 5])

        diag_lv = DiagLazyVariable(diag_var1)
        diag_ev = DiagLazyVariable(diag_var2).evaluate()

        # Forward
        res = diag_lv.inv_matmul(Variable(test_mat))
        actual = gpytorch.inv_matmul(diag_ev, Variable(test_mat))
        self.assertLess(torch.norm(res.data - actual.data), 1e-4)

        # Backward
        res.sum().backward()
        actual.sum().backward()
        self.assertLess(
            torch.norm(diag_var1.grad.data - diag_var2.grad.data),
            1e-3,
        )

        # 2d
        diag_var1 = Variable(diag, requires_grad=True)
        diag_var2 = Variable(diag, requires_grad=True)
        test_mat = torch.eye(3)

        diag_lv = DiagLazyVariable(diag_var1)
        diag_ev = DiagLazyVariable(diag_var2).evaluate()

        # Forward
        res = diag_lv.inv_matmul(Variable(test_mat))
        actual = gpytorch.inv_matmul(diag_ev, Variable(test_mat))
        self.assertLess(torch.norm(res.data - actual.data), 1e-4)

        # Backward
        res.sum().backward()
        actual.sum().backward()
        self.assertLess(
            torch.norm(diag_var1.grad.data - diag_var2.grad.data),
            1e-3,
        )
Beispiel #8
0
 def test_get_item(self):
     diag_lv = DiagLazyVariable(Variable(diag))
     diag_ev = diag_lv.evaluate()
     self.assertTrue(
         torch.equal(diag_lv[0:2].evaluate().data, diag_ev[0:2].data))
Beispiel #9
0
 def test_evaluate(self):
     diag_lv = DiagLazyVariable(Variable(diag))
     self.assertTrue(torch.equal(diag_lv.evaluate().data, diag.diag()))
def test_evaluate():
    diag_lv = DiagLazyVariable(Variable(diag))
    assert torch.equal(diag_lv.evaluate().data, diag.diag())
def test_get_item():
    diag_lv = DiagLazyVariable(Variable(diag))
    diag_ev = diag_lv.evaluate()
    assert torch.equal(diag_lv[0:2].evaluate().data, diag_ev[0:2].data)