def test_batch_diag(self):
        left_interp_indices = Variable(
            torch.LongTensor([[2, 3], [3, 4], [4, 5]]).repeat(5, 1, 1))
        left_interp_values = Variable(
            torch.Tensor([[1, 1], [1, 1], [1, 1]]).repeat(5, 1, 1))
        right_interp_indices = Variable(
            torch.LongTensor([[0, 1], [1, 2], [2, 3]]).repeat(5, 1, 1))
        right_interp_values = Variable(
            torch.Tensor([[1, 1], [1, 1], [1, 1]]).repeat(5, 1, 1))

        base_lazy_variable_mat = torch.randn(5, 6, 6)
        base_lazy_variable_mat = base_lazy_variable_mat.transpose(
            1, 2).matmul(base_lazy_variable_mat)

        base_lazy_variable = NonLazyVariable(
            Variable(base_lazy_variable_mat, requires_grad=True))
        interp_lazy_var = InterpolatedLazyVariable(base_lazy_variable,
                                                   left_interp_indices,
                                                   left_interp_values,
                                                   right_interp_indices,
                                                   right_interp_values)

        actual = interp_lazy_var.evaluate()
        actual_diag = torch.stack([
            actual[0].diag(), actual[1].diag(), actual[2].diag(),
            actual[3].diag(), actual[4].diag()
        ])

        self.assertTrue(
            approx_equal(actual_diag.data,
                         interp_lazy_var.diag().data))
Exemplo n.º 2
0
    def test_diag(self):
        left_interp_indices = Variable(torch.LongTensor([[2, 3], [3, 4], [4, 5]]))
        left_interp_values = Variable(torch.Tensor([[1, 1], [1, 1], [1, 1]]))
        right_interp_indices = Variable(torch.LongTensor([[0, 1], [1, 2], [2, 3]]))
        right_interp_values = Variable(torch.Tensor([[1, 1], [1, 1], [1, 1]]))

        base_lazy_variable_mat = torch.randn(6, 6)
        base_lazy_variable_mat = (
            base_lazy_variable_mat.t().
            matmul(base_lazy_variable_mat)
        )

        base_lazy_variable = NonLazyVariable(
            Variable(base_lazy_variable_mat, requires_grad=True)
        )
        interp_lazy_var = InterpolatedLazyVariable(
            base_lazy_variable,
            left_interp_indices,
            left_interp_values,
            right_interp_indices,
            right_interp_values,
        )

        actual = interp_lazy_var.evaluate()
        self.assertTrue(approx_equal(actual.diag().data, interp_lazy_var.diag().data))