Beispiel #1
0
def _replace_submodules(gm: GraphModule, replacement: torch.nn.Module) -> None:
    gm.delete_all_unused_submodules()

    if isinstance(replacement, GraphModule):
        replacement.graph.lint()

    def try_get_submodule_or_attr(mod: torch.nn.Module,
                                  target: str) -> Optional[torch.nn.Module]:
        try:
            mod_match = mod.get_submodule(target)
            return mod_match
        except AttributeError:
            pass

        # supports getattr as well
        try:
            attr = getattr(mod, target)
            return attr
        except AttributeError:
            return None

    for node in gm.graph.nodes:
        if node.op == "call_module" or node.op == "get_attr":

            gm_submod = try_get_submodule_or_attr(gm, node.target)

            replacement_submod = try_get_submodule_or_attr(
                replacement, node.target)

            # CASE 1: This target already exists as a submodule in our
            # result GraphModule. Whether or not it exists in
            # `replacement`, the existing submodule takes precedence.
            if gm_submod is not None:
                continue

            # CASE 2: The target exists as a submodule in `replacement`
            # only, so we need to copy it over.
            elif replacement_submod is not None:
                new_submod = copy.deepcopy(getattr(replacement, node.target))
                if isinstance(new_submod, torch.nn.Module):
                    gm.add_submodule(node.target, new_submod)
                else:
                    setattr(gm, node.target, new_submod)

            # CASE 3: The target doesn't exist as a submodule in `gm`
            # or `replacement`
            else:
                continue
                raise RuntimeError(
                    "Attempted to create a \"", node.op,
                    "\" node during subgraph rewriting "
                    f"with target {node.target}, but "
                    "the referenced submodule does not "
                    "exist in either the original "
                    "GraphModule `gm` or the replacement"
                    " GraphModule `replacement`")

    gm.graph.lint()
def get_size_of_node(fx_module: GraphModule, node: Node) -> size_bytes:
    """Given a node with node.dtype and node.shape, return its total size and its output size.
    total_size = weights + bias + output_size
    """
    # Total num of elements
    total_num_of_elems = 0
    # For a module, conside all parameters
    if node.op == "call_module":
        submodule_dict = dict(fx_module.named_modules())
        submodule = submodule_dict[node.target]
        parameters = submodule.named_parameters()
        # Parameters are named tuples
        for name, p in parameters:
            total_num_of_elems += p.numel()
    # Don't forget the output size
    # node.shape is the shape of this node's output
    tensor_meta = get_tensor_meta(node)
    output_elem = tensor_meta.shape.numel()
    total_num_of_elems += output_elem
    # Assume for now if it's quantized then it's qint8 or quint8
    if tensor_meta.is_quantized:
        size_per_elem_bytes = torch._empty_affine_quantized(
            [], dtype=tensor_meta.dtype).element_size()
    else:
        size_per_elem_bytes = torch.tensor(
            [], dtype=tensor_meta.dtype).element_size()
    total_size = size_per_elem_bytes * total_num_of_elems
    output_size = size_per_elem_bytes * output_elem
    return size_bytes(output_size, total_size)
def get_size_of_node(fx_module: GraphModule, node: Node) -> size_bytes:
    """Given a node with node.dtype and node.shape, return its total size and its output size.
    total_size = weights + bias + output_size
    """
    # Total num of elements
    total_num_of_elems = 0
    # For a module, conside all parameters
    if node.op == "call_module":
        submodule_dict = dict(fx_module.named_modules())
        submodule = submodule_dict[node.target]
        parameters = submodule.named_parameters()
        # Parameters are named tuples
        for name, p in parameters:
            total_num_of_elems += p.numel()
    # Don't forget the output size
    # node.shape is the shape of this node's output
    shape = getattr(node, "shape", None)
    if shape:
        output_elem = shape.numel()
    else:
        raise RuntimeError("Node has no shape attr")
    total_num_of_elems += output_elem
    size_per_elem_bytes = 0
    dtype = getattr(node, "dtype", None)
    if dtype:
        size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
    else:
        raise RuntimeError("Node has no dtype attr")
    total_size = size_per_elem_bytes * total_num_of_elems
    output_size = size_per_elem_bytes * output_elem
    return size_bytes(output_size, total_size)
Beispiel #4
0
def legalize_graph(gm: GraphModule):
    """
    Replace the graph of the given GraphModule with one that contains the same nodes as the
    original, but in topologically sorted order.

    This is used by the merge_matmul transformation below, which disturbs the topologically sorted
    order of its input GraphModule, so that this order is restored before further transformation.

    Arguments:
        gm: The graph module to topologically sort. It is modified in-place.

    """
    # Build an adjacency list representation of node dependencies in the graph. This also
    # serves as a list of nodes that still need to be inserted into the new, topologically
    # sorted graph.
    dependencies = {
        node: node.all_input_nodes.copy()
        for node in gm.graph.nodes
    }

    # Construct a new graph that will contain all nodes in topologically sorted order.
    new_graph = Graph()
    value_remap: Dict[Node, Node] = {}

    # Copy over all nodes with no dependencies.
    for node, deps in dependencies.items():
        if not deps:
            value_remap[node] = new_graph.node_copy(node,
                                                    lambda n: value_remap[n])

    # Remove the copied over nodes from the adjacency list.
    for copied_node in value_remap.keys():
        del dependencies[copied_node]

    # While there are still nodes to insert into the new graph:
    while dependencies:
        copied_this_round = []

        # Copy over all nodes whose dependencies already exist in the new graph.
        for node, deps in dependencies.items():
            all_deps_copied = True
            for dep in deps:
                if dep not in value_remap:
                    all_deps_copied = False

            if all_deps_copied:
                value_remap[node] = new_graph.node_copy(
                    node, lambda n: value_remap[n])
                copied_this_round.append(node)

        # Delete all nodes copied over in this iteration from dependencies.
        for copied_node in copied_this_round:
            del dependencies[copied_node]

    # Replace the old graph with the new, topologically sorted one.
    gm.graph = new_graph
Beispiel #5
0
def lift_lowering_attrs_to_nodes(fx_module: GraphModule) -> None:
    """Recursively traverse all `fx_module` nodes and fetch the module's attributes if the node is a leaf module.
    """
    submodules = dict(fx_module.named_modules())

    for node in fx_module.graph.nodes:
        if node.op == "call_module":
            if isinstance(submodules[node.target], GraphModule):
                lift_lowering_attrs_to_nodes(submodules[node.target])
            else:
                node.attrs_for_lowering = extract_attrs_for_lowering(
                    submodules[node.target])
Beispiel #6
0
def insert_subgm(gm: GraphModule, sub_gm: GraphModule,
                 orig_inputs: Tuple[Node, ...], orig_outputs: Tuple[Node,
                                                                    ...]):
    # add sub_gm into gm
    submodule_name = sub_gm.__class__.__name__
    gm.add_submodule(submodule_name, sub_gm)

    # Create a call_module node in main graph.
    module_node = gm.graph.call_module(submodule_name,
                                       args=orig_inputs,
                                       kwargs=None)

    if len(orig_outputs) == 1:
        # main_remapping[comp.orig_outputs[0]] = module_node
        orig_outputs[0].replace_all_uses_with(module_node)
    else:
        for i, orig_output in enumerate(orig_outputs):
            # Use Proxy to record getitem access.
            proxy_out = torch.fx.Proxy(
                module_node)[i].node  # type: ignore[index]
            orig_output.replace_all_uses_with(proxy_out)
    return gm
Beispiel #7
0
def lift_subgraph_as_module(gm: GraphModule,
                            subgraph: Graph,
                            class_name: str = 'GraphModule') -> GraphModule:
    """
    Create a GraphModule for subgraph, which copies the necessory attributes from the original parent graph_module.

    Args:
        gm (GraphModule): parent graph module

        subgraph (Graph): a valid subgraph that contains copied nodes from the parent graph

        class_name (str): name for the submodule

    """

    # Loop through all module calls (call_module) and param fetches (get_attr)
    # in this component, creating HolderModules as necessary to match the path.
    # e.g. if in the original module there's a get_attr node fetches "conv.weight".
    # We create a HolderModule as root -> add a HolderModule named "conv" ->
    # make "weight" a attribute of "conv" HolderModule and point to conv.weight in
    # the original module.
    submodule = HolderModule({})
    for n in subgraph.nodes:
        if n.op not in ("call_module", "get_attr"):
            continue

        target = n.target
        assert isinstance(target, str)
        target_name_parts = target.split(".")
        curr = submodule
        orig_gm = gm

        for name in target_name_parts[:-1]:
            if not hasattr(curr, name):
                curr.add_module(name, HolderModule({}))

            curr = getattr(curr, name)
            orig_gm = getattr(orig_gm, name)

        leaf_node_name = target_name_parts[-1]
        leaf_node = getattr(orig_gm, leaf_node_name)

        # Relies on custom __setattr__ magic.
        setattr(curr, leaf_node_name, leaf_node)

    return GraphModule(submodule, subgraph, class_name)
Beispiel #8
0
def replace_target_nodes_with(
    fx_module: GraphModule,
    old_op: str,
    old_target: Target,
    new_op: str,
    new_target: Target,
):
    """Modifies all nodes in fx_module.graph.nodes which match the specified op code and target,
    and updates them to match the new op code and target"""
    new_graph = Graph()
    val_map : Dict[Node, Node] = {}
    for node in fx_module.graph.nodes:
        if node.op == old_op and node.target == old_target:
            args = map_arg(node.args, lambda n: val_map[n])
            kwargs = map_arg(node.kwargs, lambda n: val_map[n])
            assert isinstance(args, tuple)
            assert isinstance(kwargs, dict)
            val_map[node] = new_graph.create_node(new_op, new_target, args, kwargs, node.name)
        else:
            val_map[node] = new_graph.node_copy(node, lambda n : val_map[n])
    fx_module.graph = new_graph
def symbolic_trace_with_rewrite(
        root: Union[torch.nn.Module, Callable]) -> GraphModule:
    return GraphModule(
        root if isinstance(root, torch.nn.Module) else torch.nn.Module(),
        RewritingTracer().trace(root),
    )
Beispiel #10
0
def serialize_module(fx_module: GraphModule,
                     weights: Dict,
                     name_prefix="") -> Dict:
    """Recursively Serializes a graph module (fx_module) to a dictionary which is later exported to JSON.
    It also adds all weights the provided weights dictionary by qualified_name.
    Dictionary Schema:
    MODULE
    {
        modules: {module_name: MODULE],
        nodes: [NODE],
        weights {qualified_name: WEIGHT},
    }
    NODE
    {
        shape: [],
        dtype: dtype,
        target: target,
        op_code: op_code,
        name: name,
        args: [],
        kwargs: {}
    }
    WEIGHT
    {
        dtype: dtype,
        is_quantized: bool,
        shape: [],
        quantization_info: QUANTIZATION
    }
    QUANTIZATION
    {
        qscheme: qscheme,
        q_scale: float,
        q_zero_point: float,
        q_per_channel_scales, [],
        q_per_channel_zero_points: [],
        q_per_channel_axis, int
    }
    """
    serialized_dict: Dict[str, Any] = {}
    serialized_dict["modules"] = {}
    serialized_dict["weights"] = {}
    serialized_dict["nodes"] = []
    parameters = fx_module.named_parameters()
    prefix = f"{name_prefix}." if name_prefix else ""
    submodules = dict(fx_module.named_modules())
    for name, p in parameters:
        if isinstance(p, torch.Tensor):
            weight = serialize_weight(p)
            serialized_dict["weights"][prefix + name] = weight
            weights[prefix + name] = p
    # Note: lift_lowering_attrs_to_nodes is only used to support leaf modules
    # that cannot currently be symbolically traced into, e.g. batch norm.
    lift_lowering_attrs_to_nodes(fx_module)
    for node in fx_module.graph.nodes:
        node_rep: Dict[str, Any] = {}
        # Get shape/type info, currently not needed for call_module.
        if node.op != "call_module" or not isinstance(submodules[node.target],
                                                      GraphModule):
            shape, dtype = get_shape_and_dtype(node)
            node_rep["shape"] = serialize_shape(shape)
            node_rep["dtype"] = str(dtype)

        # Recurse down into any submodules we are calling.
        if node.op == "call_module":
            if isinstance(submodules[node.target], GraphModule):
                serialized_module = serialize_module(
                    getattr(fx_module, node.target), weights, node.target)
                serialized_dict["modules"][node.target] = serialized_module
            else:
                node_rep["parameters"] = serialize_leaf_module(
                    node,
                    serialized_dict["weights"],
                    weights,
                    prefix + node.target,
                )

        if node.op == "call_function":
            node_rep["target"] = _get_qualified_name(node.target)
        else:
            node_rep["target"] = str(node.target)

        # Make sure we capture all constants.
        if node.op == "get_attr":
            # If we are targeting a parent constant we update the target.
            if node.target.startswith("parent."):
                node.name = node.name[len("parent."):]
                node_rep["target"] = str(node.target[len("parent."):])
                weight = serialize_weight(
                    weights[node.target[len("parent."):]])
                serialized_dict["weights"][
                    node.target[len("parent."):]] = weight
            else:
                # Iterate through the module hierarchy to find the attr.
                target = fx_module
                split = node.target.split(".")
                assert len(split)
                while len(split):
                    target = getattr(target, split.pop(0))

                qualname = prefix + node.target
                # Check that the target is a tensor, and that we haven't added it already from a leaf module.
                if isinstance(target,
                              torch.Tensor) and qualname not in weights:
                    weight = serialize_weight(target)
                    serialized_dict["weights"][prefix + node.target] = weight
                    weights[prefix + node.target] = target

        node_rep["op_code"] = node.op
        node_rep["name"] = node.name

        if node.op == "output":

            def get_output_info(arg: Node) -> Argument:
                shape, dtype = get_shape_and_dtype(arg)
                return {
                    "is_node": True,
                    "name": str(arg),
                    "shape": serialize_shape(shape),
                    "dtype": str(dtype),
                }

            node_rep["args"] = map_arg(
                node.args,
                get_output_info,
            )

            # If there're multiple outputs then node_rep["args"][0] will be a tuple.
            # In this case we want to unpack the tuple.
            if isinstance(node_rep["args"][0], tuple):
                node_rep["args"] = node_rep["args"][0]
        else:
            node_rep["args"] = map_arg(
                node.args, lambda arg: {
                    "is_node": True,
                    "name": str(arg)
                })

        node_rep["kwargs"] = map_arg(
            node.kwargs, lambda arg: {
                "is_node": True,
                "name": str(arg)
            })
        serialized_dict["nodes"] += [node_rep]

    return serialized_dict
Beispiel #11
0
def serialize_module(fx_module: GraphModule, weights: Dict, name_prefix="") -> Dict:
    """Recursively Serializes a graph module (fx_module) to a dictionary which is later exported to JSON.
    It also adds all weights the provided weights dictionary by qualified_name.
    Dictionary Schema:
    MODULE
    {
        modules: {module_name: MODULE],
        nodes: [NODE],
        weights {qualified_name: WEIGHT},
    }
    NODE
    {
        shape: [],
        dtype: dtype,
        is_quantized: bool,
        target: target,
        op_code: op_code,
        name: name,
        args: [],
        kwargs: {}
    }
    WEIGHT
    {
        dtype: dtype,
        is_quantized: bool,
        shape: [],
        QUANTIZATION,
    }
    QUANTIZATION
    {
        qscheme: qscheme,
        q_scale: float,
        q_zero_point: float,
        q_per_channel_scales, [],
        q_per_channel_zero_points: [],
        q_per_channel_axis, int
    }
    """
    serialized_dict: Dict[str, Any] = {}
    serialized_dict["modules"] = {}
    serialized_dict["weights"] = {}
    serialized_dict["nodes"] = []
    submodules = dict(fx_module.named_modules())
    prefix = f"{name_prefix}." if name_prefix else ""

    def add_weight_tensors(named_tensors):
        for name, p in named_tensors:
            if name.startswith("parent.") or not isinstance(p, torch.Tensor):
                continue
            weight = serialize_weight(p)
            serialized_dict["weights"][prefix + name] = weight
            weights[prefix + name] = p

    add_weight_tensors(fx_module.named_parameters())
    add_weight_tensors(fx_module.named_buffers())

    def get_node_info(node):
        shape, dtype = get_shape_and_dtype(node)
        tensor_meta = node.meta.get('tensor_meta')
        if not tensor_meta:
            raise RuntimeError(f'Node {node} has no tensor metadata! Ensure shape '
                               f'propagation has been run!')
        node_rep = {
            "shape": serialize_shape(shape),
            "dtype": str(dtype),
            "is_quantized": tensor_meta.is_quantized,
        }

        if tensor_meta.is_quantized:
            node_rep["qscheme"] = str(tensor_meta.qscheme)

            if tensor_meta.qscheme in {torch.per_tensor_affine, torch.per_tensor_symmetric}:
                node_rep["q_scale"] = tensor_meta.q_scale
                node_rep["q_zero_point"] = tensor_meta.q_zero_point

        return node_rep

    # Note: lift_lowering_attrs_to_nodes is only used to support leaf modules
    # that cannot currently be symbolically traced into, e.g. batch norm.
    lift_lowering_attrs_to_nodes(fx_module)
    for node in fx_module.graph.nodes:
        node_rep: Dict[str, Any] = {}
        # Get shape/type info, currently not needed for call_module node
        # whose target is a GraphModule and output node.
        if (
            not (
                node.op == "call_module"
                and isinstance(submodules[node.target], GraphModule)
            )
            and node.op != "output"
        ):
            node_rep.update(get_node_info(node))

        # Recurse down into any submodules we are calling.
        if node.op == "call_module":
            if isinstance(submodules[node.target], GraphModule):
                serialized_module = serialize_module(
                    getattr(fx_module, node.target), weights, node.target
                )
                serialized_dict["modules"][node.target] = serialized_module
            else:
                node_rep["parameters"] = serialize_leaf_module(
                    node,
                    serialized_dict["weights"],
                    weights,
                    prefix + node.target,
                )

        if node.op == "call_function":
            node_rep["target"] = _get_qualified_name(node.target)
        else:
            node_rep["target"] = str(node.target)

        # Make sure we capture all constants.
        if node.op == "get_attr":
            # If we are targeting a parent constant we update the target.
            if node.target.startswith("parent."):
                stripped_name = node.target[len("parent.") :]
                node.name = stripped_name
                node_rep["target"] = stripped_name
                weight = serialize_weight(weights[stripped_name])
                serialized_dict["weights"][stripped_name] = weight
            else:
                # Find the actual target parameter/buffer from the fx_module.
                submod_path, _, target_name = node.target.rpartition(".")
                submod: Optional[torch.nn.Module] = (
                    fx_module.get_submodule(submod_path) if submod_path else fx_module
                )
                assert submod is not None, f"submod {submod_path} not found"
                target = getattr(submod, target_name, None)
                assert target is not None, f"{target_name} not an attr of {submod_path}"
                qualname = prefix + node.target
                # Check that the target is a tensor, and that we haven't added it already from a leaf module.
                if isinstance(target, torch.Tensor) and qualname not in weights:
                    weight = serialize_weight(target)
                    serialized_dict["weights"][qualname] = weight
                    weights[qualname] = target

        node_rep["op_code"] = node.op
        node_rep["name"] = node.name

        def get_arg_info(arg: Argument) -> Any:
            if isinstance(arg, torch.fx.Node):
                return {"is_node": True, "name": str(arg)}
            elif isinstance(arg, torch.dtype):
                return str(arg)
            else:
                return arg

        def get_output_arg_info(arg: Node) -> Dict[str, Any]:
            node_rep: Dict[str, Any] = get_arg_info(arg)
            node_rep.update(get_node_info(arg))
            return node_rep

        if node.op == "output":
            node_rep["args"] = map_arg(
                node.args,
                get_output_arg_info,
            )

            # If there're multiple outputs then node_rep["args"][0] will be a tuple.
            # In this case we want to unpack the tuple.
            if isinstance(node_rep["args"][0], tuple):
                node_rep["args"] = node_rep["args"][0]
        else:
            node_rep["args"] = map_aggregate(
                node.args, get_arg_info
            )

        node_rep["kwargs"] = map_aggregate(
            node.kwargs, get_arg_info
        )
        serialized_dict["nodes"] += [node_rep]

    return serialized_dict
def serialize_module(fx_module: GraphModule,
                     weights: Dict,
                     name_prefix="") -> Dict:
    """Recursively Serializes a graph module (fx_module) to a dictionary which is later exported to JSON.
    It also adds all weights the provided weights dictionary by qualified_name.
    Dictionary Schema:
    MODULE
    {
        modules: {module_name: MODULE],
        nodes: [NODE],
        weights {qualified_name: WEIGHT},
    }
    NODE
    {
        shape: [],
        dtype: dtype,
        target: target,
        op_code: op_code,
        name: name,
        args: [],
        kwargs: {}
    }
    WEIGHT
    {
        dtype: dtype,
        is_quantized: bool,
        shape: [],
        quantization_info: QUANTIZATION
    }
    QUANTIZATION
    {
        qscheme: qscheme,
        q_scale: float,
        q_zero_point: float,
        q_per_channel_scales, [],
        q_per_channel_zero_points: [],
        q_per_channel_axis, int
    }
    """
    serialized_dict: Dict[str, Any] = {}
    serialized_dict["modules"] = {}
    serialized_dict["weights"] = {}
    serialized_dict["nodes"] = []
    parameters = fx_module.named_parameters()
    prefix = f"{name_prefix}." if name_prefix else ""
    submodules = dict(fx_module.named_modules())
    for name, p in parameters:
        if isinstance(p, torch.Tensor):
            weight = serialize_weight(p)
            serialized_dict["weights"][prefix + name] = weight
            weights[prefix + name] = p
    for node in fx_module.graph.nodes:
        node_rep: Dict[str, Any] = {}
        # Get shape/type info, currently not needed for call_module.
        if node.op != "call_module" or not isinstance(submodules[node.target],
                                                      GraphModule):
            shape = getattr(node, "shape", None)
            if shape:
                node_rep["shape"] = serialize_shape(shape)
            else:
                raise RuntimeError(
                    "Node has no shape attr, this is likely because shape propagation has not been run on this Graph."
                )
            dtype = getattr(node, "dtype", None)
            if dtype:
                node_rep["dtype"] = str(dtype)
            else:
                raise RuntimeError(
                    "Node has no dtype attr, this is likely because shape propagation has not been run on this Graph."
                )

        # Recurse down into any submodules we are calling.
        if node.op == "call_module":
            if isinstance(submodules[node.target], GraphModule):
                serialized_module = serialize_module(
                    getattr(fx_module, node.target), weights, node.target)
                serialized_dict["modules"][node.target] = serialized_module
            else:
                node_rep["parameters"] = serialize_leaf_module(
                    submodules[node.target],
                    serialized_dict["weights"],
                    weights,
                    prefix + node.target,
                )

        if node.op == "call_function":
            node_rep["target"] = get_qualified_name(node.target)
        else:
            node_rep["target"] = str(node.target)

        # Make sure we capture all constants.
        if node.op == "get_attr":
            target = getattr(fx_module, node.target)
            qualname = prefix + node.target
            if isinstance(target, torch.Tensor) and qualname not in weights:
                weight = serialize_weight(target)
                serialized_dict["weights"][prefix + node.target] = weight
                weights[prefix + node.target] = target

        node_rep["op_code"] = node.op
        node_rep["name"] = node.name
        node_rep["args"] = map_arg(
            node.args, lambda arg: {
                "is_node": True,
                "name": str(arg)
            })
        node_rep["kwargs"] = map_arg(
            node.kwargs, lambda arg: {
                "is_node": True,
                "name": str(arg)
            })
        serialized_dict["nodes"] += [node_rep]

    return serialized_dict
Beispiel #13
0
def serialize_module(fx_module: GraphModule, weights: Dict, name_prefix="") -> Dict:
    """Recursively Serializes a graph module (fx_module) to a dictionary which is later exported to JSON.
    It also adds all weights the provided weights dictionary by qualified_name.
    Dictionary Schema:
    MODULE
    {
        modules: {module_name: MODULE],
        nodes: [NODE],
        weights {qualified_name: WEIGHT},
    }
    NODE
    {
        shape: [],
        stride: [],
        dtype: dtype,
        is_quantized: bool,
        target: target,
        op_code: op_code,
        name: name,
        args: [],
        kwargs: {}
    }
    WEIGHT
    {
        dtype: dtype,
        is_quantized: bool,
        shape: [],
        QUANTIZATION,
    }
    QUANTIZATION
    {
        qscheme: qscheme,
        q_scale: float,
        q_zero_point: float,
        q_per_channel_scales, [],
        q_per_channel_zero_points: [],
        q_per_channel_axis, int
    }
    """
    serialized_dict: Dict[str, Any] = {}
    serialized_dict["modules"] = {}
    serialized_dict["weights"] = {}
    serialized_dict["nodes"] = []
    submodules = dict(fx_module.named_modules())
    prefix = f"{name_prefix}." if name_prefix else ""

    def add_weight_tensors(named_tensors):
        for name, p in named_tensors:
            if name.startswith("parent.") or not isinstance(p, torch.Tensor):
                continue
            weight_dict = serialize_weight(p, weights, prefix + name)
            serialized_dict["weights"].update(weight_dict)
            weights[prefix + name] = p

    add_weight_tensors(fx_module.named_parameters())
    add_weight_tensors(fx_module.named_buffers())

    def get_node_info(node):
        tensor_meta = get_tensor_meta(node)
        node_rep = {
            "shape": serialize_shape(tensor_meta.shape),
            "dtype": str(tensor_meta.dtype),
            "requires_grad": str(tensor_meta.requires_grad),
            "stride": serialize_stride(tensor_meta.stride),
            "is_quantized": tensor_meta.is_quantized,
        }

        if tensor_meta.is_quantized:
            node_rep["qscheme"] = str(tensor_meta.qscheme)

            if tensor_meta.qscheme in {
                torch.per_tensor_affine,
                torch.per_tensor_symmetric,
            }:
                node_rep["q_scale"] = tensor_meta.q_scale
                node_rep["q_zero_point"] = tensor_meta.q_zero_point

        return node_rep

    # Note: lift_lowering_attrs_to_nodes is only used to support leaf modules
    # that cannot currently be symbolically traced into, e.g. batch norm.
    lift_lowering_attrs_to_nodes(fx_module)
    for node in fx_module.graph.nodes:
        node_rep: Dict[str, Any] = {}
        # Get shape/type info, currently not needed for call_module node
        # whose target is a GraphModule and output node.
        if (
            not (
                node.op == "call_module"
                and isinstance(submodules[node.target], GraphModule)
            )
            and node.op != "output"
        ):
            node_rep.update(get_node_info(node))

        # Recurse down into any submodules we are calling.
        if node.op == "call_module":
            if isinstance(submodules[node.target], GraphModule):
                serialized_module = serialize_module(
                    getattr(fx_module, node.target), weights, node.target
                )
                serialized_dict["modules"][node.target] = serialized_module
            else:
                node_rep["parameters"] = serialize_leaf_module(
                    node,
                    serialized_dict["weights"],
                    weights,
                    prefix + node.target,
                )

        if node.op == "call_function":
            node_rep["target"] = _get_qualified_name(node.target)
        else:
            node_rep["target"] = str(node.target)

        # Make sure we capture all constants.
        if node.op == "get_attr":
            # If we are targeting a parent constant we update the target.
            if node.target.startswith("parent."):
                stripped_name = node.target[len("parent.") :]
                node.name = stripped_name
                node_rep["target"] = stripped_name
                weight = serialize_weight(
                    weights[stripped_name], weights, node.target[len("parent.") :]
                )
                # For quantized embedding tables we need to update the shape/type,
                # so we check if the users of this get_attr is a quantized EB and this is the weight for the EB.
                user_targets = {
                    _get_qualified_name(
                        n.target
                    ).replace("torch.fx.experimental.fx_acc.", "").replace("glow.fb.fx.", ""): n
                    for n in node.users.keys()
                }
                if (
                    "acc_ops.embedding_bag_byte_rowwise_offsets" in user_targets
                    and str(
                        user_targets[
                            "acc_ops.embedding_bag_byte_rowwise_offsets"
                        ].kwargs["weight"]
                    )
                    == stripped_name
                ):
                    weight[stripped_name]["dtype"] = "acc.uint8fused"
                # Same as above, but for the 4 bit version.
                if (
                    "acc_ops.embedding_bag_4bit_rowwise_offsets" in user_targets
                    and str(
                        user_targets[
                            "acc_ops.embedding_bag_4bit_rowwise_offsets"
                        ].kwargs["weight"]
                    )
                    == stripped_name
                ):
                    weight[stripped_name]["dtype"] = "acc.uint4fused"

                serialized_dict["weights"].update(weight)
            else:
                # Find the actual target parameter/buffer from the fx_module.
                submod_path, _, target_name = node.target.rpartition(".")
                submod: Optional[torch.nn.Module] = (
                    fx_module.get_submodule(submod_path) if submod_path else fx_module
                )
                assert submod is not None, f"submod {submod_path} not found"
                target = getattr(submod, target_name, None)
                assert target is not None, f"{target_name} not an attr of {submod_path}"
                qualname = prefix + node.target
                # Check that the target is a tensor, and that we haven't added it already from a leaf module.
                if isinstance(target, torch.Tensor) and qualname not in weights:
                    weight = serialize_weight(target, weights, qualname)
                    serialized_dict["weights"].update(weight)
                    weights[qualname] = target

        node_rep["op_code"] = node.op
        node_rep["name"] = node.name

        def get_user_info(user_node: Argument) -> Any:
            return {"is_node": True, "name": str(user_node)}

        def get_arg_info(arg: Argument) -> Any:
            if isinstance(arg, torch.fx.Node):
                return {"is_node": True, "name": str(arg)}
            elif isinstance(arg, (torch.dtype, torch.memory_format, torch.qscheme)):
                return str(arg)
            else:
                return arg

        def get_output_arg_info(arg: Node) -> Dict[str, Any]:
            node_rep: Dict[str, Any] = get_arg_info(arg)
            node_rep.update(get_node_info(arg))
            return node_rep

        if node.op == "output":
            node_rep["args"] = map_arg(
                node.args,
                get_output_arg_info,
            )

            # If there're multiple outputs then node_rep["args"][0] will be a tuple.
            # In this case we want to unpack the tuple.
            if isinstance(node_rep["args"][0], tuple):
                node_rep["args"] = node_rep["args"][0]
        else:
            node_rep["args"] = map_aggregate(node.args, get_arg_info)

        node_rep["kwargs"] = map_aggregate(node.kwargs, get_arg_info)
        node_rep["users"] = map_aggregate(list(node.users.keys()), get_user_info)
        serialized_dict["nodes"] += [node_rep]

    return serialized_dict
Beispiel #14
0
    def _match_nodes(self, pn: Node, gn: Node, original_module: GraphModule,
                     pattern_module: GraphModule) -> bool:

        if isinstance(pn, (tuple, list)):
            if not isinstance(gn, type(pn)):
                return False
            return all(
                self._match_nodes(a1, a2, original_module, pattern_module)
                for a1, a2 in zip(pn, gn))  # type: ignore[call-overload]
        elif isinstance(pn, dict):
            if not isinstance(gn, dict):
                return False
            return pn.keys() == gn.keys() and \
                all(self._match_nodes(v1, v2, original_module, pattern_module)
                    for v1, v2 in zip(pn.values(), gn.values()))

        # Check if we've already matched these nodes in the current
        # traversal
        if pn in self.nodes_map:
            return self.nodes_map[pn] == gn

        PRIM_TYPES = (int, float, torch.dtype)
        # if both pattern and graph are not Node, we check for equality of these values
        if not isinstance(pn, Node) and not isinstance(gn, Node):
            return pn == gn

        # trying to match the input in pattern graph with a primitive type values
        if isinstance(gn, PRIM_TYPES):
            if isinstance(pn, Node) and pn.op == "placeholder":
                self.nodes_map[pn] = gn
                return True
            else:
                return False

        original_modules = dict(original_module.named_modules())
        pattern_modules = dict(pattern_module.named_modules())

        def attributes_are_equal(pn: Node, gn: Node) -> bool:
            # Use placeholder and output nodes as wildcards. The
            # only exception is that an output node can't match
            # a placeholder
            if (pn.op == "placeholder"
                    or (pn.op == "output" and gn.op != "placeholder")):
                return True
            elif pn.op == "get_attr" and gn.op == "get_attr":
                # assuming get_attr nodes are the same
                return True
            elif pn.op == "call_module" and gn.op == "call_module":
                original_m = original_modules[gn.target]
                pattern_m = pattern_modules[pn.target]
                return type(original_m) == type(pattern_m)
            return pn.op == gn.op and pn.target == gn.target

        # Terminate early if the node attributes are not equal
        if not attributes_are_equal(pn, gn):
            return False

        # Optimistically mark `pn` as a match for `gn`
        self.nodes_map[pn] = gn

        # Traverse the use-def relationships to ensure that `pn` is a true
        # match for `gn`
        if pn.op == "placeholder":
            return True
        if (pn.op != "output" and len(pn.args) != len(gn.args)):
            return False
        if pn.op == "output":
            match_found = any(
                self._match_nodes(pn.all_input_nodes[0], gn_, original_module,
                                  pattern_module)
                for gn_ in gn.all_input_nodes)
        else:
            # using args here to make sure we can match Node and non-Node
            # arguments
            # also allows us to match a Node with a primitive type value
            match_found = (len(pn.args) == len(gn.args)
                           and all(self._match_nodes(pn_, gn_, original_module, pattern_module)  # type: ignore[arg-type]
                                   for pn_, gn_ \
                                   in zip(pn.args, gn.args)))
        if not match_found:
            self.nodes_map.pop(pn)
            return False

        return True
Beispiel #15
0
def replace_pattern(
        gm: GraphModule,
        pattern: Callable,
        replacement: Callable,
        is_match_filters: Optional[List[Callable]] = None) -> List[Match]:
    """
    Matches all possible non-overlapping sets of operators and their
    data dependencies (``pattern``) in the Graph of a GraphModule
    (``gm``), then replaces each of these matched subgraphs with another
    subgraph (``replacement``).

    Args:
        ``gm``: The GraphModule that wraps the Graph to operate on
        ``pattern``: The subgraph to match in ``gm`` for replacement
        ``replacement``: The subgraph to replace ``pattern`` with

    Returns:
        List[Match]: A list of ``Match`` objects representing the places
        in the original graph that ``pattern`` was matched to. The list
        is empty if there are no matches. ``Match`` is defined as:

        .. code-block:: python

            class Match(NamedTuple):
                # Node from which the match was found
                anchor: Node
                # Maps nodes in the pattern subgraph to nodes in the larger graph
                nodes_map: Dict[Node, Node]

    Examples:

    .. code-block:: python

        import torch
        from torch.fx import symbolic_trace, subgraph_rewriter

        class M(torch.nn.Module):
            def __init__(self):
                super().__init__()

            def forward(self, x, w1, w2):
                m1 = torch.cat([w1, w2]).sum()
                m2 = torch.cat([w1, w2]).sum()
                return x + torch.max(m1) + torch.max(m2)

        def pattern(w1, w2):
            return torch.cat([w1, w2]).sum()

        def replacement(w1, w2):
            return torch.stack([w1, w2])

        traced_module = symbolic_trace(M())

        subgraph_rewriter.replace_pattern(traced_module, pattern, replacement)

    The above code will first match ``pattern`` in the ``forward``
    method of ``traced_module``. Pattern-matching is done based on
    use-def relationships, not node names. For example, if you had
    ``p = torch.cat([a, b])`` in ``pattern``, you could match
    ``m = torch.cat([a, b])`` in the original ``forward`` function,
    despite the variable names being different (``p`` vs ``m``).

    The ``return`` statement in ``pattern`` is matched based on its
    value only; it may or may not match to the ``return`` statement in
    the larger graph. In other words, the pattern doesn't have to extend
    to the end of the larger graph.

    When the pattern is matched, it will be removed from the larger
    function and replaced by ``replacement``. If there are multiple
    matches for ``pattern`` in the larger function, each non-overlapping
    match will be replaced. In the case of a match overlap, the first
    found match in the set of overlapping matches will be replaced.
    ("First" here being defined as the first in a topological ordering
    of the Nodes' use-def relationships. In most cases, the first Node
    is the parameter that appears directly after ``self``, while the
    last Node is whatever the function returns.)

    One important thing to note is that the parameters of the
    ``pattern`` Callable must be used in the Callable itself,
    and the parameters of the ``replacement`` Callable must match
    the pattern. The first rule is why, in the above code block, the
    ``forward`` function has parameters ``x, w1, w2``, but the
    ``pattern`` function only has parameters ``w1, w2``. ``pattern``
    doesn't use ``x``, so it shouldn't specify ``x`` as a parameter.
    As an example of the second rule, consider replacing

    .. code-block:: python

        def pattern(x, y):
            return torch.neg(x) + torch.relu(y)

    with

    .. code-block:: python

        def replacement(x, y):
            return torch.relu(x)

    In this case, ``replacement`` needs the same number of parameters
    as ``pattern`` (both ``x`` and ``y``), even though the parameter
    ``y`` isn't used in ``replacement``.

    After calling ``subgraph_rewriter.replace_pattern``, the generated
    Python code looks like this:

    .. code-block:: python

        def forward(self, x, w1, w2):
            stack_1 = torch.stack([w1, w2])
            sum_1 = stack_1.sum()
            stack_2 = torch.stack([w1, w2])
            sum_2 = stack_2.sum()
            max_1 = torch.max(sum_1)
            add_1 = x + max_1
            max_2 = torch.max(sum_2)
            add_2 = add_1 + max_2
            return add_2
    """
    # Get the module and graph for `gm`, `pattern`, `replacement`
    original_module = gm
    original_graph = original_module.graph
    pattern_module = symbolic_trace(pattern)
    pattern_graph = pattern_module.graph
    replacement_module = symbolic_trace(replacement)
    replacement_graph = replacement_module.graph

    # Find all possible pattern matches in original_graph. Note that
    # pattern matches may overlap with each other.
    matcher = _SubgraphMatcher(pattern_graph)
    matches: List[Match] = []

    # Consider each node as an "anchor" (deepest matching graph node)
    for anchor in original_graph.nodes:

        if matcher.matches_subgraph_from_anchor(anchor, original_module,
                                                pattern_module):

            def pattern_is_contained(nodes_map: Dict[Node, Node]) -> bool:
                # `lookup` represents all the nodes in `original_graph`
                # that are part of `pattern`
                lookup: Dict[Node, Node] = {v: k for k, v in nodes_map.items()}
                for n in lookup.keys():

                    # Nodes that can "leak"...
                    if not isinstance(lookup[n], Node):
                        continue
                    # Placeholders (by definition)
                    if lookup[n].op == "placeholder":
                        continue
                    # Pattern output (acts as a container)
                    if lookup[n].op == "output":
                        continue
                    # Result contained by pattern output (what we'll
                    # hook in to the new Graph, thus what we'll
                    # potentially use in other areas of the Graph as
                    # an input Node)
                    if (len(lookup[n].users) == 1 and list(
                            lookup[n].users.keys())[0].op == "output"):
                        continue

                    if not isinstance(n, Node):
                        continue

                    for user in n.users:
                        # If this node has users that were not in
                        # `lookup`, then it must leak out of the
                        # pattern subgraph
                        if user not in lookup:
                            return False
                return True

            # It's not a match if the pattern leaks out into the rest
            # of the graph
            if pattern_is_contained(matcher.nodes_map):
                # Shallow copy nodes_map
                matches.append(
                    Match(anchor=anchor,
                          nodes_map=copy.copy({
                              key: value
                              for key, value in matcher.nodes_map.items()
                          })))

    # The set of all nodes in `original_graph` that we've seen thus far
    # as part of a pattern match
    replaced_nodes: Set[Node] = set()
    # As we progressively replace nodes, we'll need to keep track of how the match results should change
    match_changed_node: Dict[Node, Node] = dict()

    # Return True if one of the nodes in the current match has already
    # been used as part of another match
    def overlaps_with_prev_match(match: Match) -> bool:
        for pn, gn in match.nodes_map.items():
            if not isinstance(pn, Node):
                continue
            if pn.op in ["placeholder", "output"]:
                continue
            if not isinstance(gn, Node):
                continue
            if gn in replaced_nodes and gn.op != "placeholder":
                return True
        return False

    if is_match_filters is None:
        is_match_filters = []

    def is_match(match: Match):
        # for mypy
        assert is_match_filters is not None
        for filter in is_match_filters:
            if not filter(match, pattern_graph, replacement_graph):
                return False
        return True

    for match in matches:
        # Skip overlapping matches
        if overlaps_with_prev_match(match):
            continue

        if not is_match(match):
            continue

        # Map replacement graph nodes to their copy in `original_graph`
        val_map: Dict[Node, Node] = {}

        pattern_placeholders = [
            n for n in pattern_graph.nodes if n.op == "placeholder"
        ]
        assert len(pattern_placeholders) > 0
        replacement_placeholders = [
            n for n in replacement_graph.nodes if n.op == "placeholder"
        ]
        assert len(pattern_placeholders) == len(replacement_placeholders)
        placeholder_map = {
            r: p
            for r, p in zip(replacement_placeholders, pattern_placeholders)
        }

        # node from `original_graph` that matched with the output node
        # in `pattern`
        subgraph_output: Node = match.anchor

        def mark_node_as_replaced(n: Node) -> None:
            if n not in match.nodes_map.values():
                return
            for n_ in n.all_input_nodes:
                mark_node_as_replaced(n_)
            replaced_nodes.add(n)

        for input_node in subgraph_output.all_input_nodes:
            mark_node_as_replaced(input_node)

        # Initialize `val_map` with mappings from placeholder nodes in
        # `replacement` to their corresponding node in `original_graph`
        for replacement_node in replacement_placeholders:
            # Get the `original_graph` placeholder node
            # corresponding to the current `replacement_node`
            pattern_node = placeholder_map[replacement_node]
            original_graph_node = match_changed_node.get(
                match.nodes_map[pattern_node], match.nodes_map[pattern_node])

            # Populate `val_map`
            val_map[replacement_node] = original_graph_node

        # Copy the replacement graph over
        with original_graph.inserting_before(subgraph_output):
            copied_output = original_graph.graph_copy(replacement_graph,
                                                      val_map)

        # Hook the output Node of the replacement subgraph in to the
        # original Graph at the correct location

        # CASE 1: We need to hook the replacement subgraph in somewhere
        # in the middle of the graph. We replace the Node in the
        # original graph that corresponds to the end of the pattern
        # subgraph
        if subgraph_output.op != "output":
            pattern_outputs = [
                n for n in pattern_graph.nodes if n.op == "output"
            ]
            assert len(pattern_outputs) > 0
            replacement_outputs = [
                n for n in replacement_graph.nodes if n.op == "output"
            ]
            assert len(replacement_outputs) == len(pattern_outputs)
            outputs_map = {
                p: r
                for r, p in zip(replacement_outputs, pattern_outputs)
            }

            for pn, gn in match.nodes_map.items():
                if not isinstance(gn, Node):
                    continue
                if gn.op == "placeholder":
                    continue

                # Search for the node corresponding to the output of the pattern
                if pn.op != "output":
                    continue
                assert subgraph_output == gn

                # Update all anchor inputs to the new nodes
                rn = outputs_map[pn]
                for pn_input, rn_input in zip(pn.args, rn.args):
                    gn_input = match.nodes_map[pn_input]  # type: ignore[index]
                    rn_input_in_original_graph = val_map[rn_input]
                    gn_input.replace_all_uses_with(rn_input_in_original_graph)
                    # We store the updated node point in case other nodes want to use it
                    match_changed_node[gn_input] = rn_input_in_original_graph

            assert subgraph_output.op != "output"
        # CASE 2: The pattern subgraph match extends to the end of the
        # original graph, so we need to change the current graph's
        # output Node to reflect the insertion of the replacement graph.
        # We'll keep the current output Node, but update its args and
        # `_input_nodes` as necessary
        else:
            subgraph_output.args = ((copied_output, ))
            if isinstance(copied_output, Node):
                subgraph_output._input_nodes = {copied_output: None}

        assert isinstance(copied_output, Node)
        # Erase the `pattern` nodes
        for node in reversed(original_graph.nodes):
            if len(node.users
                   ) == 0 and node.op != "output" and node.op != "placeholder":
                original_graph.erase_node(node)

    # Update the passed-in GraphModule to reflect the new state of
    # `original_graph`
    gm.recompile()

    # If `replacement` was an nn.Module, we'll need to make sure that
    # all the submodules have been copied over correctly
    if isinstance(replacement, torch.nn.Module):
        _replace_submodules(gm, replacement)

    return matches
Beispiel #16
0
def serialize_module(fx_module: GraphModule,
                     weights: Dict,
                     name_prefix="") -> Dict:
    """Recursively Serializes a graph module (fx_module) to a dictionary which is later exported to JSON.
    It also adds all weights the provided weights dictionary by qualified_name.
    Dictionary Schema:
    MODULE
    {
        modules: {module_name: MODULE],
        nodes: [NODE],
        weights {qualified_name: WEIGHT},
    }
    NODE
    {
        shape: [],
        stride: [],
        dtype: dtype,
        is_quantized: bool,
        target: target,
        op_code: op_code,
        name: name,
        args: [],
        kwargs: {}
    }
    WEIGHT
    {
        dtype: dtype,
        is_quantized: bool,
        shape: [],
        QUANTIZATION,
    }
    QUANTIZATION
    {
        qscheme: qscheme,
        q_scale: float,
        q_zero_point: float,
        q_per_channel_scales, [],
        q_per_channel_zero_points: [],
        q_per_channel_axis, int
    }
    """
    serialized_dict: Dict[str, Any] = {}
    serialized_dict["modules"] = {}
    serialized_dict["weights"] = {}
    serialized_dict["nodes"] = []
    submodules = dict(fx_module.named_modules())
    prefix = f"{name_prefix}." if name_prefix else ""

    def get_node_info(node):
        tensor_meta = get_tensor_meta(node)
        node_rep = {
            "shape": serialize_shape(tensor_meta.shape),
            "dtype": str(tensor_meta.dtype),
            "requires_grad": str(tensor_meta.requires_grad),
            "stride": serialize_stride(tensor_meta.stride),
            "is_quantized": tensor_meta.is_quantized,
        }

        if tensor_meta.is_quantized:
            node_rep["qscheme"] = str(tensor_meta.qparams["qscheme"])

            if tensor_meta.qparams["qscheme"] in {
                    torch.per_tensor_affine,
                    torch.per_tensor_symmetric,
            }:
                node_rep["q_scale"] = tensor_meta.qparams["scale"]
                node_rep["q_zero_point"] = tensor_meta.qparams["zero_point"]

        # Add all extra lowering_info that was provided in node.meta.
        lowering_info = node.meta.get("lowering_info")
        if lowering_info is not None:
            overlapping_keys = node_rep.keys() & lowering_info.keys()
            assert (
                len(overlapping_keys) == 0
            ), f"Overlap found between lowering_info and node_rep: {overlapping_keys}"
            node_rep.update(lowering_info)

        return node_rep

    # Note: lift_lowering_attrs_to_nodes is only used to support leaf modules
    # that cannot currently be symbolically traced into, e.g. batch norm.
    lift_lowering_attrs_to_nodes(fx_module)
    for node in fx_module.graph.nodes:
        node_rep: Dict[str, Any] = {}
        # Get shape/type info, currently not needed for call_module node
        # whose target is a GraphModule and output node.
        if (not (node.op == "call_module"
                 and isinstance(submodules[node.target], GraphModule))
                and node.op != "output"):
            node_rep.update(get_node_info(node))

        # Recurse down into any submodules we are calling.
        if node.op == "call_module":
            if isinstance(submodules[node.target], GraphModule):
                serialized_module = serialize_module(
                    getattr(fx_module, node.target), weights, node.target)
                serialized_dict["modules"][node.target] = serialized_module
            else:
                node_rep["parameters"] = serialize_leaf_module(
                    node,
                    serialized_dict["weights"],
                    weights,
                    prefix + node.target,
                )

        if node.op == "call_function":
            node_rep["target"] = _get_qualified_name(node.target)
        else:
            node_rep["target"] = str(node.target)

        # Make sure we capture all constants.
        if node.op == "get_attr":
            # If we are targeting a parent constant we update the target.
            if node.target.startswith("parent."):
                qualname = node.target[len("parent."):]
                node.name = qualname
                node_rep["target"] = qualname
            else:
                qualname = prefix + node.target
            # Find the actual target parameter/buffer from the fx_module.
            submod_path, _, target_name = node.target.rpartition(".")
            submod: Optional[torch.nn.Module] = (
                fx_module.get_submodule(submod_path)
                if submod_path else fx_module)
            assert submod is not None, f"submod {submod_path} not found"
            target = getattr(submod, target_name, None)
            assert target is not None, f"{target_name} not an attr of {submod_path}"
            # Check that the target is a tensor, and that we haven't added it already from a leaf module.
            if isinstance(target, torch.Tensor) and qualname not in weights:
                weight = serialize_weight(target, weights, qualname)
                _update_weight_fused_dtypes(weight, qualname, node)
                serialized_dict["weights"].update(weight)
                weights[qualname] = target
        elif node.op == "placeholder":
            ph_type = node.meta.get("ph_type", "")
            assert (
                ph_type == "" or ph_type == "input_ph"
                or ph_type == "output_ph"
            ), "When present, placeholder type must be 'input_ph' or 'ouput_ph'"
            if ph_type == "input_ph":
                node_rep["ph_type"] = "input_ph"
            elif ph_type == "output_ph":
                node_rep["ph_type"] = "output_ph"

        node_rep["op_code"] = node.op
        node_rep["name"] = node.name

        def get_user_info(user_node: Argument) -> Any:
            return {"is_node": True, "name": str(user_node)}

        def get_arg_info(arg: Argument) -> Any:
            if isinstance(arg, torch.fx.Node):
                return {"is_node": True, "name": str(arg)}
            elif isinstance(arg,
                            (torch.dtype, torch.memory_format, torch.qscheme)):
                return str(arg)
            else:
                return arg

        def get_output_arg_info(arg: Node) -> Dict[str, Any]:
            node_rep: Dict[str, Any] = get_arg_info(arg)
            node_rep.update(get_node_info(arg))
            return node_rep

        if node.op == "output":
            node_rep["args"] = map_arg(
                node.args,
                get_output_arg_info,
            )

            # If there're multiple outputs then node_rep["args"][0] will be a tuple or
            # list. In this case we want to unpack the tuple or list.
            if isinstance(node_rep["args"][0], (tuple, list)):
                node_rep["args"] = node_rep["args"][0]
        else:
            node_rep["args"] = map_aggregate(node.args, get_arg_info)

        node_rep["kwargs"] = map_aggregate(node.kwargs, get_arg_info)
        node_rep["users"] = map_aggregate(list(node.users.keys()),
                                          get_user_info)
        serialized_dict["nodes"] += [node_rep]

    return serialized_dict