def test_evaluate(self):
        avar = Variable(a)
        bvar = Variable(b)
        cvar = Variable(c)
        kp_lazy_var = KroneckerProductLazyVariable(NonLazyVariable(avar), NonLazyVariable(bvar), NonLazyVariable(cvar))
        res = kp_lazy_var.evaluate()
        actual = kron(kron(avar, bvar), cvar)
        self.assertTrue(approx_equal(res.data, actual.data))

        avar = Variable(a.repeat(3, 1, 1))
        bvar = Variable(b.repeat(3, 1, 1))
        cvar = Variable(c.repeat(3, 1, 1))
        kp_lazy_var = KroneckerProductLazyVariable(NonLazyVariable(avar), NonLazyVariable(bvar), NonLazyVariable(cvar))
        res = kp_lazy_var.evaluate()
        actual = kron(kron(avar, bvar), cvar)
        self.assertTrue(approx_equal(res.data, actual.data))
def test_get_item_square_on_variable():
    kronecker_product_var = KroneckerProductLazyVariable(
        Variable(torch.Tensor([[1, 2, 3, 4], [5, 6, 7, 8]])),
        added_diag=Variable(torch.ones(16) * 3))
    evaluated = kronecker_product_var.evaluate().data

    assert utils.approx_equal(kronecker_product_var[2:4, 2:4].evaluate().data,
                              evaluated[2:4, 2:4])
def test_get_item_on_interpolated_variable_no_diagonal():
    no_diag_kronecker_product = KroneckerProductLazyVariable(
        lazy_kronecker_product_var.columns, lazy_kronecker_product_var.J_lefts,
        lazy_kronecker_product_var.C_lefts,
        lazy_kronecker_product_var.J_rights,
        lazy_kronecker_product_var.C_rights)
    evaluated = no_diag_kronecker_product.evaluate().data
    assert utils.approx_equal(no_diag_kronecker_product[4:6].evaluate().data,
                              evaluated[4:6])
    assert utils.approx_equal(
        no_diag_kronecker_product[4:6, 2:6].evaluate().data, evaluated[4:6,
                                                                       2:6])