Example #1
0
def packed_quantized_conv2d_mapper(node: torch.fx.Node,
                                   mod: nn.Module) -> torch.fx.Node:
    """
    Mapping from quantzed Conv2d module to acc_op.conv. We unpack all the parameters
    in this mapper and pass them directly to conv2d node.
    """
    assert isinstance(node.target, str)
    conv_module = dict(mod.named_modules())[node.target]
    prefix = node.target.replace(".", "_")
    weight_name = f"{prefix}_weight"
    bias_name = f"{prefix}_bias"

    # Store weight and bias in the main module
    mod.register_buffer(weight_name, conv_module.weight())
    if conv_module.bias() is not None:
        mod.register_buffer(bias_name, conv_module.bias())

    with node.graph.inserting_before(node):
        # Insert get_attr nodes for weight and bias
        get_weight = node.graph.get_attr(weight_name)
        get_weight.meta["tensor_meta"] = _extract_tensor_metadata(
            conv_module.weight())

        get_bias = None
        if conv_module.bias() is not None:
            get_bias = node.graph.get_attr(bias_name)
            get_bias.meta["tensor_meta"] = _extract_tensor_metadata(
                conv_module.bias())

        # Create kwargs for acc_op.conv
        kwargs = {
            "input":
            node.kwargs["input"],
            "weight":
            get_weight,
            "bias":
            get_bias,
            "stride":
            conv_module.stride,
            "padding":
            conv_module.padding,
            "dilation":
            conv_module.dilation,
            "groups":
            conv_module.groups,
            "padding_mode":
            conv_module.padding_mode,
            "acc_out_ty":
            acc_utils.build_raw_tensor_meta(
                q_scale=conv_module.scale,
                q_zero_point=conv_module.zero_point),
        }

        new_node = node.graph.call_function(quantized_conv2d, kwargs=kwargs)
        new_node.meta = node.meta
        return new_node
Example #2
0
def proxy_call(func_overload, args, kwargs=None):
    func = func_overload.overloadpacket
    if func_overload in CURRENT_DECOMPOSITION_TABLE:
        return CURRENT_DECOMPOSITION_TABLE[func_overload](*args, **kwargs)
    if func_overload == aten._local_scalar_dense.default:
        raise RuntimeError("It appears that you're trying to get value out of a tracing tensor - erroring out! "
                           "It's likely that this is caused by data-dependent control flow or similar."
                           "Try torch.fx.experimental.proxy_tensor.enable_strict(False) to disable this check")

    def unwrap_proxy(e):
        return e.proxy if isinstance(e, ProxyTensor) else e

    proxy_args = pytree.tree_map(unwrap_proxy, args)
    proxy_kwargs = pytree.tree_map(unwrap_proxy, kwargs)

    proxy_out = func(*proxy_args, **proxy_kwargs)

    # Kind of a hacky way to test if an op is in-place or not
    if func.__name__[-1] == "_" and func.__name__[0] != "_":
        args[0].proxy = proxy_out
        proxy_out.node.meta['tensor_meta'] = _extract_tensor_metadata(args[0])

    with no_dispatch():
        real_out = func_overload(*args, **kwargs)

    return wrap_output(real_out, proxy_out)
Example #3
0
    def _test_const_fold_tensor_meta(self, requires_grad):
        """
        Verify tensor_meta is handled correctly.
        """

        class ConstFoldTestModule(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.attr_1 = torch.nn.Parameter(torch.tensor([[-0.9]]), requires_grad)
                self.attr_2 = torch.nn.Parameter(torch.tensor([[17.1]]), requires_grad)

            def forward(self, x, y):
                a = self.attr_1 + self.attr_1
                x = x - a
                return x * y + self.attr_2

        mod = ConstFoldTestModule()
        gm = torch.fx.symbolic_trace(mod)
        in_x, in_y = torch.tensor([[-0.45]]), torch.tensor([0.9])
        ShapeProp(gm).propagate(in_x, in_y)
        mod_folded: const_fold.FoldedGraphModule = const_fold.split_const_subgraphs(gm)
        self._verify_const_fold_mod(mod_folded)

        mod_folded.run_folding()

        for n in mod_folded.graph.nodes:
            if n.op == "get_attr":
                attr = self._get_attr(n)
                self.assertEquals(_extract_tensor_metadata(attr), n.meta["tensor_meta"])

        # Now run both folded and non-folded to check results equal.
        base_result = mod(in_x, in_y)
        fold_result = mod_folded(in_x, in_y)
        self.assertTrue(torch.equal(fold_result, base_result))
Example #4
0
 def __init__(self, elem, proxy, *, requires_grad=None):
     if elem.is_sparse:
         proxy.node.meta['tensor_meta'] = {}
     else:
         proxy.node.meta['tensor_meta'] = _extract_tensor_metadata(self)
     self.elem = elem
     self.proxy = proxy
Example #5
0
def packed_quantized_linear_mapper(node: torch.fx.Node,
                                   mod: nn.Module) -> torch.fx.Node:
    """
    Mapping from quantized_linear module to acc_op.linear. We unpack weight and bias
    in this mapper and pass them directly to linear node.
    """
    linear_module = dict(mod.named_modules())[node.target]
    prefix = node.target.replace(".", "_")
    weight_name = f"{prefix}_weight"
    bias_name = f"{prefix}_bias"

    # Store weight and bias in the main module
    mod.register_buffer(weight_name, linear_module.weight())
    if linear_module.bias() is not None:
        mod.register_buffer(bias_name, linear_module.bias())

    with node.graph.inserting_before(node):
        # Insert get_attr nodes for weight and bias
        get_weight = node.graph.get_attr(weight_name)
        get_weight.meta["tensor_meta"] = _extract_tensor_metadata(
            linear_module.weight())

        get_bias = None
        if linear_module.bias() is not None:
            get_bias = node.graph.get_attr(bias_name)
            get_bias.meta["tensor_meta"] = _extract_tensor_metadata(
                linear_module.bias())

        # Create kwargs for acc_op.quantized_linear
        kwargs = {
            "input":
            node.kwargs["input"],
            "weight":
            get_weight,
            "bias":
            get_bias,
            "acc_out_ty":
            acc_utils.build_raw_tensor_meta(
                q_scale=linear_module.scale,
                q_zero_point=linear_module.zero_point),
        }

        new_node = node.graph.call_function(quantized_linear, kwargs=kwargs)
        new_node.meta = node.meta
        return new_node
Example #6
0
 def __init__(self, elem, proxy, *, requires_grad=None):
     if elem.is_sparse:
         proxy.node.meta['tensor_meta'] = {}
     else:
         proxy.node.meta['tensor_meta'] = _extract_tensor_metadata(self)
     # This detects situations where you accidentally put a ProxyTensor
     # inside a ProxyTensor for the same trace; this is a layering violation
     assert not (isinstance(elem, ProxyTensor) and elem.proxy.tracer is proxy.tracer)
     self.elem = elem
     self.proxy = proxy
Example #7
0
 def __init__(self, elem, proxy, *, requires_grad=None, constant=None):
     # TODO: hack since _extract_tensor_metadata currently tries to access stride
     if elem.is_sparse or self.has_sym_ints:
         proxy.node.meta['tensor_meta'] = {}
     else:
         proxy.node.meta['tensor_meta'] = _extract_tensor_metadata(self)
     # This detects situations where you accidentally put a ProxyTensor
     # inside a ProxyTensor for the same trace; this is a layering violation
     assert not (isinstance(elem, ProxyTensor) and elem.proxy.tracer is proxy.tracer)
     self.elem = elem
     self.proxy = proxy
     self.constant = constant
Example #8
0
    def __torch_dispatch__(cls, func_overload, types, args=(), kwargs=None):
        func = func_overload.overloadpacket
        if func_overload in CURRENT_DECOMPOSITION_TABLE:
            return CURRENT_DECOMPOSITION_TABLE[func_overload](*args, **kwargs)
        # Commenting this out for now since it causes some spurious failures (such as error checking)
        # if func == aten._local_scalar_dense:
        #     raise RuntimeError("It appears that you're trying to get value out of a tracing tensor - erroring out! "
        #                        "It's likely that this is caused by data-dependent control flow or similar.")

        def unwrap_proxy(e):
            return e.proxy if isinstance(e, PythonTensor) else e

        proxy_args = pytree.tree_map(unwrap_proxy, args)
        proxy_kwargs = pytree.tree_map(unwrap_proxy, kwargs)

        proxy_out = func(*proxy_args, **proxy_kwargs)

        # Kind of a hacky way to test if an op is in-place or not
        if func.__name__[-1] == "_" and func.__name__[0] != "_":
            args[0].proxy = proxy_out
            proxy_out.node.meta['tensor_meta'] = _extract_tensor_metadata(
                args[0])

        with no_dispatch():
            real_out = func_overload(*args, **kwargs)

        def wrap_with_proxy(e, proxy):
            # Some ops (like native_batch_norm_backward) return undefined tensors that get
            # converted into None in python.
            # As the function signature expects tensors, if we directly return these None
            # tensors back to C++, we'll error.
            if e is None:
                e = torch.empty(())
            if type(e) == torch.Tensor:
                return PythonTensor(e, proxy)
            else:
                return e

        if isinstance(real_out, tuple):
            return tuple(
                wrap_with_proxy(e, proxy_out[idx])
                for idx, e in enumerate(real_out))
        elif isinstance(real_out, list):
            return [
                wrap_with_proxy(e, proxy_out[idx])
                for idx, e in enumerate(real_out)
            ]
        elif isinstance(real_out, torch.Tensor):
            return wrap_with_proxy(real_out, proxy_out)
        else:
            return real_out
Example #9
0
    def __new__(cls, elem, proxy, *, requires_grad=None):
        # Hack to deal with super().__new__ not working for sparse tensors
        if elem.is_sparse or requires_grad is not None:
            r = torch.Tensor._make_subclass(cls, elem, requires_grad)
        else:
            r = super().__new__(cls, elem)  # type: ignore[call-arg]

        if elem.is_sparse:
            proxy.node.meta['tensor_meta'] = {}
        else:
            proxy.node.meta['tensor_meta'] = _extract_tensor_metadata(r)
        r.proxy = proxy  # type: ignore[attr-defined]

        return r
Example #10
0
    def __new__(cls, elem, proxy):
        # Wrapping something in PythonTensor implicitly detaches
        # gradients.  If something required grad, we will collect it as if it
        # were a leaf.  A consequence of detaching in this way is you
        # need to maintain a parameter cache when translating tensors
        # into PythonTensor, so you don't create multiple copies of
        # a gradient (they are aliased, but they would count as independent
        # leaves).  An alternate strategy would be to avoid implicitly
        # detaching and instead "catch" gradients as they exit the
        # PythonTensor boundary.
        # assert not elem.requires_grad or not torch.is_grad_enabled()

        r = torch.Tensor._make_subclass(cls, elem, elem.requires_grad)
        r.proxy = proxy
        if elem.is_sparse:
            proxy.node.meta['tensor_meta'] = {}
        else:
            proxy.node.meta['tensor_meta'] = _extract_tensor_metadata(r)
        return r
Example #11
0
def proxy_call(func_overload, args, kwargs=None):
    if kwargs is None:
        kwargs = {}
    func = func_overload.overloadpacket
    if func_overload in CURRENT_DECOMPOSITION_TABLE:
        return CURRENT_DECOMPOSITION_TABLE[func_overload](*args, **kwargs)
    if func_overload == aten._local_scalar_dense.default:
        raise RuntimeError("It appears that you're trying to get value out of a tracing tensor - erroring out! "
                           "It's likely that this is caused by data-dependent control flow or similar."
                           "Try torch.fx.experimental.proxy_tensor.enable_strict(False) to disable this check")

    def unwrap_proxy(e):
        return e.proxy if isinstance(e, ProxyTensor) else e

    def unwrap_elem(e):
        if isinstance(e, ProxyTensor):
            return e.elem
        return e

    proxy_args = pytree.tree_map(unwrap_proxy, args)
    proxy_kwargs = pytree.tree_map(unwrap_proxy, kwargs)

    proxy_res = func_overload(*proxy_args, **proxy_kwargs)

    # Kind of a hacky way to test if an op is in-place or not
    if func.__name__[-1] == "_" and func.__name__[0] != "_":
        args[0].proxy = proxy_res
        proxy_res.node.meta['tensor_meta'] = _extract_tensor_metadata(args[0])

    inner_res = func_overload(*pytree.tree_map(unwrap_elem, args), **pytree.tree_map(unwrap_elem, kwargs))
    # Needed to sync up metadata for in-place operators that modify metadata
    if torch.Tag.inplace_view in func_overload.tags:  # type: ignore[attr-defined]
        with no_dispatch():
            func_overload(*args, **kwargs)

    # TODO(chilli): Enable this after it's been refactored to work with wrapper tensor subclasses in general
    # pytree.tree_map(lambda x: check_metadata_consistency(x, ProxyTensor), (inner_res, args, kwargs))
    return wrap_output(inner_res, proxy_res)
Example #12
0
def proxy_call(func_overload, args, kwargs=None):
    if kwargs is None:
        kwargs = {}

    func = func_overload.overloadpacket
    if func_overload in CURRENT_DECOMPOSITION_TABLE:
        return CURRENT_DECOMPOSITION_TABLE[func_overload](*args, **kwargs)
    if func_overload == aten._local_scalar_dense.default:
        t, = args
        assert not kwargs
        if t.constant is not None:
            with maybe_disable_fake_tensor_mode():
                return t.constant.item()
        raise RuntimeError("It appears that you're trying to get value out of a tracing tensor - erroring out! "
                           "It's likely that this is caused by data-dependent control flow or similar."
                           "Try torch.fx.experimental.proxy_tensor.enable_strict(False) to disable this check")

    def unwrap_proxy(e):
        return e.proxy if isinstance(e, ProxyTensor) else e

    def unwrap_elem(e):
        if isinstance(e, ProxyTensor):
            return e.elem
        if isinstance(e, torch._C.SymbolicIntNode):
            if isinstance(e.get_pyobj(), ProxySymInt):
                return e.get_pyobj().sym_int
            else:
                raise RuntimeError(f"Something has gone wrong, we are trying to put SymInt {e.get_pyobj()} into the graph,"
                                   f"even though it's not a ProxySymInt. This is a bug.")

        return e

    proxy_args = pytree.tree_map(unwrap_proxy, args)
    proxy_kwargs = pytree.tree_map(unwrap_proxy, kwargs)

    proxy_res = func_overload(*proxy_args, **proxy_kwargs)
    # Kind of a hacky way to test if an op is in-place or not
    if func.__name__[-1] == "_" and func.__name__[0] != "_":
        args[0].proxy = proxy_res
        proxy_res.node.meta['tensor_meta'] = _extract_tensor_metadata(args[0])
    inner_res = func_overload(*pytree.tree_map(unwrap_elem, args), **pytree.tree_map(unwrap_elem, kwargs))

    # Needed to sync up metadata for in-place operators that modify metadata
    # TODO: instead forward the metadata to the inner tensor so updating
    # is not necessary
    if torch.Tag.inplace_view in func_overload.tags:  # type: ignore[attr-defined]
        with no_dispatch():
            func_overload(*args, **kwargs)

    # In some circumstances, we will be tracing in a situation where a tensor
    # is *statically* known to be a constant (currently, this only happens if
    # you run torch.tensor; deterministic factory functions like torch.arange
    # don't get this treatment).  When the tensor in question is small, it's
    # helpful to due constant propagation in case we call item() (in which
    # case we can return the constant value that is known, rather than give
    # an error.)  The logic here tests if constant propagation is possible
    # (because all of the inputs are constant).  If so, we disable fake tensor
    # mode (if it is on) and do true compute on the constant.
    #
    # It's worth highlighting that we're making a policy decision here.
    # There is a potential that the tensor is actually quite large, and we
    # don't actually want to run the compute.  The tensor being quite large
    # is one of the reasons why factory functions don't get this treatment
    # (since they can be quite large; if a parameter is initialized to a
    # constant value it will be!)  Similarly, there is also a potential
    # to run an operator that blows up the size of a small tensor; we don't
    # protect against this case, but we could force, e.g., only single
    # element constant computation by testing the numel of the result before
    # propagating const-ness.  Similarly, we don't require the constant to
    # live on CPU, but we could.
    all_constant = True
    any_constant = False

    def check_constant(e):
        nonlocal all_constant, any_constant
        if isinstance(e, ProxyTensor):
            if e.constant is None:
                all_constant = False
            else:
                any_constant = True

    pytree.tree_map(check_constant, args)
    pytree.tree_map(check_constant, kwargs)

    def unwrap_constant(e):
        if isinstance(e, ProxyTensor):
            return e.constant
        return e

    constant = None
    # NB: do NOT include factories as constants
    if all_constant and any_constant:
        with maybe_disable_fake_tensor_mode():
            constant = func_overload(
                *pytree.tree_map(unwrap_constant, args),
                **pytree.tree_map(unwrap_constant, kwargs)
            )

    # TODO(chilli): Enable this after it's been refactored to work with wrapper tensor subclasses in general
    # pytree.tree_map(lambda x: check_metadata_consistency(x, ProxyTensor), (inner_res, args, kwargs))
    return wrap_output(inner_res, proxy_res, constant=constant)
Example #13
0
def proxy_call(proxy_mode, func_overload, args, kwargs=None):
    if kwargs is None:
        kwargs = {}

    func = func_overload.overloadpacket
    if func_overload in CURRENT_DECOMPOSITION_TABLE:
        with proxy_mode.restore():
            r = CURRENT_DECOMPOSITION_TABLE[func_overload](*args, **kwargs)
            if r is not NotImplemented:
                return r

    # Some of these are not "real" aten ops and will fail if we
    # call _dispatch_has_kernel_for_dispatch_key on them.
    # This list is probably incomplete
    if func_overload not in [torch.ops.aten.size.default]:
        with proxy_mode.restore():
            r = func_overload.decompose(*args, **kwargs)
            if r is not NotImplemented:
                return r

    tracer = proxy_mode.tracer

    f_args, f_kwargs = pytree.tree_map_only(torch.Tensor,
                                            fetch_tensor_proxy(tracer),
                                            (args, kwargs))

    # If there are SymInts, we also should not consider this constant.
    # However, fake tensor handling of SymInts is sufficiently broken that
    # I couldn't write a test for this case
    all_constant = (
        pytree.tree_all_only(_ProxyTensor, lambda t: t.constant is not None,
                             (f_args, f_kwargs))
        # TODO: maybe constant SymInts should also be allowed?  Not sure if
        # this can happen
        and pytree.tree_all_only(SymInt, lambda _: False, (args, kwargs)))

    if torch.Tag.data_dependent_output in func_overload.tags:  # type: ignore[attr-defined]
        # Check if all of the Tensor inputs are constants
        if all_constant:
            const_args, const_kwargs = pytree.tree_map_only(
                _ProxyTensor, lambda t: t.constant, (f_args, f_kwargs))
            with maybe_disable_fake_tensor_mode():
                return func_overload(*const_args, **const_kwargs)
        raise RuntimeError(
            "It appears that you're trying to get value out of a tracing tensor - erroring out! "
            "It's likely that this is caused by data-dependent control flow or similar."
        )

    proxy_args, proxy_kwargs = pytree.tree_map_only(
        SymInt, fetch_symint_proxy(proxy_mode.tracer),
        pytree.tree_map_only(_ProxyTensor, lambda e: e.proxy,
                             (f_args, f_kwargs)))
    proxy_out = func_overload(*proxy_args, **proxy_kwargs)

    # Kind of a hacky way to test if an op is in-place or not
    if func.__name__[-1] == "_" and func.__name__[0] != "_":
        # This makes DCE marginally less likely to DCE inplace operations.
        # It is not strictly necessary
        args[0].proxy = proxy_out
        proxy_out.node.meta['tensor_meta'] = _extract_tensor_metadata(args[0])

    out = func_overload(*args, **kwargs)

    # In some circumstances, we will be tracing in a situation where a tensor
    # is *statically* known to be a constant (currently, this only happens if
    # you run torch.tensor; deterministic factory functions like torch.arange
    # don't get this treatment).  When the tensor in question is small, it's
    # helpful to due constant propagation in case we call item() (in which
    # case we can return the constant value that is known, rather than give
    # an error.)  The logic here tests if constant propagation is possible
    # (because all of the inputs are constant).  If so, we disable fake tensor
    # mode (if it is on) and do true compute on the constant.
    #
    # It's worth highlighting that we're making a policy decision here.
    # There is a potential that the tensor is actually quite large, and we
    # don't actually want to run the compute.  The tensor being quite large
    # is one of the reasons why factory functions don't get this treatment
    # (since they can be quite large; if a parameter is initialized to a
    # constant value it will be!)  Similarly, there is also a potential
    # to run an operator that blows up the size of a small tensor; we don't
    # protect against this case, but we could force, e.g., only single
    # element constant computation by testing the numel of the result before
    # propagating const-ness.  Similarly, we don't require the constant to
    # live on CPU, but we could.
    any_constant = pytree.tree_any_only(_ProxyTensor,
                                        lambda t: t.constant is not None,
                                        (f_args, f_kwargs))

    constant = None
    # NB: do NOT include factories as constants
    if (torch.Tag.nondeterministic_seeded
            not in func_overload.tags  # type: ignore[attr-defined]
            and all_constant and any_constant and pytree.tree_all_only(
                torch.Tensor, lambda t: t.numel() <= CONSTANT_NUMEL_LIMIT,
                out)):
        with maybe_disable_fake_tensor_mode():
            const_args, const_kwargs = pytree.tree_map_only(
                _ProxyTensor, lambda t: t.constant, (f_args, f_kwargs))
            constant = func_overload(*const_args, **const_kwargs)

    track_tensor_tree(out, proxy_out, constant=constant, tracer=tracer)
    return out
Example #14
0
 def wrap_with_proxy(e, proxy, constant):
     if isinstance(e, torch.Tensor):
         track_tensor(e, proxy, tracer=tracer, constant=constant)
         if not e.is_sparse:
             proxy.node.meta['tensor_meta'] = _extract_tensor_metadata(e)