Example #1
0
def proxy_call(func_overload, args, kwargs=None):
    func = func_overload.overloadpacket
    if func_overload in CURRENT_DECOMPOSITION_TABLE:
        return CURRENT_DECOMPOSITION_TABLE[func_overload](*args, **kwargs)
    if func_overload == aten._local_scalar_dense.default:
        raise RuntimeError("It appears that you're trying to get value out of a tracing tensor - erroring out! "
                           "It's likely that this is caused by data-dependent control flow or similar."
                           "Try torch.fx.experimental.proxy_tensor.enable_strict(False) to disable this check")

    def unwrap_proxy(e):
        return e.proxy if isinstance(e, ProxyTensor) else e

    proxy_args = pytree.tree_map(unwrap_proxy, args)
    proxy_kwargs = pytree.tree_map(unwrap_proxy, kwargs)

    proxy_out = func(*proxy_args, **proxy_kwargs)

    # Kind of a hacky way to test if an op is in-place or not
    if func.__name__[-1] == "_" and func.__name__[0] != "_":
        args[0].proxy = proxy_out
        proxy_out.node.meta['tensor_meta'] = _extract_tensor_metadata(args[0])

    with no_dispatch():
        real_out = func_overload(*args, **kwargs)

    return wrap_output(real_out, proxy_out)
Example #2
0
def run_cpu_fallback(func, args, kwargs, orig_not_implemented_exception):
    with no_dispatch():

        def to_cpu(e):
            if isinstance(e, FakeTensor):
                return torch.zeros_like(e, device="cpu")
            return e

        try:
            args = tree_map(to_cpu, args)
            kwargs = tree_map(to_cpu, kwargs)
            r = func(*args, **kwargs)
        except Exception as new_exception:
            raise orig_not_implemented_exception from new_exception

        tensor_impls = set()
        storages = set()

        for e in tree_flatten((args, kwargs))[0]:
            if isinstance(e, torch.Tensor):
                tensor_impls.add(e)
                storages.add(e.storage()._cdata)

        # TODO: also check metadata change on inputs
        # proper aliasing/metadata relationship between outputs and inputs will
        # not be set up, bc of conversion to cpu, error on reused impls
        for e in tree_flatten(r)[0]:
            if e in tensor_impls or (isinstance(e, torch.Tensor)
                                     and e.storage()._cdata in storages):
                raise orig_not_implemented_exception

    # we're only converting these to MetaTensors now, not Fake Tensors,
    # and the cpu inputs should be temporary. just convert outputs to meta
    # and continue
    return tree_map(MetaConverter(), r)
Example #3
0
    def new(self, *args, **kwargs):
        # torch.Tensor.new does not go through the normal dispatcher pattern
        # so in order to use the same pattern as normal invocation of
        # returning meta device within the kernel we need to intercept
        # the call here
        out_device = self.fake_device
        if "device" in kwargs:
            kwarg_device = kwargs.pop("device")
            out_device = kwarg_device if kwarg_device else out_device
            kwargs["device"] = "meta"

        with in_kernel_invocation_manager(self.fake_mode):
            with no_dispatch():
                meta_out = super().new(*args, **kwargs)

        with no_dispatch():
            return FakeTensor(self.fake_mode, meta_out, out_device)
Example #4
0
 def wrap_with_proxy(e, proxy, constant):
     if isinstance(e, torch.Tensor):
         with no_dispatch():
             return ProxyTensor(e,
                                proxy,
                                constant=constant,
                                proxy_mode=proxy_mode)
     else:
         return e
Example #5
0
    def new(self, *args, **kwargs):
        # torch.Tensor.new does not go through the normal dispatcher pattern
        # so in order to use the same pattern as normal invocation of
        # returning meta device within the kernel we need to intercept
        # the call here
        # because it doesn't go through the dispatcher, we run into errors
        # when attempting to compute an output in meta, so
        # we compute the real tensor then convert to meta
        out_device = self.fake_device
        with no_dispatch():
            real_out = super().new(*args, **kwargs)

        assert not isinstance(real_out, FakeTensor), real_out
        assert real_out.device.type != "meta", real_out.device

        with no_dispatch():
            meta_out = MetaConverter()(real_out)
            return FakeTensor(self.fake_mode, meta_out, out_device)
Example #6
0
def to_copy(fake_mode, func, *args, **kwargs):
    _, new_kwargs = normalize_function(
        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
    )

    input_device = new_kwargs.pop("device", None)
    out_device = input_device if input_device else new_kwargs["input"].device
    with no_dispatch():
        input = new_kwargs.pop("input").to("meta")
        return FakeTensor(fake_mode, aten._to_copy(input, **new_kwargs), out_device)
Example #7
0
def run_fallback_kernel(fake_mode, func, args, kwargs, orig_not_implemented_exception):
    # these should all be supported, just to be safe
    # avoid fallback for operators which inplace modify metadata
    # because the input fake tensors would be umodified
    if torch.Tag.inplace_view in func.tags:  # type: ignore[attr-defined]
        raise orig_not_implemented_exception

    with no_dispatch():
        inp_impls = {}

        def to_real_tensor(e):
            if isinstance(e, FakeTensor):
                out = torch.zeros_like(e, device=e.fake_device)
                if e.is_sparse:
                    out._coalesced_(e.is_coalesced())
                inp_impls[id(out)] = e
                return out
            return e

        args = tree_map(to_real_tensor, args)
        kwargs = tree_map(to_real_tensor, kwargs)

        r = func(*args, **kwargs)

        tensor_impls = set()
        storages = set()

        for e in tree_flatten((args, kwargs))[0]:
            if isinstance(e, torch.Tensor):
                if not e.is_sparse:
                    storages.add(e.storage()._cdata)

        # TODO: also check metadata change on inputs
        # proper aliasing/metadata relationship between outputs and inputs will
        # not be set up, bc of conversion to device, unless we can reuse an
        # input impl
        for e in tree_flatten(r)[0]:
            if id(e) not in inp_impls and (
                isinstance(e, torch.Tensor)
                and not e.is_sparse
                and e.storage()._cdata in storages
            ):
                raise orig_not_implemented_exception

    def map_out(e):
        if isinstance(e, torch.Tensor):
            if id(e) in inp_impls:
                return inp_impls[id(e)]
            else:
                return fake_mode.fake_tensor_converter(fake_mode, e)
        else:
            return e

    return tree_map(map_out, r)
Example #8
0
    def __torch_dispatch__(self, func_overload, types, args=(), kwargs=None):
        func = func_overload.overloadpacket
        if any(tuple(isinstance(arg, ProxyTensor) for arg in args)):
            return proxy_call(func_overload, args, kwargs)
        else:
            proxy_out = self.tracer.create_proxy('call_function', func, args, kwargs,
                                                 name=self.tracer.graph._target_to_str(func.__name__))

            with no_dispatch():
                real_out = func_overload(*args, **kwargs)

            return wrap_output(real_out, proxy_out)
Example #9
0
    def wrapped(*args):
        flat_args, args_spec = pytree.tree_flatten(args)
        assert (len(flat_args) == len(flat_inps))
        for idx, arg in enumerate(flat_args):
            if isinstance(flat_inps[idx], torch.Tensor):
                with no_dispatch():
                    flat_args[idx] = ProxyTensor(flat_inps[idx], arg, requires_grad=flat_inps[idx].is_leaf)
            else:
                flat_args[idx] = flat_inps[idx]

        tree_args = pytree.tree_unflatten(flat_args, args_spec)
        out = f(*tree_args)
        flat_outs, out_spec = pytree.tree_flatten(out)
        for idx in range(len(flat_outs)):
            if isinstance(flat_outs[idx], torch.Tensor) and isinstance(flat_outs[idx], ProxyTensor):
                flat_outs[idx] = flat_outs[idx].proxy
        return pytree.tree_unflatten(flat_outs, out_spec)
Example #10
0
 def from_real_tensor(self, fake_mode, t):
     maybe_memo = self._get_memo(t)
     if maybe_memo is not None:
         return maybe_memo
     existing_device = t.device
     # not yet supported in metatensors
     if t.is_complex():
         raise UnsupportedFakeTensorException("complex nyi in meta tensors")
     if t.is_sparse:
         raise UnsupportedFakeTensorException("sparse nyi in meta tensors")
     if t.is_quantized:
         raise UnsupportedFakeTensorException("quantized nyi in meta tensors")
     with no_dispatch():
         out = FakeTensor(fake_mode, self.meta_converter(t), existing_device)
     if type(t) is torch.nn.Parameter:
         out = torch.nn.Parameter(out, requires_grad=out.requires_grad)  # type: ignore[assignment]
     self.set_tensor_memo(t, out)
     return out
Example #11
0
def proxy_call(func_overload, args, kwargs=None):
    if kwargs is None:
        kwargs = {}
    func = func_overload.overloadpacket
    if func_overload in CURRENT_DECOMPOSITION_TABLE:
        return CURRENT_DECOMPOSITION_TABLE[func_overload](*args, **kwargs)
    if func_overload == aten._local_scalar_dense.default:
        raise RuntimeError("It appears that you're trying to get value out of a tracing tensor - erroring out! "
                           "It's likely that this is caused by data-dependent control flow or similar."
                           "Try torch.fx.experimental.proxy_tensor.enable_strict(False) to disable this check")

    def unwrap_proxy(e):
        return e.proxy if isinstance(e, ProxyTensor) else e

    def unwrap_elem(e):
        if isinstance(e, ProxyTensor):
            return e.elem
        return e

    proxy_args = pytree.tree_map(unwrap_proxy, args)
    proxy_kwargs = pytree.tree_map(unwrap_proxy, kwargs)

    proxy_res = func_overload(*proxy_args, **proxy_kwargs)

    # Kind of a hacky way to test if an op is in-place or not
    if func.__name__[-1] == "_" and func.__name__[0] != "_":
        args[0].proxy = proxy_res
        proxy_res.node.meta['tensor_meta'] = _extract_tensor_metadata(args[0])

    inner_res = func_overload(*pytree.tree_map(unwrap_elem, args), **pytree.tree_map(unwrap_elem, kwargs))
    # Needed to sync up metadata for in-place operators that modify metadata
    if torch.Tag.inplace_view in func_overload.tags:  # type: ignore[attr-defined]
        with no_dispatch():
            func_overload(*args, **kwargs)

    # TODO(chilli): Enable this after it's been refactored to work with wrapper tensor subclasses in general
    # pytree.tree_map(lambda x: check_metadata_consistency(x, ProxyTensor), (inner_res, args, kwargs))
    return wrap_output(inner_res, proxy_res)
Example #12
0
 def __repr__(self):
     with no_dispatch():
         return f"ProxyTensor({self.as_subclass(torch.Tensor)}, proxy={self.proxy})"  # type: ignore[arg-type]
Example #13
0
        def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
            def unwrap(e):
                return e.elem if isinstance(e, CompositeCompliantTensor) else e

            def wrap(e):
                return CompositeCompliantTensor(e) if isinstance(
                    e, torch.Tensor) else e

            if func == torch.ops.aten._local_scalar_dense.default:
                raise RuntimeError(
                    ".item() is not allowed to be called inside of composite "
                    "functions in the PyTorch library because not all backends "
                    "and/or Tensor subclasses (e.g. vmap, ProxyTensor) support them."
                )

            if func.overloadpacket.__name__ in ('set_', 'resize_'):
                raise RuntimeError(
                    f"{func.__name__} is not allowed to be called inside of "
                    f"Composite operators.")

            if is_inplace(func):
                # NB: We are making an assumption that if the function is in-place,
                # then the first argument is being written to. Introspection please save us!
                mutated_argument = args[0]
                if not isinstance(mutated_argument, CompositeCompliantTensor) and \
                        any([isinstance(a, CompositeCompliantTensor) for a in args[1:]]):
                    raise RuntimeError(
                        'Not composite compliant: performing in-place operation '
                        f'{func.__name__} where the Tensor being written to is '
                        'regular Tensor but the other tensors are Tensor Subclasses. '
                        'Please try to avoid this in-place operation.')

            with enable_reentrant_dispatch():
                with contextlib.nullcontext(
                ) if enable_recursive_torch_dispatch else no_dispatch():
                    unwrapped_args = tree_map(unwrap, args)
                    unwrapped_kwargs = tree_map(unwrap, kwargs)
                    unwrapped_rs = func(*unwrapped_args, **unwrapped_kwargs)
                    rs = tree_map(wrap, unwrapped_rs)

            if is_view_fn(func) and autograd_view_consistency:
                # Note [Alias Result]
                # Autograd asserts that for B = A.view_fn(...), B and A's storages
                # are the same. Here we try to make B alias A to avoid those asserts.
                # See https://github.com/pytorch/pytorch/issues/65339 for more information
                # about the issue.
                with enable_reentrant_dispatch():
                    with no_dispatch():
                        # Idea: this is a weird way of getting a storage that aliases the input.
                        # This is a workaround for #65339.
                        # 1. under no_dispatch, all of the wrapper tensors look like regular
                        #    tensors with special storage (the storage is nullptr and
                        #    advertises CPU/CUDA device.
                        # 2. we run func, which ends up running the view operation
                        # 3. All view operations reuse the input's storage and return
                        #    result Tensor(s) with new sizes/strides/offset that alias
                        #    the input.
                        # 4. we set the storage (and sizes/strides/offset) of the wrapper
                        #    tensor results to be that of the tensors that alias the input
                        result = func(*args, **kwargs)
                        if isinstance(result, tuple) or isinstance(
                                result, list):
                            for a, b in zip(rs, result):
                                a.set_(b)
                        else:
                            rs.set_(result)

            # Some operations are allowed to in-place modify the metadata of the
            # inputs. The only ones are the "inplace view functions"; when we
            # run into these, we manually modify the metadata of the input.
            with no_dispatch():
                if is_inplace_view_fn(func):
                    func(*args, **kwargs)

            # For each CompositeCompliantTensor t, we check that t and t.elem
            # have consistent metadata. If they don't have consistent metadata,
            # that means the operator did something fishy.
            check = partial(check_metadata_consistency, CCT=cls)
            tree_map(check, args)
            tree_map(check, kwargs)
            tree_map(check, rs)
            return rs
Example #14
0
    def meta_tensor(self, t):
        if t not in self.tensor_memo:
            with torch.inference_mode(t.is_inference()):
                if t._is_view():
                    # Construct views in two steps: recursively meta-fy their
                    # base, and then create the view off that.  NB: doing it
                    # directly from storage is WRONG because this won't cause
                    # version counters to get shared.
                    assert t._is_view()
                    base = self.meta_tensor(t._base)

                    def is_c_of_r(complex_dtype, real_dtype):
                        return is_complex_dtype(complex_dtype) and \
                            corresponding_real_dtype(complex_dtype) == real_dtype

                    if base.dtype == t.dtype:
                        pass
                    elif is_c_of_r(base.dtype, t.dtype):
                        base = torch.view_as_real(base)
                    elif is_c_of_r(t.dtype, base.dtype):
                        base = torch.view_as_complex(base)
                    else:
                        # This is not guaranteed to succeed.  If it fails, it
                        # means there is another dtype-converting view function
                        # that hasn't been handled here
                        base = base.view(t.dtype)

                    with torch.enable_grad():
                        r = base.as_strided(t.size(), t.stride(),
                                            t.storage_offset())
                else:
                    is_leaf = safe_is_leaf(t)
                    # Fake up some autograd history.
                    if t.requires_grad:
                        r = torch.empty((0, ),
                                        dtype=t.dtype,
                                        device='meta',
                                        requires_grad=True)
                        if not is_leaf:
                            with torch.enable_grad():
                                # The backward function here will be wrong, but
                                # that's OK; our goal is just to get the metadata
                                # looking as close as possible; we're not going to
                                # actually try to backward() on these produced
                                # metas.  TODO: would be safer to install some
                                # sort of unsupported grad_fn here
                                r = r.clone()
                    else:
                        r = torch.empty((0, ), dtype=t.dtype, device='meta')
                    # As long as meta storage is not supported, need to prevent
                    # redispatching on set_(Storage, ...) which will choke with
                    # meta storage
                    s = self.meta_storage(t.storage())
                    with no_dispatch():
                        with torch.no_grad():
                            r.set_(s, t.storage_offset(), t.size(), t.stride())

                torch._C._set_conj(r, t.is_conj())
                torch._C._set_neg(r, t.is_neg())
            self.tensor_memo[t] = r

        return self.tensor_memo[t]
Example #15
0
 def from_tensor(self, tensor):
     with no_dispatch():
         return self.fake_tensor_converter(self, tensor)
Example #16
0
    def __torch_dispatch__(self, func, types, args=(), kwargs=None):
        kwargs = kwargs if kwargs else {}

        if func == torch.ops.prim.device.default:
            assert len(args) == 1 and isinstance(args[0], FakeTensor)
            if args[0].fake_mode.in_kernel_invocation:
                return torch.device("meta")
            else:
                return args[0].fake_device

        # prims already wrap FakeTensor inputs to FakeTensor outputs
        # and do device logic, we dont need do anything but run them
        if "prims::" in func._schema.name:
            with no_dispatch():
                return func(*args, **kwargs)

        with no_dispatch():
            # TODO: apply as no_dispatch decorator
            converter = self.fake_tensor_converter

            # this is generated from torch.tensor(), which does not use the
            # dispatcher, to allow wrapper subclasses to wrap the new tensor
            # we need to handle before error checking
            if func == torch.ops.aten.lift.default:
                assert (len(kwargs) == 0 and len(args) == 1
                        and type(args[0]) is torch.Tensor)
                with no_dispatch():
                    return converter(self, args[0])

            def wrap(e, device=None):
                if isinstance(e,
                              torch.Tensor) and not isinstance(e, FakeTensor):
                    return converter(self, e, device)
                else:
                    return e

            # if we are in the dispatch mode, we will enter this function even if the inputs
            # are not FakeTensors. For now, throw if any non-Fake Tensor inputs
            # and just support constructors. TODO: extend more broadly
            conversion_made = False

            def check_non_fake_tensor(x):
                nonlocal conversion_made
                conversion_made = conversion_made or (isinstance(
                    x, torch.Tensor) and not isinstance(x, FakeTensor))

            tree_map(check_non_fake_tensor, args)
            tree_map(check_non_fake_tensor, kwargs)

            if conversion_made:
                raise Exception(
                    "Invoking operators with non-Fake Tensor inputs in FakeTensorMode is not yet supported. "
                    f"Please convert all Tensors to FakeTensors first. Found in {func}"
                )

            for run_impl_check, op_impl in op_implementations:
                if run_impl_check(func):
                    return op_impl(self, func, *args, **kwargs)

            self.in_kernel_invocation = True
            try:
                r = func(*args, **kwargs)
            except NotImplementedError as not_implemented_error:
                if not self.allow_cpu_fallback:
                    raise not_implemented_error
                r = run_cpu_fallback(func, args, kwargs, not_implemented_error)
            finally:
                self.in_kernel_invocation = False

            # TODO: handle non-kwarg devices
            assert func not in _device_not_kwarg_ops, f"NYI: {func}"

            # if device is specified, use that
            if kwargs.get("device", None):
                return tree_map(partial(wrap, device=kwargs["device"]), r)

            common_device = FakeTensor._find_common_device(func, args, kwargs)

            return tree_map(partial(wrap, device=common_device), r)
Example #17
0
def torch_dispatch_impl(cls_or_mode_instance, func, types, args, kwargs,
                        run_function):
    kwargs = kwargs if kwargs else {}
    in_fake_mode = isinstance(cls_or_mode_instance, FakeTensorMode)
    converter = cls_or_mode_instance.fake_tensor_converter if in_fake_mode else FakeTensorConverter(
    )

    # This classes virtualizes .device() calls, need to short-circuit
    # it instead of calling device again or we would keep on recurring
    if func == torch.ops.prim.device.default:
        assert len(args) == 1 and isinstance(args[0], FakeTensor)
        return args[0].fake_device

    def wrap(e, device=None):
        if isinstance(e, torch.Tensor) and not isinstance(e, FakeTensor):
            return converter(e, device)
        else:
            return e

    # if we are in the dispatch mode, we will enter this function even if the inputs
    # are not FakeTensors. For now, throw if any non-Fake Tensor inputs
    # and just support constructors. TODO: extend more broadly
    if isinstance(cls_or_mode_instance, FakeTensorMode):
        conversion_made = False

        def check_non_fake_tensor(x):
            nonlocal conversion_made
            conversion_made = conversion_made or (isinstance(
                x, torch.Tensor) and not isinstance(x, FakeTensor))

        tree_map(check_non_fake_tensor, args)
        tree_map(check_non_fake_tensor, kwargs)

        if conversion_made:
            raise Exception(
                "Invoking operators with non-Fake Tensor inputs in FakeTensorMode is not yet supported. "
                f"Please convert all Tensors to FakeTensors first. Found in {func}"
            )

    # _to_copy fails when run with FakeTensors to cuda device
    # TODO: debug
    if func == torch.ops.aten._to_copy.default:
        _, new_kwargs = normalize_function(func,
                                           args=args,
                                           kwargs=kwargs,
                                           normalize_to_only_use_kwargs=True)
        out_device = new_kwargs.pop("device", new_kwargs["input"].device)
        with no_dispatch():
            input = new_kwargs.pop("input").to("meta")
            return FakeTensor(torch.ops.aten._to_copy(input, **new_kwargs),
                              out_device)

    if _is_tensor_constructor(func):
        assert func not in _non_kwarg_device_constructors
        _, new_kwargs = normalize_function(func,
                                           args=args,
                                           kwargs=kwargs,
                                           normalize_to_only_use_kwargs=True)
        # cpu is default device if none is specified
        out_device = new_kwargs.pop("device", torch.device("cpu"))
        new_kwargs["device"] = torch.device("meta")
        r = run_function(func, types, (), new_kwargs)
        return FakeTensor(r, out_device)

    r = run_function(func, types, args, kwargs)

    # TODO: handle non-kwarg devices
    assert func not in _device_not_kwarg_ops, f"NYI: {func}"

    # if device is specified, use that
    if kwargs.get("device", None):
        return tree_map(partial(wrap, device=kwargs["device"]), r)

    # operators which copy size from another tensor do not
    # also take device from the size tensor
    # other size_as operators are not builtin operators
    if func == aten.resize_as_.default:
        _, new_kwargs = normalize_function(func,
                                           args=args,
                                           kwargs=kwargs,
                                           normalize_to_only_use_kwargs=True)
        # device of the input is returned
        return tree_map(partial(wrap, device=new_kwargs["input"].device), r)

    common_device = FakeTensor._find_common_device(func, args, kwargs)

    return tree_map(partial(wrap, device=common_device), r)
Example #18
0
def proxy_call(func_overload, args, kwargs=None):
    if kwargs is None:
        kwargs = {}

    func = func_overload.overloadpacket
    if func_overload in CURRENT_DECOMPOSITION_TABLE:
        return CURRENT_DECOMPOSITION_TABLE[func_overload](*args, **kwargs)
    if func_overload == aten._local_scalar_dense.default:
        t, = args
        assert not kwargs
        if t.constant is not None:
            with maybe_disable_fake_tensor_mode():
                return t.constant.item()
        raise RuntimeError("It appears that you're trying to get value out of a tracing tensor - erroring out! "
                           "It's likely that this is caused by data-dependent control flow or similar."
                           "Try torch.fx.experimental.proxy_tensor.enable_strict(False) to disable this check")

    def unwrap_proxy(e):
        return e.proxy if isinstance(e, ProxyTensor) else e

    def unwrap_elem(e):
        if isinstance(e, ProxyTensor):
            return e.elem
        if isinstance(e, torch._C.SymbolicIntNode):
            if isinstance(e.get_pyobj(), ProxySymInt):
                return e.get_pyobj().sym_int
            else:
                raise RuntimeError(f"Something has gone wrong, we are trying to put SymInt {e.get_pyobj()} into the graph,"
                                   f"even though it's not a ProxySymInt. This is a bug.")

        return e

    proxy_args = pytree.tree_map(unwrap_proxy, args)
    proxy_kwargs = pytree.tree_map(unwrap_proxy, kwargs)

    proxy_res = func_overload(*proxy_args, **proxy_kwargs)
    # Kind of a hacky way to test if an op is in-place or not
    if func.__name__[-1] == "_" and func.__name__[0] != "_":
        args[0].proxy = proxy_res
        proxy_res.node.meta['tensor_meta'] = _extract_tensor_metadata(args[0])
    inner_res = func_overload(*pytree.tree_map(unwrap_elem, args), **pytree.tree_map(unwrap_elem, kwargs))

    # Needed to sync up metadata for in-place operators that modify metadata
    # TODO: instead forward the metadata to the inner tensor so updating
    # is not necessary
    if torch.Tag.inplace_view in func_overload.tags:  # type: ignore[attr-defined]
        with no_dispatch():
            func_overload(*args, **kwargs)

    # In some circumstances, we will be tracing in a situation where a tensor
    # is *statically* known to be a constant (currently, this only happens if
    # you run torch.tensor; deterministic factory functions like torch.arange
    # don't get this treatment).  When the tensor in question is small, it's
    # helpful to due constant propagation in case we call item() (in which
    # case we can return the constant value that is known, rather than give
    # an error.)  The logic here tests if constant propagation is possible
    # (because all of the inputs are constant).  If so, we disable fake tensor
    # mode (if it is on) and do true compute on the constant.
    #
    # It's worth highlighting that we're making a policy decision here.
    # There is a potential that the tensor is actually quite large, and we
    # don't actually want to run the compute.  The tensor being quite large
    # is one of the reasons why factory functions don't get this treatment
    # (since they can be quite large; if a parameter is initialized to a
    # constant value it will be!)  Similarly, there is also a potential
    # to run an operator that blows up the size of a small tensor; we don't
    # protect against this case, but we could force, e.g., only single
    # element constant computation by testing the numel of the result before
    # propagating const-ness.  Similarly, we don't require the constant to
    # live on CPU, but we could.
    all_constant = True
    any_constant = False

    def check_constant(e):
        nonlocal all_constant, any_constant
        if isinstance(e, ProxyTensor):
            if e.constant is None:
                all_constant = False
            else:
                any_constant = True

    pytree.tree_map(check_constant, args)
    pytree.tree_map(check_constant, kwargs)

    def unwrap_constant(e):
        if isinstance(e, ProxyTensor):
            return e.constant
        return e

    constant = None
    # NB: do NOT include factories as constants
    if all_constant and any_constant:
        with maybe_disable_fake_tensor_mode():
            constant = func_overload(
                *pytree.tree_map(unwrap_constant, args),
                **pytree.tree_map(unwrap_constant, kwargs)
            )

    # TODO(chilli): Enable this after it's been refactored to work with wrapper tensor subclasses in general
    # pytree.tree_map(lambda x: check_metadata_consistency(x, ProxyTensor), (inner_res, args, kwargs))
    return wrap_output(inner_res, proxy_res, constant=constant)
Example #19
0
 def __repr__(self):
     with no_dispatch():
         return f"ProxyTensor({self.elem}, proxy={self.proxy})"
Example #20
0
 def wrap_with_proxy(e, proxy):
     if isinstance(e, torch.Tensor):
         with no_dispatch():
             return ProxyTensor(e, proxy, **kwargs)
     else:
         return e
Example #21
0
 def wrap_with_proxy(e, proxy):
     if type(e) == torch.Tensor:
         with no_dispatch():
             return ProxyTensor(e, proxy)
     else:
         return e
Example #22
0
    def __torch_dispatch__(self, func, types, args=(), kwargs=None):
        kwargs = kwargs if kwargs else {}

        if func == torch.ops.prim.device.default:
            assert len(args) == 1 and isinstance(args[0], FakeTensor)
            if args[0].fake_mode.in_kernel_invocation:
                return torch.device("meta")
            else:
                return args[0].fake_device
        flat_arg_tensors = [
            i for i in tree_flatten((args, kwargs))[0]
            if isinstance(i, FakeTensor)
        ]
        has_symbolic_sizes = any([i.has_sym_ints for i in flat_arg_tensors])
        if has_symbolic_sizes:
            # TODO: Find better approach for this
            # Avoid circular import
            from torch._decomp import decomposition_table
            from torch._meta_registrations import meta_table

            # TODO: hack, doesn't actually work.
            # see https://github.com/pytorch/pytorch/pull/81598#issuecomment-1192030435
            with enable_torch_dispatch_mode(
                    self), torch.overrides.enable_reentrant_dispatch():
                if func in meta_table:
                    r = meta_table[func](*args, **kwargs)
                    return r
                if func in decomposition_table:
                    return decomposition_table[func](*args, **kwargs)

            with no_dispatch():
                if symbolic_shapes.is_symbolic_op(func):
                    return symbolic_shapes.handle_symbolic_op(
                        func, args, kwargs)

        # prims already wrap FakeTensor inputs to FakeTensor outputs
        # and do device logic, we dont need do anything but run them

        if "prims::" in func._schema.name:
            with no_dispatch():
                return func(*args, **kwargs)

        if has_symbolic_sizes:
            constructors = [torch.ops.aten.empty.SymInt]
            if func not in constructors:
                raise RuntimeError(
                    f"{func} - couldn't find symbolic meta function/decomposition"
                )

        with no_dispatch():
            # TODO: apply as no_dispatch decorator
            converter = self.fake_tensor_converter

            def wrap(e, device=None):
                if isinstance(e,
                              torch.Tensor) and not isinstance(e, FakeTensor):
                    return converter(self, e, device)
                else:
                    return e

            # if we are in the dispatch mode, we will enter this function even if the inputs
            # are not FakeTensors. For now, throw if any non-Fake Tensor inputs
            # and just support constructors. TODO: extend more broadly
            conversion_made = False
            subclass_seen = False

            def check_non_fake_tensor(x):
                nonlocal conversion_made, subclass_seen
                conversion_made = conversion_made or (isinstance(
                    x, torch.Tensor) and not isinstance(x, FakeTensor))
                subclass_seen = subclass_seen or (isinstance(
                    x, torch.Tensor) and not isinstance(x, FakeTensor) and
                                                  type(x) is not torch.Tensor)

            tree_map(check_non_fake_tensor, args)
            tree_map(check_non_fake_tensor, kwargs)

            # Suppose we enable fake tensor mode.  This means that fake tensor
            # mode will run first.  But what if we do an operation that
            # involves a tensor subclass that will desugar into normal tensor
            # operations?  Without this line, fake tensor mode will run first,
            # decide that a conversion was made (since there was a non fake
            # tensor argument), and report an error that converting non
            # fake tensor is not supported.  What we actually wanted to happen
            # was to give the subclass a chance to figure out what it wants to
            # before erroring out.  Returning NotImplemented here allows this.
            #
            # NB: If you're seeing a mysterious infinite loop involving fake
            # tensor, it might be related to this line.  Though I'm not sure
            # how you'll know to read this comment, as this line won't show up
            # in the stack trace.
            if subclass_seen:
                return NotImplemented

            # this is generated from torch.tensor(), which does not use the
            # dispatcher, to allow wrapper subclasses to wrap the new tensor
            # we need to handle before error checking
            if func in [
                    torch.ops.aten.lift_fresh.default,
                    torch.ops.aten.lift_fresh_copy.default,
            ]:
                assert (len(kwargs) == 0 and len(args) == 1
                        and type(args[0]) is torch.Tensor), f"{args} {kwargs}"
                with no_dispatch():
                    return converter(self, args[0])

            if conversion_made:
                raise Exception(
                    "Invoking operators with non-Fake Tensor inputs in FakeTensorMode is not yet supported. "
                    f"Please convert all Tensors to FakeTensors first. Found in {func}(*{args}, **{kwargs})"
                )

            for run_impl_check, op_impl in op_implementations:
                if run_impl_check(func):
                    return op_impl(self, func, *args, **kwargs)

            with in_kernel_invocation_manager(self):
                try:
                    r = func(*args, **kwargs)
                except NotImplementedError as not_implemented_error:
                    if not self.allow_fallback_kernels:
                        raise not_implemented_error
                    r = run_fallback_kernel(func, args, kwargs,
                                            not_implemented_error)

            # TODO: handle non-kwarg devices
            assert func not in _device_not_kwarg_ops, f"NYI: {func}"

            # if device is specified, use that
            if kwargs.get("device", None):
                return tree_map(partial(wrap, device=kwargs["device"]), r)

            common_device = FakeTensor._find_common_device(func, args, kwargs)

            return tree_map(partial(wrap, device=common_device), r)
Example #23
0
 def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
     with no_dispatch():
         return cls._torch_dispatch(func, types, args, kwargs)
Example #24
0
def clone(fake_mode, func, input, memory_format=None):
    out_device = input.device
    with no_dispatch():
        out = torch.ops.aten._to_copy(input.to("meta"),
                                      memory_format=memory_format)
        return FakeTensor(fake_mode, out, out_device)