def propose_partitions(self) -> List[Partition]: candidates: NodeList = self.__get_supported_nodes() # assumptions: nodes in candidate list is sorted in topological order assignment: Dict[Node, int] = {} # maping from node to partition_id partitions_by_id: Dict[int, Partition] = { } # mapping from partition_id to partition new_partition_id = itertools.count() def assign(node: Node, id: Optional[int] = None): # If id is None, remove the node from original assigment # node has been assigned before, clean up and re-assign if node in assignment: original_id = assignment[node] del assignment[node] partitions_by_id[original_id].remove_node(node) if partitions_by_id[original_id].size() == 0: del partitions_by_id[original_id] if id is not None: assignment[node] = id if id not in partitions_by_id: partitions_by_id[id] = Partition(id=id, nodes=[node]) else: partitions_by_id[id].add_node(node) logging.debug("Proposing partitions...") # visit candidates in reversed topological order for node in reversed(candidates): # use Dict as an ordered set to ensure deterministic partitioning result, don't care value user_partitions: Dict[Partition, None] = {} for user_node in node.users: if user_node in assignment: id = assignment[user_node] user_partitions[partitions_by_id[id]] = None else: user_partitions[Partition(nodes=[user_node])] = None # Filter out all the partitions that has dependency on other users # TODO: find a better way to do this, rather than pair-wise comparision user_partitions_list = list(user_partitions.keys()) for i in range(len(user_partitions_list)): for j in range(i + 1, len(user_partitions_list)): pi = user_partitions_list[i] pj = user_partitions_list[j] dependency = self.__partition_depends_on(pi, pj) if dependency == 1 and pj in user_partitions: del user_partitions[pj] elif dependency == -1 and pi in user_partitions: del user_partitions[pi] # We use the following rules for partition assignment: # 1. If none of the candidates has been assigned to a partition, create a new partition # 2. If there is one partition candidate, assign to the partition # 3. If there are more than one partition candidates, assign current node to the first partition and # merge the other partitions with first partition, since user_partitions doesn't have depedency between # each other. assigned_candidate_partition_ids = [ partition.id for partition in user_partitions if partition.id is not None ] if len(assigned_candidate_partition_ids) == 0: # create a new partition assign(node, next(new_partition_id)) elif len(assigned_candidate_partition_ids) == 1: id = assigned_candidate_partition_ids[0] assign(node, id) else: # users are assigned to more than one partition, since user_partitions doesn't have # dependency on each other, they can be fused into a single partition id = assigned_candidate_partition_ids[0] assign(node, id) reassignment: Dict[Node, int] = {} for other_id in assigned_candidate_partition_ids[1:]: for other_node in partitions_by_id[other_id].nodes: reassignment[other_node] = id for other_node in reassignment: assign(other_node, id) # post processing to re-assign "getitem" nodes into upstream partition logger.debug( "Reassigning getitem nodes to its producer node's partition...") nodes_reassignment: Dict[Node, int] = {} for node in self.graph_module.graph.nodes: is_tuple_output = True for user in node.users: if user.op != "call_function" or \ _get_qualified_name(user.target) != "_operator.getitem": # type: ignore[arg-type] is_tuple_output = False break # node has tuple outputs, re-assign all following getitem node into node's partition if is_tuple_output: id = assignment.get(node, None) # type: ignore[arg-type] for user in node.users: if assignment.get(user, None) != id: # type: ignore[arg-type] nodes_reassignment[user] = id for node, id in nodes_reassignment.items(): assign(node, id) # filter out single node partitions if not self.allows_single_node_partition: logger.debug("Filtering out single node partitions...") non_compute_ops = {"torch.ops.aten.view", "_operator.getitem"} partitions_to_remove: List[int] = [] for id, partition in partitions_by_id.items(): compute_node_count = 0 for node in partition.nodes: if node.op == "call_function" and \ _get_qualified_name(node.target) not in non_compute_ops: # type: ignore[arg-type] compute_node_count += 1 if compute_node_count <= 1: partitions_to_remove.append(id) for id in partitions_to_remove: del partitions_by_id[id] logging.debug("Partitions proposed:") for id, partition in partitions_by_id.items(): logging.debug(f"partition #{id}", [node.name for node in partition.nodes]) return list(partitions_by_id.values())
def normalize(mod: torch.fx.GraphModule, expect_nodes_have_shapes: bool = False): assert len(_normalization_dict) > 0 graph = mod.graph # For "call_module" node we return _base_class_origin if it's a # RewrittenModule, otherwise, return its type. For other nodes, # we return node.target. def get_target(mod: torch.fx.GraphModule, node: torch.fx.Node): if node.op != "call_module": return node.target # Find the module that node.target points to m = dict(mod.named_modules())[node.target] return getattr(m, "_base_class_origin", type(m)) def normalize_to_acc_op( node: torch.fx.Node, normalization_info: NormalizationInfo, normalized_args: Tuple[Any, ...], normalized_kwargs: Dict[str, Any], ): # If there's a custom mapping function then use it. if normalization_info.custom_mapping_fn is not None: # For custom mapping, the normalized_kwargs are used for the original op, # i.e. *before* custom acc_ops normalization. Do that now. node.args = normalized_args node.kwargs = normalized_kwargs new_node = normalization_info.custom_mapping_fn(node, mod) # If a new node is returned then use it to replace the old node. Otherwise # the custom mapping function did its own replacement, so return early. if new_node is None: return else: # If there's kwargs_to_move_to_acc_out_ty then use it to setup acc_out_ty in # normalized_kwargs, and remove the kwarg from normalized_kwargs. move_kwargs_to_acc_out_ty(normalization_info, normalized_kwargs) # All acc ops are functions. Create a call to the correct acc_ops target using # the normalized kwargs provided. with graph.inserting_before(node): new_node = graph.create_node( "call_function", normalization_info.new_fn_target, args=normalized_args, kwargs=normalized_kwargs, name=node.name, ) new_node.meta = node.meta.copy() # Finally replace the original node with the normalized node. node.replace_all_uses_with(new_node) graph.erase_node(node) for node in graph.nodes: if node.op in {"placeholder", "get_attr", "output"}: continue normalization_info = _normalization_dict.get( (node.op, get_target(mod, node))) # Also check if the torch_packaged version of the op was specified to be normalized. if normalization_info is None and node.op == "call_function": # Strip off the mangle_index suffix here before checking the map. target = re.sub( r"\A<torch_package_\d+>", "<torch_package_>", _get_qualified_name(node.target), ) torch_package_op_and_target = (node.op, target) normalization_info = _normalization_dict.get( torch_package_op_and_target) if normalization_info is None: continue # Get the normalized kwargs to be used by normalize_to_acc_op below. If # normalization_info.arg_replacement_tuples is empty then assume the function # signature must be left as is. assert normalization_info.arg_replacement_tuples is not None if len(normalization_info.arg_replacement_tuples) == 0: normalized_args = node.args normalized_kwargs = node.kwargs else: normalized_args = () try: normalized_kwargs = get_normalized_kwargs( node, normalization_info.arg_replacement_tuples) except Exception: print( f"Error during kwarg normalization for: {node.format_node()}; " f"arg_replacement_tuples={normalization_info.arg_replacement_tuples}" ) raise if (normalization_info.needs_shapes_for_normalization and not expect_nodes_have_shapes): # All nodes needing shapes for normalization should be custom mapped. assert normalization_info.custom_mapping_fn is not None # For custom mapping, the normalized_kwargs are used for the original op, # i.e. *before* custom acc_ops normalization. Do that now so that whoever # consumes the graph next (e.g. shape inference) can use kwargs safely. node.args = normalized_args node.kwargs = normalized_kwargs continue try: normalize_to_acc_op(node, normalization_info, normalized_args, normalized_kwargs) except Exception: print(f"Error during normalization for node: {node.format_node()}") raise # If there are any dead nodes left after normalization, eliminate them now. mod.graph.eliminate_dead_code()
def serialize_module(fx_module: GraphModule, weights: Dict, name_prefix="") -> Dict: """Recursively Serializes a graph module (fx_module) to a dictionary which is later exported to JSON. It also adds all weights the provided weights dictionary by qualified_name. Dictionary Schema: MODULE { modules: {module_name: MODULE], nodes: [NODE], weights {qualified_name: WEIGHT}, } NODE { shape: [], stride: [], dtype: dtype, is_quantized: bool, target: target, op_code: op_code, name: name, args: [], kwargs: {} } WEIGHT { dtype: dtype, is_quantized: bool, shape: [], QUANTIZATION, } QUANTIZATION { qscheme: qscheme, q_scale: float, q_zero_point: float, q_per_channel_scales, [], q_per_channel_zero_points: [], q_per_channel_axis, int } """ serialized_dict: Dict[str, Any] = {} serialized_dict["modules"] = {} serialized_dict["weights"] = {} serialized_dict["nodes"] = [] submodules = dict(fx_module.named_modules()) prefix = f"{name_prefix}." if name_prefix else "" def add_weight_tensors(named_tensors): for name, p in named_tensors: if name.startswith("parent.") or not isinstance(p, torch.Tensor): continue weight_dict = serialize_weight(p, weights, prefix + name) serialized_dict["weights"].update(weight_dict) weights[prefix + name] = p add_weight_tensors(fx_module.named_parameters()) add_weight_tensors(fx_module.named_buffers()) def get_node_info(node): tensor_meta = get_tensor_meta(node) node_rep = { "shape": serialize_shape(tensor_meta.shape), "dtype": str(tensor_meta.dtype), "requires_grad": str(tensor_meta.requires_grad), "stride": serialize_stride(tensor_meta.stride), "is_quantized": tensor_meta.is_quantized, } if tensor_meta.is_quantized: node_rep["qscheme"] = str(tensor_meta.qscheme) if tensor_meta.qscheme in { torch.per_tensor_affine, torch.per_tensor_symmetric, }: node_rep["q_scale"] = tensor_meta.q_scale node_rep["q_zero_point"] = tensor_meta.q_zero_point # Add all extra lowering_info that was provided in node.meta. lowering_info = node.meta.get("lowering_info") if lowering_info is not None: overlapping_keys = node_rep.keys() & lowering_info.keys() assert ( len(overlapping_keys) == 0 ), f"Overlap found between lowering_info and node_rep: {overlapping_keys}" node_rep.update(lowering_info) return node_rep # Note: lift_lowering_attrs_to_nodes is only used to support leaf modules # that cannot currently be symbolically traced into, e.g. batch norm. lift_lowering_attrs_to_nodes(fx_module) for node in fx_module.graph.nodes: node_rep: Dict[str, Any] = {} # Get shape/type info, currently not needed for call_module node # whose target is a GraphModule and output node. if (not (node.op == "call_module" and isinstance(submodules[node.target], GraphModule)) and node.op != "output"): node_rep.update(get_node_info(node)) # Recurse down into any submodules we are calling. if node.op == "call_module": if isinstance(submodules[node.target], GraphModule): serialized_module = serialize_module( getattr(fx_module, node.target), weights, node.target) serialized_dict["modules"][node.target] = serialized_module else: node_rep["parameters"] = serialize_leaf_module( node, serialized_dict["weights"], weights, prefix + node.target, ) if node.op == "call_function": node_rep["target"] = _get_qualified_name(node.target) else: node_rep["target"] = str(node.target) # Make sure we capture all constants. if node.op == "get_attr": # If we are targeting a parent constant we update the target. if node.target.startswith("parent."): stripped_name = node.target[len("parent."):] node.name = stripped_name node_rep["target"] = stripped_name weight = serialize_weight(weights[stripped_name], weights, node.target[len("parent."):]) # For quantized embedding tables we need to update the shape/type, # so we check if the users of this get_attr is a quantized EB and this is the weight for the EB. user_targets = { _get_qualified_name(n.target).replace( "torch.fx.experimental.fx_acc.", "").replace("glow.fb.fx.", ""): n for n in node.users.keys() } if ("acc_ops.embedding_bag_byte_rowwise_offsets" in user_targets and str(user_targets[ "acc_ops.embedding_bag_byte_rowwise_offsets"]. kwargs["weight"]) == stripped_name): weight[stripped_name]["dtype"] = "acc.uint8fused" # Same as above, but for the 4 bit version. if ("acc_ops.embedding_bag_4bit_rowwise_offsets" in user_targets and str(user_targets[ "acc_ops.embedding_bag_4bit_rowwise_offsets"]. kwargs["weight"]) == stripped_name): weight[stripped_name]["dtype"] = "acc.uint4fused" serialized_dict["weights"].update(weight) else: # Find the actual target parameter/buffer from the fx_module. submod_path, _, target_name = node.target.rpartition(".") submod: Optional[torch.nn.Module] = ( fx_module.get_submodule(submod_path) if submod_path else fx_module) assert submod is not None, f"submod {submod_path} not found" target = getattr(submod, target_name, None) assert target is not None, f"{target_name} not an attr of {submod_path}" qualname = prefix + node.target # Check that the target is a tensor, and that we haven't added it already from a leaf module. if isinstance(target, torch.Tensor) and qualname not in weights: weight = serialize_weight(target, weights, qualname) serialized_dict["weights"].update(weight) weights[qualname] = target node_rep["op_code"] = node.op node_rep["name"] = node.name def get_user_info(user_node: Argument) -> Any: return {"is_node": True, "name": str(user_node)} def get_arg_info(arg: Argument) -> Any: if isinstance(arg, torch.fx.Node): return {"is_node": True, "name": str(arg)} elif isinstance(arg, (torch.dtype, torch.memory_format, torch.qscheme)): return str(arg) else: return arg def get_output_arg_info(arg: Node) -> Dict[str, Any]: node_rep: Dict[str, Any] = get_arg_info(arg) node_rep.update(get_node_info(arg)) return node_rep if node.op == "output": node_rep["args"] = map_arg( node.args, get_output_arg_info, ) # If there're multiple outputs then node_rep["args"][0] will be a tuple. # In this case we want to unpack the tuple. if isinstance(node_rep["args"][0], tuple): node_rep["args"] = node_rep["args"][0] else: node_rep["args"] = map_aggregate(node.args, get_arg_info) node_rep["kwargs"] = map_aggregate(node.kwargs, get_arg_info) node_rep["users"] = map_aggregate(list(node.users.keys()), get_user_info) serialized_dict["nodes"] += [node_rep] return serialized_dict
def serialize_module(fx_module: GraphModule, weights: Dict, name_prefix="") -> Dict: """Recursively Serializes a graph module (fx_module) to a dictionary which is later exported to JSON. It also adds all weights the provided weights dictionary by qualified_name. Dictionary Schema: MODULE { modules: {module_name: MODULE], nodes: [NODE], weights {qualified_name: WEIGHT}, } NODE { shape: [], dtype: dtype, target: target, op_code: op_code, name: name, args: [], kwargs: {} } WEIGHT { dtype: dtype, is_quantized: bool, shape: [], quantization_info: QUANTIZATION } QUANTIZATION { qscheme: qscheme, q_scale: float, q_zero_point: float, q_per_channel_scales, [], q_per_channel_zero_points: [], q_per_channel_axis, int } """ serialized_dict: Dict[str, Any] = {} serialized_dict["modules"] = {} serialized_dict["weights"] = {} serialized_dict["nodes"] = [] submodules = dict(fx_module.named_modules()) prefix = f"{name_prefix}." if name_prefix else "" def add_weight_tensors(named_tensors): for name, p in named_tensors: if name.startswith("parent.") or not isinstance(p, torch.Tensor): continue weight = serialize_weight(p) serialized_dict["weights"][prefix + name] = weight weights[prefix + name] = p add_weight_tensors(fx_module.named_parameters()) add_weight_tensors(fx_module.named_buffers()) # Note: lift_lowering_attrs_to_nodes is only used to support leaf modules # that cannot currently be symbolically traced into, e.g. batch norm. lift_lowering_attrs_to_nodes(fx_module) for node in fx_module.graph.nodes: node_rep: Dict[str, Any] = {} # Get shape/type info, currently not needed for call_module node # whose target is a GraphModule and output node. if ( not ( node.op == "call_module" and isinstance(submodules[node.target], GraphModule) ) and node.op != "output" ): shape, dtype = get_shape_and_dtype(node) node_rep["shape"] = serialize_shape(shape) node_rep["dtype"] = str(dtype) # Recurse down into any submodules we are calling. if node.op == "call_module": if isinstance(submodules[node.target], GraphModule): serialized_module = serialize_module( getattr(fx_module, node.target), weights, node.target ) serialized_dict["modules"][node.target] = serialized_module else: node_rep["parameters"] = serialize_leaf_module( node, serialized_dict["weights"], weights, prefix + node.target, ) if node.op == "call_function": node_rep["target"] = _get_qualified_name(node.target) else: node_rep["target"] = str(node.target) # Make sure we capture all constants. if node.op == "get_attr": # If we are targeting a parent constant we update the target. if node.target.startswith("parent."): stripped_name = node.target[len("parent.") :] node.name = stripped_name node_rep["target"] = stripped_name weight = serialize_weight(weights[stripped_name]) serialized_dict["weights"][stripped_name] = weight else: # Find the actual target parameter/buffer from the fx_module. submod_path, _, target_name = node.target.rpartition(".") submod: Optional[torch.nn.Module] = ( fx_module.get_submodule(submod_path) if submod_path else fx_module ) assert submod is not None, f"submod {submod_path} not found" target = getattr(submod, target_name, None) assert target is not None, f"{target_name} not an attr of {submod_path}" qualname = prefix + node.target # Check that the target is a tensor, and that we haven't added it already from a leaf module. if isinstance(target, torch.Tensor) and qualname not in weights: weight = serialize_weight(target) serialized_dict["weights"][qualname] = weight weights[qualname] = target node_rep["op_code"] = node.op node_rep["name"] = node.name if node.op == "output": def get_output_info(arg: Node) -> Argument: shape, dtype = get_shape_and_dtype(arg) return { "is_node": True, "name": str(arg), "shape": serialize_shape(shape), "dtype": str(dtype), } node_rep["args"] = map_arg( node.args, get_output_info, ) # If there're multiple outputs then node_rep["args"][0] will be a tuple. # In this case we want to unpack the tuple. if isinstance(node_rep["args"][0], tuple): node_rep["args"] = node_rep["args"][0] else: node_rep["args"] = map_arg( node.args, lambda arg: {"is_node": True, "name": str(arg)} ) node_rep["kwargs"] = map_arg( node.kwargs, lambda arg: {"is_node": True, "name": str(arg)} ) serialized_dict["nodes"] += [node_rep] return serialized_dict
def serialize_module(fx_module: GraphModule, weights: Dict, name_prefix="") -> Dict: """Recursively Serializes a graph module (fx_module) to a dictionary which is later exported to JSON. It also adds all weights the provided weights dictionary by qualified_name. Dictionary Schema: MODULE { modules: {module_name: MODULE], nodes: [NODE], weights {qualified_name: WEIGHT}, } NODE { shape: [], dtype: dtype, target: target, op_code: op_code, name: name, args: [], kwargs: {} } WEIGHT { dtype: dtype, is_quantized: bool, shape: [], quantization_info: QUANTIZATION } QUANTIZATION { qscheme: qscheme, q_scale: float, q_zero_point: float, q_per_channel_scales, [], q_per_channel_zero_points: [], q_per_channel_axis, int } """ serialized_dict: Dict[str, Any] = {} serialized_dict["modules"] = {} serialized_dict["weights"] = {} serialized_dict["nodes"] = [] parameters = fx_module.named_parameters() prefix = f"{name_prefix}." if name_prefix else "" submodules = dict(fx_module.named_modules()) for name, p in parameters: if isinstance(p, torch.Tensor): weight = serialize_weight(p) serialized_dict["weights"][prefix + name] = weight weights[prefix + name] = p # Note: lift_lowering_attrs_to_nodes is only used to support leaf modules # that cannot currently be symbolically traced into, e.g. batch norm. lift_lowering_attrs_to_nodes(fx_module) for node in fx_module.graph.nodes: node_rep: Dict[str, Any] = {} # Get shape/type info, currently not needed for call_module. if node.op != "call_module" or not isinstance(submodules[node.target], GraphModule): shape = getattr(node, "shape", None) if shape: node_rep["shape"] = serialize_shape(shape) else: raise RuntimeError( "Node has no shape attr, this is likely because shape propagation has not been run on this Graph." ) dtype = getattr(node, "dtype", None) if dtype: node_rep["dtype"] = str(dtype) else: raise RuntimeError( "Node has no dtype attr, this is likely because shape propagation has not been run on this Graph." ) # Recurse down into any submodules we are calling. if node.op == "call_module": if isinstance(submodules[node.target], GraphModule): serialized_module = serialize_module( getattr(fx_module, node.target), weights, node.target) serialized_dict["modules"][node.target] = serialized_module else: node_rep["parameters"] = serialize_leaf_module( node, serialized_dict["weights"], weights, prefix + node.target, ) if node.op == "call_function": node_rep["target"] = _get_qualified_name(node.target) else: node_rep["target"] = str(node.target) # Make sure we capture all constants. if node.op == "get_attr": # If we are targeting a parent constant we update the target. if node.target.startswith("parent."): node.name = node.name[len("parent."):] node_rep["target"] = str(node.target[len("parent."):]) weight = serialize_weight( weights[node.target[len("parent."):]]) serialized_dict["weights"][ node.target[len("parent."):]] = weight else: # Iterate through the module hierarchy to find the attr. target = fx_module split = node.target.split(".") assert len(split) while len(split): target = getattr(target, split.pop(0)) qualname = prefix + node.target # Check that the target is a tensor, and that we haven't added it already from a leaf module. if isinstance(target, torch.Tensor) and qualname not in weights: weight = serialize_weight(target) serialized_dict["weights"][prefix + node.target] = weight weights[prefix + node.target] = target node_rep["op_code"] = node.op node_rep["name"] = node.name node_rep["args"] = map_arg( node.args, lambda arg: { "is_node": True, "name": str(arg) }) node_rep["kwargs"] = map_arg( node.kwargs, lambda arg: { "is_node": True, "name": str(arg) }) serialized_dict["nodes"] += [node_rep] return serialized_dict