def test_batch_diag(self):
        lhs = Variable(torch.randn(4, 5, 3))
        rhs = Variable(torch.randn(4, 3, 5))
        actual = lhs.matmul(rhs)
        actual_diag = torch.cat([
            actual[0].diag().unsqueeze(0),
            actual[1].diag().unsqueeze(0),
            actual[2].diag().unsqueeze(0),
            actual[3].diag().unsqueeze(0),
        ])

        res = MatmulLazyVariable(lhs, rhs)
        self.assertTrue(approx_equal(actual_diag.data, res.diag().data))
 def test_diag(self):
     lhs = Variable(torch.randn(5, 3))
     rhs = Variable(torch.randn(3, 5))
     actual = lhs.matmul(rhs)
     res = MatmulLazyVariable(lhs, rhs)
     self.assertTrue(approx_equal(actual.diag().data, res.diag().data))