Пример #1
0
 def test_view_dtype(self):
     x = torch.randn(4, dtype=torch.float32)
     y = x.view(dtype=torch.int32)
     m = MetaConverter()(y)
     self.assertEqual(m.shape, y.shape)
     self.assertEqual(m.stride(), y.stride())
     self.assertEqual(m.dtype, y.dtype)
Пример #2
0
 def test_view_as_complex(self):
     x = torch.randn((4, 2), dtype=torch.float32)
     y = torch.view_as_complex(x)
     m = MetaConverter()(y)
     self.assertEqual(m.shape, y.shape)
     self.assertEqual(m.stride(), y.stride())
     self.assertEqual(m.dtype, y.dtype)
Пример #3
0
 def test_view_as_real(self):
     x = torch.randn(4, dtype=torch.complex64)
     y = torch.view_as_real(x)
     m = MetaConverter()(y)
     self.assertEqual(m.shape, y.shape)
     self.assertEqual(m.stride(), y.stride())
     self.assertEqual(m.dtype, y.dtype)
Пример #4
0
 def test_imag(self):
     x = torch.randn(4, dtype=torch.complex64)
     y = x.imag
     m = MetaConverter()(y)
     self.assertEqual(m.shape, y.shape)
     self.assertEqual(m.dtype, y.dtype)
     self.assertEqual(m.stride(), y.stride())
     self.assertEqual(m.storage_offset(), y.storage_offset())
Пример #5
0
 def test_leaf(self):
     x = torch.randn(4, requires_grad=True)
     to_meta = MetaConverter()
     m = to_meta(x)
     self.assertEqual(m.shape, x.shape)
     self.assertTrue(m.is_leaf)
     self.assertTrue(m.requires_grad)
Пример #6
0
def run_cpu_fallback(func, args, kwargs, orig_not_implemented_exception):
    with no_dispatch():

        def to_cpu(e):
            if isinstance(e, FakeTensor):
                return torch.zeros_like(e, device="cpu")
            return e

        try:
            args = tree_map(to_cpu, args)
            kwargs = tree_map(to_cpu, kwargs)
            r = func(*args, **kwargs)
        except Exception as new_exception:
            raise orig_not_implemented_exception from new_exception

        tensor_impls = set()
        storages = set()

        for e in tree_flatten((args, kwargs))[0]:
            if isinstance(e, torch.Tensor):
                tensor_impls.add(e)
                storages.add(e.storage()._cdata)

        # TODO: also check metadata change on inputs
        # proper aliasing/metadata relationship between outputs and inputs will
        # not be set up, bc of conversion to cpu, error on reused impls
        for e in tree_flatten(r)[0]:
            if e in tensor_impls or (isinstance(e, torch.Tensor)
                                     and e.storage()._cdata in storages):
                raise orig_not_implemented_exception

    # we're only converting these to MetaTensors now, not Fake Tensors,
    # and the cpu inputs should be temporary. just convert outputs to meta
    # and continue
    return tree_map(MetaConverter(), r)
Пример #7
0
 def test_tensor_outlives_converter(self):
     m = MetaConverter()
     ref = weakref.ref(m)
     x = torch.randn([4, 4])
     y = m(x)
     del m
     self.assertIs(ref(), None)
Пример #8
0
 def test_non_leaf(self):
     x = torch.randn(4, requires_grad=True)
     y = x.neg()
     to_meta = MetaConverter()
     m = to_meta(y)
     self.assertEqual(m.shape, y.shape)
     self.assertFalse(m.is_leaf)
     self.assertTrue(m.requires_grad)
Пример #9
0
 def test_view_of_leaf(self):
     x = torch.randn(4, requires_grad=True)
     z1 = x[:]
     z2 = x[:]
     to_meta = MetaConverter()
     m1 = to_meta(z1)
     m2 = to_meta(z2)
     self.assertEqual(m1.shape, z1.shape)
     self.assertTrue(m1._is_view())
     self.assertTrue(m1._base.is_leaf)
     self.assertSameVersionCounter(m1, m2)
Пример #10
0
def run_fallback_kernel(func, args, kwargs, orig_not_implemented_exception):
    # these should all be supported, just to be safe
    # avoid fallback for operators which inplace modify metadata
    # because the input fake tensors would be umodified
    if torch.Tag.inplace_view in func.tags:  # type: ignore[attr-defined]
        raise orig_not_implemented_exception

    with no_dispatch():
        inp_impls = {}

        def to_real_tensor(e):
            if isinstance(e, FakeTensor):
                out = torch.zeros_like(e, device=e.fake_device)
                inp_impls[id(out)] = e
                return out
            return e

        try:
            args = tree_map(to_real_tensor, args)
            kwargs = tree_map(to_real_tensor, kwargs)

            r = func(*args, **kwargs)
        except Exception as new_exception:
            raise orig_not_implemented_exception from new_exception

        tensor_impls = set()
        storages = set()

        for e in tree_flatten((args, kwargs))[0]:
            if isinstance(e, torch.Tensor):
                storages.add(e.storage()._cdata)

        # TODO: also check metadata change on inputs
        # proper aliasing/metadata relationship between outputs and inputs will
        # not be set up, bc of conversion to device, unless we can reuse an
        # input impl
        for e in tree_flatten(r)[0]:
            if id(e) not in inp_impls and (
                isinstance(e, torch.Tensor) and e.storage()._cdata in storages
            ):
                raise orig_not_implemented_exception

    # the outputs which are are not reused from impls will be converted
    # to fake tensors later
    meta_converter = MetaConverter()

    def map_out(e):
        return inp_impls.get(id(e), meta_converter(e))

    return tree_map(map_out, r)
Пример #11
0
    def new(self, *args, **kwargs):
        # torch.Tensor.new does not go through the normal dispatcher pattern
        # so in order to use the same pattern as normal invocation of
        # returning meta device within the kernel we need to intercept
        # the call here
        # because it doesn't go through the dispatcher, we run into errors
        # when attempting to compute an output in meta, so
        # we compute the real tensor then convert to meta
        out_device = self.fake_device
        with no_dispatch():
            real_out = super().new(*args, **kwargs)

        assert not isinstance(real_out, FakeTensor), real_out
        assert real_out.device.type != "meta", real_out.device

        with no_dispatch():
            meta_out = MetaConverter()(real_out)
            return FakeTensor(self.fake_mode, meta_out, out_device)
Пример #12
0
 def test_weakref(self):
     x = torch.randn(4, 4, 4)
     m = MetaConverter()
     y = m(x)
     z = m(x)
     self.assertIs(y, z)
     self.assertEqual(len(m.tensor_memo), 1)
     self.assertEqual(len(m.storage_memo), 1)
     del x
     self.assertEqual(len(m.tensor_memo), 0)
     m.check_for_expired_weak_storages()
     self.assertEqual(len(m.storage_memo), 0)
     li = []
     for i in range(4):
         li.append(torch.rand([i]))
         m(li[-1])
     self.assertEqual(len(m.tensor_memo), 4)
     del li
     self.assertEqual(len(m.tensor_memo), 0)
     m.check_for_expired_weak_storages()
     self.assertEqual(len(m.storage_memo), 0)
Пример #13
0
 def __init__(self):
     self.tensor_memo = {}
     self.meta_converter = MetaConverter()
Пример #14
0
def run_meta_crossref(
    test_case,
    test_expect,
    func,
    args,
    kwargs,
    *,
    dtype,
    device_type,
):
    to_meta = MetaConverter()
    do_meta = test_expect is not TestExpect.SKIP

    if do_meta:
        try:
            meta_args = tree_map(to_meta, args)
            meta_kwargs = tree_map(to_meta, kwargs)
        except Exception as e:
            raise RuntimeError(f"failed to convert args to meta; "
                               f"originally (*{args}, **{kwargs})") from e

    rs = func(*args, **kwargs)

    # TODO: also handle cases where func raise an exception

    # For now, only attempt if we managed to convert all tensor types
    # (if any of them failed, we're in a mixed device situation and
    # this isn't well supported)
    if do_meta and to_meta.successful():
        # Special cases
        if func is torch.tensor_split:
            # Use original indices_or_sections, this argument is data dependent
            meta_args = (meta_args[0], args[1]) + meta_args[2:]
        elif func is torch.ops.aten.repeat_interleave.Tensor:
            if kwargs.get("output_size", None) is None:
                meta_args = args
        elif func is torch.ops.aten.index.Tensor:
            # Don't convert boolean tensors to meta as they will have nonzero
            # called on them
            indices = []
            for meta_index, real_index in zip(meta_args[1], args[1]):
                if meta_index is not None and meta_index.dtype in [
                        torch.int8, torch.bool
                ]:
                    indices.append(real_index)
                else:
                    indices.append(meta_index)
            meta_args = (meta_args[0], indices)

        if kwargs.get("device", None) is not None:
            meta_kwargs["device"] = "meta"

        try:
            # Suppress warnings, this doesn't matter for test_meta.py
            # but it does matter if you want to use this decorator
            # for cross-ref testing, as some tests may be looking at
            # errors
            with warnings.catch_warnings():
                warnings.simplefilter("ignore")
                meta_rs = func(*meta_args, **meta_kwargs)
        except Exception as e:
            if test_expect is TestExpect.XFAILURE:
                return rs
            seen_failed.setdefault(func, set()).add(dtype)
            if isinstance(e, NotImplementedError):
                m = RE_NOT_IMPLEMENTED_MSG.search(e.args[0])
                if m:
                    failed_reasons[func].add(m.group(1))
            if COLLECT_EXPECT:
                return rs
            raise RuntimeError(f"""\
failed to run: {resolve_name(func)}(
*{verbose_print(meta_args)},
**{verbose_print(meta_kwargs)}
)""") from e
        else:
            try:
                delim = ',\n  '
                assert_ref_meta_equal(
                    test_case, meta_rs, rs, lambda msg: f"""\
meta disagrees with real impl:
{resolve_name(func)}(
  {delim.join(map(verbose_print, meta_args))},
  {delim.join(k + ": " + verbose_print(v) for k, v in meta_kwargs.items())}
) = (
  {verbose_print(meta_rs)}
)
{msg}
""")
            except Exception:
                if test_expect is TestExpect.XFAILURE:
                    return rs
                seen_failed.setdefault(func, set()).add(dtype)
                if COLLECT_EXPECT:
                    return rs
                raise
            else:
                seen_succeeded.setdefault(func, set()).add(dtype)
                if test_expect is TestExpect.XFAILURE and not COLLECT_EXPECT:
                    raise RuntimeError(
                        f"unexpected success {resolve_name(func)}")

    return rs
Пример #15
0
 def __init__(self):
     # FakeTensors store the FakeTensorMode which in turn stores a
     # FakeTensor, so we need to hold a weak reference to the FakeTensor
     # otherwise we would induce a circular reference
     self.tensor_memo = weakref.WeakValueDictionary()
     self.meta_converter = MetaConverter()
Пример #16
0
 def test_requires_grad_false(self):
     x = torch.randn(4, requires_grad=False)
     to_meta = MetaConverter()
     m = to_meta(x)
     self.assertEqual(m.shape, x.shape)
     self.assertFalse(m.requires_grad)
Пример #17
0
 def test_complex_noncontiguous_bug(self):
     x = torch.randn((2, 2, 4, 9), dtype=torch.complex32)[:, 0, :, :]
     m = MetaConverter()(x)
     self.assertEqual(m.shape, x.shape)
     self.assertEqual(m.stride(), x.stride())
     self.assertEqual(m.dtype, x.dtype)