コード例 #1
0
ファイル: partitioner.py プロジェクト: timgates42/pytorch
    def propose_partitions(self) -> List[Partition]:
        candidates: NodeList = self.__get_supported_nodes()

        # assumptions: nodes in candidate list is sorted in topological order
        assignment: Dict[Node, int] = {}  # maping from node to partition_id
        partitions_by_id: Dict[int, Partition] = {
        }  # mapping from partition_id to partition
        new_partition_id = itertools.count()

        def assign(node: Node, id: Optional[int] = None):
            # If id is None, remove the node from original assigment

            # node has been assigned before, clean up and re-assign
            if node in assignment:
                original_id = assignment[node]
                del assignment[node]
                partitions_by_id[original_id].remove_node(node)
                if partitions_by_id[original_id].size() == 0:
                    del partitions_by_id[original_id]

            if id is not None:
                assignment[node] = id
                if id not in partitions_by_id:
                    partitions_by_id[id] = Partition(id=id, nodes=[node])
                else:
                    partitions_by_id[id].add_node(node)

        logging.debug("Proposing partitions...")

        # visit candidates in reversed topological order
        for node in reversed(candidates):
            # use Dict as an ordered set to ensure deterministic partitioning result, don't care value
            user_partitions: Dict[Partition, None] = {}
            for user_node in node.users:
                if user_node in assignment:
                    id = assignment[user_node]
                    user_partitions[partitions_by_id[id]] = None
                else:
                    user_partitions[Partition(nodes=[user_node])] = None

            # Filter out all the partitions that has dependency on other users
            # TODO: find a better way to do this, rather than pair-wise comparision
            user_partitions_list = list(user_partitions.keys())
            for i in range(len(user_partitions_list)):
                for j in range(i + 1, len(user_partitions_list)):
                    pi = user_partitions_list[i]
                    pj = user_partitions_list[j]
                    dependency = self.__partition_depends_on(pi, pj)
                    if dependency == 1 and pj in user_partitions:
                        del user_partitions[pj]
                    elif dependency == -1 and pi in user_partitions:
                        del user_partitions[pi]

            # We use the following rules for partition assignment:
            # 1. If none of the candidates has been assigned to a partition, create a new partition
            # 2. If there is one partition candidate, assign to the partition
            # 3. If there are more than one partition candidates, assign current node to the first partition and
            #    merge the other partitions with first partition, since user_partitions doesn't have depedency between
            #    each other.

            assigned_candidate_partition_ids = [
                partition.id for partition in user_partitions
                if partition.id is not None
            ]

            if len(assigned_candidate_partition_ids) == 0:
                # create a new partition
                assign(node, next(new_partition_id))
            elif len(assigned_candidate_partition_ids) == 1:
                id = assigned_candidate_partition_ids[0]
                assign(node, id)
            else:
                # users are assigned to more than one partition, since user_partitions doesn't have
                # dependency on each other, they can be fused into a single partition
                id = assigned_candidate_partition_ids[0]
                assign(node, id)

                reassignment: Dict[Node, int] = {}
                for other_id in assigned_candidate_partition_ids[1:]:
                    for other_node in partitions_by_id[other_id].nodes:
                        reassignment[other_node] = id
                for other_node in reassignment:
                    assign(other_node, id)

        # post processing to re-assign "getitem" nodes into upstream partition
        logger.debug(
            "Reassigning getitem nodes to its producer node's partition...")
        nodes_reassignment: Dict[Node, int] = {}
        for node in self.graph_module.graph.nodes:
            is_tuple_output = True
            for user in node.users:
                if user.op != "call_function" or \
                   _get_qualified_name(user.target) != "_operator.getitem":     # type: ignore[arg-type]
                    is_tuple_output = False
                    break

            # node has tuple outputs, re-assign all following getitem node into node's partition
            if is_tuple_output:
                id = assignment.get(node, None)  # type: ignore[arg-type]
                for user in node.users:
                    if assignment.get(user,
                                      None) != id:  # type: ignore[arg-type]
                        nodes_reassignment[user] = id
        for node, id in nodes_reassignment.items():
            assign(node, id)

        # filter out single node partitions
        if not self.allows_single_node_partition:
            logger.debug("Filtering out single node partitions...")
            non_compute_ops = {"torch.ops.aten.view", "_operator.getitem"}
            partitions_to_remove: List[int] = []
            for id, partition in partitions_by_id.items():
                compute_node_count = 0
                for node in partition.nodes:
                    if node.op == "call_function" and \
                       _get_qualified_name(node.target) not in non_compute_ops:  # type: ignore[arg-type]
                        compute_node_count += 1
                if compute_node_count <= 1:
                    partitions_to_remove.append(id)
            for id in partitions_to_remove:
                del partitions_by_id[id]

        logging.debug("Partitions proposed:")
        for id, partition in partitions_by_id.items():
            logging.debug(f"partition #{id}",
                          [node.name for node in partition.nodes])

        return list(partitions_by_id.values())
コード例 #2
0
def normalize(mod: torch.fx.GraphModule,
              expect_nodes_have_shapes: bool = False):
    assert len(_normalization_dict) > 0
    graph = mod.graph

    # For "call_module" node we return _base_class_origin if it's a
    # RewrittenModule, otherwise, return its type. For other nodes,
    # we return node.target.
    def get_target(mod: torch.fx.GraphModule, node: torch.fx.Node):
        if node.op != "call_module":
            return node.target

        # Find the module that node.target points to
        m = dict(mod.named_modules())[node.target]
        return getattr(m, "_base_class_origin", type(m))

    def normalize_to_acc_op(
        node: torch.fx.Node,
        normalization_info: NormalizationInfo,
        normalized_args: Tuple[Any, ...],
        normalized_kwargs: Dict[str, Any],
    ):
        # If there's a custom mapping function then use it.
        if normalization_info.custom_mapping_fn is not None:
            # For custom mapping, the normalized_kwargs are used for the original op,
            # i.e. *before* custom acc_ops normalization. Do that now.
            node.args = normalized_args
            node.kwargs = normalized_kwargs
            new_node = normalization_info.custom_mapping_fn(node, mod)
            # If a new node is returned then use it to replace the old node. Otherwise
            # the custom mapping function did its own replacement, so return early.
            if new_node is None:
                return
        else:
            # If there's kwargs_to_move_to_acc_out_ty then use it to setup acc_out_ty in
            # normalized_kwargs, and remove the kwarg from normalized_kwargs.
            move_kwargs_to_acc_out_ty(normalization_info, normalized_kwargs)

            # All acc ops are functions. Create a call to the correct acc_ops target using
            # the normalized kwargs provided.
            with graph.inserting_before(node):
                new_node = graph.create_node(
                    "call_function",
                    normalization_info.new_fn_target,
                    args=normalized_args,
                    kwargs=normalized_kwargs,
                    name=node.name,
                )
                new_node.meta = node.meta.copy()

        # Finally replace the original node with the normalized node.
        node.replace_all_uses_with(new_node)
        graph.erase_node(node)

    for node in graph.nodes:
        if node.op in {"placeholder", "get_attr", "output"}:
            continue

        normalization_info = _normalization_dict.get(
            (node.op, get_target(mod, node)))

        # Also check if the torch_packaged version of the op was specified to be normalized.
        if normalization_info is None and node.op == "call_function":
            # Strip off the mangle_index suffix here before checking the map.
            target = re.sub(
                r"\A<torch_package_\d+>",
                "<torch_package_>",
                _get_qualified_name(node.target),
            )
            torch_package_op_and_target = (node.op, target)
            normalization_info = _normalization_dict.get(
                torch_package_op_and_target)

        if normalization_info is None:
            continue

        # Get the normalized kwargs to be used by normalize_to_acc_op below. If
        # normalization_info.arg_replacement_tuples is empty then assume the function
        # signature must be left as is.
        assert normalization_info.arg_replacement_tuples is not None
        if len(normalization_info.arg_replacement_tuples) == 0:
            normalized_args = node.args
            normalized_kwargs = node.kwargs
        else:
            normalized_args = ()
            try:
                normalized_kwargs = get_normalized_kwargs(
                    node, normalization_info.arg_replacement_tuples)
            except Exception:
                print(
                    f"Error during kwarg normalization for: {node.format_node()}; "
                    f"arg_replacement_tuples={normalization_info.arg_replacement_tuples}"
                )
                raise

        if (normalization_info.needs_shapes_for_normalization
                and not expect_nodes_have_shapes):
            # All nodes needing shapes for normalization should be custom mapped.
            assert normalization_info.custom_mapping_fn is not None
            # For custom mapping, the normalized_kwargs are used for the original op,
            # i.e. *before* custom acc_ops normalization. Do that now so that whoever
            # consumes the graph next (e.g. shape inference) can use kwargs safely.
            node.args = normalized_args
            node.kwargs = normalized_kwargs
            continue

        try:
            normalize_to_acc_op(node, normalization_info, normalized_args,
                                normalized_kwargs)
        except Exception:
            print(f"Error during normalization for node: {node.format_node()}")
            raise

    # If there are any dead nodes left after normalization, eliminate them now.
    mod.graph.eliminate_dead_code()
コード例 #3
0
def serialize_module(fx_module: GraphModule,
                     weights: Dict,
                     name_prefix="") -> Dict:
    """Recursively Serializes a graph module (fx_module) to a dictionary which is later exported to JSON.
    It also adds all weights the provided weights dictionary by qualified_name.
    Dictionary Schema:
    MODULE
    {
        modules: {module_name: MODULE],
        nodes: [NODE],
        weights {qualified_name: WEIGHT},
    }
    NODE
    {
        shape: [],
        stride: [],
        dtype: dtype,
        is_quantized: bool,
        target: target,
        op_code: op_code,
        name: name,
        args: [],
        kwargs: {}
    }
    WEIGHT
    {
        dtype: dtype,
        is_quantized: bool,
        shape: [],
        QUANTIZATION,
    }
    QUANTIZATION
    {
        qscheme: qscheme,
        q_scale: float,
        q_zero_point: float,
        q_per_channel_scales, [],
        q_per_channel_zero_points: [],
        q_per_channel_axis, int
    }
    """
    serialized_dict: Dict[str, Any] = {}
    serialized_dict["modules"] = {}
    serialized_dict["weights"] = {}
    serialized_dict["nodes"] = []
    submodules = dict(fx_module.named_modules())
    prefix = f"{name_prefix}." if name_prefix else ""

    def add_weight_tensors(named_tensors):
        for name, p in named_tensors:
            if name.startswith("parent.") or not isinstance(p, torch.Tensor):
                continue
            weight_dict = serialize_weight(p, weights, prefix + name)
            serialized_dict["weights"].update(weight_dict)
            weights[prefix + name] = p

    add_weight_tensors(fx_module.named_parameters())
    add_weight_tensors(fx_module.named_buffers())

    def get_node_info(node):
        tensor_meta = get_tensor_meta(node)
        node_rep = {
            "shape": serialize_shape(tensor_meta.shape),
            "dtype": str(tensor_meta.dtype),
            "requires_grad": str(tensor_meta.requires_grad),
            "stride": serialize_stride(tensor_meta.stride),
            "is_quantized": tensor_meta.is_quantized,
        }

        if tensor_meta.is_quantized:
            node_rep["qscheme"] = str(tensor_meta.qscheme)

            if tensor_meta.qscheme in {
                    torch.per_tensor_affine,
                    torch.per_tensor_symmetric,
            }:
                node_rep["q_scale"] = tensor_meta.q_scale
                node_rep["q_zero_point"] = tensor_meta.q_zero_point

        # Add all extra lowering_info that was provided in node.meta.
        lowering_info = node.meta.get("lowering_info")
        if lowering_info is not None:
            overlapping_keys = node_rep.keys() & lowering_info.keys()
            assert (
                len(overlapping_keys) == 0
            ), f"Overlap found between lowering_info and node_rep: {overlapping_keys}"
            node_rep.update(lowering_info)

        return node_rep

    # Note: lift_lowering_attrs_to_nodes is only used to support leaf modules
    # that cannot currently be symbolically traced into, e.g. batch norm.
    lift_lowering_attrs_to_nodes(fx_module)
    for node in fx_module.graph.nodes:
        node_rep: Dict[str, Any] = {}
        # Get shape/type info, currently not needed for call_module node
        # whose target is a GraphModule and output node.
        if (not (node.op == "call_module"
                 and isinstance(submodules[node.target], GraphModule))
                and node.op != "output"):
            node_rep.update(get_node_info(node))

        # Recurse down into any submodules we are calling.
        if node.op == "call_module":
            if isinstance(submodules[node.target], GraphModule):
                serialized_module = serialize_module(
                    getattr(fx_module, node.target), weights, node.target)
                serialized_dict["modules"][node.target] = serialized_module
            else:
                node_rep["parameters"] = serialize_leaf_module(
                    node,
                    serialized_dict["weights"],
                    weights,
                    prefix + node.target,
                )

        if node.op == "call_function":
            node_rep["target"] = _get_qualified_name(node.target)
        else:
            node_rep["target"] = str(node.target)

        # Make sure we capture all constants.
        if node.op == "get_attr":
            # If we are targeting a parent constant we update the target.
            if node.target.startswith("parent."):
                stripped_name = node.target[len("parent."):]
                node.name = stripped_name
                node_rep["target"] = stripped_name
                weight = serialize_weight(weights[stripped_name], weights,
                                          node.target[len("parent."):])
                # For quantized embedding tables we need to update the shape/type,
                # so we check if the users of this get_attr is a quantized EB and this is the weight for the EB.
                user_targets = {
                    _get_qualified_name(n.target).replace(
                        "torch.fx.experimental.fx_acc.",
                        "").replace("glow.fb.fx.", ""): n
                    for n in node.users.keys()
                }
                if ("acc_ops.embedding_bag_byte_rowwise_offsets"
                        in user_targets
                        and str(user_targets[
                            "acc_ops.embedding_bag_byte_rowwise_offsets"].
                                kwargs["weight"]) == stripped_name):
                    weight[stripped_name]["dtype"] = "acc.uint8fused"
                # Same as above, but for the 4 bit version.
                if ("acc_ops.embedding_bag_4bit_rowwise_offsets"
                        in user_targets
                        and str(user_targets[
                            "acc_ops.embedding_bag_4bit_rowwise_offsets"].
                                kwargs["weight"]) == stripped_name):
                    weight[stripped_name]["dtype"] = "acc.uint4fused"

                serialized_dict["weights"].update(weight)
            else:
                # Find the actual target parameter/buffer from the fx_module.
                submod_path, _, target_name = node.target.rpartition(".")
                submod: Optional[torch.nn.Module] = (
                    fx_module.get_submodule(submod_path)
                    if submod_path else fx_module)
                assert submod is not None, f"submod {submod_path} not found"
                target = getattr(submod, target_name, None)
                assert target is not None, f"{target_name} not an attr of {submod_path}"
                qualname = prefix + node.target
                # Check that the target is a tensor, and that we haven't added it already from a leaf module.
                if isinstance(target,
                              torch.Tensor) and qualname not in weights:
                    weight = serialize_weight(target, weights, qualname)
                    serialized_dict["weights"].update(weight)
                    weights[qualname] = target

        node_rep["op_code"] = node.op
        node_rep["name"] = node.name

        def get_user_info(user_node: Argument) -> Any:
            return {"is_node": True, "name": str(user_node)}

        def get_arg_info(arg: Argument) -> Any:
            if isinstance(arg, torch.fx.Node):
                return {"is_node": True, "name": str(arg)}
            elif isinstance(arg,
                            (torch.dtype, torch.memory_format, torch.qscheme)):
                return str(arg)
            else:
                return arg

        def get_output_arg_info(arg: Node) -> Dict[str, Any]:
            node_rep: Dict[str, Any] = get_arg_info(arg)
            node_rep.update(get_node_info(arg))
            return node_rep

        if node.op == "output":
            node_rep["args"] = map_arg(
                node.args,
                get_output_arg_info,
            )

            # If there're multiple outputs then node_rep["args"][0] will be a tuple.
            # In this case we want to unpack the tuple.
            if isinstance(node_rep["args"][0], tuple):
                node_rep["args"] = node_rep["args"][0]
        else:
            node_rep["args"] = map_aggregate(node.args, get_arg_info)

        node_rep["kwargs"] = map_aggregate(node.kwargs, get_arg_info)
        node_rep["users"] = map_aggregate(list(node.users.keys()),
                                          get_user_info)
        serialized_dict["nodes"] += [node_rep]

    return serialized_dict
コード例 #4
0
def serialize_module(fx_module: GraphModule, weights: Dict, name_prefix="") -> Dict:
    """Recursively Serializes a graph module (fx_module) to a dictionary which is later exported to JSON.
    It also adds all weights the provided weights dictionary by qualified_name.
    Dictionary Schema:
    MODULE
    {
        modules: {module_name: MODULE],
        nodes: [NODE],
        weights {qualified_name: WEIGHT},
    }
    NODE
    {
        shape: [],
        dtype: dtype,
        target: target,
        op_code: op_code,
        name: name,
        args: [],
        kwargs: {}
    }
    WEIGHT
    {
        dtype: dtype,
        is_quantized: bool,
        shape: [],
        quantization_info: QUANTIZATION
    }
    QUANTIZATION
    {
        qscheme: qscheme,
        q_scale: float,
        q_zero_point: float,
        q_per_channel_scales, [],
        q_per_channel_zero_points: [],
        q_per_channel_axis, int
    }
    """
    serialized_dict: Dict[str, Any] = {}
    serialized_dict["modules"] = {}
    serialized_dict["weights"] = {}
    serialized_dict["nodes"] = []
    submodules = dict(fx_module.named_modules())
    prefix = f"{name_prefix}." if name_prefix else ""

    def add_weight_tensors(named_tensors):
        for name, p in named_tensors:
            if name.startswith("parent.") or not isinstance(p, torch.Tensor):
                continue
            weight = serialize_weight(p)
            serialized_dict["weights"][prefix + name] = weight
            weights[prefix + name] = p

    add_weight_tensors(fx_module.named_parameters())
    add_weight_tensors(fx_module.named_buffers())

    # Note: lift_lowering_attrs_to_nodes is only used to support leaf modules
    # that cannot currently be symbolically traced into, e.g. batch norm.
    lift_lowering_attrs_to_nodes(fx_module)
    for node in fx_module.graph.nodes:
        node_rep: Dict[str, Any] = {}
        # Get shape/type info, currently not needed for call_module node
        # whose target is a GraphModule and output node.
        if (
            not (
                node.op == "call_module"
                and isinstance(submodules[node.target], GraphModule)
            )
            and node.op != "output"
        ):
            shape, dtype = get_shape_and_dtype(node)
            node_rep["shape"] = serialize_shape(shape)
            node_rep["dtype"] = str(dtype)

        # Recurse down into any submodules we are calling.
        if node.op == "call_module":
            if isinstance(submodules[node.target], GraphModule):
                serialized_module = serialize_module(
                    getattr(fx_module, node.target), weights, node.target
                )
                serialized_dict["modules"][node.target] = serialized_module
            else:
                node_rep["parameters"] = serialize_leaf_module(
                    node,
                    serialized_dict["weights"],
                    weights,
                    prefix + node.target,
                )

        if node.op == "call_function":
            node_rep["target"] = _get_qualified_name(node.target)
        else:
            node_rep["target"] = str(node.target)

        # Make sure we capture all constants.
        if node.op == "get_attr":
            # If we are targeting a parent constant we update the target.
            if node.target.startswith("parent."):
                stripped_name = node.target[len("parent.") :]
                node.name = stripped_name
                node_rep["target"] = stripped_name
                weight = serialize_weight(weights[stripped_name])
                serialized_dict["weights"][stripped_name] = weight
            else:
                # Find the actual target parameter/buffer from the fx_module.
                submod_path, _, target_name = node.target.rpartition(".")
                submod: Optional[torch.nn.Module] = (
                    fx_module.get_submodule(submod_path) if submod_path else fx_module
                )
                assert submod is not None, f"submod {submod_path} not found"
                target = getattr(submod, target_name, None)
                assert target is not None, f"{target_name} not an attr of {submod_path}"
                qualname = prefix + node.target
                # Check that the target is a tensor, and that we haven't added it already from a leaf module.
                if isinstance(target, torch.Tensor) and qualname not in weights:
                    weight = serialize_weight(target)
                    serialized_dict["weights"][qualname] = weight
                    weights[qualname] = target

        node_rep["op_code"] = node.op
        node_rep["name"] = node.name

        if node.op == "output":

            def get_output_info(arg: Node) -> Argument:
                shape, dtype = get_shape_and_dtype(arg)
                return {
                    "is_node": True,
                    "name": str(arg),
                    "shape": serialize_shape(shape),
                    "dtype": str(dtype),
                }

            node_rep["args"] = map_arg(
                node.args,
                get_output_info,
            )

            # If there're multiple outputs then node_rep["args"][0] will be a tuple.
            # In this case we want to unpack the tuple.
            if isinstance(node_rep["args"][0], tuple):
                node_rep["args"] = node_rep["args"][0]
        else:
            node_rep["args"] = map_arg(
                node.args, lambda arg: {"is_node": True, "name": str(arg)}
            )

        node_rep["kwargs"] = map_arg(
            node.kwargs, lambda arg: {"is_node": True, "name": str(arg)}
        )
        serialized_dict["nodes"] += [node_rep]

    return serialized_dict
コード例 #5
0
def serialize_module(fx_module: GraphModule,
                     weights: Dict,
                     name_prefix="") -> Dict:
    """Recursively Serializes a graph module (fx_module) to a dictionary which is later exported to JSON.
    It also adds all weights the provided weights dictionary by qualified_name.
    Dictionary Schema:
    MODULE
    {
        modules: {module_name: MODULE],
        nodes: [NODE],
        weights {qualified_name: WEIGHT},
    }
    NODE
    {
        shape: [],
        dtype: dtype,
        target: target,
        op_code: op_code,
        name: name,
        args: [],
        kwargs: {}
    }
    WEIGHT
    {
        dtype: dtype,
        is_quantized: bool,
        shape: [],
        quantization_info: QUANTIZATION
    }
    QUANTIZATION
    {
        qscheme: qscheme,
        q_scale: float,
        q_zero_point: float,
        q_per_channel_scales, [],
        q_per_channel_zero_points: [],
        q_per_channel_axis, int
    }
    """
    serialized_dict: Dict[str, Any] = {}
    serialized_dict["modules"] = {}
    serialized_dict["weights"] = {}
    serialized_dict["nodes"] = []
    parameters = fx_module.named_parameters()
    prefix = f"{name_prefix}." if name_prefix else ""
    submodules = dict(fx_module.named_modules())
    for name, p in parameters:
        if isinstance(p, torch.Tensor):
            weight = serialize_weight(p)
            serialized_dict["weights"][prefix + name] = weight
            weights[prefix + name] = p
    # Note: lift_lowering_attrs_to_nodes is only used to support leaf modules
    # that cannot currently be symbolically traced into, e.g. batch norm.
    lift_lowering_attrs_to_nodes(fx_module)
    for node in fx_module.graph.nodes:
        node_rep: Dict[str, Any] = {}
        # Get shape/type info, currently not needed for call_module.
        if node.op != "call_module" or not isinstance(submodules[node.target],
                                                      GraphModule):
            shape = getattr(node, "shape", None)
            if shape:
                node_rep["shape"] = serialize_shape(shape)
            else:
                raise RuntimeError(
                    "Node has no shape attr, this is likely because shape propagation has not been run on this Graph."
                )
            dtype = getattr(node, "dtype", None)
            if dtype:
                node_rep["dtype"] = str(dtype)
            else:
                raise RuntimeError(
                    "Node has no dtype attr, this is likely because shape propagation has not been run on this Graph."
                )

        # Recurse down into any submodules we are calling.
        if node.op == "call_module":
            if isinstance(submodules[node.target], GraphModule):
                serialized_module = serialize_module(
                    getattr(fx_module, node.target), weights, node.target)
                serialized_dict["modules"][node.target] = serialized_module
            else:
                node_rep["parameters"] = serialize_leaf_module(
                    node,
                    serialized_dict["weights"],
                    weights,
                    prefix + node.target,
                )

        if node.op == "call_function":
            node_rep["target"] = _get_qualified_name(node.target)
        else:
            node_rep["target"] = str(node.target)

        # Make sure we capture all constants.
        if node.op == "get_attr":
            # If we are targeting a parent constant we update the target.
            if node.target.startswith("parent."):
                node.name = node.name[len("parent."):]
                node_rep["target"] = str(node.target[len("parent."):])
                weight = serialize_weight(
                    weights[node.target[len("parent."):]])
                serialized_dict["weights"][
                    node.target[len("parent."):]] = weight
            else:
                # Iterate through the module hierarchy to find the attr.
                target = fx_module
                split = node.target.split(".")
                assert len(split)
                while len(split):
                    target = getattr(target, split.pop(0))

                qualname = prefix + node.target
                # Check that the target is a tensor, and that we haven't added it already from a leaf module.
                if isinstance(target,
                              torch.Tensor) and qualname not in weights:
                    weight = serialize_weight(target)
                    serialized_dict["weights"][prefix + node.target] = weight
                    weights[prefix + node.target] = target

        node_rep["op_code"] = node.op
        node_rep["name"] = node.name
        node_rep["args"] = map_arg(
            node.args, lambda arg: {
                "is_node": True,
                "name": str(arg)
            })
        node_rep["kwargs"] = map_arg(
            node.kwargs, lambda arg: {
                "is_node": True,
                "name": str(arg)
            })
        serialized_dict["nodes"] += [node_rep]

    return serialized_dict