def return_first_non_observer_node( node: Node, gm: GraphModule, ) -> Node: """ If node is not an observer, returns it. If node is an observer, navigates up the graph and returns the first parent which is not an observer. For example, graph: (node_non_obs), node = node_non_obs : returns node_non_obs graph: (node_non_obs -> obs0), node = obs0 : returns node_non_obs graph: (node_non_obs -> obs0 -> fq0), node = fq0 : returns node_non_obs """ if node.op == "call_module": node_obj = getattr_from_fqn(gm, node.target) # type: ignore[arg-type] if is_activation_post_process(node_obj): assert len(node.args) == 1 assert isinstance(node.args[0], Node) node = node.args[0] # code duplication intended, not worth refactoring assert isinstance(node.target, str) node_obj = getattr_from_fqn(gm, node.target) if is_activation_post_process(node_obj): assert len(node.args) == 1 assert isinstance(node.args[0], Node) node = node.args[0] return node
def _get_scale_zp_from_function_args(node, gm, scale_arg_idx, zp_arg_idx): scale_node, zp_node = node.args[scale_arg_idx], node.args[zp_arg_idx] assert isinstance(scale_node, Node) and isinstance( scale_node.target, str) assert isinstance(zp_node, Node) and isinstance(zp_node.target, str) scale_obj = getattr_from_fqn(gm, scale_node.target) zp_obj = getattr_from_fqn(gm, zp_node.target) return (scale_obj, zp_obj)
def _get_node_target_type(node: Node, gm: GraphModule) -> Optional[NSNodeTargetType]: if node.op in ('call_function', 'call_method'): return node.target elif node.op == 'call_module': assert isinstance(node.target, str) mod = getattr_from_fqn(gm, node.target) return type(mod) return None
def _is_matchable(self, node: Node) -> bool: if node.op == 'call_function': return not (node.target in self.non_matchable_functions) elif node.op == 'call_module': assert isinstance(node.target, str) target_mod = getattr_from_fqn(self.gm, node.target) return not \ any(isinstance(target_mod, t) # type: ignore[arg-type] for t in self.non_matchable_modules) elif node.op == 'call_method': return not (node.target in self.non_matchable_methods) else: return False
def get_target_type_str(node: Node, gm: GraphModule) -> str: """ Returns a string representation of the type of the function or module pointed to by this node, or '' for other node types. """ target_type = "" if node.op in ("call_function", "call_method"): target_type = torch.typename(node.target) elif node.op == "call_module": assert isinstance(node.target, str) target_mod = getattr_from_fqn(gm, node.target) target_type = torch.typename(target_mod) return target_type
def get_number_of_non_param_args( node: Node, gm: GraphModule, ) -> int: """ Assumes that all non-param args occur first. Returns the number of non-param args expected for a node. For example, for F.linear(x, weight, bias) Returns 1, because x is a non-param arg and weight and bias are params. For lstm_mod(x, hid) Returns 2, because both x and hid are non-param args. """ if node.op == "call_module": node_obj = getattr_from_fqn(gm, node.target) # type: ignore[arg-type] if isinstance(node_obj, nn.LSTM): return 2 # default is 1 return 1
def __next__(self) -> NSSubgraph: """ Returns the next matchable subgraph. """ while len(self.stack) > 0: cur_end_node = self.stack.pop() if cur_end_node in self.seen_nodes: continue # for subgraphs which are single nodes, start_node == end_node # for subgraphs with more than one node, start node != end_node cur_start_node = cur_end_node # Subgraphs like linear-relu have the base node as the start node. # Subgraphs like dequantize-linear-relu-to(torch.float16) have the # base node as the second node. # The cur_base_op_node var will move to the actual node during # the fusion matching later in this code block. cur_base_op_node = cur_end_node # Check for potential fusions. For now, we are greedy # and always skip all non-base nodes of a fusion. For example, # if we match linear-relu backwards, we will always skip the # relu node and attempt to match the linear node. This can # be made configurable later if needed. for _reverse_fusion_ops, base_op_idx in get_reversed_fusions(): is_match = end_node_matches_reversed_fusion( cur_end_node, _reverse_fusion_ops, self.gm) if is_match: # navigate to the base node for rev_fusion_idx in range(len(_reverse_fusion_ops) - 1): self.seen_nodes.add(cur_start_node) # for now, assume that there are no other nodes # which need to be added to the stack cur_start_node = cur_start_node.args[ 0] # type: ignore[assignment] # if the base op index matches the current node, set it rev_base_op_idx = \ len(_reverse_fusion_ops) - 2 - base_op_idx if rev_fusion_idx == rev_base_op_idx: cur_base_op_node = cur_start_node break self.seen_nodes.add(cur_start_node) # add args of previous nodes to stack for arg in cur_start_node.all_input_nodes: self._recursively_add_node_arg_to_stack(arg) # skip unmatchable nodes # note: this check is done on the start_node, i.e. # if we are matching linear-relu in reverse, this would do the matchable # check on the linear if not self._is_matchable(cur_base_op_node): continue # If an observer or a fake_quant was not matched as a part of # a pattern of multiple nodes, ignore it. One case where this is # relevant is an observer on a graph input, which was added because # it is necessary for the next node. if cur_end_node.op == 'call_module' and cur_start_node is cur_end_node: maybe_obs = getattr_from_fqn( self.gm, cur_end_node.target) # type: ignore[arg-type] if isinstance(maybe_obs, (ObserverBase, FakeQuantizeBase)): continue return NSSubgraph(start_node=cur_start_node, end_node=cur_end_node, base_op_node=cur_base_op_node) raise StopIteration
def _get_subgraph_relationship_type( subgraph_a: NSSubgraph, subgraph_b: NSSubgraph, gm_a: GraphModule, gm_b: GraphModule, type_a_related_to_b: Set[Tuple[NSNodeTargetType, NSNodeTargetType]], ) -> SubgraphTypeRelationship: node_a = subgraph_a.base_op_node node_b = subgraph_b.base_op_node # TODO(next): make this code handle matching by what is before the base op if node_a.op != node_b.op: if not (node_a.op in ('call_function', 'call_method') and node_b.op in ('call_function', 'call_method')): return SubgraphTypeRelationship.NOT_RELATED if node_a.op in ('call_function', 'call_method'): key = (node_a.target, node_b.target) if key not in type_a_related_to_b: if node_a.target == node_b.target: return SubgraphTypeRelationship.EQUAL_BUT_UKNOWN else: return SubgraphTypeRelationship.NOT_RELATED # after this point, we are dealing with known types if node_a.target == node_b.target: node_a_has_prev = subgraph_a.base_op_node == subgraph_a.start_node node_b_has_prev = subgraph_b.base_op_node == subgraph_b.start_node if node_a_has_prev and (not node_b_has_prev): return SubgraphTypeRelationship.RELATED_BUT_NOT_EQUAL elif (not node_a_has_prev) and node_b_has_prev: return SubgraphTypeRelationship.RELATED_BUT_NOT_EQUAL elif (not node_a_has_prev) and (not node_b_has_prev): return SubgraphTypeRelationship.EQUAL else: # TODO(future PR): check for matches start_op_node and base_op_node return SubgraphTypeRelationship.EQUAL if key in type_a_related_to_b: return SubgraphTypeRelationship.RELATED_BUT_NOT_EQUAL else: return SubgraphTypeRelationship.NOT_RELATED elif node_a.op == 'call_module': assert (subgraph_a.base_op_node == subgraph_a.start_node and subgraph_b.base_op_node == subgraph_b.start_node), \ "Matching call_module patterns where base_op_node != start_node is not supported yet" # for call_module, we need to look up the modules to do the type check assert isinstance(node_a.target, str) mod_a = getattr_from_fqn(gm_a, node_a.target) assert isinstance(node_b.target, str) mod_b = getattr_from_fqn(gm_b, node_b.target) key = (type(mod_a), type(mod_b)) if key not in type_a_related_to_b: if type(mod_a) == type(mod_b): return SubgraphTypeRelationship.EQUAL_BUT_UKNOWN else: return SubgraphTypeRelationship.NOT_RELATED elif type(mod_a) == type(mod_b): return SubgraphTypeRelationship.EQUAL else: return SubgraphTypeRelationship.RELATED_BUT_NOT_EQUAL return SubgraphTypeRelationship.NOT_RELATED
def get_node_first_input_and_output_type( node: Node, gm: GraphModule, logger_cls: Callable, node_type_to_io_type_map: Dict[str, Set[NSNodeTargetType]], ) -> Tuple[NodeInputOrOutputType, NodeInputOrOutputType]: # TODO(future PR): clean this up FUNS_IO_TYPE_FP32 = node_type_to_io_type_map["funs_io_type_fp32"] FUNS_IO_TYPE_FP16 = node_type_to_io_type_map["funs_io_type_fp16"] FUNS_IO_TYPE_INT8 = node_type_to_io_type_map["funs_io_type_int8"] FUNS_IO_TYPE_FP32_OR_INT8 = node_type_to_io_type_map[ "funs_io_type_fp32_or_int8"] MODS_IO_TYPE_FP32 = node_type_to_io_type_map["mods_io_type_fp32"] MODS_IO_TYPE_INT8 = node_type_to_io_type_map["mods_io_type_int8"] MODS_IO_TYPE_FP32_OR_INT8 = node_type_to_io_type_map[ "mods_io_type_fp32_or_int8"] METHS_IO_TYPE_FP32_OR_INT8 = node_type_to_io_type_map[ "meths_io_type_fp32_or_int8"] if node.op == "call_function": if node.target in FUNS_IO_TYPE_FP32: return (NodeInputOrOutputType.FP32, NodeInputOrOutputType.FP32) if node.target in FUNS_IO_TYPE_FP16: return (NodeInputOrOutputType.FP16, NodeInputOrOutputType.FP16) elif node.target in FUNS_IO_TYPE_INT8: return (NodeInputOrOutputType.INT8, NodeInputOrOutputType.INT8) elif node.target in FUNS_IO_TYPE_FP32_OR_INT8: return ( NodeInputOrOutputType.FP32_OR_INT8, NodeInputOrOutputType.FP32_OR_INT8, ) else: return (NodeInputOrOutputType.UNKNOWN, NodeInputOrOutputType.UNKNOWN) elif node.op == "call_module": assert node.op == "call_module" assert isinstance(node.target, str) mod = getattr_from_fqn(gm, node.target) if isinstance(mod, (logger_cls, ObserverBase, FakeQuantizeBase)): # type: ignore[arg-type] # A logger or observer's input and output type is the output # type of the preceding node. first_arg = node.args[0] assert isinstance(first_arg, Node) ( _prev_node_input_type, prev_node_output_type, ) = get_node_first_input_and_output_type(first_arg, gm, logger_cls, node_type_to_io_type_map) return (prev_node_output_type, prev_node_output_type) is_known_fp32_input_module = any( isinstance(mod, target_type) for target_type in MODS_IO_TYPE_FP32 # type: ignore[arg-type] ) is_known_int8_input_module = any( isinstance(mod, target_type) for target_type in MODS_IO_TYPE_INT8 # type: ignore[arg-type] ) is_known_fp32_or_int8_input_module = any( isinstance(mod, target_type) for target_type in MODS_IO_TYPE_FP32_OR_INT8 # type: ignore[arg-type] ) if is_known_fp32_input_module: return (NodeInputOrOutputType.FP32, NodeInputOrOutputType.FP32) elif is_known_int8_input_module: return (NodeInputOrOutputType.INT8, NodeInputOrOutputType.INT8) elif is_known_fp32_or_int8_input_module: return ( NodeInputOrOutputType.FP32_OR_INT8, NodeInputOrOutputType.FP32_OR_INT8, ) else: return (NodeInputOrOutputType.UNKNOWN, NodeInputOrOutputType.UNKNOWN) elif node.op == "call_method": if node.target == "dequantize": # Dequantize is a special node because it allows multiple input types. # So, we look up the output type of the previous node and return that # as the input type of this node instance. prev_node = node.args[0] assert isinstance(prev_node, Node) ( _prev_node_input_type, prev_node_output_type, ) = get_node_first_input_and_output_type(prev_node, gm, logger_cls, node_type_to_io_type_map) return (prev_node_output_type, NodeInputOrOutputType.FP32) elif node.target == "to": # to is a special node because it allows multiple input types. # So, we look up the output type of the previous node and return that # as the input type of this node instance. We also look up the target # of to and return the correct output type. prev_node = node.args[0] assert isinstance(prev_node, Node) ( _prev_node_input_type, prev_node_output_type, ) = get_node_first_input_and_output_type(prev_node, gm, logger_cls, node_type_to_io_type_map) cur_node_dtype_target = node.args[1] assert (cur_node_dtype_target is torch.float16 ), f"{cur_node_dtype_target} handling needs to be added" return (prev_node_output_type, NodeInputOrOutputType.FP16) elif node.target in METHS_IO_TYPE_FP32_OR_INT8: return ( NodeInputOrOutputType.FP32_OR_INT8, NodeInputOrOutputType.FP32_OR_INT8, ) return (NodeInputOrOutputType.UNKNOWN, NodeInputOrOutputType.UNKNOWN) else: return (NodeInputOrOutputType.UNKNOWN, NodeInputOrOutputType.UNKNOWN)
def get_node_input_qparams( node: Node, gm: GraphModule, node_type_to_io_type_map: Dict[str, Set[NSNodeTargetType]], ) -> Optional[Tuple[Union[torch.Tensor, float], Union[torch.Tensor, int]]]: """ Returns the qparams (scale, zero_point) of the first input to `node`, if they can be inferred from the graph. """ prev_node = node.args[0] if not isinstance(prev_node, Node): return None MODS_IO_TYPE_FP32_OR_INT8 = node_type_to_io_type_map[ "mods_io_type_fp32_or_int8"] def _get_scale_zp_from_function_args(node, gm, scale_arg_idx, zp_arg_idx): scale_node, zp_node = node.args[scale_arg_idx], node.args[zp_arg_idx] assert isinstance(scale_node, Node) and isinstance( scale_node.target, str) assert isinstance(zp_node, Node) and isinstance(zp_node.target, str) scale_obj = getattr_from_fqn(gm, scale_node.target) zp_obj = getattr_from_fqn(gm, zp_node.target) return (scale_obj, zp_obj) if prev_node.op == "call_function": # quantize - read the args directly if prev_node.target == torch.quantize_per_tensor: return _get_scale_zp_from_function_args(prev_node, gm, 1, 2) elif prev_node.target in (toq.add, toq.add_relu, toq.mul, toq.mul_relu): return _get_scale_zp_from_function_args(prev_node, gm, 2, 3) return None # TODO(future PR): handle more functionals # TODO(future PR): handle functional ops which inherit qparams from input elif prev_node.op == "call_module": # get type of the module assert isinstance(prev_node.target, str) module_obj = getattr_from_fqn(gm, prev_node.target) if isinstance( module_obj, ( nnq.Linear, nnq.Conv1d, nnq.Conv2d, nniq.ConvReLU2d, nnq.Conv3d, nnq.BatchNorm2d, nnq.BatchNorm3d, nnq.ConvTranspose1d, nnq.ConvTranspose2d, nnq.ELU, nnq.GroupNorm, nnq.InstanceNorm1d, nnq.InstanceNorm2d, nnq.InstanceNorm3d, nnq.LayerNorm, nnq.Hardswish, nnq.LeakyReLU, nnq.ReLU6, nniq.BNReLU2d, nniq.BNReLU3d, nniq.ConvReLU1d, nniq.ConvReLU2d, nniq.ConvReLU3d, nniq.LinearReLU, ), ): return (module_obj.scale, module_obj.zero_point ) # type: ignore[return-value] is_known_fp32_or_int8_input_module = any( isinstance(module_obj, target_type) for target_type in MODS_IO_TYPE_FP32_OR_INT8 # type: ignore[arg-type] ) if is_known_fp32_or_int8_input_module: return get_node_input_qparams(prev_node, gm, node_type_to_io_type_map) return None
def end_node_matches_reversed_fusion( end_node: Node, reversed_fusion: NSFusionType, gm: GraphModule, ) -> bool: """ Returns true if a pattern ending with `end_node` matches the fusion pattern. """ cur_node = end_node for fusion_idx in range(len(reversed_fusion)): cur_fusion_el = reversed_fusion[fusion_idx] if cur_node.op == 'call_function': fusion_el_is_fun = (not isinstance(cur_fusion_el, str)) and \ (not isinstance(cur_fusion_el, type)) if fusion_el_is_fun: if cur_node.target != cur_fusion_el: return False if len(cur_node.args) > 0 and isinstance(cur_node.args[0], Node): cur_node = cur_node.args[0] else: return False else: return False elif cur_node.op == 'call_module': fusion_el_is_mod = isinstance(cur_fusion_el, type) if fusion_el_is_mod: assert isinstance(cur_node.target, str) target_mod = getattr_from_fqn(gm, cur_node.target) if not isinstance(cur_fusion_el, type): return False if not isinstance(target_mod, cur_fusion_el): return False if len(cur_node.args) > 0 and isinstance(cur_node.args[0], Node): cur_node = cur_node.args[0] else: return False else: return False elif cur_node.op == 'call_method': fusion_el_is_meth_with_second_arg = \ isinstance(cur_fusion_el, tuple) and len(cur_fusion_el) == 2 fusion_el_is_meth_without_args = isinstance(cur_fusion_el, str) if fusion_el_is_meth_without_args or fusion_el_is_meth_with_second_arg: if fusion_el_is_meth_without_args: if cur_node.target != cur_fusion_el: return False else: assert isinstance(cur_fusion_el, tuple) if cur_node.target != cur_fusion_el[0]: return False elif len(cur_node.args) < 2: return False elif cur_node.args[1] != cur_fusion_el[1]: return False if len(cur_node.args) > 0 and isinstance(cur_node.args[0], Node): cur_node = cur_node.args[0] else: return False else: return False else: return False return True