def return_first_non_observer_node( node: Node, gm: GraphModule, ) -> Node: """ If node is not an observer, returns it. If node is an observer, navigates up the graph and returns the first parent which is not an observer. For example, graph: (node_non_obs), node = node_non_obs : returns node_non_obs graph: (node_non_obs -> obs0), node = obs0 : returns node_non_obs graph: (node_non_obs -> obs0 -> fq0), node = fq0 : returns node_non_obs """ if node.op == 'call_module': node_obj = getattr_from_fqn(gm, node.target) # type: ignore[arg-type] if is_activation_post_process(node_obj): assert len(node.args) == 1 assert isinstance(node.args[0], Node) node = node.args[0] # code duplication intended, not worth refactoring assert isinstance(node.target, str) node_obj = getattr_from_fqn(gm, node.target) if is_activation_post_process(node_obj): assert len(node.args) == 1 assert isinstance(node.args[0], Node) node = node.args[0] return node
def remove_qconfig_observer_fx(model): # remove activation post process act_post_process_removed_graph = Graph() env = {} # type: Dict[str, Any] modules = dict(model.named_modules()) def load_arg(a): return map_arg(a, lambda node: env[node.name]) for node in model.graph.nodes: if node.op == "output": act_post_process_removed_graph.output( map_arg(node.args[0], load_arg)) continue if node.op == "call_module" and is_activation_post_process( modules[node.target]): # remove activation post process node env[node.name] = env[node.args[0].name] else: env[node.name] = act_post_process_removed_graph.node_copy( node, load_arg) _remove_qconfig(model) model = GraphModule(model, act_post_process_removed_graph) return model
def remove_observers_add_loggers( gm: GraphModule, node_to_instrument_to_ref_node_name: Dict[Node, Optional[str]], logger_cls: Callable, model_name: str, ) -> GraphModule: """ Takes the graph of gm, removes all observers, adds loggers to the output of each node in nodes_to_instrument. Returns a GraphModule with the new graph. """ new_graph = Graph() env: Dict[str, Any] = {} modules = dict(gm.named_modules()) def load_arg(a): return map_arg(a, lambda node: env[node.name]) for node in gm.graph.nodes: if node.op == 'output': new_graph.output(map_arg(node.args[0], load_arg)) continue if node.op == 'call_module' and is_activation_post_process( modules[node.target]): # remove activation post process node env[node.name] = env[node.args[0].name] elif node in node_to_instrument_to_ref_node_name: other_node_name = node_to_instrument_to_ref_node_name[node] # ensure env is populated with base node env[node.name] = new_graph.node_copy(node, load_arg) # add the logger after the base node env[node.name] = _insert_logger_after_node(env[node.name], gm, logger_cls, '_ns_logger_', model_name, other_node_name) else: env[node.name] = new_graph.node_copy(node, load_arg) new_gm = GraphModule(gm, new_graph) return new_gm
def test_remove_qconfig_observer_fx(self): r"""Remove activation_post_process node from fx prepred model""" float_model = SingleLayerLinearModel() float_model.eval() qengine = torch.backends.quantized.engine qconfig = get_default_qconfig(qengine) qconfig_dict = {"": qconfig} prepared_model = prepare_fx(float_model, qconfig_dict) prepared_float_model = copy.deepcopy(prepared_model) prepared_float_model.eval() model = remove_qconfig_observer_fx(prepared_float_model) modules = dict(model.named_modules()) for node in model.graph.nodes: if node.op == "call_module": self.assertFalse(is_activation_post_process(modules[node.target]))
def remove_observers_add_loggers( gm: GraphModule, node_to_instrument_inputs_to_ref_node_name: Dict[Node, str], node_to_instrument_outputs_to_ref_node_name: Dict[Node, str], logger_cls: Callable, model_name: str, ) -> GraphModule: """ Takes the graph of gm, removes all observers, adds loggers to the output of each node in nodes_to_instrument. Returns a GraphModule with the new graph. """ new_graph = Graph() env: Dict[str, Any] = {} modules = dict(gm.named_modules()) def load_arg(a): return map_arg(a, lambda node: env[node.name]) for node in gm.graph.nodes: if node.op == 'output': new_graph.output(map_arg(node.args[0], load_arg)) continue if node.op == 'call_module' and is_activation_post_process( modules[node.target]): # remove activation post process node env[node.name] = env[node.args[0].name] elif ((node in node_to_instrument_inputs_to_ref_node_name) or (node in node_to_instrument_outputs_to_ref_node_name)): if node in node_to_instrument_inputs_to_ref_node_name: ref_name = node_to_instrument_inputs_to_ref_node_name[node] if type(node.args[0]) == Node: # create a single input logger prev_node = env[node.args[0].name] env[node.args[0].name] = _insert_logger_after_node( prev_node, gm, logger_cls, '_ns_logger_', node.name, model_name, ref_name, NSSingleResultValuesType.NODE_INPUT.value, index_within_arg=0) elif type(node.args[0] ) == torch.fx.immutable_collections.immutable_list: # create N input loggers, one for each node for arg_idx, arg in enumerate(node.args[0]): prev_node = env[arg.name] env[prev_node.name] = _insert_logger_after_node( prev_node, gm, logger_cls, '_ns_logger_', node.name, model_name, ref_name, NSSingleResultValuesType.NODE_INPUT.value, index_within_arg=arg_idx) else: raise AssertionError( f"type {type(node.args[0])} is not handled yet") # ensure env is populated with base node # Note: runs for both inputs and outputs env[node.name] = new_graph.node_copy(node, load_arg) if node in node_to_instrument_outputs_to_ref_node_name: ref_name = node_to_instrument_outputs_to_ref_node_name[node] # add the logger after the base node env[node.name] = _insert_logger_after_node( env[node.name], gm, logger_cls, '_ns_logger_', node.name, model_name, ref_name, NSSingleResultValuesType.NODE_OUTPUT.value, index_within_arg=0) else: env[node.name] = new_graph.node_copy(node, load_arg) new_gm = GraphModule(gm, new_graph) return new_gm
def create_a_shadows_b( name_a: str, gm_a: GraphModule, name_b: str, gm_b: GraphModule, matched_subgraph_pairs: Dict[str, Tuple[NSSubgraph, NSSubgraph]], logger_cls: Callable, should_log_inputs: bool, ) -> GraphModule: """ Creates a new GraphModule consisting of the graph of C, with the meaningful nodes of A shadowing the corresponding nodes of B. For example, Graph A: a0 -> op0_fp32 -> a1 -> op1_fp32 -> a2 Graph B: b0 -> op0_int8 -> b1 -> op1_int8 -> b2 matched_node_pairs: {'op0': (op0_fp32, op0_int8), 'op1': (op1_fp32, op1_int8)} Graph C (A shadows B): / dequant0 -> op0_fp32 -> logger_a_0 / dequant_1 -> op1_fp32 -> logger_a_1 / / b0 -------------> op0_int8 -> logger_b_0 --------------> op1_int8 -> logger_b_1 In a nutshell, this function does the following for each node pair: * copies the necessary attributes and modules from gm_a to gm_b, keeping names unique * adds a dtype cast op (dequant, quant, etc) * adds a copy of node_a in gm_b's graph * adds loggers to the outputs of node_a and node_b """ # graph_c is the graph created from copying the nodes of graph_b and inserting # the shadows with the nodes copied from graph_a graph_c = Graph() env_c: Dict[str, Any] = {} modules = dict(gm_b.named_modules()) def load_arg(a): return map_arg(a, lambda node: env_c[node.name]) start_node_b_to_matched_subgraph_a_and_name = {} end_node_b_to_matched_subgraph_a_and_name = {} for match_name, match in matched_subgraph_pairs.items(): subgraph_a, subgraph_b = match start_node_b_to_matched_subgraph_a_and_name[subgraph_b.start_node] = \ (subgraph_a, match_name) end_node_b_to_matched_subgraph_a_and_name[subgraph_b.end_node] = \ (subgraph_a, match_name) for node_b in gm_b.graph.nodes: if node_b.op == 'output': graph_c.output(map_arg(node_b.args[0], load_arg)) continue # calculate the flags to determine what to do with this node node_b_is_observer = \ node_b.op == 'call_module' and is_activation_post_process(modules[node_b.target]) node_b_is_start_node = node_b in start_node_b_to_matched_subgraph_a_and_name node_b_is_end_node = node_b in end_node_b_to_matched_subgraph_a_and_name if node_b_is_observer: # remove activation post process node env_c[node_b.name] = env_c[node_b.args[0].name] # type: ignore elif (node_b_is_start_node or node_b_is_end_node): if node_b_is_start_node: subgraph_a, ref_name = \ start_node_b_to_matched_subgraph_a_and_name[node_b] else: assert node_b_is_end_node subgraph_a, ref_name = \ end_node_b_to_matched_subgraph_a_and_name[node_b] # For both start_node and end_node verify that we know how to do # the dtype cast. If we do not, skip. node_input_type_a, node_output_type_a = \ get_node_first_input_and_output_type(subgraph_a.start_node, gm_a, logger_cls) node_input_type_b, node_output_type_b = \ get_node_first_input_and_output_type(node_b, gm_b, logger_cls) node_io_types_known_a_and_b = ( node_input_type_a != NodeInputOrOutputType.UNKNOWN and node_output_type_a != NodeInputOrOutputType.UNKNOWN and node_input_type_b != NodeInputOrOutputType.UNKNOWN and node_output_type_b != NodeInputOrOutputType.UNKNOWN) if not node_io_types_known_a_and_b: print( f'skipping shadow loggers for node_b: {get_target_type_str(node_b, gm_b)}' + f', start_node_a: {get_target_type_str(subgraph_a.start_node, gm_a)}' + ', unknown dtype cast') env_c[node_b.name] = graph_c.node_copy(node_b, load_arg) continue if node_b_is_start_node: # if necessary, log the input of node_c if should_log_inputs: if isinstance(node_b.args[0], Node): prev_node_c = env_c[node_b.args[0].name] env_c[prev_node_c.name] = _insert_logger_after_node( prev_node_c, gm_b, logger_cls, '_ns_logger_b_inp_', node_b.name, name_b, ref_name, NSSingleResultValuesType.NODE_INPUT.value, index_within_arg=0) elif isinstance(node_b.args[0], list): # first, save the prev_node instances, because they # will be overwritten in the env after the first logger # is added prev_node_c_list = [ env_c[arg.name] for arg in node_b.args[0] ] for arg_idx, arg in enumerate(node_b.args[0]): prev_node_c = prev_node_c_list[arg_idx] env_c[ prev_node_c.name] = _insert_logger_after_node( prev_node_c, gm_b, logger_cls, '_ns_logger_b_inp_', node_b.name, name_b, ref_name, NSSingleResultValuesType.NODE_INPUT.value, index_within_arg=arg_idx) else: # logging of inputs which are not lists is not supported yet raise AssertionError( f"type {type(node_b.args[0])} is not handled yet") # subgraph so far: # # (prev_node_c)+ -> (logger_c_input)? # Note: this if statement is always True, spelling it out to clarify code # intent. if node_b_is_start_node or node_b_is_end_node: # ensure env_c is populated with base node env_c[node_b.name] = graph_c.node_copy(node_b, load_arg) node_c = env_c[node_b.name] # after this point, # # node_a is the original node from graph_a, with parent module gm_a # node_b is the original node from graph_b, with parent module gm_b # node_c is the copy of node_b in graph_c # # subgraph so far: # # (prev_node_c)+ -> (logger_c_input)? -> node_start_c if node_b_is_start_node: # cast dtype from the dtype of node_c's input to the dtype of # node_a's input (dequant, etc) prev_node_c = node_c.args[0] if should_log_inputs: # skip the input logger when inserting a dtype cast if isinstance(prev_node_c, Node): prev_node_c = prev_node_c.args[0] elif isinstance(prev_node_c, list): prev_node_c = [arg.args[0] for arg in prev_node_c] dtype_cast_node = _insert_dtype_cast_after_node( subgraph_a.start_node, node_c, prev_node_c, gm_a, gm_b, graph_c, node_b.name + '_dtype_cast_', logger_cls) # note: not inserting to env_c because all nodes which use the dtype # casts are copied from graph_a # # subgraph so far: # # (dtype_cast_node)+ # / # (prev_node_c)+ -> (logger_c_input)? -> node_start_c # if input logging is enabled, log the input to the subgraph if should_log_inputs: # TODO: explain this ref_node_name = '' if isinstance(dtype_cast_node, Node): dtype_cast_node = _insert_logger_after_node( dtype_cast_node, gm_b, logger_cls, '_ns_logger_a_inp_', ref_node_name, name_a, ref_name, NSSingleResultValuesType.NODE_INPUT.value, index_within_arg=0) input_logger: Union[Node, List[Node]] = dtype_cast_node else: assert isinstance(dtype_cast_node, list) new_loggers = [] for dtype_cast_idx, dtype_cast_node_inner in enumerate( dtype_cast_node): dtype_cast_logger = _insert_logger_after_node( dtype_cast_node_inner, gm_b, logger_cls, '_ns_logger_a_inp_', ref_node_name, name_a, ref_name, NSSingleResultValuesType.NODE_INPUT.value, index_within_arg=dtype_cast_idx) new_loggers.append(dtype_cast_logger) dtype_cast_node = new_loggers input_logger = dtype_cast_node # subgraph so far: # # (dtype_cast_node)+ -> (logger_a_input)? # / # prev_node_c -> (logger_c_input)? -> node_start_c # hook up the new mod_a copy to be in the graph, receiving the # same inputs as mod_b does, with dtype cast to match a # Some ops, such as LSTMs, have two non-param inputs. If we have # such an op, pass the second param as well. Note: dtype casting # for the second param is not implemented yet, it can be added # later if there is a use case. node_c_second_non_param_arg = None num_non_param_args_node_a = get_number_of_non_param_args( subgraph_a.start_node, gm_a) if num_non_param_args_node_a == 2: node_c_second_non_param_arg = node_c.args[1] node_a_shadows_c = _insert_copy_of_subgraph_a_after_input_node_c( dtype_cast_node, node_c_second_non_param_arg, subgraph_a, gm_a, gm_b, node_c.name + '_shadow_copy_') env_c[node_a_shadows_c.name] = node_a_shadows_c # subgraph so far: # # dtype_cast_node -> (logger_a_input)? -> subgraph_a_copy(args/kwargs not shown) # / # (prev_node_c)+ -> (logger_c_input)? -> node_start_c if should_log_inputs: # When we created the input logger, we left the ref_node_name # as an empty string, because the subgraph copy did not exist # yet. Now that the subgraph copy exists, we modify this name # to its true value. # Note: the alternative to this is to create the input logger # after creating the subgraph, which is slightly more # complicated. This is the lesser of two evils. # input_logger = env_c[dtype_cast_node.name] # Find the first node in the subgraph cur_node = node_a_shadows_c while cur_node.args[0] != input_logger: cur_node = cur_node.args[0] # type: ignore if isinstance(input_logger, Node): input_logger_mod = getattr(gm_b, input_logger.name) input_logger_mod.ref_node_name = cur_node.name else: assert isinstance(input_logger, list) for input_logger_inner in input_logger: input_logger_mod = getattr(gm_b, input_logger_inner.name) input_logger_mod.ref_node_name = cur_node.name # hook up a logger to the mod_a copy env_c[node_a_shadows_c.name] = _insert_logger_after_node( env_c[node_a_shadows_c.name], gm_b, logger_cls, '_ns_logger_a_', node_a_shadows_c.name, name_a, ref_name, NSSingleResultValuesType.NODE_OUTPUT.value, index_within_arg=0) # subgraph so far: # # dtype_cast_node -> (logger_a_input)? -> subgraph_a_copy -> logger_a # / # (prev_node_c)+ -> (logger_c_input)? -> node_start_c if node_b_is_end_node: # hook up a logger to the mod_b copy env_c[node_b.name] = _insert_logger_after_node( env_c[node_b.name], gm_b, logger_cls, '_ns_logger_b_', node_b.name, name_b, ref_name, NSSingleResultValuesType.NODE_OUTPUT.value, index_within_arg=0) # subgraph so far: # # dtype_cast_node -> (logger_a_input)? -> subgraph_a_copy -> logger_a # / # (prev_node_c+) -> (logger_c_input)? -> node_start_c -> ... -> node_end_c -> logger_c # # Note: node_start_c may be the same node as node_end_c, or they # may have nodes inbetween. else: env_c[node_b.name] = graph_c.node_copy(node_b, load_arg) gm_c = GraphModule(gm_b, graph_c) return gm_c
def remove_observers_add_loggers( gm: GraphModule, node_to_instrument_inputs_to_ref_node_name: Dict[Node, str], node_to_instrument_outputs_to_ref_node_name: Dict[Node, str], logger_cls: Callable, model_name: str, ) -> GraphModule: """ Takes the graph of gm, removes all observers, adds loggers to the output of each node in nodes_to_instrument. Returns a GraphModule with the new graph. """ new_graph = Graph() env: Dict[str, Any] = {} modules = dict(gm.named_modules()) def load_arg(a): return map_arg(a, lambda node: env[node.name]) for node in gm.graph.nodes: if node.op == 'output': new_graph.output(map_arg(node.args[0], load_arg)) continue if node.op == 'call_module' and is_activation_post_process(modules[node.target]): # remove activation post process node env[node.name] = env[node.args[0].name] elif ( (node in node_to_instrument_inputs_to_ref_node_name) or (node in node_to_instrument_outputs_to_ref_node_name) ): if node in node_to_instrument_inputs_to_ref_node_name: ref_name = node_to_instrument_inputs_to_ref_node_name[node] # Ops such add and mul are special because either # one or two of the first two arguments can be tensors, # and if one argument is a tensor it can be first or # second (x + 1 versus 1 + x). arg_indices_to_log = get_arg_indices_of_inputs_to_log(node) for node_arg_idx in arg_indices_to_log: node_arg = node.args[node_arg_idx] if type(node_arg) == Node: # create a single input logger prev_node = env[node_arg.name] env[node_arg.name] = _insert_logger_after_node( prev_node, gm, logger_cls, '_ns_logger_', node.name, model_name, ref_name, NSSingleResultValuesType.NODE_INPUT.value, index_within_arg=0, index_of_arg=node_arg_idx) elif type(node_arg) == torch.fx.immutable_collections.immutable_list: # create N input loggers, one for each node for arg_idx, arg in enumerate(node_arg): prev_node = env[arg.name] env[prev_node.name] = _insert_logger_after_node( prev_node, gm, logger_cls, '_ns_logger_', node.name, model_name, ref_name, NSSingleResultValuesType.NODE_INPUT.value, index_within_arg=arg_idx, index_of_arg=node_arg_idx) else: pass # ensure env is populated with base node # Note: runs for both inputs and outputs env[node.name] = new_graph.node_copy(node, load_arg) if node in node_to_instrument_outputs_to_ref_node_name: ref_name = node_to_instrument_outputs_to_ref_node_name[node] # add the logger after the base node env[node.name] = _insert_logger_after_node( env[node.name], gm, logger_cls, '_ns_logger_', node.name, model_name, ref_name, NSSingleResultValuesType.NODE_OUTPUT.value, index_within_arg=0, index_of_arg=0) else: env[node.name] = new_graph.node_copy(node, load_arg) new_gm = GraphModule(gm, new_graph) return new_gm
def create_a_shadows_b( name_a: str, gm_a: GraphModule, name_b: str, gm_b: GraphModule, matched_subgraph_pairs: Dict[str, Tuple[Tuple[Node, Node], Tuple[Node, Node]]], logger_cls: Callable, ) -> GraphModule: """ Creates a new GraphModule consisting of the graph of C, with the meaningful nodes of A shadowing the corresponding nodes of B. For example, Graph A: a0 -> op0_fp32 -> a1 -> op1_fp32 -> a2 Graph B: b0 -> op0_int8 -> b1 -> op1_int8 -> b2 matched_node_pairs: {'op0': (op0_fp32, op0_int8), 'op1': (op1_fp32, op1_int8)} Graph C (A shadows B): / dequant0 -> op0_fp32 -> logger_a_0 / dequant_1 -> op1_fp32 -> logger_a_1 / / b0 -------------> op0_int8 -> logger_b_0 --------------> op1_int8 -> logger_b_1 In a nutshell, this function does the following for each node pair: * copies the necessary attributes and modules from gm_a to gm_b, keeping names unique * adds a dtype cast op (dequant, quant, etc) * adds a copy of node_a in gm_b's graph * adds loggers to the outputs of node_a and node_b """ # graph_c is the graph created from copying the nodes of graph_b and inserting # the shadows with the nodes copied from graph_a graph_c = Graph() env_c: Dict[str, Any] = {} modules = dict(gm_b.named_modules()) def load_arg(a): return map_arg(a, lambda node: env_c[node.name]) node_b_to_matched_subgraph_a = {} for match_name, match in matched_subgraph_pairs.items(): (node_start_a, node_end_a), (node_start_b, node_end_b) = match assert node_start_b is node_end_b, \ "Shadowing subgraphs of B with multiple nodes is not yet handled." node_b_to_matched_subgraph_a[node_end_b] = (node_start_a, node_end_a) for node_b in gm_b.graph.nodes: if node_b.op == 'output': graph_c.output(map_arg(node_b.args[0], load_arg)) continue if node_b.op == 'call_module' and is_activation_post_process( modules[node_b.target]): # remove activation post process node env_c[node_b.name] = env_c[node_b.args[0].name] # type: ignore elif node_b in node_b_to_matched_subgraph_a: node_start_a, node_end_a = node_b_to_matched_subgraph_a[node_b] if False: print('b') print_node(node_b) print('a') print_node(node_start_a) print_node(node_end_a) # ensure env_c is populated with base node env_c[node_b.name] = graph_c.node_copy(node_b, load_arg) node_c = env_c[node_b.name] # after this point, # # node_a is the original node from graph_a, with parent module gm_a # node_b is the original node from graph_b, with parent module gm_b # node_c is the copy of node_b in graph_c # # subgraph so far: # # prev_node_c -> node_c # cast dtype from the dtype of node_c's input to the dtype of # node_a's input (dequant, etc) dtype_cast_node = _insert_dtype_cast_after_node( node_start_a, node_c, node_c.args[0], gm_a, gm_b, graph_c, node_b.name + '_dtype_cast_') env_c[dtype_cast_node.name] = dtype_cast_node # subgraph so far: # # dtype_cast_node # / # prev_node_c -> node_c # hook up the new mod_a copy to be in the graph, receiving the # same inputs as mod_b does, with dtype cast to match a node_a_shadows_c = _insert_copy_of_subgraph_a_after_input_node_c( env_c[dtype_cast_node.name], node_start_a, node_end_a, gm_a, gm_b, node_c.name + '_shadow_copy_') env_c[node_a_shadows_c.name] = node_a_shadows_c # subgraph so far: # # dtype_cast_node --> subgraph_a_copy(args/kwargs not shown) # / # prev_node_c -> node_c # hook up a logger to the mod_b copy env_c[node_b.name] = _insert_logger_after_node( env_c[node_b.name], gm_b, logger_cls, '_ns_logger_b_', name_b) # subgraph so far: # # dtype_cast_node --> subgraph_a_copy # / # prev_node_c -> node_c --> logger_c # hook up a logger to the mod_a copy # Note: we pass node_b.name to this logger, for easy matching later env_c[node_a_shadows_c.name] = _insert_logger_after_node( env_c[node_a_shadows_c.name], gm_b, logger_cls, '_ns_logger_a_', name_a, node_b.name) # subgraph so far: # # dtype_cast_node --> subgraph_a_copy --> logger_a # / # prev_node_c -> node_c --> logger_c else: env_c[node_b.name] = graph_c.node_copy(node_b, load_arg) gm_c = GraphModule(gm_b, graph_c) return gm_c
def create_a_shadows_b( name_a: str, gm_a: GraphModule, name_b: str, gm_b: GraphModule, matched_subgraph_pairs: Dict[str, Tuple[Tuple[Node, Node], Tuple[Node, Node]]], logger_cls: Callable, should_log_inputs: bool, ) -> GraphModule: """ Creates a new GraphModule consisting of the graph of C, with the meaningful nodes of A shadowing the corresponding nodes of B. For example, Graph A: a0 -> op0_fp32 -> a1 -> op1_fp32 -> a2 Graph B: b0 -> op0_int8 -> b1 -> op1_int8 -> b2 matched_node_pairs: {'op0': (op0_fp32, op0_int8), 'op1': (op1_fp32, op1_int8)} Graph C (A shadows B): / dequant0 -> op0_fp32 -> logger_a_0 / dequant_1 -> op1_fp32 -> logger_a_1 / / b0 -------------> op0_int8 -> logger_b_0 --------------> op1_int8 -> logger_b_1 In a nutshell, this function does the following for each node pair: * copies the necessary attributes and modules from gm_a to gm_b, keeping names unique * adds a dtype cast op (dequant, quant, etc) * adds a copy of node_a in gm_b's graph * adds loggers to the outputs of node_a and node_b """ # graph_c is the graph created from copying the nodes of graph_b and inserting # the shadows with the nodes copied from graph_a graph_c = Graph() env_c: Dict[str, Any] = {} modules = dict(gm_b.named_modules()) def load_arg(a): return map_arg(a, lambda node: env_c[node.name]) node_b_to_matched_subgraph_a_and_name = {} for match_name, match in matched_subgraph_pairs.items(): (node_start_a, node_end_a), (node_start_b, node_end_b) = match assert node_start_b is node_end_b, \ "Shadowing subgraphs of B with multiple nodes is not yet handled." node_b_to_matched_subgraph_a_and_name[node_end_b] = \ ((node_start_a, node_end_a), match_name) for node_b in gm_b.graph.nodes: if node_b.op == 'output': graph_c.output(map_arg(node_b.args[0], load_arg)) continue if node_b.op == 'call_module' and is_activation_post_process( modules[node_b.target]): # remove activation post process node env_c[node_b.name] = env_c[node_b.args[0].name] # type: ignore elif node_b in node_b_to_matched_subgraph_a_and_name: (node_start_a, node_end_a), ref_name = \ node_b_to_matched_subgraph_a_and_name[node_b] if False: print('b') print_node(node_b) print('a') print_node(node_start_a) print_node(node_end_a) # if necessary, log the input of node_c if should_log_inputs: if isinstance(node_b.args[0], Node): prev_node_c = env_c[node_b.args[0].name] env_c[prev_node_c.name] = _insert_logger_after_node( prev_node_c, gm_b, logger_cls, '_ns_logger_b_inp_', node_b.name, name_b, ref_name, NSSingleResultValuesType.NODE_INPUT.value, index_within_arg=0) elif isinstance(node_b.args[0], list): # first, save the prev_node instances, because they # will be overwritten in the env after the first logger # is added prev_node_c_list = [ env_c[arg.name] for arg in node_b.args[0] ] for arg_idx, arg in enumerate(node_b.args[0]): prev_node_c = prev_node_c_list[arg_idx] env_c[prev_node_c.name] = _insert_logger_after_node( prev_node_c, gm_b, logger_cls, '_ns_logger_b_inp_', node_b.name, name_b, ref_name, NSSingleResultValuesType.NODE_INPUT.value, index_within_arg=arg_idx) else: # logging of inputs which are not lists is not supported yet raise AssertionError( f"type {type(node_b.args[0])} is not handled yet") # subgraph so far: # # (prev_node_c)+ -> (logger_c_input)? # ensure env_c is populated with base node env_c[node_b.name] = graph_c.node_copy(node_b, load_arg) node_c = env_c[node_b.name] # after this point, # # node_a is the original node from graph_a, with parent module gm_a # node_b is the original node from graph_b, with parent module gm_b # node_c is the copy of node_b in graph_c # # subgraph so far: # # (prev_node_c)+ -> (logger_c_input)? -> node_c # cast dtype from the dtype of node_c's input to the dtype of # node_a's input (dequant, etc) dtype_cast_node = _insert_dtype_cast_after_node( node_start_a, node_c, node_c.args[0], gm_a, gm_b, graph_c, node_b.name + '_dtype_cast_') # note: not inserting to env_c because all nodes which use the dtype # casts are copied from graph_a # # subgraph so far: # # (dtype_cast_node)+ # / # (prev_node_c)+ -> (logger_c_input)? -> node_c # if input logging is enabled, log the input to the subgraph if should_log_inputs: # TODO: explain this ref_node_name = '' if isinstance(dtype_cast_node, Node): dtype_cast_node = _insert_logger_after_node( dtype_cast_node, gm_b, logger_cls, '_ns_logger_a_inp_', ref_node_name, name_a, ref_name, NSSingleResultValuesType.NODE_INPUT.value, index_within_arg=0) input_logger: Union[Node, List[Node]] = dtype_cast_node else: assert isinstance(dtype_cast_node, list) new_loggers = [] for dtype_cast_idx, dtype_cast_node_inner in enumerate( dtype_cast_node): dtype_cast_logger = _insert_logger_after_node( dtype_cast_node_inner, gm_b, logger_cls, '_ns_logger_a_inp_', ref_node_name, name_a, ref_name, NSSingleResultValuesType.NODE_INPUT.value, index_within_arg=dtype_cast_idx) new_loggers.append(dtype_cast_logger) dtype_cast_node = new_loggers input_logger = dtype_cast_node # subgraph so far: # # (dtype_cast_node)+ -> (logger_a_input)? # / # prev_node_c -> (logger_c_input)? -> node_c # hook up the new mod_a copy to be in the graph, receiving the # same inputs as mod_b does, with dtype cast to match a node_a_shadows_c = _insert_copy_of_subgraph_a_after_input_node_c( dtype_cast_node, node_start_a, node_end_a, gm_a, gm_b, node_c.name + '_shadow_copy_') env_c[node_a_shadows_c.name] = node_a_shadows_c # subgraph so far: # # dtype_cast_node -> (logger_a_input)? -> subgraph_a_copy(args/kwargs not shown) # / # (prev_node_c)+ -> (logger_c_input)? -> node_c if should_log_inputs: # When we created the input logger, we left the ref_node_name # as an empty string, because the subgraph copy did not exist # yet. Now that the subgraph copy exists, we modify this name # to its true value. # Note: the alternative to this is to create the input logger # after creating the subgraph, which is slightly more # complicated. This is the lesser of two evils. # input_logger = env_c[dtype_cast_node.name] # Find the first node in the subgraph cur_node = node_a_shadows_c while cur_node.args[0] != input_logger: cur_node = cur_node.args[0] # type: ignore if isinstance(input_logger, Node): input_logger_mod = getattr(gm_b, input_logger.name) input_logger_mod.ref_node_name = cur_node.name else: assert isinstance(input_logger, list) for input_logger_inner in input_logger: input_logger_mod = getattr(gm_b, input_logger_inner.name) input_logger_mod.ref_node_name = cur_node.name # hook up a logger to the mod_b copy env_c[node_b.name] = _insert_logger_after_node( env_c[node_b.name], gm_b, logger_cls, '_ns_logger_b_', node_b.name, name_b, ref_name, NSSingleResultValuesType.NODE_OUTPUT.value, index_within_arg=0) # subgraph so far: # # dtype_cast_node -> (logger_a_input)? -> subgraph_a_copy # / # (prev_node_c+) -> (logger_c_input)? -> node_c -> logger_c # hook up a logger to the mod_a copy # Note: we pass node_b.name to this logger, for easy matching later env_c[node_a_shadows_c.name] = _insert_logger_after_node( env_c[node_a_shadows_c.name], gm_b, logger_cls, '_ns_logger_a_', node_a_shadows_c.name, name_a, ref_name, NSSingleResultValuesType.NODE_OUTPUT.value, index_within_arg=0) # subgraph so far: # # dtype_cast_node -> (logger_a_input)? -> subgraph_a_copy -> logger_a # / # (prev_node_c)+ -> (logger_c_input)? -> node_c -> logger_c else: env_c[node_b.name] = graph_c.node_copy(node_b, load_arg) gm_c = GraphModule(gm_b, graph_c) return gm_c