def test_throw(self): mode = FakeTensorMode(inner=None) x = torch.tensor(0.) # TODO: tensor() errors with enable_torch_dispatch_mode(mode): x_conv = mode.from_tensor(x) y = torch.rand([4, 4], device="cuda") z = torch.rand([4, 4], device="cpu") self.assertRaises(Exception, lambda: torch.lerp(x_conv, y, z))
def test_fake_grad_copy(self): x = torch.rand([4, 4], requires_grad=True) x.grad = torch.rand([4, 4]) mode = FakeTensorMode() fake_x = mode.from_tensor(x) prims.utils.compare_tensor_meta(fake_x, x) prims.utils.compare_tensor_meta(fake_x.grad, x.grad) self.assertTrue(isinstance(fake_x.grad, FakeTensor))
def test_basic(self): mode = FakeTensorMode(inner=None) x = torch.empty(2, 2, device="cpu") y = torch.empty(4, 2, 2, device="cpu") with enable_torch_dispatch_mode(mode): x = mode.from_tensor(x) y = mode.from_tensor(y) z = x + y self.assertEqual(z.shape, (4, 2, 2)) self.assertEqual(z.device, torch.device("cpu")) self.assertTrue(isinstance(z, FakeTensor))
def test_memoized_conversion_to_meta(self): x = torch.rand(2, 2, 2) mode = FakeTensorMode(inner=None) self.assertTrue(mode.from_tensor(x) is mode.from_tensor(x))