def test_batch_diag(self): left_interp_indices = torch.LongTensor([[2, 3], [3, 4], [4, 5]]).repeat(5, 1, 1) left_interp_values = torch.tensor([[1, 1], [1, 1], [1, 1]], dtype=torch.float).repeat(5, 1, 1) right_interp_indices = torch.LongTensor([[0, 1], [1, 2], [2, 3]]).repeat(5, 1, 1) right_interp_values = torch.tensor([[1, 1], [1, 1], [1, 1]], dtype=torch.float).repeat(5, 1, 1) base_lazy_tensor_mat = torch.randn(5, 6, 6) base_lazy_tensor_mat = base_lazy_tensor_mat.transpose( 1, 2).matmul(base_lazy_tensor_mat) base_lazy_tensor_mat.requires_grad = True base_lazy_tensor = NonLazyTensor(base_lazy_tensor_mat) interp_lazy_tensor = InterpolatedLazyTensor(base_lazy_tensor, left_interp_indices, left_interp_values, right_interp_indices, right_interp_values) actual = interp_lazy_tensor.evaluate() actual_diag = torch.stack([ actual[0].diag(), actual[1].diag(), actual[2].diag(), actual[3].diag(), actual[4].diag() ]) self.assertTrue(approx_equal(actual_diag, interp_lazy_tensor.diag()))
def test_diag(self): left_interp_indices = torch.LongTensor([[2, 3], [3, 4], [4, 5]]) left_interp_values = torch.tensor([[1, 1], [1, 1], [1, 1]], dtype=torch.float) right_interp_indices = torch.LongTensor([[0, 1], [1, 2], [2, 3]]) right_interp_values = torch.tensor([[1, 1], [1, 1], [1, 1]], dtype=torch.float) base_lazy_tensor_mat = torch.randn(6, 6) base_lazy_tensor_mat = base_lazy_tensor_mat.t().matmul( base_lazy_tensor_mat) base_lazy_tensor_mat.requires_grad = True base_lazy_tensor = NonLazyTensor(base_lazy_tensor_mat) interp_lazy_tensor = InterpolatedLazyTensor(base_lazy_tensor, left_interp_indices, left_interp_values, right_interp_indices, right_interp_values) actual = interp_lazy_tensor.evaluate() self.assertTrue(approx_equal(actual.diag(), interp_lazy_tensor.diag()))