def test_base_sample_shape(self):
        a = torch.randn(5, 10)
        lazy_square_a = RootLazyTensor(lazify(a))
        dist = MultivariateNormal(torch.zeros(5), lazy_square_a)

        # check that providing the base samples is okay
        samples = dist.rsample(torch.Size((16, )),
                               base_samples=torch.randn(16, 10))
        self.assertEqual(samples.shape, torch.Size((16, 5)))

        # check that an event shape of base samples fails
        self.assertRaises(RuntimeError,
                          dist.rsample,
                          torch.Size((16, )),
                          base_samples=torch.randn(16, 5))

        # check that the proper event shape of base samples is okay for
        # a non root lt
        nonlazy_square_a = lazify(lazy_square_a.evaluate())
        dist = MultivariateNormal(torch.zeros(5), nonlazy_square_a)

        samples = dist.rsample(torch.Size((16, )),
                               base_samples=torch.randn(16, 5))
        self.assertEqual(samples.shape, torch.Size((16, 5)))
 def test_evaluate(self):
     root = torch.randn(5, 3)
     actual = root.matmul(root.transpose(-1, -2))
     res = RootLazyTensor(root)
     self.assertTrue(approx_equal(actual, res.evaluate()))