def test_fetch(self):
        attrs_for_lowering: Dict[str, List[str]] = {
            "torch.nn.modules.conv.Conv2d": [
                "weight", "bias", "kernel_size", "stride", "padding",
                "dilation", "groups", "padding_mode"
            ],
            "torch.nn.modules.batchnorm.BatchNorm2d":
            ["weight", "bias", "running_mean", "running_var", "eps"],
        }

        class TestModule(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.conv = torch.nn.Conv2d(3, 3, 2)
                self.bn = torch.nn.BatchNorm2d(3)

            def forward(self, a):
                a = self.conv(a)
                a += a
                return self.bn(a)

        mod = TestModule()
        traced = symbolic_trace(mod)
        lift_lowering_attrs_to_nodes(traced)

        for node in traced.graph.nodes:
            if node.op == "call_module":
                assert hasattr(node, "attrs_for_lowering")
                para_list = attrs_for_lowering[node.attrs_for_lowering["name"]]

                # node.attrs_for_lowering has an addition field of class name
                assert len(para_list) + 1 == len(node.attrs_for_lowering)
                for p_name in para_list:
                    assert p_name in node.attrs_for_lowering
Пример #2
0
def serialize_module(fx_module: GraphModule,
                     weights: Dict,
                     name_prefix="") -> Dict:
    """Recursively Serializes a graph module (fx_module) to a dictionary which is later exported to JSON.
    It also adds all weights the provided weights dictionary by qualified_name.
    Dictionary Schema:
    MODULE
    {
        modules: {module_name: MODULE],
        nodes: [NODE],
        weights {qualified_name: WEIGHT},
    }
    NODE
    {
        shape: [],
        dtype: dtype,
        target: target,
        op_code: op_code,
        name: name,
        args: [],
        kwargs: {}
    }
    WEIGHT
    {
        dtype: dtype,
        is_quantized: bool,
        shape: [],
        quantization_info: QUANTIZATION
    }
    QUANTIZATION
    {
        qscheme: qscheme,
        q_scale: float,
        q_zero_point: float,
        q_per_channel_scales, [],
        q_per_channel_zero_points: [],
        q_per_channel_axis, int
    }
    """
    serialized_dict: Dict[str, Any] = {}
    serialized_dict["modules"] = {}
    serialized_dict["weights"] = {}
    serialized_dict["nodes"] = []
    parameters = fx_module.named_parameters()
    prefix = f"{name_prefix}." if name_prefix else ""
    submodules = dict(fx_module.named_modules())
    for name, p in parameters:
        if isinstance(p, torch.Tensor):
            weight = serialize_weight(p)
            serialized_dict["weights"][prefix + name] = weight
            weights[prefix + name] = p
    # Note: lift_lowering_attrs_to_nodes is only used to support leaf modules
    # that cannot currently be symbolically traced into, e.g. batch norm.
    lift_lowering_attrs_to_nodes(fx_module)
    for node in fx_module.graph.nodes:
        node_rep: Dict[str, Any] = {}
        # Get shape/type info, currently not needed for call_module.
        if node.op != "call_module" or not isinstance(submodules[node.target],
                                                      GraphModule):
            shape, dtype = get_shape_and_dtype(node)
            node_rep["shape"] = serialize_shape(shape)
            node_rep["dtype"] = str(dtype)

        # Recurse down into any submodules we are calling.
        if node.op == "call_module":
            if isinstance(submodules[node.target], GraphModule):
                serialized_module = serialize_module(
                    getattr(fx_module, node.target), weights, node.target)
                serialized_dict["modules"][node.target] = serialized_module
            else:
                node_rep["parameters"] = serialize_leaf_module(
                    node,
                    serialized_dict["weights"],
                    weights,
                    prefix + node.target,
                )

        if node.op == "call_function":
            node_rep["target"] = _get_qualified_name(node.target)
        else:
            node_rep["target"] = str(node.target)

        # Make sure we capture all constants.
        if node.op == "get_attr":
            # If we are targeting a parent constant we update the target.
            if node.target.startswith("parent."):
                node.name = node.name[len("parent."):]
                node_rep["target"] = str(node.target[len("parent."):])
                weight = serialize_weight(
                    weights[node.target[len("parent."):]])
                serialized_dict["weights"][
                    node.target[len("parent."):]] = weight
            else:
                # Iterate through the module hierarchy to find the attr.
                target = fx_module
                split = node.target.split(".")
                assert len(split)
                while len(split):
                    target = getattr(target, split.pop(0))

                qualname = prefix + node.target
                # Check that the target is a tensor, and that we haven't added it already from a leaf module.
                if isinstance(target,
                              torch.Tensor) and qualname not in weights:
                    weight = serialize_weight(target)
                    serialized_dict["weights"][prefix + node.target] = weight
                    weights[prefix + node.target] = target

        node_rep["op_code"] = node.op
        node_rep["name"] = node.name

        if node.op == "output":

            def get_output_info(arg: Node) -> Argument:
                shape, dtype = get_shape_and_dtype(arg)
                return {
                    "is_node": True,
                    "name": str(arg),
                    "shape": serialize_shape(shape),
                    "dtype": str(dtype),
                }

            node_rep["args"] = map_arg(
                node.args,
                get_output_info,
            )

            # If there're multiple outputs then node_rep["args"][0] will be a tuple.
            # In this case we want to unpack the tuple.
            if isinstance(node_rep["args"][0], tuple):
                node_rep["args"] = node_rep["args"][0]
        else:
            node_rep["args"] = map_arg(
                node.args, lambda arg: {
                    "is_node": True,
                    "name": str(arg)
                })

        node_rep["kwargs"] = map_arg(
            node.kwargs, lambda arg: {
                "is_node": True,
                "name": str(arg)
            })
        serialized_dict["nodes"] += [node_rep]

    return serialized_dict
Пример #3
0
def serialize_module(fx_module: GraphModule, weights: Dict, name_prefix="") -> Dict:
    """Recursively Serializes a graph module (fx_module) to a dictionary which is later exported to JSON.
    It also adds all weights the provided weights dictionary by qualified_name.
    Dictionary Schema:
    MODULE
    {
        modules: {module_name: MODULE],
        nodes: [NODE],
        weights {qualified_name: WEIGHT},
    }
    NODE
    {
        shape: [],
        stride: [],
        dtype: dtype,
        is_quantized: bool,
        target: target,
        op_code: op_code,
        name: name,
        args: [],
        kwargs: {}
    }
    WEIGHT
    {
        dtype: dtype,
        is_quantized: bool,
        shape: [],
        QUANTIZATION,
    }
    QUANTIZATION
    {
        qscheme: qscheme,
        q_scale: float,
        q_zero_point: float,
        q_per_channel_scales, [],
        q_per_channel_zero_points: [],
        q_per_channel_axis, int
    }
    """
    serialized_dict: Dict[str, Any] = {}
    serialized_dict["modules"] = {}
    serialized_dict["weights"] = {}
    serialized_dict["nodes"] = []
    submodules = dict(fx_module.named_modules())
    prefix = f"{name_prefix}." if name_prefix else ""

    def add_weight_tensors(named_tensors):
        for name, p in named_tensors:
            if name.startswith("parent.") or not isinstance(p, torch.Tensor):
                continue
            weight_dict = serialize_weight(p, weights, prefix + name)
            serialized_dict["weights"].update(weight_dict)
            weights[prefix + name] = p

    add_weight_tensors(fx_module.named_parameters())
    add_weight_tensors(fx_module.named_buffers())

    def get_node_info(node):
        tensor_meta = get_tensor_meta(node)
        node_rep = {
            "shape": serialize_shape(tensor_meta.shape),
            "dtype": str(tensor_meta.dtype),
            "requires_grad": str(tensor_meta.requires_grad),
            "stride": serialize_stride(tensor_meta.stride),
            "is_quantized": tensor_meta.is_quantized,
        }

        if tensor_meta.is_quantized:
            node_rep["qscheme"] = str(tensor_meta.qscheme)

            if tensor_meta.qscheme in {
                torch.per_tensor_affine,
                torch.per_tensor_symmetric,
            }:
                node_rep["q_scale"] = tensor_meta.q_scale
                node_rep["q_zero_point"] = tensor_meta.q_zero_point

        return node_rep

    # Note: lift_lowering_attrs_to_nodes is only used to support leaf modules
    # that cannot currently be symbolically traced into, e.g. batch norm.
    lift_lowering_attrs_to_nodes(fx_module)
    for node in fx_module.graph.nodes:
        node_rep: Dict[str, Any] = {}
        # Get shape/type info, currently not needed for call_module node
        # whose target is a GraphModule and output node.
        if (
            not (
                node.op == "call_module"
                and isinstance(submodules[node.target], GraphModule)
            )
            and node.op != "output"
        ):
            node_rep.update(get_node_info(node))

        # Recurse down into any submodules we are calling.
        if node.op == "call_module":
            if isinstance(submodules[node.target], GraphModule):
                serialized_module = serialize_module(
                    getattr(fx_module, node.target), weights, node.target
                )
                serialized_dict["modules"][node.target] = serialized_module
            else:
                node_rep["parameters"] = serialize_leaf_module(
                    node,
                    serialized_dict["weights"],
                    weights,
                    prefix + node.target,
                )

        if node.op == "call_function":
            node_rep["target"] = _get_qualified_name(node.target)
        else:
            node_rep["target"] = str(node.target)

        # Make sure we capture all constants.
        if node.op == "get_attr":
            # If we are targeting a parent constant we update the target.
            if node.target.startswith("parent."):
                stripped_name = node.target[len("parent.") :]
                node.name = stripped_name
                node_rep["target"] = stripped_name
                weight = serialize_weight(
                    weights[stripped_name], weights, node.target[len("parent.") :]
                )
                # For quantized embedding tables we need to update the shape/type,
                # so we check if the users of this get_attr is a quantized EB and this is the weight for the EB.
                user_targets = {
                    _get_qualified_name(
                        n.target
                    ).replace("torch.fx.experimental.fx_acc.", "").replace("glow.fb.fx.", ""): n
                    for n in node.users.keys()
                }
                if (
                    "acc_ops.embedding_bag_byte_rowwise_offsets" in user_targets
                    and str(
                        user_targets[
                            "acc_ops.embedding_bag_byte_rowwise_offsets"
                        ].kwargs["weight"]
                    )
                    == stripped_name
                ):
                    weight[stripped_name]["dtype"] = "acc.uint8fused"
                # Same as above, but for the 4 bit version.
                if (
                    "acc_ops.embedding_bag_4bit_rowwise_offsets" in user_targets
                    and str(
                        user_targets[
                            "acc_ops.embedding_bag_4bit_rowwise_offsets"
                        ].kwargs["weight"]
                    )
                    == stripped_name
                ):
                    weight[stripped_name]["dtype"] = "acc.uint4fused"

                serialized_dict["weights"].update(weight)
            else:
                # Find the actual target parameter/buffer from the fx_module.
                submod_path, _, target_name = node.target.rpartition(".")
                submod: Optional[torch.nn.Module] = (
                    fx_module.get_submodule(submod_path) if submod_path else fx_module
                )
                assert submod is not None, f"submod {submod_path} not found"
                target = getattr(submod, target_name, None)
                assert target is not None, f"{target_name} not an attr of {submod_path}"
                qualname = prefix + node.target
                # Check that the target is a tensor, and that we haven't added it already from a leaf module.
                if isinstance(target, torch.Tensor) and qualname not in weights:
                    weight = serialize_weight(target, weights, qualname)
                    serialized_dict["weights"].update(weight)
                    weights[qualname] = target

        node_rep["op_code"] = node.op
        node_rep["name"] = node.name

        def get_user_info(user_node: Argument) -> Any:
            return {"is_node": True, "name": str(user_node)}

        def get_arg_info(arg: Argument) -> Any:
            if isinstance(arg, torch.fx.Node):
                return {"is_node": True, "name": str(arg)}
            elif isinstance(arg, (torch.dtype, torch.memory_format, torch.qscheme)):
                return str(arg)
            else:
                return arg

        def get_output_arg_info(arg: Node) -> Dict[str, Any]:
            node_rep: Dict[str, Any] = get_arg_info(arg)
            node_rep.update(get_node_info(arg))
            return node_rep

        if node.op == "output":
            node_rep["args"] = map_arg(
                node.args,
                get_output_arg_info,
            )

            # If there're multiple outputs then node_rep["args"][0] will be a tuple.
            # In this case we want to unpack the tuple.
            if isinstance(node_rep["args"][0], tuple):
                node_rep["args"] = node_rep["args"][0]
        else:
            node_rep["args"] = map_aggregate(node.args, get_arg_info)

        node_rep["kwargs"] = map_aggregate(node.kwargs, get_arg_info)
        node_rep["users"] = map_aggregate(list(node.users.keys()), get_user_info)
        serialized_dict["nodes"] += [node_rep]

    return serialized_dict
Пример #4
0
def serialize_module(fx_module: GraphModule, weights: Dict, name_prefix="") -> Dict:
    """Recursively Serializes a graph module (fx_module) to a dictionary which is later exported to JSON.
    It also adds all weights the provided weights dictionary by qualified_name.
    Dictionary Schema:
    MODULE
    {
        modules: {module_name: MODULE],
        nodes: [NODE],
        weights {qualified_name: WEIGHT},
    }
    NODE
    {
        shape: [],
        dtype: dtype,
        is_quantized: bool,
        target: target,
        op_code: op_code,
        name: name,
        args: [],
        kwargs: {}
    }
    WEIGHT
    {
        dtype: dtype,
        is_quantized: bool,
        shape: [],
        QUANTIZATION,
    }
    QUANTIZATION
    {
        qscheme: qscheme,
        q_scale: float,
        q_zero_point: float,
        q_per_channel_scales, [],
        q_per_channel_zero_points: [],
        q_per_channel_axis, int
    }
    """
    serialized_dict: Dict[str, Any] = {}
    serialized_dict["modules"] = {}
    serialized_dict["weights"] = {}
    serialized_dict["nodes"] = []
    submodules = dict(fx_module.named_modules())
    prefix = f"{name_prefix}." if name_prefix else ""

    def add_weight_tensors(named_tensors):
        for name, p in named_tensors:
            if name.startswith("parent.") or not isinstance(p, torch.Tensor):
                continue
            weight = serialize_weight(p)
            serialized_dict["weights"][prefix + name] = weight
            weights[prefix + name] = p

    add_weight_tensors(fx_module.named_parameters())
    add_weight_tensors(fx_module.named_buffers())

    def get_node_info(node):
        shape, dtype = get_shape_and_dtype(node)
        tensor_meta = node.meta.get('tensor_meta')
        if not tensor_meta:
            raise RuntimeError(f'Node {node} has no tensor metadata! Ensure shape '
                               f'propagation has been run!')
        node_rep = {
            "shape": serialize_shape(shape),
            "dtype": str(dtype),
            "is_quantized": tensor_meta.is_quantized,
        }

        if tensor_meta.is_quantized:
            node_rep["qscheme"] = str(tensor_meta.qscheme)

            if tensor_meta.qscheme in {torch.per_tensor_affine, torch.per_tensor_symmetric}:
                node_rep["q_scale"] = tensor_meta.q_scale
                node_rep["q_zero_point"] = tensor_meta.q_zero_point

        return node_rep

    # Note: lift_lowering_attrs_to_nodes is only used to support leaf modules
    # that cannot currently be symbolically traced into, e.g. batch norm.
    lift_lowering_attrs_to_nodes(fx_module)
    for node in fx_module.graph.nodes:
        node_rep: Dict[str, Any] = {}
        # Get shape/type info, currently not needed for call_module node
        # whose target is a GraphModule and output node.
        if (
            not (
                node.op == "call_module"
                and isinstance(submodules[node.target], GraphModule)
            )
            and node.op != "output"
        ):
            node_rep.update(get_node_info(node))

        # Recurse down into any submodules we are calling.
        if node.op == "call_module":
            if isinstance(submodules[node.target], GraphModule):
                serialized_module = serialize_module(
                    getattr(fx_module, node.target), weights, node.target
                )
                serialized_dict["modules"][node.target] = serialized_module
            else:
                node_rep["parameters"] = serialize_leaf_module(
                    node,
                    serialized_dict["weights"],
                    weights,
                    prefix + node.target,
                )

        if node.op == "call_function":
            node_rep["target"] = _get_qualified_name(node.target)
        else:
            node_rep["target"] = str(node.target)

        # Make sure we capture all constants.
        if node.op == "get_attr":
            # If we are targeting a parent constant we update the target.
            if node.target.startswith("parent."):
                stripped_name = node.target[len("parent.") :]
                node.name = stripped_name
                node_rep["target"] = stripped_name
                weight = serialize_weight(weights[stripped_name])
                serialized_dict["weights"][stripped_name] = weight
            else:
                # Find the actual target parameter/buffer from the fx_module.
                submod_path, _, target_name = node.target.rpartition(".")
                submod: Optional[torch.nn.Module] = (
                    fx_module.get_submodule(submod_path) if submod_path else fx_module
                )
                assert submod is not None, f"submod {submod_path} not found"
                target = getattr(submod, target_name, None)
                assert target is not None, f"{target_name} not an attr of {submod_path}"
                qualname = prefix + node.target
                # Check that the target is a tensor, and that we haven't added it already from a leaf module.
                if isinstance(target, torch.Tensor) and qualname not in weights:
                    weight = serialize_weight(target)
                    serialized_dict["weights"][qualname] = weight
                    weights[qualname] = target

        node_rep["op_code"] = node.op
        node_rep["name"] = node.name

        def get_arg_info(arg: Argument) -> Any:
            if isinstance(arg, torch.fx.Node):
                return {"is_node": True, "name": str(arg)}
            elif isinstance(arg, torch.dtype):
                return str(arg)
            else:
                return arg

        def get_output_arg_info(arg: Node) -> Dict[str, Any]:
            node_rep: Dict[str, Any] = get_arg_info(arg)
            node_rep.update(get_node_info(arg))
            return node_rep

        if node.op == "output":
            node_rep["args"] = map_arg(
                node.args,
                get_output_arg_info,
            )

            # If there're multiple outputs then node_rep["args"][0] will be a tuple.
            # In this case we want to unpack the tuple.
            if isinstance(node_rep["args"][0], tuple):
                node_rep["args"] = node_rep["args"][0]
        else:
            node_rep["args"] = map_aggregate(
                node.args, get_arg_info
            )

        node_rep["kwargs"] = map_aggregate(
            node.kwargs, get_arg_info
        )
        serialized_dict["nodes"] += [node_rep]

    return serialized_dict
Пример #5
0
def serialize_module(fx_module: GraphModule,
                     weights: Dict,
                     name_prefix="") -> Dict:
    """Recursively Serializes a graph module (fx_module) to a dictionary which is later exported to JSON.
    It also adds all weights the provided weights dictionary by qualified_name.
    Dictionary Schema:
    MODULE
    {
        modules: {module_name: MODULE],
        nodes: [NODE],
        weights {qualified_name: WEIGHT},
    }
    NODE
    {
        shape: [],
        dtype: dtype,
        target: target,
        op_code: op_code,
        name: name,
        args: [],
        kwargs: {}
    }
    WEIGHT
    {
        dtype: dtype,
        is_quantized: bool,
        shape: [],
        quantization_info: QUANTIZATION
    }
    QUANTIZATION
    {
        qscheme: qscheme,
        q_scale: float,
        q_zero_point: float,
        q_per_channel_scales, [],
        q_per_channel_zero_points: [],
        q_per_channel_axis, int
    }
    """
    serialized_dict: Dict[str, Any] = {}
    serialized_dict["modules"] = {}
    serialized_dict["weights"] = {}
    serialized_dict["nodes"] = []
    parameters = fx_module.named_parameters()
    prefix = f"{name_prefix}." if name_prefix else ""
    submodules = dict(fx_module.named_modules())
    for name, p in parameters:
        if isinstance(p, torch.Tensor):
            weight = serialize_weight(p)
            serialized_dict["weights"][prefix + name] = weight
            weights[prefix + name] = p
    lift_lowering_attrs_to_nodes(fx_module)
    for node in fx_module.graph.nodes:
        node_rep: Dict[str, Any] = {}
        # Get shape/type info, currently not needed for call_module.
        if node.op != "call_module" or not isinstance(submodules[node.target],
                                                      GraphModule):
            shape = getattr(node, "shape", None)
            if shape:
                node_rep["shape"] = serialize_shape(shape)
            else:
                raise RuntimeError(
                    "Node has no shape attr, this is likely because shape propagation has not been run on this Graph."
                )
            dtype = getattr(node, "dtype", None)
            if dtype:
                node_rep["dtype"] = str(dtype)
            else:
                raise RuntimeError(
                    "Node has no dtype attr, this is likely because shape propagation has not been run on this Graph."
                )

        # Recurse down into any submodules we are calling.
        if node.op == "call_module":
            if isinstance(submodules[node.target], GraphModule):
                serialized_module = serialize_module(
                    getattr(fx_module, node.target), weights, node.target)
                serialized_dict["modules"][node.target] = serialized_module
            else:
                node_rep["parameters"] = serialize_leaf_module(
                    node,
                    serialized_dict["weights"],
                    weights,
                    prefix + node.target,
                )

        if node.op == "call_function":
            node_rep["target"] = _get_qualified_name(node.target)
        else:
            node_rep["target"] = str(node.target)

        # Make sure we capture all constants.
        if node.op == "get_attr":
            target = getattr(fx_module, node.target)
            qualname = prefix + node.target
            if isinstance(target, torch.Tensor) and qualname not in weights:
                weight = serialize_weight(target)
                serialized_dict["weights"][prefix + node.target] = weight
                weights[prefix + node.target] = target

        node_rep["op_code"] = node.op
        node_rep["name"] = node.name
        node_rep["args"] = map_arg(
            node.args, lambda arg: {
                "is_node": True,
                "name": str(arg)
            })
        node_rep["kwargs"] = map_arg(
            node.kwargs, lambda arg: {
                "is_node": True,
                "name": str(arg)
            })
        serialized_dict["nodes"] += [node_rep]

    return serialized_dict