예제 #1
0
    def run_node(self, n : Node) -> Any:
        try:
            result = super().run_node(n)
        except Exception:
            traceback.print_exc()
            raise RuntimeError(
                f"ShapeProp error for: node={n.format_node()} with "
                f"meta={n.meta}"
            )

        found_tensor = False

        def extract_tensor_meta(obj):
            if isinstance(obj, torch.Tensor):
                nonlocal found_tensor
                found_tensor = True
                return _extract_tensor_metadata(obj)
            else:
                return obj

        meta = map_aggregate(result, extract_tensor_meta)
        if found_tensor:
            n.meta['tensor_meta'] = meta

        n.meta['type'] = type(result)
        return result
예제 #2
0
 def __torch_function__(self, func, types, args=(), kwargs={}):
     namespace, func_name = func.split("::")
     func = getattr(getattr(torch.ops, namespace), func_name)
     outs = kwargs['val']
     rets = []
     proxy_args = map_aggregate(
         args, lambda i: i.proxy if isinstance(i, PythonTensor) else i)
     out_proxy = func(*proxy_args)
     if len(outs) == 1 and isinstance(outs[0], torch.Tensor):
         return [PythonTensor(outs[0], out_proxy)]
     for idx, out in enumerate(outs):
         if isinstance(out, torch.Tensor):
             rets.append(PythonTensor(out, out_proxy[idx]))
         else:
             rets.append(out)
     return rets
예제 #3
0
    def run_node(self, n):
        result = super().run_node(n)

        found_tensor = False

        def extract_tensor_meta(obj):
            if isinstance(obj, torch.Tensor):
                nonlocal found_tensor
                found_tensor = True
                return obj
            else:
                return obj

        from torch.fx.node import map_aggregate
        concrete_value = map_aggregate(result, extract_tensor_meta)
        if found_tensor:
            n.meta['concrete_value'] = concrete_value
        return result
예제 #4
0
    def run_node(self, n: Node) -> Any:
        result = super().run_node(n)

        found_tensor = False

        def extract_tensor_meta(obj):
            if isinstance(obj, torch.Tensor):
                nonlocal found_tensor
                found_tensor = True
                return extract_tensor_metadata(obj)
            else:
                return obj

        meta = map_aggregate(result, extract_tensor_meta)
        if found_tensor:
            n.meta['tensor_meta'] = meta

        n.meta['type'] = type(result)
        return result
예제 #5
0
파일: normalize.py 프로젝트: zqzhai/pytorch
    def run_node(self, n: Node) -> Any:
        args, kwargs = self.fetch_args_kwargs_from_env(n)

        def get_type(arg):
            if isinstance(arg, fx.Node):
                return n.meta['type'] if 'type' in n.meta else None
            return type(arg)

        arg_types = map_aggregate(n.args, get_type)
        assert (isinstance(arg_types, tuple))
        arg_types = tuple([create_type_hint(i) for i in arg_types])
        kwarg_types = {k: get_type(v) for k, v in kwargs.items()}
        if n.op == 'call_function':
            out = self.call_function(n.target, args, kwargs, arg_types,
                                     kwarg_types)
        else:
            out = super().run_node(n)
        self.node_map[out] = n
        return out
예제 #6
0
def serialize_module(fx_module: GraphModule, weights: Dict, name_prefix="") -> Dict:
    """Recursively Serializes a graph module (fx_module) to a dictionary which is later exported to JSON.
    It also adds all weights the provided weights dictionary by qualified_name.
    Dictionary Schema:
    MODULE
    {
        modules: {module_name: MODULE],
        nodes: [NODE],
        weights {qualified_name: WEIGHT},
    }
    NODE
    {
        shape: [],
        dtype: dtype,
        is_quantized: bool,
        target: target,
        op_code: op_code,
        name: name,
        args: [],
        kwargs: {}
    }
    WEIGHT
    {
        dtype: dtype,
        is_quantized: bool,
        shape: [],
        QUANTIZATION,
    }
    QUANTIZATION
    {
        qscheme: qscheme,
        q_scale: float,
        q_zero_point: float,
        q_per_channel_scales, [],
        q_per_channel_zero_points: [],
        q_per_channel_axis, int
    }
    """
    serialized_dict: Dict[str, Any] = {}
    serialized_dict["modules"] = {}
    serialized_dict["weights"] = {}
    serialized_dict["nodes"] = []
    submodules = dict(fx_module.named_modules())
    prefix = f"{name_prefix}." if name_prefix else ""

    def add_weight_tensors(named_tensors):
        for name, p in named_tensors:
            if name.startswith("parent.") or not isinstance(p, torch.Tensor):
                continue
            weight = serialize_weight(p)
            serialized_dict["weights"][prefix + name] = weight
            weights[prefix + name] = p

    add_weight_tensors(fx_module.named_parameters())
    add_weight_tensors(fx_module.named_buffers())

    def get_node_info(node):
        shape, dtype = get_shape_and_dtype(node)
        tensor_meta = node.meta.get('tensor_meta')
        if not tensor_meta:
            raise RuntimeError(f'Node {node} has no tensor metadata! Ensure shape '
                               f'propagation has been run!')
        node_rep = {
            "shape": serialize_shape(shape),
            "dtype": str(dtype),
            "is_quantized": tensor_meta.is_quantized,
        }

        if tensor_meta.is_quantized:
            node_rep["qscheme"] = str(tensor_meta.qscheme)

            if tensor_meta.qscheme in {torch.per_tensor_affine, torch.per_tensor_symmetric}:
                node_rep["q_scale"] = tensor_meta.q_scale
                node_rep["q_zero_point"] = tensor_meta.q_zero_point

        return node_rep

    # Note: lift_lowering_attrs_to_nodes is only used to support leaf modules
    # that cannot currently be symbolically traced into, e.g. batch norm.
    lift_lowering_attrs_to_nodes(fx_module)
    for node in fx_module.graph.nodes:
        node_rep: Dict[str, Any] = {}
        # Get shape/type info, currently not needed for call_module node
        # whose target is a GraphModule and output node.
        if (
            not (
                node.op == "call_module"
                and isinstance(submodules[node.target], GraphModule)
            )
            and node.op != "output"
        ):
            node_rep.update(get_node_info(node))

        # Recurse down into any submodules we are calling.
        if node.op == "call_module":
            if isinstance(submodules[node.target], GraphModule):
                serialized_module = serialize_module(
                    getattr(fx_module, node.target), weights, node.target
                )
                serialized_dict["modules"][node.target] = serialized_module
            else:
                node_rep["parameters"] = serialize_leaf_module(
                    node,
                    serialized_dict["weights"],
                    weights,
                    prefix + node.target,
                )

        if node.op == "call_function":
            node_rep["target"] = _get_qualified_name(node.target)
        else:
            node_rep["target"] = str(node.target)

        # Make sure we capture all constants.
        if node.op == "get_attr":
            # If we are targeting a parent constant we update the target.
            if node.target.startswith("parent."):
                stripped_name = node.target[len("parent.") :]
                node.name = stripped_name
                node_rep["target"] = stripped_name
                weight = serialize_weight(weights[stripped_name])
                serialized_dict["weights"][stripped_name] = weight
            else:
                # Find the actual target parameter/buffer from the fx_module.
                submod_path, _, target_name = node.target.rpartition(".")
                submod: Optional[torch.nn.Module] = (
                    fx_module.get_submodule(submod_path) if submod_path else fx_module
                )
                assert submod is not None, f"submod {submod_path} not found"
                target = getattr(submod, target_name, None)
                assert target is not None, f"{target_name} not an attr of {submod_path}"
                qualname = prefix + node.target
                # Check that the target is a tensor, and that we haven't added it already from a leaf module.
                if isinstance(target, torch.Tensor) and qualname not in weights:
                    weight = serialize_weight(target)
                    serialized_dict["weights"][qualname] = weight
                    weights[qualname] = target

        node_rep["op_code"] = node.op
        node_rep["name"] = node.name

        def get_arg_info(arg: Argument) -> Any:
            if isinstance(arg, torch.fx.Node):
                return {"is_node": True, "name": str(arg)}
            elif isinstance(arg, torch.dtype):
                return str(arg)
            else:
                return arg

        def get_output_arg_info(arg: Node) -> Dict[str, Any]:
            node_rep: Dict[str, Any] = get_arg_info(arg)
            node_rep.update(get_node_info(arg))
            return node_rep

        if node.op == "output":
            node_rep["args"] = map_arg(
                node.args,
                get_output_arg_info,
            )

            # If there're multiple outputs then node_rep["args"][0] will be a tuple.
            # In this case we want to unpack the tuple.
            if isinstance(node_rep["args"][0], tuple):
                node_rep["args"] = node_rep["args"][0]
        else:
            node_rep["args"] = map_aggregate(
                node.args, get_arg_info
            )

        node_rep["kwargs"] = map_aggregate(
            node.kwargs, get_arg_info
        )
        serialized_dict["nodes"] += [node_rep]

    return serialized_dict
예제 #7
0
def serialize_module(fx_module: GraphModule, weights: Dict, name_prefix="") -> Dict:
    """Recursively Serializes a graph module (fx_module) to a dictionary which is later exported to JSON.
    It also adds all weights the provided weights dictionary by qualified_name.
    Dictionary Schema:
    MODULE
    {
        modules: {module_name: MODULE],
        nodes: [NODE],
        weights {qualified_name: WEIGHT},
    }
    NODE
    {
        shape: [],
        stride: [],
        dtype: dtype,
        is_quantized: bool,
        target: target,
        op_code: op_code,
        name: name,
        args: [],
        kwargs: {}
    }
    WEIGHT
    {
        dtype: dtype,
        is_quantized: bool,
        shape: [],
        QUANTIZATION,
    }
    QUANTIZATION
    {
        qscheme: qscheme,
        q_scale: float,
        q_zero_point: float,
        q_per_channel_scales, [],
        q_per_channel_zero_points: [],
        q_per_channel_axis, int
    }
    """
    serialized_dict: Dict[str, Any] = {}
    serialized_dict["modules"] = {}
    serialized_dict["weights"] = {}
    serialized_dict["nodes"] = []
    submodules = dict(fx_module.named_modules())
    prefix = f"{name_prefix}." if name_prefix else ""

    def add_weight_tensors(named_tensors):
        for name, p in named_tensors:
            if name.startswith("parent.") or not isinstance(p, torch.Tensor):
                continue
            weight_dict = serialize_weight(p, weights, prefix + name)
            serialized_dict["weights"].update(weight_dict)
            weights[prefix + name] = p

    add_weight_tensors(fx_module.named_parameters())
    add_weight_tensors(fx_module.named_buffers())

    def get_node_info(node):
        tensor_meta = get_tensor_meta(node)
        node_rep = {
            "shape": serialize_shape(tensor_meta.shape),
            "dtype": str(tensor_meta.dtype),
            "requires_grad": str(tensor_meta.requires_grad),
            "stride": serialize_stride(tensor_meta.stride),
            "is_quantized": tensor_meta.is_quantized,
        }

        if tensor_meta.is_quantized:
            node_rep["qscheme"] = str(tensor_meta.qscheme)

            if tensor_meta.qscheme in {
                torch.per_tensor_affine,
                torch.per_tensor_symmetric,
            }:
                node_rep["q_scale"] = tensor_meta.q_scale
                node_rep["q_zero_point"] = tensor_meta.q_zero_point

        return node_rep

    # Note: lift_lowering_attrs_to_nodes is only used to support leaf modules
    # that cannot currently be symbolically traced into, e.g. batch norm.
    lift_lowering_attrs_to_nodes(fx_module)
    for node in fx_module.graph.nodes:
        node_rep: Dict[str, Any] = {}
        # Get shape/type info, currently not needed for call_module node
        # whose target is a GraphModule and output node.
        if (
            not (
                node.op == "call_module"
                and isinstance(submodules[node.target], GraphModule)
            )
            and node.op != "output"
        ):
            node_rep.update(get_node_info(node))

        # Recurse down into any submodules we are calling.
        if node.op == "call_module":
            if isinstance(submodules[node.target], GraphModule):
                serialized_module = serialize_module(
                    getattr(fx_module, node.target), weights, node.target
                )
                serialized_dict["modules"][node.target] = serialized_module
            else:
                node_rep["parameters"] = serialize_leaf_module(
                    node,
                    serialized_dict["weights"],
                    weights,
                    prefix + node.target,
                )

        if node.op == "call_function":
            node_rep["target"] = _get_qualified_name(node.target)
        else:
            node_rep["target"] = str(node.target)

        # Make sure we capture all constants.
        if node.op == "get_attr":
            # If we are targeting a parent constant we update the target.
            if node.target.startswith("parent."):
                stripped_name = node.target[len("parent.") :]
                node.name = stripped_name
                node_rep["target"] = stripped_name
                weight = serialize_weight(
                    weights[stripped_name], weights, node.target[len("parent.") :]
                )
                # For quantized embedding tables we need to update the shape/type,
                # so we check if the users of this get_attr is a quantized EB and this is the weight for the EB.
                user_targets = {
                    _get_qualified_name(
                        n.target
                    ).replace("torch.fx.experimental.fx_acc.", "").replace("glow.fb.fx.", ""): n
                    for n in node.users.keys()
                }
                if (
                    "acc_ops.embedding_bag_byte_rowwise_offsets" in user_targets
                    and str(
                        user_targets[
                            "acc_ops.embedding_bag_byte_rowwise_offsets"
                        ].kwargs["weight"]
                    )
                    == stripped_name
                ):
                    weight[stripped_name]["dtype"] = "acc.uint8fused"
                # Same as above, but for the 4 bit version.
                if (
                    "acc_ops.embedding_bag_4bit_rowwise_offsets" in user_targets
                    and str(
                        user_targets[
                            "acc_ops.embedding_bag_4bit_rowwise_offsets"
                        ].kwargs["weight"]
                    )
                    == stripped_name
                ):
                    weight[stripped_name]["dtype"] = "acc.uint4fused"

                serialized_dict["weights"].update(weight)
            else:
                # Find the actual target parameter/buffer from the fx_module.
                submod_path, _, target_name = node.target.rpartition(".")
                submod: Optional[torch.nn.Module] = (
                    fx_module.get_submodule(submod_path) if submod_path else fx_module
                )
                assert submod is not None, f"submod {submod_path} not found"
                target = getattr(submod, target_name, None)
                assert target is not None, f"{target_name} not an attr of {submod_path}"
                qualname = prefix + node.target
                # Check that the target is a tensor, and that we haven't added it already from a leaf module.
                if isinstance(target, torch.Tensor) and qualname not in weights:
                    weight = serialize_weight(target, weights, qualname)
                    serialized_dict["weights"].update(weight)
                    weights[qualname] = target

        node_rep["op_code"] = node.op
        node_rep["name"] = node.name

        def get_user_info(user_node: Argument) -> Any:
            return {"is_node": True, "name": str(user_node)}

        def get_arg_info(arg: Argument) -> Any:
            if isinstance(arg, torch.fx.Node):
                return {"is_node": True, "name": str(arg)}
            elif isinstance(arg, (torch.dtype, torch.memory_format, torch.qscheme)):
                return str(arg)
            else:
                return arg

        def get_output_arg_info(arg: Node) -> Dict[str, Any]:
            node_rep: Dict[str, Any] = get_arg_info(arg)
            node_rep.update(get_node_info(arg))
            return node_rep

        if node.op == "output":
            node_rep["args"] = map_arg(
                node.args,
                get_output_arg_info,
            )

            # If there're multiple outputs then node_rep["args"][0] will be a tuple.
            # In this case we want to unpack the tuple.
            if isinstance(node_rep["args"][0], tuple):
                node_rep["args"] = node_rep["args"][0]
        else:
            node_rep["args"] = map_aggregate(node.args, get_arg_info)

        node_rep["kwargs"] = map_aggregate(node.kwargs, get_arg_info)
        node_rep["users"] = map_aggregate(list(node.users.keys()), get_user_info)
        serialized_dict["nodes"] += [node_rep]

    return serialized_dict
예제 #8
0
        def __call__(self, *args, **kwargs):
            new_args = map_aggregate(args, convert_to_dispatch_proxy)
            new_kwargs = map_aggregate(kwargs, convert_to_dispatch_proxy)
            orig_module_call = torch.nn.Module.__call__
            orig_nn_sequential_forward = torch.nn.Sequential.forward

            def _patched_module_call(self, *args, **kwargs):
                if enable_logging:
                    fqn = module_id_to_fqn.get(id(self), None)
                    logger.debug(f"\nstarting fqn {fqn}")

                nonlocal cur_module
                old_module = cur_module
                cur_module = self
                try:
                    parent_module = module_stack[-1] if len(
                        module_stack) else None
                    module_stack.append(self)
                    hook_type = get_module_hook_type(parent_module, cur_module)
                    if enable_logging:
                        logger.debug(
                            f"_patched_module_call {type(self)} " +
                            # f"arg_types {[type(arg) for arg in args]} " +
                            f"arg_dtypes {[arg.dtype if isinstance(arg, torch.Tensor) else None for arg in args]} "
                            + f"hook_type {hook_type}")

                    if hook_type is HookType.OP_HOOKS:
                        # before hooks
                        assert parent_module is not None
                        assert isinstance(parent_module._auto_quant_state,
                                          AutoQuantizationState)
                        qstate = parent_module._auto_quant_state
                        if enable_logging:
                            logger.debug(qstate)
                        qstate.validate_cur_op(cur_module)
                        _, args, kwargs = qstate.op_convert_before_hook(
                            cur_module, args, kwargs, cur_module)
                        # forward
                        output = orig_module_call(self, *args, **kwargs)
                        # after hooks
                        output = qstate.op_convert_after_hook(
                            cur_module, output)
                        qstate.mark_cur_op_complete(cur_module)

                    elif hook_type is HookType.MODULE_IO_HOOKS:
                        cur_qstate = cur_module._auto_quant_state
                        if enable_logging:
                            logger.debug(cur_qstate)

                        cur_qstate.validate_is_at_first_idx()

                        # before hooks (TODO)
                        # forward
                        output = orig_module_call(self, *args, **kwargs)
                        # after hooks
                        assert isinstance(cur_qstate, AutoQuantizationState)
                        output = cur_qstate.outputs_convert_hook(output)
                        cur_qstate.validate_is_at_last_seen_idx()

                    elif hook_type is HookType.ARG_DEQUANTS:
                        # disabling torch function to prevent infinite recursion on
                        # getset
                        # TODO(future PR): handle more dtypes
                        with torch._C.DisableTorchFunction():
                            new_args = []
                            for arg in args:
                                if isinstance(
                                        arg,
                                        torch.Tensor) and arg.is_quantized:
                                    dequant = arg.dequantize().as_subclass(
                                        QuantizationConvertTensorProxy
                                    )  # type: ignore[arg-type]
                                    new_args.append(dequant)
                                else:
                                    new_args.append(arg)
                            args = tuple(new_args)
                        output = orig_module_call(self, *args, **kwargs)

                    else:
                        output = orig_module_call(self, *args, **kwargs)

                    if enable_logging:
                        logger.debug(
                            f"_patched_module_call {type(self)} " +
                            # f"out {type(output)} " +
                            f"dtype {output.dtype if isinstance(output, torch.Tensor) else None} "
                            + "end")
                        logger.debug(f"ending fqn {fqn}\n")
                    return output
                finally:
                    module_stack.pop()
                    cur_module = old_module

            torch.nn.Module.__call__ = _patched_module_call
            torch.nn.Sequential.forward = _nn_sequential_patched_forward  # type: ignore[assignment]

            try:
                for k, v in self.named_modules():
                    module_id_to_fqn[id(v)] = k
                    if hasattr(v, '_auto_quant_state'):
                        v._auto_quant_state.reset_to_new_call()

                needs_io_hooks = hasattr(self, '_auto_quant_state')

                # handle module input dtype conversions
                # TODO(implement)

                output = super().__call__(*new_args, **new_kwargs)

                # handle module output dtype conversions
                if needs_io_hooks:
                    qstate = self._auto_quant_state
                    assert isinstance(qstate, AutoQuantizationState)
                    output = qstate.outputs_convert_hook(output)

                def unwrap_proxy(a):
                    if isinstance(a, QuantizationConvertTensorProxy):
                        a.__class__ = torch.Tensor  # type: ignore[assignment]
                    return a

                output = map_aggregate(output, unwrap_proxy)
                return output
            finally:
                torch.nn.Module.__call__ = orig_module_call
                torch.nn.Sequential.forward = orig_nn_sequential_forward  # type: ignore[assignment]
예제 #9
0
        def __call__(self, *args, **kwargs):
            new_args = map_aggregate(args, convert_to_interception_proxy)
            new_kwargs = map_aggregate(kwargs, convert_to_interception_proxy)
            orig_module_call = torch.nn.Module.__call__
            orig_nn_sequential_forward = torch.nn.Sequential.forward

            def _patched_module_call(self, *args, **kwargs):

                if enable_logging:
                    logger.debug(f"_patched_module_call: {type(self)}")

                nonlocal cur_module
                old_module = cur_module
                cur_module = self
                try:
                    parent_module = module_stack[-1] if len(
                        module_stack) else None
                    module_stack.append(self)
                    fqn = module_id_to_fqn.get(id(self), None)

                    if enable_logging:
                        fqn = module_id_to_fqn.get(id(self), None)
                        logger.debug(f"\nstarting fqn {fqn}")

                    hook_type = get_module_hook_type(parent_module, cur_module)

                    if hook_type is HookType.OP_HOOKS:
                        assert parent_module is not None
                        parent_qstate = parent_module._auto_quant_state
                        assert isinstance(parent_qstate, AutoQuantizationState)
                        # before hooks
                        if not first_call:
                            parent_qstate.validate_cur_op(cur_module)
                        args, kwargs = parent_qstate.op_prepare_before_hook(
                            cur_module, args, kwargs, first_call, qtensor_id,
                            fqn, cur_module)

                        # original forward
                        output = orig_module_call(self, *args, **kwargs)

                        # after hooks
                        # TODO is it correct to call_cur_module twice here?
                        output = parent_qstate.op_prepare_after_hook(
                            cur_module, output, args, first_call, qtensor_id,
                            cur_module)
                        parent_qstate.mark_cur_op_complete(cur_module)

                    elif hook_type is HookType.MODULE_IO_HOOKS:
                        # TODO(future PR): add inputs io hook

                        cur_qstate = cur_module._auto_quant_state
                        cur_qstate.validate_is_at_first_idx()

                        # original forward
                        output = orig_module_call(self, *args, **kwargs)

                        # after hooks
                        assert isinstance(cur_qstate, AutoQuantizationState)
                        output = cur_qstate.outputs_prepare_hook(
                            output, first_call, qtensor_id)
                        cur_qstate.validate_is_at_last_seen_idx()

                    elif hook_type is HookType.ARG_DEQUANTS:
                        output = orig_module_call(self, *args, **kwargs)
                        # if this fp32 was inplace, make sure to set the output dtype
                        # back to torch.float
                        if hasattr(output, '_qtensor_info'):
                            del output._qtensor_info

                    else:
                        output = orig_module_call(self, *args, **kwargs)

                    if enable_logging:
                        fqn = module_id_to_fqn.get(id(self), None)
                        logger.debug(f"\nending fqn {fqn}")

                    return output
                finally:
                    module_stack.pop()
                    cur_module = old_module

            torch.nn.Module.__call__ = _patched_module_call
            torch.nn.Sequential.forward = _nn_sequential_patched_forward  # type: ignore[assignment]
            nonlocal first_call
            try:
                # Create a list before iterating because we are adding new
                # named modules inside the loop.
                named_modules = list(self.named_modules())
                for k, v in named_modules:

                    # k is the global FQN, i.e. 'foo.bar.baz'
                    # v is the module instance
                    #
                    # we need to associate the global FQN with SeenOp
                    # for modules, this is the module FQN
                    # for functions, this is the parent module FQN
                    module_id_to_fqn[id(v)] = k

                    has_qconfig = hasattr(v,
                                          'qconfig') and v.qconfig is not None
                    if has_qconfig and not is_leaf(v):
                        if first_call:
                            if v is self:
                                # for the top level module only, specify input
                                # and output dtypes
                                v._auto_quant_state = AutoQuantizationState(
                                    v.qconfig, input_dtypes, output_dtypes)
                                pass
                            else:
                                v._auto_quant_state = AutoQuantizationState(
                                    v.qconfig)
                        else:
                            if not isinstance(v, AutoQuantizationState):
                                assert hasattr(v, '_auto_quant_state')
                                v._auto_quant_state.reset_to_new_call()

                output = super().__call__(*new_args, **new_kwargs)
                return output
            finally:
                torch.nn.Module.__call__ = orig_module_call
                torch.nn.Sequential.forward = orig_nn_sequential_forward  # type: ignore[assignment]
                first_call = False
예제 #10
0
        def __call__(self, *args, **kwargs):
            new_args = map_aggregate(args, convert_to_dispatch_proxy)
            new_kwargs = map_aggregate(kwargs, convert_to_dispatch_proxy)
            orig_module_call = torch.nn.Module.__call__
            orig_nn_sequential_forward = torch.nn.Sequential.forward

            def _patched_module_call(self, *args, **kwargs):
                nonlocal cur_module
                old_module = cur_module
                cur_module = self
                nonlocal global_disable_torch_function_override
                try:
                    parent_module = module_stack[-1] if len(
                        module_stack) else None
                    module_stack.append(self)
                    hook_type = get_module_hook_type(parent_module, cur_module)
                    if enable_logging:
                        fqn_for_logging = module_id_to_fqn.get(id(self), None)
                        logger.debug(
                            f" fqn: {fqn_for_logging} " +
                            f"_cl_ {type(self)} " +
                            f"arg_dtypes {[arg.dtype if isinstance(arg, torch.Tensor) else None for arg in args]} "
                            + f"hook_type {hook_type}")

                    if hook_type is HookType.OP_HOOKS:
                        # before hooks
                        qstate: AutoQuantizationState = \
                            parent_module._auto_quant_state  # type: ignore[union-attr, assignment]
                        qstate.validate_cur_op(cur_module)

                        # If we are in this hook, `cur_module` is a leaf module.
                        # Therefore, we do not need to override any of its
                        # children. Disabling the overrides for performance.
                        old_global_disable_torch_function_override = \
                            global_disable_torch_function_override
                        global_disable_torch_function_override = True

                        _, args, kwargs = qstate.op_convert_before_hook(
                            cur_module, args, kwargs, cur_module)
                        # forward
                        output = orig_module_call(self, *args, **kwargs)
                        # after hooks
                        output = qstate.op_convert_after_hook(
                            cur_module, output, global_op_idx)

                        # Re-enable the override.
                        global_disable_torch_function_override = \
                            old_global_disable_torch_function_override

                        qstate.mark_cur_op_complete(cur_module)

                    elif hook_type is HookType.MODULE_IO_HOOKS:
                        cur_qstate: AutoQuantizationState = cur_module._auto_quant_state

                        cur_qstate.reset_to_new_call()

                        # before hooks (TODO)
                        # forward
                        output = orig_module_call(self, *args, **kwargs)
                        # after hooks

                        # For the sake of performance, we assume no overrides
                        # are needed for quantizing/dequantizing things
                        old_global_disable_torch_function_override = \
                            global_disable_torch_function_override
                        global_disable_torch_function_override = True

                        output = cur_qstate.outputs_convert_hook(output)

                        global_disable_torch_function_override = \
                            old_global_disable_torch_function_override

                        cur_qstate.validate_is_at_last_seen_idx()

                    elif hook_type is HookType.ARG_DEQUANTS:
                        # TODO(future PR): handle more dtypes
                        new_args = []
                        for arg in args:
                            if isinstance(arg,
                                          torch.Tensor) and arg.is_quantized:
                                dequant = arg.dequantize().as_subclass(
                                    QuantizationConvertTensorProxy
                                )  # type: ignore[arg-type]
                                new_args.append(dequant)
                            else:
                                new_args.append(arg)
                        args = tuple(new_args)
                        output = orig_module_call(self, *args, **kwargs)

                    else:
                        output = orig_module_call(self, *args, **kwargs)

                    if enable_logging:
                        fqn_for_logging = module_id_to_fqn.get(id(self), None)
                        logger.debug(
                            f" fqn: {fqn_for_logging} " +
                            f"_cl_ {type(self)} " +
                            f"dtype {output.dtype if isinstance(output, torch.Tensor) else None} "
                            + "end")
                    return output
                finally:
                    module_stack.pop()
                    cur_module = old_module

            torch.nn.Module.__call__ = _patched_module_call
            torch.nn.Sequential.forward = _nn_sequential_patched_forward  # type: ignore[assignment]

            try:
                global_op_idx[0] = 0
                output = super().__call__(*new_args, **new_kwargs)

                def unwrap_proxy(a):
                    if isinstance(a, QuantizationConvertTensorProxy):
                        a.__class__ = torch.Tensor  # type: ignore[assignment]
                    return a

                output = map_aggregate(output, unwrap_proxy)
                return output
            finally:
                torch.nn.Module.__call__ = orig_module_call
                torch.nn.Sequential.forward = orig_nn_sequential_forward  # type: ignore[assignment]
예제 #11
0
        def __call__(self, *args, **kwargs):
            new_args = map_aggregate(args, convert_to_interception_proxy)
            new_kwargs = map_aggregate(kwargs, convert_to_interception_proxy)
            orig_module_call = torch.nn.Module.__call__
            orig_nn_sequential_forward = torch.nn.Sequential.forward

            def _patched_module_call(self, *args, **kwargs):

                if enable_logging:
                    fqn = module_id_to_fqn.get(id(self), None)
                    logger.debug(f" fqn:{fqn} _cl_: {type(self)} start")

                nonlocal cur_module
                old_module = cur_module
                cur_module = self
                try:
                    parent_module = module_stack[-1] if len(
                        module_stack) else None
                    module_stack.append(self)
                    fqn = module_id_to_fqn.get(id(self), None)

                    hook_type = get_module_hook_type(parent_module, cur_module)

                    if hook_type is HookType.OP_HOOKS:
                        parent_qstate: AutoQuantizationState = \
                            parent_module._auto_quant_state  # type: ignore[union-attr, assignment]
                        # before hooks
                        if not first_call:
                            parent_qstate.validate_cur_op(cur_module)

                        # If we are in this hook, `cur_module` is a leaf module.
                        # Therefore, we do not need to override any of its
                        # children. Disabling the overrides for performance.
                        nonlocal global_disable_torch_function_override
                        old_global_disable_torch_function_override = \
                            global_disable_torch_function_override
                        global_disable_torch_function_override = True

                        if first_call:
                            # mypy ignore is used instead of assert because this
                            # runs on every forward and assert has a performance cost
                            args, kwargs = parent_qstate.first_call_op_prepare_before_hook(
                                cur_module,
                                args,
                                kwargs,
                                qtensor_id,
                                fqn,
                                cur_module,  # type: ignore[arg-type]
                                OpQuantizeabilityType.QUANTIZEABLE)
                        else:
                            # mypy ignore is used instead of assert because this
                            # runs on every forward and assert has a performance cost
                            args, kwargs = parent_qstate.op_prepare_before_hook(
                                cur_module, args,
                                kwargs)  # type: ignore[arg-type]

                        # original forward
                        output = orig_module_call(self, *args, **kwargs)

                        # Re-enable the overrides.
                        global_disable_torch_function_override = \
                            old_global_disable_torch_function_override

                        # after hooks
                        if first_call:
                            output = parent_qstate.first_call_op_prepare_after_hook(
                                cur_module, output, args, qtensor_id,
                                OpQuantizeabilityType.QUANTIZEABLE)
                        else:
                            output = parent_qstate.op_prepare_after_hook(
                                cur_module, output, args, global_op_idx)
                        parent_qstate.mark_cur_op_complete(cur_module)

                    elif hook_type is HookType.MODULE_IO_HOOKS:
                        # TODO(future PR): add inputs io hook

                        cur_qstate = cur_module._auto_quant_state
                        cur_qstate.reset_to_new_call()

                        # original forward
                        output = orig_module_call(self, *args, **kwargs)

                        # after hooks
                        if first_call:
                            output = cur_qstate.first_call_outputs_prepare_hook(
                                output, qtensor_id)
                        else:
                            output = cur_qstate.outputs_prepare_hook(output)

                        cur_qstate.validate_is_at_last_seen_idx()

                    elif hook_type is HookType.ARG_DEQUANTS:
                        if first_call and parent_module is not None:
                            parent_qstate_fc = getattr(parent_module,
                                                       '_auto_quant_state',
                                                       None)
                            if parent_qstate_fc:
                                args, kwargs = \
                                    parent_qstate_fc.first_call_op_prepare_before_hook(
                                        cur_module, args, kwargs, qtensor_id, fqn,
                                        cur_module,
                                        OpQuantizeabilityType.NOT_QUANTIZEABLE)

                        output = orig_module_call(self, *args, **kwargs)
                        # if this fp32 was inplace, make sure to set the output dtype
                        # back to torch.float
                        if hasattr(output, '_qtensor_info'):
                            del output._qtensor_info

                        if first_call and parent_module is not None:
                            parent_qstate_fc = getattr(parent_module,
                                                       '_auto_quant_state',
                                                       None)
                            if parent_qstate_fc:
                                output = \
                                    parent_qstate_fc.first_call_op_prepare_after_hook(
                                        cur_module, output, args, qtensor_id,
                                        OpQuantizeabilityType.NOT_QUANTIZEABLE)

                    else:
                        output = orig_module_call(self, *args, **kwargs)

                    if enable_logging:
                        fqn = module_id_to_fqn.get(id(self), None)
                        logger.debug(f" fqn:{fqn} _cl_: {type(self)} end")

                    return output
                finally:
                    module_stack.pop()
                    cur_module = old_module

            torch.nn.Module.__call__ = _patched_module_call
            torch.nn.Sequential.forward = _nn_sequential_patched_forward  # type: ignore[assignment]
            nonlocal first_call
            try:
                if first_call:
                    # Create a list before iterating because we are adding new
                    # named modules inside the loop.
                    named_modules = list(self.named_modules())

                    # Record module instances which are leaves or children of leaves
                    leaves = set()
                    for fqn, child in named_modules:
                        if is_leaf(child, prepare_custom_config_dict):
                            for _, child_child in child.named_modules():
                                leaves.add(child_child)

                    self._fqn_to_auto_quant_state_map = AutoQuantizationStateModuleDict(
                    )

                    for fqn, v in named_modules:

                        # fqn is the global FQN, i.e. 'foo.bar.baz'
                        # v is the module instance
                        #
                        # we need to associate the global FQN with SeenOp
                        # for modules, this is the module FQN
                        # for functions, this is the parent module FQN
                        module_id_to_fqn[id(v)] = fqn

                        if v in leaves:
                            continue

                        if v is self:
                            # for the top level module only, specify input
                            # and output dtypes
                            auto_quant_state = AutoQuantizationState(
                                qconfig_dict, fqn, input_dtypes, output_dtypes)
                        else:
                            auto_quant_state = AutoQuantizationState(
                                qconfig_dict, fqn)

                        # The code below registers the auto_quant_state object
                        # of the child in the module hierarchy of the parent,
                        # and adds the auto_quant_state object to the child
                        # with a raw __setattr__, without registering it in
                        # the module hierarchy of the child.
                        # This is solving the problem of both storing extra state
                        # (observers) as well as not modifying the meaning of user
                        # code in child modules which iterates over all module
                        # children.
                        #
                        # This narrows down the issue of dynamically adding
                        # children to only affect the top level module and not
                        # the children.

                        # On the parent, register this module in the FQN map
                        fqn_to_use_for_key = \
                            get_fqn_valid_for_module_dict_key(fqn)
                        self._fqn_to_auto_quant_state_map[fqn_to_use_for_key] = \
                            auto_quant_state
                        # On the child, manually set the attribute without
                        # going through the `torch.nn.Module.__setattr__`
                        # function, to prevent this object from appearing in
                        # the child's module hierarchy.
                        object.__setattr__(v, '_auto_quant_state',
                                           auto_quant_state)

                global_op_idx[0] = 0

                output = super().__call__(*new_args, **new_kwargs)

                if first_call:
                    for _, v in self.named_modules():
                        if hasattr(v, '_auto_quant_state'):
                            v._auto_quant_state.match_fusion_patterns()
                            v._auto_quant_state.insert_observers(v)

                return output
            finally:
                torch.nn.Module.__call__ = orig_module_call
                torch.nn.Sequential.forward = orig_nn_sequential_forward  # type: ignore[assignment]
                first_call = False
예제 #12
0
        def __call__(self, *args, **kwargs):
            new_args = map_aggregate(args, convert_to_interception_proxy)
            new_kwargs = map_aggregate(kwargs, convert_to_interception_proxy)
            orig_module_call = torch.nn.Module.__call__
            orig_nn_sequential_forward = torch.nn.Sequential.forward

            def _patched_module_call(self, *args, **kwargs):

                if enable_logging:
                    fqn = module_id_to_fqn.get(id(self), None)
                    logger.debug(f" fqn:{fqn} _cl_: {type(self)} start")

                nonlocal cur_module
                old_module = cur_module
                cur_module = self
                try:
                    parent_module = module_stack[-1] if len(
                        module_stack) else None
                    module_stack.append(self)
                    fqn = module_id_to_fqn.get(id(self), None)

                    hook_type = get_module_hook_type(parent_module, cur_module)

                    if first_call and hook_type is not HookType.OP_HOOKS and \
                            parent_module is not None:
                        parent_qstate_fc = getattr(parent_module,
                                                   '_auto_quant_state', None)
                        if parent_qstate_fc:
                            parent_qstate_fc.add_seen_op_type_without_op_hooks(
                                type(cur_module))

                    if hook_type is HookType.OP_HOOKS:
                        parent_qstate: AutoQuantizationState = \
                            parent_module._auto_quant_state  # type: ignore[union-attr, assignment]
                        # before hooks
                        if not first_call:
                            parent_qstate.validate_cur_op(cur_module)

                        # If we are in this hook, `cur_module` is a leaf module.
                        # Therefore, we do not need to override any of its
                        # children. Disabling the overrides for performance.
                        nonlocal global_disable_torch_function_override
                        old_global_disable_torch_function_override = \
                            global_disable_torch_function_override
                        global_disable_torch_function_override = True

                        # mypy ignore is used instead of assert because this
                        # runs on every forward and assert has a performance cost
                        args, kwargs = parent_qstate.op_prepare_before_hook(
                            cur_module, args, kwargs, first_call, qtensor_id,
                            fqn, cur_module)  # type: ignore[arg-type]

                        # original forward
                        output = orig_module_call(self, *args, **kwargs)

                        # Re-enable the overrides.
                        global_disable_torch_function_override = \
                            old_global_disable_torch_function_override

                        # after hooks
                        output = parent_qstate.op_prepare_after_hook(
                            cur_module, output, args, first_call, qtensor_id,
                            global_op_idx)
                        parent_qstate.mark_cur_op_complete(cur_module)

                    elif hook_type is HookType.MODULE_IO_HOOKS:
                        # TODO(future PR): add inputs io hook

                        cur_qstate = cur_module._auto_quant_state
                        cur_qstate.reset_to_new_call()

                        # original forward
                        output = orig_module_call(self, *args, **kwargs)

                        # after hooks
                        output = cur_qstate.outputs_prepare_hook(
                            output, first_call, qtensor_id)
                        cur_qstate.validate_is_at_last_seen_idx()

                    elif hook_type is HookType.ARG_DEQUANTS:
                        output = orig_module_call(self, *args, **kwargs)
                        # if this fp32 was inplace, make sure to set the output dtype
                        # back to torch.float
                        if hasattr(output, '_qtensor_info'):
                            del output._qtensor_info

                    else:
                        output = orig_module_call(self, *args, **kwargs)

                    if enable_logging:
                        fqn = module_id_to_fqn.get(id(self), None)
                        logger.debug(f" fqn:{fqn} _cl_: {type(self)} end")

                    return output
                finally:
                    module_stack.pop()
                    cur_module = old_module

            torch.nn.Module.__call__ = _patched_module_call
            torch.nn.Sequential.forward = _nn_sequential_patched_forward  # type: ignore[assignment]
            nonlocal first_call
            try:
                if first_call:
                    # Create a list before iterating because we are adding new
                    # named modules inside the loop.
                    named_modules = list(self.named_modules())

                    # Record module instances which are leaves or children of leaves
                    leaves = set()
                    for fqn, child in named_modules:
                        if is_leaf(child, prepare_custom_config_dict):
                            for _, child_child in child.named_modules():
                                leaves.add(child_child)

                    for fqn, v in named_modules:

                        # fqn is the global FQN, i.e. 'foo.bar.baz'
                        # v is the module instance
                        #
                        # we need to associate the global FQN with SeenOp
                        # for modules, this is the module FQN
                        # for functions, this is the parent module FQN
                        module_id_to_fqn[id(v)] = fqn

                        if v in leaves:
                            continue

                        if v is self:
                            # for the top level module only, specify input
                            # and output dtypes
                            v._auto_quant_state = AutoQuantizationState(
                                qconfig_dict, fqn, input_dtypes, output_dtypes)
                            pass
                        else:
                            v._auto_quant_state = AutoQuantizationState(
                                qconfig_dict, fqn)

                global_op_idx[0] = 0

                output = super().__call__(*new_args, **new_kwargs)

                if first_call:
                    for _, v in self.named_modules():
                        if hasattr(v, '_auto_quant_state'):
                            v._auto_quant_state.insert_observers(v)

                return output
            finally:
                torch.nn.Module.__call__ = orig_module_call
                torch.nn.Sequential.forward = orig_nn_sequential_forward  # type: ignore[assignment]
                first_call = False
예제 #13
0
def serialize_module(fx_module: GraphModule,
                     weights: Dict,
                     name_prefix="") -> Dict:
    """Recursively Serializes a graph module (fx_module) to a dictionary which is later exported to JSON.
    It also adds all weights the provided weights dictionary by qualified_name.
    Dictionary Schema:
    MODULE
    {
        modules: {module_name: MODULE],
        nodes: [NODE],
        weights {qualified_name: WEIGHT},
    }
    NODE
    {
        shape: [],
        stride: [],
        dtype: dtype,
        is_quantized: bool,
        target: target,
        op_code: op_code,
        name: name,
        args: [],
        kwargs: {}
    }
    WEIGHT
    {
        dtype: dtype,
        is_quantized: bool,
        shape: [],
        QUANTIZATION,
    }
    QUANTIZATION
    {
        qscheme: qscheme,
        q_scale: float,
        q_zero_point: float,
        q_per_channel_scales, [],
        q_per_channel_zero_points: [],
        q_per_channel_axis, int
    }
    """
    serialized_dict: Dict[str, Any] = {}
    serialized_dict["modules"] = {}
    serialized_dict["weights"] = {}
    serialized_dict["nodes"] = []
    submodules = dict(fx_module.named_modules())
    prefix = f"{name_prefix}." if name_prefix else ""

    def get_node_info(node):
        tensor_meta = get_tensor_meta(node)
        node_rep = {
            "shape": serialize_shape(tensor_meta.shape),
            "dtype": str(tensor_meta.dtype),
            "requires_grad": str(tensor_meta.requires_grad),
            "stride": serialize_stride(tensor_meta.stride),
            "is_quantized": tensor_meta.is_quantized,
        }

        if tensor_meta.is_quantized:
            node_rep["qscheme"] = str(tensor_meta.qparams["qscheme"])

            if tensor_meta.qparams["qscheme"] in {
                    torch.per_tensor_affine,
                    torch.per_tensor_symmetric,
            }:
                node_rep["q_scale"] = tensor_meta.qparams["scale"]
                node_rep["q_zero_point"] = tensor_meta.qparams["zero_point"]

        # Add all extra lowering_info that was provided in node.meta.
        lowering_info = node.meta.get("lowering_info")
        if lowering_info is not None:
            overlapping_keys = node_rep.keys() & lowering_info.keys()
            assert (
                len(overlapping_keys) == 0
            ), f"Overlap found between lowering_info and node_rep: {overlapping_keys}"
            node_rep.update(lowering_info)

        return node_rep

    # Note: lift_lowering_attrs_to_nodes is only used to support leaf modules
    # that cannot currently be symbolically traced into, e.g. batch norm.
    lift_lowering_attrs_to_nodes(fx_module)
    for node in fx_module.graph.nodes:
        node_rep: Dict[str, Any] = {}
        # Get shape/type info, currently not needed for call_module node
        # whose target is a GraphModule and output node.
        if (not (node.op == "call_module"
                 and isinstance(submodules[node.target], GraphModule))
                and node.op != "output"):
            node_rep.update(get_node_info(node))

        # Recurse down into any submodules we are calling.
        if node.op == "call_module":
            if isinstance(submodules[node.target], GraphModule):
                serialized_module = serialize_module(
                    getattr(fx_module, node.target), weights, node.target)
                serialized_dict["modules"][node.target] = serialized_module
            else:
                node_rep["parameters"] = serialize_leaf_module(
                    node,
                    serialized_dict["weights"],
                    weights,
                    prefix + node.target,
                )

        if node.op == "call_function":
            node_rep["target"] = _get_qualified_name(node.target)
        else:
            node_rep["target"] = str(node.target)

        # Make sure we capture all constants.
        if node.op == "get_attr":
            # If we are targeting a parent constant we update the target.
            if node.target.startswith("parent."):
                qualname = node.target[len("parent."):]
                node.name = qualname
                node_rep["target"] = qualname
            else:
                qualname = prefix + node.target
            # Find the actual target parameter/buffer from the fx_module.
            submod_path, _, target_name = node.target.rpartition(".")
            submod: Optional[torch.nn.Module] = (
                fx_module.get_submodule(submod_path)
                if submod_path else fx_module)
            assert submod is not None, f"submod {submod_path} not found"
            target = getattr(submod, target_name, None)
            assert target is not None, f"{target_name} not an attr of {submod_path}"
            # Check that the target is a tensor, and that we haven't added it already from a leaf module.
            if isinstance(target, torch.Tensor) and qualname not in weights:
                weight = serialize_weight(target, weights, qualname)
                _update_weight_fused_dtypes(weight, qualname, node)
                serialized_dict["weights"].update(weight)
                weights[qualname] = target
        elif node.op == "placeholder":
            ph_type = node.meta.get("ph_type", "")
            assert (
                ph_type == "" or ph_type == "input_ph"
                or ph_type == "output_ph"
            ), "When present, placeholder type must be 'input_ph' or 'ouput_ph'"
            if ph_type == "input_ph":
                node_rep["ph_type"] = "input_ph"
            elif ph_type == "output_ph":
                node_rep["ph_type"] = "output_ph"

        node_rep["op_code"] = node.op
        node_rep["name"] = node.name

        def get_user_info(user_node: Argument) -> Any:
            return {"is_node": True, "name": str(user_node)}

        def get_arg_info(arg: Argument) -> Any:
            if isinstance(arg, torch.fx.Node):
                return {"is_node": True, "name": str(arg)}
            elif isinstance(arg,
                            (torch.dtype, torch.memory_format, torch.qscheme)):
                return str(arg)
            else:
                return arg

        def get_output_arg_info(arg: Node) -> Dict[str, Any]:
            node_rep: Dict[str, Any] = get_arg_info(arg)
            node_rep.update(get_node_info(arg))
            return node_rep

        if node.op == "output":
            node_rep["args"] = map_arg(
                node.args,
                get_output_arg_info,
            )

            # If there're multiple outputs then node_rep["args"][0] will be a tuple or
            # list. In this case we want to unpack the tuple or list.
            if isinstance(node_rep["args"][0], (tuple, list)):
                node_rep["args"] = node_rep["args"][0]
        else:
            node_rep["args"] = map_aggregate(node.args, get_arg_info)

        node_rep["kwargs"] = map_aggregate(node.kwargs, get_arg_info)
        node_rep["users"] = map_aggregate(list(node.users.keys()),
                                          get_user_info)
        serialized_dict["nodes"] += [node_rep]

    return serialized_dict