def test_batch_diag(self): left_interp_indices = Variable( torch.LongTensor([[2, 3], [3, 4], [4, 5]]).repeat(5, 1, 1)) left_interp_values = Variable( torch.Tensor([[1, 1], [1, 1], [1, 1]]).repeat(5, 1, 1)) right_interp_indices = Variable( torch.LongTensor([[0, 1], [1, 2], [2, 3]]).repeat(5, 1, 1)) right_interp_values = Variable( torch.Tensor([[1, 1], [1, 1], [1, 1]]).repeat(5, 1, 1)) base_lazy_variable_mat = torch.randn(5, 6, 6) base_lazy_variable_mat = base_lazy_variable_mat.transpose( 1, 2).matmul(base_lazy_variable_mat) base_lazy_variable = NonLazyVariable( Variable(base_lazy_variable_mat, requires_grad=True)) interp_lazy_var = InterpolatedLazyVariable(base_lazy_variable, left_interp_indices, left_interp_values, right_interp_indices, right_interp_values) actual = interp_lazy_var.evaluate() actual_diag = torch.stack([ actual[0].diag(), actual[1].diag(), actual[2].diag(), actual[3].diag(), actual[4].diag() ]) self.assertTrue( approx_equal(actual_diag.data, interp_lazy_var.diag().data))
def test_diag(self): left_interp_indices = Variable(torch.LongTensor([[2, 3], [3, 4], [4, 5]])) left_interp_values = Variable(torch.Tensor([[1, 1], [1, 1], [1, 1]])) right_interp_indices = Variable(torch.LongTensor([[0, 1], [1, 2], [2, 3]])) right_interp_values = Variable(torch.Tensor([[1, 1], [1, 1], [1, 1]])) base_lazy_variable_mat = torch.randn(6, 6) base_lazy_variable_mat = ( base_lazy_variable_mat.t(). matmul(base_lazy_variable_mat) ) base_lazy_variable = NonLazyVariable( Variable(base_lazy_variable_mat, requires_grad=True) ) interp_lazy_var = InterpolatedLazyVariable( base_lazy_variable, left_interp_indices, left_interp_values, right_interp_indices, right_interp_values, ) actual = interp_lazy_var.evaluate() self.assertTrue(approx_equal(actual.diag().data, interp_lazy_var.diag().data))