def test_cpu_fallback(self): with enable_torch_dispatch_mode( FakeTensorMode(inner=None, allow_fallback_kernels=False)): filters = torch.randn(8, 4, 3, 3).cuda() inputs = torch.randn(1, 4, 5, 5).cuda() out = torch.nn.functional.conv2d(inputs, filters, padding=1) self.assertEqual(out.device.type, "cuda") self.assertEqual(list(out.size()), [1, 8, 5, 5]) with enable_torch_dispatch_mode( FakeTensorMode(inner=None, allow_fallback_kernels=True)): # intentionally bad inputs filters = torch.randn(8, 20, 3, 3).cuda() inputs = torch.randn(1, 7, 10, 5).cuda() with self.assertRaises(RuntimeError): torch.nn.functional.conv2d(inputs, filters, padding=1) with enable_torch_dispatch_mode( FakeTensorMode(inner=None, allow_fallback_kernels=True)): filters = torch.randn(8, 4, 3, 3).cuda() inputs = torch.randn(1, 4, 5, 5).cuda() out = torch.nn.functional.conv2d(inputs, filters, padding=1) self.assertEqual(out.device.type, "cuda") self.assertEqual(list(out.size()), [1, 8, 5, 5])
def wrapped(*args): phs = pytree.tree_map(lambda _: fx.PH, args) # type: ignore[attr-defined] fx_tracer = PythonKeyTracer() fake_tensor_mode: Any = nullcontext() if tracing_mode == "real": fake_tensor_mode = nullcontext() elif tracing_mode == "fake": fake_tensor_mode = FakeTensorMode(allow_fallback_kernels=True) elif tracing_mode == "symbolic": fake_tensor_mode = FakeTensorMode(allow_fallback_kernels=False) else: raise AssertionError(f"Unexpected tracing type: {tracing_mode}") proxy_mode = ProxyTorchDispatchMode( fx_tracer, trace_factory_functions=trace_factory_functions) def wrap_fake_concrete(x): if isinstance(x, torch.Tensor): return fake_tensor_mode.from_tensor( x) # type: ignore[attr-defined] return x shape_env = ShapeEnv() # todo: Figure out a more informative name for symints def wrap_fake_symbolic(x, sym_shape): if isinstance(x, torch.Tensor): val = FakeTensor(fake_tensor_mode, torch.empty(sym_shape, device="meta"), x.device) return val return x wrap_fn_map = { "real": lambda x: x, "fake": wrap_fake_concrete, } if tracing_mode == "symbolic": flat_shapes = shape_env.create_shapes_for_args(args) flat_args, spec = pytree.tree_flatten(args) args = pytree.tree_unflatten( list( map(lambda a: wrap_fake_symbolic(a[0], a[1]), zip(flat_args, flat_shapes))), spec) else: args = pytree.tree_map(wrap_fn_map[tracing_mode], args) with decompose( decomposition_table ), fake_tensor_mode, proxy_mode: # type: ignore[attr-defined] t = dispatch_trace(wrap_key(f, args, proxy_mode), tracer=fx_tracer, concrete_args=tuple(phs)) # TODO: kind of a bad way to do it, should maybe figure out a better way t.shape_env = shape_env # type: ignore[assignment] return t
def test_throw(self): mode = FakeTensorMode(inner=None) x = torch.tensor(0.) # TODO: tensor() errors with enable_torch_dispatch_mode(mode): x_conv = mode.from_tensor(x) y = torch.rand([4, 4], device="cuda") z = torch.rand([4, 4], device="cpu") self.assertRaises(Exception, lambda: torch.lerp(x_conv, y, z))
def test_fake_grad_copy(self): x = torch.rand([4, 4], requires_grad=True) x.grad = torch.rand([4, 4]) mode = FakeTensorMode() fake_x = mode.from_tensor(x) prims.utils.compare_tensor_meta(fake_x, x) prims.utils.compare_tensor_meta(fake_x.grad, x.grad) self.assertTrue(isinstance(fake_x.grad, FakeTensor))
def test_basic(self): mode = FakeTensorMode(inner=None) x = torch.empty(2, 2, device="cpu") y = torch.empty(4, 2, 2, device="cpu") with enable_torch_dispatch_mode(mode): x = mode.from_tensor(x) y = mode.from_tensor(y) z = x + y self.assertEqual(z.shape, (4, 2, 2)) self.assertEqual(z.device, torch.device("cpu")) self.assertTrue(isinstance(z, FakeTensor))
def test_non_kwarg_device(self): with enable_torch_dispatch_mode(FakeTensorMode(inner=None)): x = torch.rand([16, 1], device="cpu") y = x.to(torch.device("cpu")) self.assertIs(x, y) z = x.to(torch.device("cuda")) self.assertEqual(z.device.type, "cuda")
def propagate(self, *args): self.multi_output_view_nodes = {} self.node_counter = -1 with FakeTensorMode(allow_meta=True) as mode: fake_args = [mode.from_tensor(a) for a in args] return super().run(*fake_args)
def test_deepcopy(self): mode = FakeTensorMode(inner=None) mod = torch.nn.BatchNorm2d(10) with torch._subclasses.fake_tensor.FakeCopyMode(mode): mod_copied = copy.deepcopy(mod) def check_copy(mod, mod_copied): for name, param in itertools.chain(mod.named_parameters(), mod.named_buffers()): param_copied = getattr(mod_copied, name) self.checkMetaProps(param, param_copied) self.assertTrue(isinstance(param_copied, FakeTensor)) self.assertEqual(isinstance(param, torch.nn.Parameter), isinstance(param_copied, torch.nn.Parameter)) self.assertEqual(param.requires_grad, param_copied.requires_grad) check_copy(mod, mod_copied) class ModuleNew(torch.nn.Module): def __init__(self): super(ModuleNew, self).__init__() self.a = torch.rand([10, 2]) self.b = self.a self.c = self.a[0] mod = ModuleNew() with torch._subclasses.fake_tensor.FakeCopyMode(mode): mod_copied = copy.deepcopy(mod) self.assertIs(mod_copied.a, mod_copied.b) self.assertEqual(mod_copied.b.storage()._cdata, mod_copied.a.storage()._cdata)
def test_type_as(self): with enable_torch_dispatch_mode(FakeTensorMode(inner=None)): x = torch.rand([16, 1], device="cpu") y = torch.rand([4, 4], device="cuda") out = x.type_as(y) self.assertEqual(out.device.type, "cuda") self.assertTrue(isinstance(out, FakeTensor))
def test_new(self): with enable_torch_dispatch_mode(FakeTensorMode(inner=None)): a = torch.rand([16, 1]) self.checkType(a.new(10, 10), "cpu", [10, 10]) self.checkType(a.new([1, 2, 3, 4]), "cpu", [4]) b = torch.rand([4, 4], device='cuda') self.checkType(b.new(device='cuda'), "cuda", [0])
def test_data_dependent_operator(self): with enable_torch_dispatch_mode( FakeTensorMode(inner=None, allow_fallback_kernels=False) ): x = torch.rand([10, 10]) self.assertRaises(DynamicOutputShapeException, lambda: torch.nonzero(x))
def test_binary_op_type_promotion(self): with enable_torch_dispatch_mode(FakeTensorMode(inner=None)): x = torch.empty([2, 2], dtype=torch.float) y = torch.empty([2, 2], dtype=torch.int64) out = x / y self.assertEqual(out.dtype, torch.float) self.assertEqual(out.device.type, "cpu")
def test_mode(self): x = FakeTensor.from_tensor(torch.rand([1])) with enable_torch_dispatch_mode(FakeTensorMode(inner=None)): y = torch.rand([4], device="cpu") out = x + y self.assertTrue(isinstance(y, FakeTensor))
def test_shape_take_not_device(self): with enable_torch_dispatch_mode(FakeTensorMode(inner=None)): x = torch.empty(1, device="cpu") y = torch.empty(8, 8, device="cuda") out = x.resize_as_(y) self.assertEqual(out.shape, (8, 8)) self.assertEqual(out.device.type, "cpu") self.assertTrue(isinstance(out, FakeTensor))
def get_prim_fake_mode(): global prim_fake_mode_ref if prim_fake_mode_ref is None or prim_fake_mode_ref() is None: mode = FakeTensorMode() prim_fake_mode_ref = weakref.ref(mode) return mode else: return prim_fake_mode_ref()
def test_randperm(self): x = torch.randperm(10) y = torch.randperm(5, device="cpu") with enable_torch_dispatch_mode(FakeTensorMode(inner=None)): x1 = torch.randperm(10) prims.utils.compare_tensor_meta(x, x1) y1 = torch.randperm(5, device="cpu") prims.utils.compare_tensor_meta(y, y1)
def test_separate_tensor_storages_view(self): x = torch.rand(2, 2, 2) y = x[0] mode = FakeTensorMode(inner=None) converter = mode.fake_tensor_converter x_conv = converter(mode, x) y_conv = converter(mode, y) self.assertEqual(torch._C._storage_id(x_conv), torch._C._storage_id(y_conv))
def test_like_constructor(self): with enable_torch_dispatch_mode(FakeTensorMode(inner=None)): x = torch.rand([4, 4]) y = torch.ones_like(x) self.assertTrue(isinstance(y, FakeTensor)) self.assertEqual(y.device.type, "cpu") z = torch.ones_like(x, device="cuda") self.assertTrue(isinstance(z, FakeTensor)) self.assertEqual(z.device.type, "cuda")
def test_nan_to_num(self): mode = FakeTensorMode(inner=None) with enable_torch_dispatch_mode(mode): for dtype in [torch.float16, torch.float32]: x = torch.rand([4], dtype=dtype) y = torch.nan_to_num(x, nan=None) z = torch.nan_to_num(x, 0.0) self.assertEqual(dtype, y.dtype) self.assertEqual(dtype, z.dtype)
def test_zero_dim(self): mode = FakeTensorMode(inner=None) with enable_torch_dispatch_mode(mode): x = torch.tensor(0.) y = torch.rand([4, 4], device="cuda") out = x + y self.assertEqual(out.shape, (4, 4)) self.assertEqual(out.device, y.device) self.assertTrue(isinstance(out, FakeTensor))
def test_no_active_mode(self): mode = FakeTensorMode(inner=None) with enable_torch_dispatch_mode(mode): x = torch.empty(2, 2, device="cpu") y = torch.empty(2, 2, device="cpu") out = x + y self.assertEqual(mode, out.fake_mode) self.assertTrue(isinstance(out, FakeTensor)) self.assertEqual(out.device.type, "cpu")
def test_dead_key(self): x = torch.rand(2, 2, 2) mode = FakeTensorMode(inner=None) converter = FakeTensorConverter() x_conv = converter(mode, x) self.assertEqual(len(converter.tensor_memo), 1) self.assertEqual(len(converter.meta_converter.tensor_memo), 1) del x self.assertEqual(len(converter.tensor_memo), 0) self.assertEqual(len(converter.meta_converter.tensor_memo), 0)
def test_fake_dispatch_keys(self): with enable_torch_dispatch_mode(FakeTensorMode(inner=None)): x = torch.rand([4]) f = FileCheck().check("CPU").check("ADInplaceOrView").check("AutogradCPU").check("AutocastCPU") f.run(torch._C._dispatch_key_set(x)) with torch.inference_mode(): x = torch.rand([4]) y = x + x FileCheck().check("CPU").check("AutocastCPU").run(torch._C._dispatch_key_set(y)) FileCheck().check_not("ADInplaceOrView").check_not("Autograd").run(torch._C._dispatch_key_set(y))
def test_dead_weak_ref(self): x = torch.rand(2, 2, 2) y = x[0] mode = FakeTensorMode(inner=None) converter = FakeTensorConverter() x_conv = converter(mode, x) x_conv_storage = torch._C._storage_id(x_conv) del x_conv self.assertFalse(x in converter.tensor_memo) y_conv = converter(mode, y) self.assertEqual(x_conv_storage, torch._C._storage_id(y_conv))
def test_fallback_memory_prop(self): m = nn.Conv2d(16, 33, 3, stride=2, device="cuda", dtype=torch.half) m = m.to(memory_format=torch.channels_last) mode = FakeTensorMode(inner=None) # TODO: module.to() doesn't work because it assigns .data, which is ignored with torch._subclasses.fake_tensor.FakeCopyMode(mode): mod_copied = copy.deepcopy(m) with enable_torch_dispatch_mode(mode): input = torch.rand(20, 16, 50, 100, dtype=torch.half, device="cuda").to(memory_format=torch.channels_last) out = mod_copied(input) self.assertTrue(out.is_contiguous(memory_format=torch.channels_last)) self.checkType(out, "cuda", [20, 33, 24, 49])
def test_separate_tensor_storages_non_view(self): x = torch.rand(2, 2, 2) y = torch.rand(4, 2) y.set_(x.storage()) mode = FakeTensorMode(inner=None) converter = mode.fake_tensor_converter x_conv = converter(mode, x) y_conv = converter(mode, y) stor_id = torch._C._storage_id(x_conv) self.assertEqual(stor_id, torch._C._storage_id(y_conv)) del x self.assertEqual(len(converter.tensor_memo), 1) converter.meta_converter.check_for_expired_weak_storages() self.assertEqual(len(converter.meta_converter.storage_memo), 1) del y self.assertEqual(len(converter.tensor_memo), 0) converter.meta_converter.check_for_expired_weak_storages() self.assertEqual(len(converter.meta_converter.storage_memo), 0)
def wrapped(*args): phs = pytree.tree_map(lambda _: fx.PH, args) # type: ignore[attr-defined] fx_tracer = PythonKeyTracer() fake_tensor_mode = FakeTensorMode() if use_fake else nullcontext() proxy_mode = ProxyTorchDispatchMode(fx_tracer) if trace_factory_functions else nullcontext() def wrap_fake(x): if isinstance(x, torch.Tensor): return fake_tensor_mode.from_tensor(x) # type: ignore[attr-defined] return x if use_fake: # type: ignore[attr-defined] args = pytree.tree_map(wrap_fake, args) with decompose(decomposition_table), fake_tensor_mode, proxy_mode: # type: ignore[attr-defined] t = dispatch_trace(wrap_key(f, args), tracer=fx_tracer, concrete_args=tuple(phs)) return t
def test_separate_mode_error(self): with enable_torch_dispatch_mode(FakeTensorMode(inner=None)): x = torch.empty(2, 2, device="cpu") with enable_torch_dispatch_mode(FakeTensorMode(inner=None)): y = torch.empty(2, 2, device="cpu") self.assertRaises(Exception, lambda: x, y)
def test_index_cuda_with_cpu(self): with enable_torch_dispatch_mode(FakeTensorMode(inner=None)): x = torch.rand([2048], device='cuda') out = x[torch.zeros([36], dtype=torch.int64)] self.checkType(out, "cuda", [36])
def test_memoized_conversion_from_meta(self): x = torch.rand(2, 2).to(device="meta") mode = FakeTensorMode(inner=None) converter = mode.fake_tensor_converter self.assertTrue(converter(mode, x, "cpu") is converter(mode, x, "cpu"))