def _insert_quantize_per_tensor_node( prev_node_c: Node, node_a: Node, gm_b: GraphModule, graph_c: Graph, scale: Union[torch.Tensor, float], zero_point: Union[torch.Tensor, int], dtype_cast_name: str, ) -> Node: # copy scale scale_node_name = \ get_new_attr_name_with_prefix( node_a.name + '_input_scale_')(gm_b) setattr(gm_b, scale_node_name, scale) scale_node = graph_c.create_node('get_attr', scale_node_name, (), {}, scale_node_name) # copy zero_point zero_point_node_name = \ get_new_attr_name_with_prefix( node_a.name + '_input_zero_point_')(gm_b) setattr(gm_b, zero_point_node_name, zero_point) zero_point_node = graph_c.create_node('get_attr', zero_point_node_name, (), {}, zero_point_node_name) # create the quantize_per_tensor call return graph_c.create_node( 'call_function', torch.quantize_per_tensor, (prev_node_c, scale_node, zero_point_node, torch.quint8), {}, dtype_cast_name)
def _insert_dtype_cast_after_node( node_a: Node, node_c: Node, prev_node_c: Union[Node, List[Node]], gm_a: GraphModule, gm_b: GraphModule, graph_c: Graph, node_name_prefix: str, ) -> Union[Node, List[Node]]: """ Given a starting graph C (derived from graph B) of ... -> prev_node_c -> node_c -> ... And a corresponding related node_a, inserts the correct dtype cast node after prev_node_c to cast into the dtype expected by node_a, resulting in: dtype_cast / ... -> prev_node_c -> node_c -> ... For example, if node_c is an int8 op and node_a is an fp32 op, this function will insert a dequant. """ dtype_cast_op = None node_input_type_a = get_node_input_type(node_a, gm_a) node_input_type_c = get_node_input_type(node_c, gm_b) if node_input_type_a == NodeInputType.FP32 and node_input_type_c == NodeInputType.INT8: dtype_cast_op = torch.dequantize else: raise AssertionError( f"dtype cast from {node_input_type_c} to {node_input_type_a} needs to be implemented" ) if isinstance(prev_node_c, Node): new_dtype_cast_name = \ get_new_attr_name_with_prefix(node_name_prefix)(gm_b) return graph_c.create_node('call_function', dtype_cast_op, (prev_node_c, ), {}, new_dtype_cast_name) elif isinstance(prev_node_c, list): results = [] for prev_node_c_inner in prev_node_c: new_dtype_cast_name = \ get_new_attr_name_with_prefix(node_name_prefix)(gm_b) new_dtype_cast_node = graph_c.create_node('call_function', dtype_cast_op, (prev_node_c_inner, ), {}, new_dtype_cast_name) results.append(new_dtype_cast_node) return results else: raise AssertionError(f"type f{type(prev_node_c)} is not handled")
def _insert_logger_after_node( node: Node, gm: GraphModule, logger_cls: Callable, logger_node_name_suffix: str, ref_node_name: str, model_name: str, ref_name: str, results_type: str, index_within_arg: int, index_of_arg: int, ) -> Node: """ Given a starting graph of prev_node -> node -> next_node This function creates a new logger_cls obj and adds it after node, resulting in prev_node -> node -> logger_obj -> next_node """ # create new name logger_node_name = \ get_new_attr_name_with_prefix(node.name + logger_node_name_suffix)(gm) target_type = get_target_type_str(node, gm) # create the logger object logger_obj = logger_cls(ref_node_name, node.name, model_name, ref_name, target_type, results_type, index_within_arg, index_of_arg) # attach the logger object to the parent module setattr(gm, logger_node_name, logger_obj) logger_node = node.graph.create_node('call_module', logger_node_name, (node, ), {}) return logger_node
def _copy_node_from_a_to_c( node_a: Node, gm_a: GraphModule, gm_b: GraphModule, graph_c: Graph, ) -> Node: """ Simple copy of node_a to graph_c. """ if node_a.op == 'get_attr': node_a_copy_name = \ get_new_attr_name_with_prefix(node_a.name + '_shadow_copy_')(gm_b) node_a_obj = getattr_from_fqn(gm_a, node_a.target) # type: ignore[arg-type] if torch.is_tensor(node_a_obj): node_a_obj = node_a_obj.detach() setattr(gm_b, node_a_copy_name, node_a_obj) node_a_copy = graph_c.create_node(node_a.op, node_a_copy_name, (), {}, node_a_copy_name) return node_a_copy elif node_a.op == 'call_method': assert node_a.target in ('dequantize', 'to'), \ f"target {node_a.target} is not implemented" if node_a.target == 'dequantize': arg_copy = _copy_node_from_a_to_c( node_a.args[0], gm_a, gm_b, graph_c) # type: ignore[arg-type] node_a_copy_name = \ get_new_attr_name_with_prefix(node_a.name + '_shadow_copy_')(gm_b) node_a_copy = graph_c.create_node(node_a.op, node_a.target, (arg_copy, ), {}, node_a_copy_name) return node_a_copy else: # to arg_copy = _copy_node_from_a_to_c( node_a.args[0], gm_a, gm_b, graph_c) # type: ignore[arg-type] node_a_copy_name = \ get_new_attr_name_with_prefix(node_a.name + '_shadow_copy_')(gm_b) node_a_copy = graph_c.create_node(node_a.op, node_a.target, (arg_copy, node_a.args[1]), {}, node_a_copy_name) return node_a_copy else: raise AssertionError( f"handling of node with op {node_a.op} is not implemented")
def _insert_logger_after_node( node: Node, gm: GraphModule, logger_cls: Callable, logger_node_name_suffix: str, ref_node_name: str, model_name: str, ref_name: str, results_type: str, index_within_arg: int, ) -> Node: """ Given a starting graph of prev_node -> node -> next_node This function creates a new logger_cls obj and adds it after node, resulting in prev_node -> node -> logger_obj -> next_node """ # create new name logger_node_name = \ get_new_attr_name_with_prefix(node.name + logger_node_name_suffix)(gm) # print('node.name', node.name, 'suffix', logger_node_name_suffix, 'new name', logger_node_name) # create a string representation of the node's target type target_type = '' if node.op == 'call_function': target_type = str(node.target) elif node.op == 'call_module': assert isinstance(node.target, str) target_mod = getattr_from_fqn(gm, node.target) target_type = str(type(target_mod)) # create the logger object logger_obj = logger_cls( ref_node_name, node.name, model_name, ref_name, target_type, results_type, index_within_arg) # attach the logger object to the parent module setattr(gm, logger_node_name, logger_obj) logger_node = node.graph.create_node( 'call_module', logger_node_name, (node,), {}) return logger_node
def _insert_dtype_cast_after_node( node_a: Node, node_c: Node, prev_node_c: Node, gm_a: GraphModule, gm_b: GraphModule, node_name_prefix: str, ) -> Node: """ Given a starting graph C (derived from graph B) of ... -> prev_node_c -> node_c -> ... And a corresponding related node_a, inserts the correct dtype cast node after prev_node_c to cast into the dtype expected by node_a, resulting in: dtype_cast / ... -> prev_node_c -> node_c -> ... For example, if node_c is an int8 op and node_a is an fp32 op, this function will insert a dequant. """ dtype_cast_op = None node_io_type_a = get_node_io_type(node_a, gm_a) node_io_type_c = get_node_io_type(node_c, gm_b) if node_io_type_a == NodeIOType.FP32 and node_io_type_c == NodeIOType.INT8: dtype_cast_op = torch.dequantize else: raise AssertionError( f"dtype cast from {node_io_type_c} to {node_io_type_a} needs to be implemented" ) new_dtype_cast_name = \ get_new_attr_name_with_prefix(node_name_prefix)(gm_b) return prev_node_c.graph.create_node('call_function', dtype_cast_op, (prev_node_c, ), {}, new_dtype_cast_name)
def _insert_copy_of_node_a_after_input_node_c( input_node_c: Union[Node, List[Node]], input_node_c_2: Optional[Union[Node, List[Node]]], node_a: Node, gm_a: GraphModule, gm_b: GraphModule, node_name_prefix: str, ) -> Node: """ Assume that node_a from graph_a has args (input, (input2)?, arg1, ...), and kwargs {kw0: kwarg0, ...} Note: input2 is optional. If it equals to None, we assume that the op has a single non-param input. If it is specified, we assume that the op has two non-param inputs. Copies the underlying values of arg1..argn and kwarg0..kwargn into gm_b, and creates the corresponding nodes in graph_c. Note: observers are ignored, so if an arg is an observer we navigate up until we find a non-observer parent. If node_a is a call_module, points the module pointed to by node_a to gm_b. Creates the copy of node_a in graph_c, with input as the first arg, and all other args and kwargs pointing to the copies of the objects in gm_b created above. An example in pictures: graph A: ======== input -------------> node_a / / / (input_2)?----------/ / / / / weight -> weight_obs / / bias ---------------- graph C (derived from B): ========================= input_node_c --> node_a_copy / / / (input_node_c_2)? / / / / weight_copy ----/ / / bias_copy ------/ """ if isinstance(input_node_c, Node): graph_c = input_node_c.graph else: assert isinstance(input_node_c, list) graph_c = input_node_c[0].graph # generically handle all args and kwargs except for the input # Note: this hasn't been tested with many ops, logic may change. new_args: List[Any] = [] # assumes that the first arg is the input num_non_param_args = 1 if input_node_c_2 is None else 2 for node_a_arg in node_a.args[num_non_param_args:]: if isinstance(node_a_arg, Node): arg_a = return_first_non_observer_node(node_a_arg, gm_a) node_a_arg_copy = _copy_node_from_a_to_c(arg_a, gm_a, gm_b, graph_c) new_args.append(node_a_arg_copy) elif isinstance(node_a_arg, (int, float)): new_args.append(node_a_arg) elif isinstance(node_a_arg, (list, tuple)): for el in node_a_arg: assert not isinstance(el, Node), \ "handling of Node inside list is not implemented" new_args.append(node_a_arg) else: raise AssertionError( f"handling for arg of type {type(node_a_arg)} is not implemented" ) new_kwargs: Dict[str, Any] = {} for node_a_k, node_a_kwarg in node_a.kwargs.items(): if isinstance(node_a_kwarg, Node): kwarg_a = return_first_non_observer_node(node_a_kwarg, gm_a) node_a_kwarg_copy = _copy_node_from_a_to_c(kwarg_a, gm_a, gm_b, graph_c) new_kwargs[node_a_k] = node_a_kwarg_copy else: new_kwargs[node_a_k] = node_a_kwarg node_a_shadows_c_name = \ get_new_attr_name_with_prefix(node_name_prefix)(gm_b) if input_node_c_2: input_node_c_args = [input_node_c, input_node_c_2] else: input_node_c_args = [input_node_c] if node_a.op == 'call_module': # if target is a module, we point to the module from gm_b new_mod_copy_name = \ get_new_attr_name_with_prefix(node_name_prefix)(gm_b) # fetch the corresponding module from gm_a assert isinstance(node_a.target, str) mod_a = getattr_from_fqn(gm_a, node_a.target) setattr(gm_b, new_mod_copy_name, mod_a) node_a_shadows_c = graph_c.create_node(node_a.op, new_mod_copy_name, (*input_node_c_args, *new_args), new_kwargs, node_a_shadows_c_name) return node_a_shadows_c else: assert node_a.op in ('call_function', 'call_method') node_a_shadows_c = graph_c.create_node(node_a.op, node_a.target, (*input_node_c_args, *new_args), new_kwargs, node_a_shadows_c_name) return node_a_shadows_c
def _insert_dtype_cast_after_node( node_a: Node, node_c: Node, prev_node_c: Union[Node, List[Node]], gm_a: GraphModule, gm_b: GraphModule, graph_c: Graph, node_name_prefix: str, logger_cls: Callable, node_type_to_io_type_map: Dict[str, Set[NSNodeTargetType]], ) -> Union[Node, List[Node]]: """ Given a starting graph C (derived from graph B) of ... -> prev_node_c -> node_c -> ... And a corresponding related node_a, inserts the correct dtype cast node after prev_node_c to cast into the dtype expected by node_a, resulting in: dtype_cast / ... -> prev_node_c -> node_c -> ... For example, if node_c is an int8 op and node_a is an fp32 op, this function will insert a dequant. """ dtype_cast_op = None dtype_cast_mod_cls = None dtype_cast_scale = None dtype_cast_zero_point = None node_input_type_a, _node_output_type_a = \ get_node_first_input_and_output_type( node_a, gm_a, logger_cls, node_type_to_io_type_map) node_input_type_c, _node_output_type_c = \ get_node_first_input_and_output_type( node_c, gm_b, logger_cls, node_type_to_io_type_map) if ((node_input_type_a == NodeInputOrOutputType.FP32 and node_input_type_c == NodeInputOrOutputType.INT8) or (node_input_type_a == NodeInputOrOutputType.FP32 and node_input_type_c == NodeInputOrOutputType.FP16) or # TODO(future PR): determine the actual dtype of node_c, # the current code only works because dequantize works with # multiple input dtypes. (node_input_type_a == NodeInputOrOutputType.FP32 and node_input_type_c == NodeInputOrOutputType.FP32_OR_INT8)): dtype_cast_op = torch.dequantize elif (node_input_type_a == node_input_type_c and node_input_type_a != NodeInputOrOutputType.UNKNOWN): dtype_cast_mod_cls = torch.nn.Identity elif (node_input_type_a == NodeInputOrOutputType.INT8 and node_input_type_c == NodeInputOrOutputType.FP32): # int8 shadows fp32, the dtype cast needs to quantize to int8 # with the right qparams. node_a_input_qparams = get_node_input_qparams( node_a, gm_a, node_type_to_io_type_map) if node_a_input_qparams is not None: dtype_cast_op = torch.quantize_per_tensor # type: ignore[assignment] dtype_cast_scale, dtype_cast_zero_point = node_a_input_qparams else: raise AssertionError( f"dtype cast from {node_input_type_c} {node_c.format_node()} to " + f"{node_input_type_a} {node_a.format_node()} needs to be implemented" ) if isinstance(prev_node_c, Node): new_dtype_cast_name = \ get_new_attr_name_with_prefix(node_name_prefix)(gm_b) if dtype_cast_op: if dtype_cast_scale is not None and dtype_cast_zero_point is not None: return _insert_quantize_per_tensor_node( prev_node_c, node_a, gm_b, graph_c, dtype_cast_scale, dtype_cast_zero_point, new_dtype_cast_name) else: return graph_c.create_node('call_function', dtype_cast_op, (prev_node_c, ), {}, new_dtype_cast_name) else: assert dtype_cast_mod_cls dtype_cast_mod = dtype_cast_mod_cls() setattr(gm_b, new_dtype_cast_name, dtype_cast_mod) return graph_c.create_node('call_module', new_dtype_cast_name, (prev_node_c, ), {}, new_dtype_cast_name) elif isinstance(prev_node_c, list): results = [] for prev_node_c_inner in prev_node_c: new_dtype_cast_name = \ get_new_attr_name_with_prefix(node_name_prefix)(gm_b) if dtype_cast_op: # TODO(future PR): add handling for quantize_per_tensor new_dtype_cast_node = graph_c.create_node( 'call_function', dtype_cast_op, (prev_node_c_inner, ), {}, new_dtype_cast_name) results.append(new_dtype_cast_node) else: assert dtype_cast_mod_cls dtype_cast_mod = dtype_cast_mod_cls() setattr(gm_b, new_dtype_cast_name, dtype_cast_mod) new_dtype_cast_node = graph_c.create_node( 'call_module', new_dtype_cast_name, (prev_node_c_inner, ), {}, new_dtype_cast_name) results.append(new_dtype_cast_node) return results else: raise AssertionError(f"type f{type(prev_node_c)} is not handled")
def _insert_dtype_cast_after_node( node_a: Node, node_c: Node, prev_node_c: Union[Node, List[Node]], gm_a: GraphModule, gm_b: GraphModule, graph_c: Graph, node_name_prefix: str, logger_cls: Callable, ) -> Union[Node, List[Node]]: """ Given a starting graph C (derived from graph B) of ... -> prev_node_c -> node_c -> ... And a corresponding related node_a, inserts the correct dtype cast node after prev_node_c to cast into the dtype expected by node_a, resulting in: dtype_cast / ... -> prev_node_c -> node_c -> ... For example, if node_c is an int8 op and node_a is an fp32 op, this function will insert a dequant. """ dtype_cast_op = None dtype_cast_mod_cls = None node_input_type_a, _node_output_type_a = \ get_node_first_input_and_output_type(node_a, gm_a, logger_cls) node_input_type_c, _node_output_type_c = \ get_node_first_input_and_output_type(node_c, gm_b, logger_cls) if ((node_input_type_a == NodeInputOrOutputType.FP32 and node_input_type_c == NodeInputOrOutputType.INT8) or (node_input_type_a == NodeInputOrOutputType.FP32 and node_input_type_c == NodeInputOrOutputType.FP16)): dtype_cast_op = torch.dequantize elif (node_input_type_a == NodeInputOrOutputType.FP32 and node_input_type_c == NodeInputOrOutputType.FP32): dtype_cast_mod_cls = torch.nn.Identity elif (node_input_type_a == NodeInputOrOutputType.INT8 and node_input_type_c == NodeInputOrOutputType.INT8): dtype_cast_mod_cls = torch.nn.Identity elif (node_input_type_a == NodeInputOrOutputType.FP32_OR_INT8 and node_input_type_c == NodeInputOrOutputType.FP32_OR_INT8): dtype_cast_mod_cls = torch.nn.Identity else: raise AssertionError( f"dtype cast from {node_input_type_c} {node_c.format_node()} to " + f"{node_input_type_a} {node_a.format_node()} needs to be implemented" ) if isinstance(prev_node_c, Node): new_dtype_cast_name = \ get_new_attr_name_with_prefix(node_name_prefix)(gm_b) if dtype_cast_op: return graph_c.create_node('call_function', dtype_cast_op, (prev_node_c, ), {}, new_dtype_cast_name) else: assert dtype_cast_mod_cls dtype_cast_mod = dtype_cast_mod_cls() setattr(gm_b, new_dtype_cast_name, dtype_cast_mod) return graph_c.create_node('call_module', new_dtype_cast_name, (prev_node_c, ), {}, new_dtype_cast_name) elif isinstance(prev_node_c, list): results = [] for prev_node_c_inner in prev_node_c: new_dtype_cast_name = \ get_new_attr_name_with_prefix(node_name_prefix)(gm_b) if dtype_cast_op: new_dtype_cast_node = graph_c.create_node( 'call_function', dtype_cast_op, (prev_node_c_inner, ), {}, new_dtype_cast_name) results.append(new_dtype_cast_node) else: assert dtype_cast_mod_cls dtype_cast_mod = dtype_cast_mod_cls() setattr(gm_b, new_dtype_cast_name, dtype_cast_mod) new_dtype_cast_node = graph_c.create_node( 'call_module', new_dtype_cast_name, (prev_node_c, ), {}, new_dtype_cast_name) results.append(new_dtype_cast_node) return results else: raise AssertionError(f"type f{type(prev_node_c)} is not handled")
def _insert_copy_of_node_a_after_input_node_c( input_node_c: Node, node_a: Node, gm_a: GraphModule, gm_b: GraphModule, node_name_prefix: str, ) -> Node: """ Assume that node_a from graph_a has args (input, arg1, ...), and kwargs {kw0: kwarg0, ...} Copies the underlying values of arg1..argn and kwarg0..kwargn into gm_b, and creates the corresponding nodes in graph_c. Note: observers are ignored, so if an arg is an observer we navigate up until we find a non-observer parent. If node_a is a call_module, points the module pointed to by node_a to gm_b. Creates the copy of node_a in graph_c, with input as the first arg, and all other args and kwargs pointing to the copies of the objects in gm_b created above. An example in pictures: graph A: ======== input -------------> node_a / / weight -> weight_obs / / bias ---------------- graph C (derived from B): ========================= input_node_c --> node_a_copy / / weight_copy ----/ / / bias_copy ------/ """ graph_c = input_node_c.graph # generically handle all args and kwargs except for the input # Note: this hasn't been tested with many ops, logic may change. new_args = [] # assumes that the first arg is the input for node_a_arg in node_a.args[1:]: if isinstance(node_a_arg, Node): arg_a = return_first_non_observer_node(node_a_arg, gm_a) arg_a_copy_name = \ get_new_attr_name_with_prefix(arg_a.name + '_shadow_copy_')(gm_b) # type: ignore arg_a_obj = getattr_from_fqn(gm_a, arg_a.target) # type: ignore setattr(gm_b, arg_a_copy_name, arg_a_obj.detach()) node_a_arg_copy = graph_c.create_node('get_attr', arg_a_copy_name, (), {}, arg_a_copy_name) new_args.append(node_a_arg_copy) else: raise AssertionError( f"handling for arg of type {type(node_a_arg)} is not implemented" ) new_kwargs: Dict[str, Any] = {} for node_a_k, node_a_kwarg in node_a.kwargs.items(): if isinstance(node_a_kwarg, Node): kwarg_a_copy_name = \ get_new_attr_name_with_prefix(node_a_kwarg.name + '_shadow_copy_')(gm_b) # type: ignore kwarg_a_obj = getattr_from_fqn(gm_a, node_a_kwarg.target) # type: ignore setattr(gm_b, kwarg_a_copy_name, kwarg_a_obj.detach()) node_a_kwarg_copy = graph_c.create_node('get_attr', kwarg_a_copy_name, (), {}, kwarg_a_copy_name) new_kwargs[node_a_k] = node_a_kwarg_copy else: new_kwargs[node_a_k] = node_a_kwarg node_a_shadows_c_name = \ get_new_attr_name_with_prefix(node_name_prefix)(gm_b) if node_a.op == 'call_module': # if target is a module, we point to the module from gm_b new_mod_copy_name = \ get_new_attr_name_with_prefix(node_name_prefix)(gm_b) # fetch the corresponding module from gm_a assert isinstance(node_a.target, str) mod_a = getattr_from_fqn(gm_a, node_a.target) setattr(gm_b, new_mod_copy_name, mod_a) node_a_shadows_c = graph_c.create_node( node_a.op, new_mod_copy_name, (input_node_c, *new_args), new_kwargs, node_a_shadows_c_name) # type: ignore return node_a_shadows_c else: assert node_a.op == 'call_function' node_a_shadows_c = graph_c.create_node( node_a.op, node_a.target, (input_node_c, *new_args), new_kwargs, node_a_shadows_c_name) # type: ignore return node_a_shadows_c