Exemple #1
0
    def __torch_dispatch__(self, func_overload, types, args=(), kwargs=None):
        if symbolic_shapes.is_symbolic_op(func_overload):
            return symbolic_shapes.handle_symbolic_op(func_overload, args, kwargs)

        func = func_overload.overloadpacket
        # We don't want to convert torch.tensor constants into tracing objects.
        if func_overload == aten.lift.default:
            return args[0]
        if any(tuple(isinstance(arg, ProxyTensor) for arg in pytree.tree_flatten(args)[0])):
            return proxy_call(func_overload, args, kwargs)
        # When we trace through a torch.tensor invocation, you never actually
        # see a torch.ops.aten.tensor call. Instead, the way this function is
        # implemented internally is that we allocate a plain tensor (this is
        # *guaranteed* to be a plain tensor, we disable all modes when doing
        # so), and then call at::lift_fresh on it (to give modes a chance to do
        # their stuff).  Furthermore, the tensor argument to lift_fresh is guaranteed
        # to be freshly allocated, so we want lift_fresh to be a no-op (directly
        # returning the input argument).
        #
        # Here is the basic problem: when we trace this sequence of executions
        # into an FX graph, what happens to this call sequence?  Traditionally,
        # tensor constants get interned as buffers on the FX GraphModule.  But
        # this is dangerous.  Consider:
        #
        #       x = torch.tensor(1)
        #       x.add_(2)
        #
        # Naively, this traces into:
        #
        #       t = self._tensor_constant0  # initialized to torch.tensor(1)
        #       x = torch.ops.aten.lift_fresh(t)
        #       x.add_(2)
        #
        # If lift_fresh returns t directly, the subsequent add_ call will
        # modify the tensor constant. Really, the problem is we've violated
        # the invariant the the argument to lift is fresh.  So what we should
        # preserve the invariant by replacing lift_fresh with lift_fresh_copy:
        #
        #       t = self._tensor_constant0  # initialized to torch.tensor(1)
        #       x = torch.ops.aten.lift_fresh_copy(t)
        #       x.add_(2)
        #
        # This is what the overload modification does.
        else:
            if func_overload is torch.ops.aten.lift_fresh.default:
                func_overload = torch.ops.aten.lift_fresh_copy.default

            proxy_res = self.tracer.create_proxy('call_function', func_overload, args, kwargs,
                                                 name=self.tracer.graph._target_to_str(func.__name__))

            inner_res = func_overload(*args, **kwargs)

            # If this is a lift, the input tensor is guaranteed to be a
            # constant, so we keep a copy of the original argument along so
            # we can query it if we're asked to item() it at some later point
            is_lift = func_overload is torch.ops.aten.lift_fresh_copy.default
            if is_lift:
                with maybe_disable_fake_tensor_mode():
                    constant = args[0].clone()
            else:
                constant = None
            return wrap_output(inner_res, proxy_res, constant=constant)
Exemple #2
0
    def __torch_dispatch__(self, func, types, args=(), kwargs=None):
        kwargs = kwargs if kwargs else {}

        if func == torch.ops.prim.device.default:
            assert len(args) == 1 and isinstance(args[0], FakeTensor)
            if args[0].fake_mode.in_kernel_invocation:
                return torch.device("meta")
            else:
                return args[0].fake_device
        flat_arg_tensors = [
            i for i in tree_flatten((args, kwargs))[0]
            if isinstance(i, FakeTensor)
        ]
        has_symbolic_sizes = any([i.has_sym_ints for i in flat_arg_tensors])
        if has_symbolic_sizes:
            # TODO: Find better approach for this
            # Avoid circular import
            from torch._decomp import decomposition_table
            from torch._meta_registrations import meta_table

            # TODO: hack, doesn't actually work.
            # see https://github.com/pytorch/pytorch/pull/81598#issuecomment-1192030435
            with enable_torch_dispatch_mode(
                    self), torch.overrides.enable_reentrant_dispatch():
                if func in meta_table:
                    r = meta_table[func](*args, **kwargs)
                    return r
                if func in decomposition_table:
                    return decomposition_table[func](*args, **kwargs)

            with no_dispatch():
                if symbolic_shapes.is_symbolic_op(func):
                    return symbolic_shapes.handle_symbolic_op(
                        func, args, kwargs)

        # prims already wrap FakeTensor inputs to FakeTensor outputs
        # and do device logic, we dont need do anything but run them

        if "prims::" in func._schema.name:
            with no_dispatch():
                return func(*args, **kwargs)

        if has_symbolic_sizes:
            constructors = [torch.ops.aten.empty.SymInt]
            if func not in constructors:
                raise RuntimeError(
                    f"{func} - couldn't find symbolic meta function/decomposition"
                )

        with no_dispatch():
            # TODO: apply as no_dispatch decorator
            converter = self.fake_tensor_converter

            def wrap(e, device=None):
                if isinstance(e,
                              torch.Tensor) and not isinstance(e, FakeTensor):
                    return converter(self, e, device)
                else:
                    return e

            # if we are in the dispatch mode, we will enter this function even if the inputs
            # are not FakeTensors. For now, throw if any non-Fake Tensor inputs
            # and just support constructors. TODO: extend more broadly
            conversion_made = False
            subclass_seen = False

            def check_non_fake_tensor(x):
                nonlocal conversion_made, subclass_seen
                conversion_made = conversion_made or (isinstance(
                    x, torch.Tensor) and not isinstance(x, FakeTensor))
                subclass_seen = subclass_seen or (isinstance(
                    x, torch.Tensor) and not isinstance(x, FakeTensor) and
                                                  type(x) is not torch.Tensor)

            tree_map(check_non_fake_tensor, args)
            tree_map(check_non_fake_tensor, kwargs)

            # Suppose we enable fake tensor mode.  This means that fake tensor
            # mode will run first.  But what if we do an operation that
            # involves a tensor subclass that will desugar into normal tensor
            # operations?  Without this line, fake tensor mode will run first,
            # decide that a conversion was made (since there was a non fake
            # tensor argument), and report an error that converting non
            # fake tensor is not supported.  What we actually wanted to happen
            # was to give the subclass a chance to figure out what it wants to
            # before erroring out.  Returning NotImplemented here allows this.
            #
            # NB: If you're seeing a mysterious infinite loop involving fake
            # tensor, it might be related to this line.  Though I'm not sure
            # how you'll know to read this comment, as this line won't show up
            # in the stack trace.
            if subclass_seen:
                return NotImplemented

            # this is generated from torch.tensor(), which does not use the
            # dispatcher, to allow wrapper subclasses to wrap the new tensor
            # we need to handle before error checking
            if func in [
                    torch.ops.aten.lift_fresh.default,
                    torch.ops.aten.lift_fresh_copy.default,
            ]:
                assert (len(kwargs) == 0 and len(args) == 1
                        and type(args[0]) is torch.Tensor), f"{args} {kwargs}"
                with no_dispatch():
                    return converter(self, args[0])

            if conversion_made:
                raise Exception(
                    "Invoking operators with non-Fake Tensor inputs in FakeTensorMode is not yet supported. "
                    f"Please convert all Tensors to FakeTensors first. Found in {func}(*{args}, **{kwargs})"
                )

            for run_impl_check, op_impl in op_implementations:
                if run_impl_check(func):
                    return op_impl(self, func, *args, **kwargs)

            with in_kernel_invocation_manager(self):
                try:
                    r = func(*args, **kwargs)
                except NotImplementedError as not_implemented_error:
                    if not self.allow_fallback_kernels:
                        raise not_implemented_error
                    r = run_fallback_kernel(func, args, kwargs,
                                            not_implemented_error)

            # TODO: handle non-kwarg devices
            assert func not in _device_not_kwarg_ops, f"NYI: {func}"

            # if device is specified, use that
            if kwargs.get("device", None):
                return tree_map(partial(wrap, device=kwargs["device"]), r)

            common_device = FakeTensor._find_common_device(func, args, kwargs)

            return tree_map(partial(wrap, device=common_device), r)
Exemple #3
0
    def inner_torch_dispatch(self, func_overload, types, args=(), kwargs=None):
        if not self.enable_tracing:
            return func_overload(*args, **kwargs)

        if symbolic_shapes.is_symbolic_op(func_overload):
            with self.restore():
                return symbolic_shapes.handle_symbolic_op(
                    func_overload, args, kwargs)

        func = func_overload.overloadpacket
        # We don't want to convert torch.tensor constants into tracing objects.
        if func_overload == aten.lift.default:
            return args[0]

        if func in [prim.device]:
            return func_overload(*args, **kwargs)

        if pytree.tree_any_only(torch.Tensor,
                                lambda t: has_proxy_slot(t, self.tracer),
                                (args, kwargs)):
            out = proxy_call(self, func_overload, args, kwargs)
        # When we trace through a torch.tensor invocation, you never actually
        # see a torch.ops.aten.tensor call. Instead, the way this function is
        # implemented internally is that we allocate a plain tensor (this is
        # *guaranteed* to be a plain tensor, we disable all modes when doing
        # so), and then call at::lift_fresh on it (to give modes a chance to do
        # their stuff).  Furthermore, the tensor argument to lift_fresh is guaranteed
        # to be freshly allocated, so we want lift_fresh to be a no-op (directly
        # returning the input argument).
        #
        # Here is the basic problem: when we trace this sequence of executions
        # into an FX graph, what happens to this call sequence?  Traditionally,
        # tensor constants get interned as buffers on the FX GraphModule.  But
        # this is dangerous.  Consider:
        #
        #       x = torch.tensor(1)
        #       x.add_(2)
        #
        # Naively, this traces into:
        #
        #       t = self._tensor_constant0  # initialized to torch.tensor(1)
        #       x = torch.ops.aten.lift_fresh(t)
        #       x.add_(2)
        #
        # If lift_fresh returns t directly, the subsequent add_ call will
        # modify the tensor constant. Really, the problem is we've violated
        # the invariant the the argument to lift is fresh.  So what we should
        # preserve the invariant by replacing lift_fresh with lift_fresh_copy:
        #
        #       t = self._tensor_constant0  # initialized to torch.tensor(1)
        #       x = torch.ops.aten.lift_fresh_copy(t)
        #       x.add_(2)
        #
        # This is what the overload modification does.
        else:
            flat_args = pytree.tree_flatten((args, kwargs))[0]
            handled_types = [torch.Tensor, _ProxyTensor, torch.nn.Parameter]

            # If there are any tensor subclasses, we need to handle those tensor subclasses first
            # TODO: we could use types to test this
            if any(
                    isinstance(arg, torch.Tensor)
                    and type(arg) not in handled_types for arg in flat_args):
                return NotImplemented

            if func_overload is torch.ops.aten.lift_fresh.default:
                func_overload = torch.ops.aten.lift_fresh_copy.default

            n_args, n_kwargs = pytree.tree_map_only(
                SymInt, fetch_symint_proxy(self.tracer), (args, kwargs))

            proxy_out = self.tracer.create_proxy(
                'call_function',
                func_overload,
                n_args,
                n_kwargs,
                name=self.tracer.graph._target_to_str(func.__name__))

            out = func_overload(*args, **kwargs)

            # If this is a lift, the input tensor is guaranteed to be a
            # constant, so we keep a copy of the original argument along so
            # we can query it if we're asked to item() it at some later point
            is_lift = func_overload is torch.ops.aten.lift_fresh_copy.default
            if is_lift and out.numel() <= CONSTANT_NUMEL_LIMIT:
                with maybe_disable_fake_tensor_mode():
                    constant = args[0].clone()
            else:
                constant = None
            track_tensor_tree(out,
                              proxy_out,
                              constant=constant,
                              tracer=self.tracer)

        def assert_proxy_tensor(e):
            assert has_proxy_slot(e, self.tracer), \
                f"Internal Error: make_fx is incorrectly baking a tensor constant into the graph: {str(e)}"

        # When we trace factory functions, we expect that tensor outputs are *always* tracked.
        # (Except for torch.tensor() constants handled through lift(), which is handled
        # specially further up).
        pytree.tree_map_only(torch.Tensor, assert_proxy_tensor, out)
        return out