def transpose_mapper(node: torch.fx.Node, _: nn.Module) -> torch.fx.Node: # Get the dim-permutation/shuffle shape_as_list = node.meta["tensor_meta"].shape ranks = len(shape_as_list) shuffle = list(i for i in range(ranks)) dim0 = cast(int, node.kwargs["dim0"]) dim1 = cast(int, node.kwargs["dim1"]) shuffle[dim0] = dim1 shuffle[dim1] = dim0 # Create the new acc_ops.permute node. Update all uses of the transpose # node and then delete the transpose node. with node.graph.inserting_after(node): permute_node = node.graph.call_function( the_function=permute, kwargs={ "input": node.kwargs.get("input"), "permutation": shuffle, }, ) permute_node.meta = node.meta.copy() node.replace_all_uses_with(permute_node) permute_node.graph.erase_node(node) return permute_node
def normalize_to_acc_op( node: torch.fx.Node, normalization_info: NormalizationInfo, normalized_args: Tuple[Any, ...], normalized_kwargs: Dict[str, Any], ): # If there's a custom mapping function then use it. if normalization_info.custom_mapping_fn is not None: # For custom mapping, the normalized_kwargs are used for the original op, # i.e. *before* custom acc_ops normalization. Do that now. node.args = normalized_args node.kwargs = normalized_kwargs new_node = normalization_info.custom_mapping_fn(node, mod) # If a new node is returned then use it to replace the old node. Otherwise # the custom mapping function did its own replacement, so return early. if new_node is None: return else: # If there's kwargs_to_move_to_acc_out_ty then use it to setup acc_out_ty in # normalized_kwargs, and remove the kwarg from normalized_kwargs. move_kwargs_to_acc_out_ty(normalization_info, normalized_kwargs) # All acc ops are functions. Create a call to the correct acc_ops target using # the normalized kwargs provided. with graph.inserting_before(node): new_node = graph.create_node( "call_function", normalization_info.new_fn_target, args=normalized_args, kwargs=normalized_kwargs, name=node.name, ) new_node.meta = node.meta.copy() # Finally replace the original node with the normalized node. node.replace_all_uses_with(new_node) graph.erase_node(node) # Don't wrap the acc_op node just because the original node was wrapped. if "is_wrapped" in new_node.meta: del new_node.meta["is_wrapped"]
def _get_node_style(self, node: torch.fx.Node) -> Dict[str, str]: template = { "shape": "record", "fillcolor": "#CAFFE3", "style": '"filled,rounded"', "fontcolor": "#000000", } if node.op in _COLOR_MAP: template["fillcolor"] = _COLOR_MAP[node.op] else: # Use a random color for each node; based on its name so it's stable. target_name = node._pretty_print_target(node.target) target_hash = int(hashlib.md5(target_name.encode()).hexdigest()[:8], 16) template["fillcolor"] = _HASH_COLOR_MAP[target_hash % len(_HASH_COLOR_MAP)] return template
def inline_lowp_func(n: torch.fx.Node): # If we find a call to a function in our "lowp" module, inline it if n.op == 'call_function' and n.target.__module__ == inline_lowp_func.__module__: # We want to insert the operations comprising the implementation of the # function before the function itself. Then, we can swap the output value # of the function call with the output value for its implementation nodes with n.graph.inserting_before(n): # We can inline code by using `fx.Proxy` instances. # map_arg traverses all aggregate types and applies the given function # to Node instances in the data structure. In this case, we are applying # the fx.Proxy constructor. proxy_args = torch.fx.node.map_arg(n.args, torch.fx.Proxy) proxy_kwargs = torch.fx.node.map_arg(n.kwargs, torch.fx.Proxy) # Call the function itself with proxy arguments. This will emit # nodes in the graph corresponding to the operations in the im- # plementation of the function output_proxy = n.target(*proxy_args, **proxy_kwargs) # Now replace the original node's uses with the output node of # the implementation. node.replace_all_uses_with(output_proxy.node) # Delete the old node node.graph.erase_node(node)