def test_view_dtype(self): x = torch.randn(4, dtype=torch.float32) y = x.view(dtype=torch.int32) m = MetaConverter()(y) self.assertEqual(m.shape, y.shape) self.assertEqual(m.stride(), y.stride()) self.assertEqual(m.dtype, y.dtype)
def test_view_as_complex(self): x = torch.randn((4, 2), dtype=torch.float32) y = torch.view_as_complex(x) m = MetaConverter()(y) self.assertEqual(m.shape, y.shape) self.assertEqual(m.stride(), y.stride()) self.assertEqual(m.dtype, y.dtype)
def test_view_as_real(self): x = torch.randn(4, dtype=torch.complex64) y = torch.view_as_real(x) m = MetaConverter()(y) self.assertEqual(m.shape, y.shape) self.assertEqual(m.stride(), y.stride()) self.assertEqual(m.dtype, y.dtype)
def test_imag(self): x = torch.randn(4, dtype=torch.complex64) y = x.imag m = MetaConverter()(y) self.assertEqual(m.shape, y.shape) self.assertEqual(m.dtype, y.dtype) self.assertEqual(m.stride(), y.stride()) self.assertEqual(m.storage_offset(), y.storage_offset())
def test_leaf(self): x = torch.randn(4, requires_grad=True) to_meta = MetaConverter() m = to_meta(x) self.assertEqual(m.shape, x.shape) self.assertTrue(m.is_leaf) self.assertTrue(m.requires_grad)
def run_cpu_fallback(func, args, kwargs, orig_not_implemented_exception): with no_dispatch(): def to_cpu(e): if isinstance(e, FakeTensor): return torch.zeros_like(e, device="cpu") return e try: args = tree_map(to_cpu, args) kwargs = tree_map(to_cpu, kwargs) r = func(*args, **kwargs) except Exception as new_exception: raise orig_not_implemented_exception from new_exception tensor_impls = set() storages = set() for e in tree_flatten((args, kwargs))[0]: if isinstance(e, torch.Tensor): tensor_impls.add(e) storages.add(e.storage()._cdata) # TODO: also check metadata change on inputs # proper aliasing/metadata relationship between outputs and inputs will # not be set up, bc of conversion to cpu, error on reused impls for e in tree_flatten(r)[0]: if e in tensor_impls or (isinstance(e, torch.Tensor) and e.storage()._cdata in storages): raise orig_not_implemented_exception # we're only converting these to MetaTensors now, not Fake Tensors, # and the cpu inputs should be temporary. just convert outputs to meta # and continue return tree_map(MetaConverter(), r)
def test_tensor_outlives_converter(self): m = MetaConverter() ref = weakref.ref(m) x = torch.randn([4, 4]) y = m(x) del m self.assertIs(ref(), None)
def test_non_leaf(self): x = torch.randn(4, requires_grad=True) y = x.neg() to_meta = MetaConverter() m = to_meta(y) self.assertEqual(m.shape, y.shape) self.assertFalse(m.is_leaf) self.assertTrue(m.requires_grad)
def test_view_of_leaf(self): x = torch.randn(4, requires_grad=True) z1 = x[:] z2 = x[:] to_meta = MetaConverter() m1 = to_meta(z1) m2 = to_meta(z2) self.assertEqual(m1.shape, z1.shape) self.assertTrue(m1._is_view()) self.assertTrue(m1._base.is_leaf) self.assertSameVersionCounter(m1, m2)
def run_fallback_kernel(func, args, kwargs, orig_not_implemented_exception): # these should all be supported, just to be safe # avoid fallback for operators which inplace modify metadata # because the input fake tensors would be umodified if torch.Tag.inplace_view in func.tags: # type: ignore[attr-defined] raise orig_not_implemented_exception with no_dispatch(): inp_impls = {} def to_real_tensor(e): if isinstance(e, FakeTensor): out = torch.zeros_like(e, device=e.fake_device) inp_impls[id(out)] = e return out return e try: args = tree_map(to_real_tensor, args) kwargs = tree_map(to_real_tensor, kwargs) r = func(*args, **kwargs) except Exception as new_exception: raise orig_not_implemented_exception from new_exception tensor_impls = set() storages = set() for e in tree_flatten((args, kwargs))[0]: if isinstance(e, torch.Tensor): storages.add(e.storage()._cdata) # TODO: also check metadata change on inputs # proper aliasing/metadata relationship between outputs and inputs will # not be set up, bc of conversion to device, unless we can reuse an # input impl for e in tree_flatten(r)[0]: if id(e) not in inp_impls and ( isinstance(e, torch.Tensor) and e.storage()._cdata in storages ): raise orig_not_implemented_exception # the outputs which are are not reused from impls will be converted # to fake tensors later meta_converter = MetaConverter() def map_out(e): return inp_impls.get(id(e), meta_converter(e)) return tree_map(map_out, r)
def new(self, *args, **kwargs): # torch.Tensor.new does not go through the normal dispatcher pattern # so in order to use the same pattern as normal invocation of # returning meta device within the kernel we need to intercept # the call here # because it doesn't go through the dispatcher, we run into errors # when attempting to compute an output in meta, so # we compute the real tensor then convert to meta out_device = self.fake_device with no_dispatch(): real_out = super().new(*args, **kwargs) assert not isinstance(real_out, FakeTensor), real_out assert real_out.device.type != "meta", real_out.device with no_dispatch(): meta_out = MetaConverter()(real_out) return FakeTensor(self.fake_mode, meta_out, out_device)
def test_weakref(self): x = torch.randn(4, 4, 4) m = MetaConverter() y = m(x) z = m(x) self.assertIs(y, z) self.assertEqual(len(m.tensor_memo), 1) self.assertEqual(len(m.storage_memo), 1) del x self.assertEqual(len(m.tensor_memo), 0) m.check_for_expired_weak_storages() self.assertEqual(len(m.storage_memo), 0) li = [] for i in range(4): li.append(torch.rand([i])) m(li[-1]) self.assertEqual(len(m.tensor_memo), 4) del li self.assertEqual(len(m.tensor_memo), 0) m.check_for_expired_weak_storages() self.assertEqual(len(m.storage_memo), 0)
def __init__(self): self.tensor_memo = {} self.meta_converter = MetaConverter()
def run_meta_crossref( test_case, test_expect, func, args, kwargs, *, dtype, device_type, ): to_meta = MetaConverter() do_meta = test_expect is not TestExpect.SKIP if do_meta: try: meta_args = tree_map(to_meta, args) meta_kwargs = tree_map(to_meta, kwargs) except Exception as e: raise RuntimeError(f"failed to convert args to meta; " f"originally (*{args}, **{kwargs})") from e rs = func(*args, **kwargs) # TODO: also handle cases where func raise an exception # For now, only attempt if we managed to convert all tensor types # (if any of them failed, we're in a mixed device situation and # this isn't well supported) if do_meta and to_meta.successful(): # Special cases if func is torch.tensor_split: # Use original indices_or_sections, this argument is data dependent meta_args = (meta_args[0], args[1]) + meta_args[2:] elif func is torch.ops.aten.repeat_interleave.Tensor: if kwargs.get("output_size", None) is None: meta_args = args elif func is torch.ops.aten.index.Tensor: # Don't convert boolean tensors to meta as they will have nonzero # called on them indices = [] for meta_index, real_index in zip(meta_args[1], args[1]): if meta_index is not None and meta_index.dtype in [ torch.int8, torch.bool ]: indices.append(real_index) else: indices.append(meta_index) meta_args = (meta_args[0], indices) if kwargs.get("device", None) is not None: meta_kwargs["device"] = "meta" try: # Suppress warnings, this doesn't matter for test_meta.py # but it does matter if you want to use this decorator # for cross-ref testing, as some tests may be looking at # errors with warnings.catch_warnings(): warnings.simplefilter("ignore") meta_rs = func(*meta_args, **meta_kwargs) except Exception as e: if test_expect is TestExpect.XFAILURE: return rs seen_failed.setdefault(func, set()).add(dtype) if isinstance(e, NotImplementedError): m = RE_NOT_IMPLEMENTED_MSG.search(e.args[0]) if m: failed_reasons[func].add(m.group(1)) if COLLECT_EXPECT: return rs raise RuntimeError(f"""\ failed to run: {resolve_name(func)}( *{verbose_print(meta_args)}, **{verbose_print(meta_kwargs)} )""") from e else: try: delim = ',\n ' assert_ref_meta_equal( test_case, meta_rs, rs, lambda msg: f"""\ meta disagrees with real impl: {resolve_name(func)}( {delim.join(map(verbose_print, meta_args))}, {delim.join(k + ": " + verbose_print(v) for k, v in meta_kwargs.items())} ) = ( {verbose_print(meta_rs)} ) {msg} """) except Exception: if test_expect is TestExpect.XFAILURE: return rs seen_failed.setdefault(func, set()).add(dtype) if COLLECT_EXPECT: return rs raise else: seen_succeeded.setdefault(func, set()).add(dtype) if test_expect is TestExpect.XFAILURE and not COLLECT_EXPECT: raise RuntimeError( f"unexpected success {resolve_name(func)}") return rs
def __init__(self): # FakeTensors store the FakeTensorMode which in turn stores a # FakeTensor, so we need to hold a weak reference to the FakeTensor # otherwise we would induce a circular reference self.tensor_memo = weakref.WeakValueDictionary() self.meta_converter = MetaConverter()
def test_requires_grad_false(self): x = torch.randn(4, requires_grad=False) to_meta = MetaConverter() m = to_meta(x) self.assertEqual(m.shape, x.shape) self.assertFalse(m.requires_grad)
def test_complex_noncontiguous_bug(self): x = torch.randn((2, 2, 4, 9), dtype=torch.complex32)[:, 0, :, :] m = MetaConverter()(x) self.assertEqual(m.shape, x.shape) self.assertEqual(m.stride(), x.stride()) self.assertEqual(m.dtype, x.dtype)