def test_evaluate(self): avar = Variable(a) bvar = Variable(b) cvar = Variable(c) kp_lazy_var = KroneckerProductLazyVariable(NonLazyVariable(avar), NonLazyVariable(bvar), NonLazyVariable(cvar)) res = kp_lazy_var.evaluate() actual = kron(kron(avar, bvar), cvar) self.assertTrue(approx_equal(res.data, actual.data)) avar = Variable(a.repeat(3, 1, 1)) bvar = Variable(b.repeat(3, 1, 1)) cvar = Variable(c.repeat(3, 1, 1)) kp_lazy_var = KroneckerProductLazyVariable(NonLazyVariable(avar), NonLazyVariable(bvar), NonLazyVariable(cvar)) res = kp_lazy_var.evaluate() actual = kron(kron(avar, bvar), cvar) self.assertTrue(approx_equal(res.data, actual.data))
def test_get_item_square_on_variable(): kronecker_product_var = KroneckerProductLazyVariable( Variable(torch.Tensor([[1, 2, 3, 4], [5, 6, 7, 8]])), added_diag=Variable(torch.ones(16) * 3)) evaluated = kronecker_product_var.evaluate().data assert utils.approx_equal(kronecker_product_var[2:4, 2:4].evaluate().data, evaluated[2:4, 2:4])
def test_get_item_on_interpolated_variable_no_diagonal(): no_diag_kronecker_product = KroneckerProductLazyVariable( lazy_kronecker_product_var.columns, lazy_kronecker_product_var.J_lefts, lazy_kronecker_product_var.C_lefts, lazy_kronecker_product_var.J_rights, lazy_kronecker_product_var.C_rights) evaluated = no_diag_kronecker_product.evaluate().data assert utils.approx_equal(no_diag_kronecker_product[4:6].evaluate().data, evaluated[4:6]) assert utils.approx_equal( no_diag_kronecker_product[4:6, 2:6].evaluate().data, evaluated[4:6, 2:6])