def test_batch_get_indices(self): lhs = torch.randn(2, 5, 1) rhs = torch.randn(2, 1, 5) actual = lhs.matmul(rhs) res = MatmulLazyVariable(lhs, rhs) batch_indices = torch.LongTensor([0, 1, 0, 1]) left_indices = torch.LongTensor([1, 2, 4, 0]) right_indices = torch.LongTensor([0, 1, 3, 2]) self.assertTrue( approx_equal( actual[batch_indices, left_indices, right_indices], res._batch_get_indices(batch_indices, left_indices, right_indices))) batch_indices = torch.LongTensor( [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1]) left_indices = torch.LongTensor( [1, 2, 4, 0, 1, 2, 3, 1, 2, 2, 1, 1, 0, 0, 4, 4, 4, 4]) right_indices = torch.LongTensor( [0, 1, 3, 2, 3, 4, 2, 2, 1, 1, 2, 1, 2, 4, 4, 3, 3, 0]) self.assertTrue( approx_equal( actual[batch_indices, left_indices, right_indices], res._batch_get_indices(batch_indices, left_indices, right_indices)))
def test_batch_diag(self): lhs = Variable(torch.randn(4, 5, 3)) rhs = Variable(torch.randn(4, 3, 5)) actual = lhs.matmul(rhs) actual_diag = torch.cat([ actual[0].diag().unsqueeze(0), actual[1].diag().unsqueeze(0), actual[2].diag().unsqueeze(0), actual[3].diag().unsqueeze(0), ]) res = MatmulLazyVariable(lhs, rhs) self.assertTrue(approx_equal(actual_diag.data, res.diag().data))
def forward(self, x1, x2): if x1.size() == x2.size() and torch.equal(x1, x2): # Use RootLazyVariable when x1 == x2 for efficiency when composing # with other kernels prod = RootLazyVariable(x1 - self.offset) else: prod = MatmulLazyVariable(x1 - self.offset, (x2 - self.offset).transpose(2, 1)) return prod + self.variance.expand(prod.size())
def test_matmul(self): lhs = Variable(torch.randn(5, 3), requires_grad=True) rhs = Variable(torch.randn(3, 4), requires_grad=True) covar = MatmulLazyVariable(lhs, rhs) mat = Variable(torch.randn(4, 10)) res = covar.matmul(mat) lhs_clone = Variable(lhs.data.clone(), requires_grad=True) rhs_clone = Variable(rhs.data.clone(), requires_grad=True) mat_clone = Variable(mat.data.clone()) actual = lhs_clone.matmul(rhs_clone).matmul(mat_clone) self.assertTrue(approx_equal(res.data, actual.data)) actual.sum().backward() res.sum().backward() self.assertTrue(approx_equal(lhs.grad.data, lhs_clone.grad.data)) self.assertTrue(approx_equal(rhs.grad.data, rhs_clone.grad.data))
def _get_covariance(self, x1, x2): k_ux1 = self.base_kernel_module(x1, self.inducing_points).evaluate() if torch.equal(x1, x2): covar = RootLazyVariable(k_ux1.matmul(self._inducing_inv_root)) else: k_ux2 = self.base_kernel_module(x2, self.inducing_points).evaluate() covar = MatmulLazyVariable( k_ux1.matmul(self._inducing_inv_root), k_ux2.matmul(self._inducing_inv_root).transpose(-1, -2) ) return covar
def test_matmul(self): lhs = torch.randn(5, 3, requires_grad=True) rhs = torch.randn(3, 4, requires_grad=True) covar = MatmulLazyVariable(lhs, rhs) mat = torch.randn(4, 10) res = covar.matmul(mat) lhs_clone = lhs.clone().detach() rhs_clone = rhs.clone().detach() mat_clone = mat.clone().detach() lhs_clone.requires_grad = True rhs_clone.requires_grad = True mat_clone.requires_grad = True actual = lhs_clone.matmul(rhs_clone).matmul(mat_clone) self.assertTrue(approx_equal(res, actual)) actual.sum().backward() res.sum().backward() self.assertTrue(approx_equal(lhs.grad, lhs_clone.grad)) self.assertTrue(approx_equal(rhs.grad, rhs_clone.grad))
def test_transpose(self): lhs = Variable(torch.randn(5, 3)) rhs = Variable(torch.randn(3, 5)) actual = lhs.matmul(rhs) res = MatmulLazyVariable(lhs, rhs) self.assertTrue(approx_equal(actual.t().data, res.t().evaluate().data))
def test_diag(self): lhs = Variable(torch.randn(5, 3)) rhs = Variable(torch.randn(3, 5)) actual = lhs.matmul(rhs) res = MatmulLazyVariable(lhs, rhs) self.assertTrue(approx_equal(actual.diag().data, res.diag().data))
def test_evaluate(self): lhs = torch.randn(5, 3) rhs = torch.randn(3, 5) actual = lhs.matmul(rhs) res = MatmulLazyVariable(lhs, rhs) self.assertTrue(approx_equal(actual, res.evaluate()))
def test_evaluate(): lhs = Variable(torch.randn(5, 3)) rhs = Variable(torch.randn(3, 5)) actual = lhs.matmul(rhs) res = MatmulLazyVariable(lhs, rhs) assert approx_equal(actual.data, res.evaluate().data)