Beispiel #1
0
def transpose_mapper(node: torch.fx.Node, _: nn.Module) -> torch.fx.Node:
    # Get the dim-permutation/shuffle
    shape_as_list = node.meta["tensor_meta"].shape
    ranks = len(shape_as_list)
    shuffle = list(i for i in range(ranks))
    dim0 = cast(int, node.kwargs["dim0"])
    dim1 = cast(int, node.kwargs["dim1"])
    shuffle[dim0] = dim1
    shuffle[dim1] = dim0

    # Create the new acc_ops.permute node. Update all uses of the transpose
    # node and then delete the transpose node.
    with node.graph.inserting_after(node):
        permute_node = node.graph.call_function(
            the_function=permute,
            kwargs={
                "input": node.kwargs.get("input"),
                "permutation": shuffle,
            },
        )
        permute_node.meta = node.meta.copy()
        node.replace_all_uses_with(permute_node)

    permute_node.graph.erase_node(node)
    return permute_node
Beispiel #2
0
    def normalize_to_acc_op(
        node: torch.fx.Node,
        normalization_info: NormalizationInfo,
        normalized_args: Tuple[Any, ...],
        normalized_kwargs: Dict[str, Any],
    ):
        # If there's a custom mapping function then use it.
        if normalization_info.custom_mapping_fn is not None:
            # For custom mapping, the normalized_kwargs are used for the original op,
            # i.e. *before* custom acc_ops normalization. Do that now.
            node.args = normalized_args
            node.kwargs = normalized_kwargs
            new_node = normalization_info.custom_mapping_fn(node, mod)
            # If a new node is returned then use it to replace the old node. Otherwise
            # the custom mapping function did its own replacement, so return early.
            if new_node is None:
                return
        else:
            # If there's kwargs_to_move_to_acc_out_ty then use it to setup acc_out_ty in
            # normalized_kwargs, and remove the kwarg from normalized_kwargs.
            move_kwargs_to_acc_out_ty(normalization_info, normalized_kwargs)

            # All acc ops are functions. Create a call to the correct acc_ops target using
            # the normalized kwargs provided.
            with graph.inserting_before(node):
                new_node = graph.create_node(
                    "call_function",
                    normalization_info.new_fn_target,
                    args=normalized_args,
                    kwargs=normalized_kwargs,
                    name=node.name,
                )
                new_node.meta = node.meta.copy()

        # Finally replace the original node with the normalized node.
        node.replace_all_uses_with(new_node)
        graph.erase_node(node)

        # Don't wrap the acc_op node just because the original node was wrapped.
        if "is_wrapped" in new_node.meta:
            del new_node.meta["is_wrapped"]
Beispiel #3
0
 def _get_node_style(self, node: torch.fx.Node) -> Dict[str, str]:
     template = {
         "shape": "record",
         "fillcolor": "#CAFFE3",
         "style": '"filled,rounded"',
         "fontcolor": "#000000",
     }
     if node.op in _COLOR_MAP:
         template["fillcolor"] = _COLOR_MAP[node.op]
     else:
         # Use a random color for each node; based on its name so it's stable.
         target_name = node._pretty_print_target(node.target)
         target_hash = int(hashlib.md5(target_name.encode()).hexdigest()[:8], 16)
         template["fillcolor"] = _HASH_COLOR_MAP[target_hash % len(_HASH_COLOR_MAP)]
     return template
Beispiel #4
0
def inline_lowp_func(n: torch.fx.Node):
    # If we find a call to a function in our "lowp" module, inline it
    if n.op == 'call_function' and n.target.__module__ == inline_lowp_func.__module__:
        # We want to insert the operations comprising the implementation of the
        # function before the function itself. Then, we can swap the output value
        # of the function call with the output value for its implementation nodes
        with n.graph.inserting_before(n):
            # We can inline code by using `fx.Proxy` instances.
            # map_arg traverses all aggregate types and applies the given function
            # to Node instances in the data structure. In this case, we are applying
            # the fx.Proxy constructor.
            proxy_args = torch.fx.node.map_arg(n.args, torch.fx.Proxy)
            proxy_kwargs = torch.fx.node.map_arg(n.kwargs, torch.fx.Proxy)
            # Call the function itself with proxy arguments. This will emit
            # nodes in the graph corresponding to the operations in the im-
            # plementation of the function
            output_proxy = n.target(*proxy_args, **proxy_kwargs)
            # Now replace the original node's uses with the output node of
            # the implementation.
            node.replace_all_uses_with(output_proxy.node)
            # Delete the old node
            node.graph.erase_node(node)