Esempio n. 1
0
def return_first_non_observer_node(
    node: Node,
    gm: GraphModule,
) -> Node:
    """
    If node is not an observer, returns it.  If node is an observer,
    navigates up the graph and returns the first parent which is not an
    observer.  For example,

    graph: (node_non_obs), node = node_non_obs : returns node_non_obs
    graph: (node_non_obs -> obs0), node = obs0 : returns node_non_obs
    graph: (node_non_obs -> obs0 -> fq0), node = fq0 : returns node_non_obs
    """
    if node.op == 'call_module':
        node_obj = getattr_from_fqn(gm, node.target)  # type: ignore[arg-type]
        if is_activation_post_process(node_obj):
            assert len(node.args) == 1
            assert isinstance(node.args[0], Node)
            node = node.args[0]
            # code duplication intended, not worth refactoring
            assert isinstance(node.target, str)
            node_obj = getattr_from_fqn(gm, node.target)
            if is_activation_post_process(node_obj):
                assert len(node.args) == 1
                assert isinstance(node.args[0], Node)
                node = node.args[0]
    return node
Esempio n. 2
0
def remove_qconfig_observer_fx(model):
    # remove activation post process
    act_post_process_removed_graph = Graph()
    env = {}  # type: Dict[str, Any]

    modules = dict(model.named_modules())

    def load_arg(a):
        return map_arg(a, lambda node: env[node.name])

    for node in model.graph.nodes:
        if node.op == "output":
            act_post_process_removed_graph.output(
                map_arg(node.args[0], load_arg))
            continue
        if node.op == "call_module" and is_activation_post_process(
                modules[node.target]):
            # remove activation post process node
            env[node.name] = env[node.args[0].name]
        else:
            env[node.name] = act_post_process_removed_graph.node_copy(
                node, load_arg)

    _remove_qconfig(model)
    model = GraphModule(model, act_post_process_removed_graph)
    return model
Esempio n. 3
0
def remove_observers_add_loggers(
    gm: GraphModule,
    node_to_instrument_to_ref_node_name: Dict[Node, Optional[str]],
    logger_cls: Callable,
    model_name: str,
) -> GraphModule:
    """
    Takes the graph of gm, removes all observers, adds loggers to the output
    of each node in nodes_to_instrument. Returns a GraphModule with the new
    graph.
    """

    new_graph = Graph()
    env: Dict[str, Any] = {}
    modules = dict(gm.named_modules())

    def load_arg(a):
        return map_arg(a, lambda node: env[node.name])

    for node in gm.graph.nodes:
        if node.op == 'output':
            new_graph.output(map_arg(node.args[0], load_arg))
            continue

        if node.op == 'call_module' and is_activation_post_process(
                modules[node.target]):
            # remove activation post process node
            env[node.name] = env[node.args[0].name]

        elif node in node_to_instrument_to_ref_node_name:
            other_node_name = node_to_instrument_to_ref_node_name[node]
            # ensure env is populated with base node
            env[node.name] = new_graph.node_copy(node, load_arg)
            # add the logger after the base node
            env[node.name] = _insert_logger_after_node(env[node.name], gm,
                                                       logger_cls,
                                                       '_ns_logger_',
                                                       model_name,
                                                       other_node_name)

        else:
            env[node.name] = new_graph.node_copy(node, load_arg)

    new_gm = GraphModule(gm, new_graph)
    return new_gm
Esempio n. 4
0
    def test_remove_qconfig_observer_fx(self):
        r"""Remove activation_post_process node from fx prepred model"""
        float_model = SingleLayerLinearModel()
        float_model.eval()

        qengine = torch.backends.quantized.engine
        qconfig = get_default_qconfig(qengine)

        qconfig_dict = {"": qconfig}

        prepared_model = prepare_fx(float_model, qconfig_dict)

        prepared_float_model = copy.deepcopy(prepared_model)
        prepared_float_model.eval()

        model = remove_qconfig_observer_fx(prepared_float_model)

        modules = dict(model.named_modules())
        for node in model.graph.nodes:
            if node.op == "call_module":
                self.assertFalse(is_activation_post_process(modules[node.target]))
Esempio n. 5
0
def remove_observers_add_loggers(
    gm: GraphModule,
    node_to_instrument_inputs_to_ref_node_name: Dict[Node, str],
    node_to_instrument_outputs_to_ref_node_name: Dict[Node, str],
    logger_cls: Callable,
    model_name: str,
) -> GraphModule:
    """
    Takes the graph of gm, removes all observers, adds loggers to the output
    of each node in nodes_to_instrument. Returns a GraphModule with the new
    graph.
    """

    new_graph = Graph()
    env: Dict[str, Any] = {}
    modules = dict(gm.named_modules())

    def load_arg(a):
        return map_arg(a, lambda node: env[node.name])

    for node in gm.graph.nodes:
        if node.op == 'output':
            new_graph.output(map_arg(node.args[0], load_arg))
            continue

        if node.op == 'call_module' and is_activation_post_process(
                modules[node.target]):
            # remove activation post process node
            env[node.name] = env[node.args[0].name]

        elif ((node in node_to_instrument_inputs_to_ref_node_name)
              or (node in node_to_instrument_outputs_to_ref_node_name)):

            if node in node_to_instrument_inputs_to_ref_node_name:
                ref_name = node_to_instrument_inputs_to_ref_node_name[node]
                if type(node.args[0]) == Node:
                    # create a single input logger
                    prev_node = env[node.args[0].name]
                    env[node.args[0].name] = _insert_logger_after_node(
                        prev_node,
                        gm,
                        logger_cls,
                        '_ns_logger_',
                        node.name,
                        model_name,
                        ref_name,
                        NSSingleResultValuesType.NODE_INPUT.value,
                        index_within_arg=0)
                elif type(node.args[0]
                          ) == torch.fx.immutable_collections.immutable_list:
                    # create N input loggers, one for each node
                    for arg_idx, arg in enumerate(node.args[0]):
                        prev_node = env[arg.name]
                        env[prev_node.name] = _insert_logger_after_node(
                            prev_node,
                            gm,
                            logger_cls,
                            '_ns_logger_',
                            node.name,
                            model_name,
                            ref_name,
                            NSSingleResultValuesType.NODE_INPUT.value,
                            index_within_arg=arg_idx)
                else:
                    raise AssertionError(
                        f"type {type(node.args[0])} is not handled yet")

            # ensure env is populated with base node
            # Note: runs for both inputs and outputs
            env[node.name] = new_graph.node_copy(node, load_arg)

            if node in node_to_instrument_outputs_to_ref_node_name:
                ref_name = node_to_instrument_outputs_to_ref_node_name[node]
                # add the logger after the base node
                env[node.name] = _insert_logger_after_node(
                    env[node.name],
                    gm,
                    logger_cls,
                    '_ns_logger_',
                    node.name,
                    model_name,
                    ref_name,
                    NSSingleResultValuesType.NODE_OUTPUT.value,
                    index_within_arg=0)

        else:
            env[node.name] = new_graph.node_copy(node, load_arg)

    new_gm = GraphModule(gm, new_graph)
    return new_gm
Esempio n. 6
0
def create_a_shadows_b(
    name_a: str,
    gm_a: GraphModule,
    name_b: str,
    gm_b: GraphModule,
    matched_subgraph_pairs: Dict[str, Tuple[NSSubgraph, NSSubgraph]],
    logger_cls: Callable,
    should_log_inputs: bool,
) -> GraphModule:
    """
    Creates a new GraphModule consisting of the graph of C, with the meaningful
    nodes of A shadowing the corresponding nodes of B.  For example,

    Graph A:
    a0 -> op0_fp32 -> a1 -> op1_fp32 -> a2

    Graph B:
    b0 -> op0_int8 -> b1 -> op1_int8 -> b2

    matched_node_pairs: {'op0': (op0_fp32, op0_int8), 'op1': (op1_fp32, op1_int8)}

    Graph C (A shadows B):

        / dequant0 -> op0_fp32 -> logger_a_0  / dequant_1 -> op1_fp32 -> logger_a_1
       /                                     /
    b0 -------------> op0_int8 -> logger_b_0 --------------> op1_int8 -> logger_b_1

    In a nutshell, this function does the following for each node pair:
    * copies the necessary attributes and modules from gm_a to gm_b,
      keeping names unique
    * adds a dtype cast op (dequant, quant, etc)
    * adds a copy of node_a in gm_b's graph
    * adds loggers to the outputs of node_a and node_b
    """

    # graph_c is the graph created from copying the nodes of graph_b and inserting
    # the shadows with the nodes copied from graph_a
    graph_c = Graph()
    env_c: Dict[str, Any] = {}
    modules = dict(gm_b.named_modules())

    def load_arg(a):
        return map_arg(a, lambda node: env_c[node.name])

    start_node_b_to_matched_subgraph_a_and_name = {}
    end_node_b_to_matched_subgraph_a_and_name = {}
    for match_name, match in matched_subgraph_pairs.items():
        subgraph_a, subgraph_b = match
        start_node_b_to_matched_subgraph_a_and_name[subgraph_b.start_node] = \
            (subgraph_a, match_name)
        end_node_b_to_matched_subgraph_a_and_name[subgraph_b.end_node] = \
            (subgraph_a, match_name)

    for node_b in gm_b.graph.nodes:
        if node_b.op == 'output':
            graph_c.output(map_arg(node_b.args[0], load_arg))
            continue

        # calculate the flags to determine what to do with this node
        node_b_is_observer = \
            node_b.op == 'call_module' and is_activation_post_process(modules[node_b.target])
        node_b_is_start_node = node_b in start_node_b_to_matched_subgraph_a_and_name
        node_b_is_end_node = node_b in end_node_b_to_matched_subgraph_a_and_name

        if node_b_is_observer:
            # remove activation post process node
            env_c[node_b.name] = env_c[node_b.args[0].name]  # type: ignore

        elif (node_b_is_start_node or node_b_is_end_node):

            if node_b_is_start_node:
                subgraph_a, ref_name = \
                    start_node_b_to_matched_subgraph_a_and_name[node_b]
            else:
                assert node_b_is_end_node
                subgraph_a, ref_name = \
                    end_node_b_to_matched_subgraph_a_and_name[node_b]

            # For both start_node and end_node verify that we know how to do
            # the dtype cast. If we do not, skip.
            node_input_type_a, node_output_type_a = \
                get_node_first_input_and_output_type(subgraph_a.start_node, gm_a, logger_cls)
            node_input_type_b, node_output_type_b = \
                get_node_first_input_and_output_type(node_b, gm_b, logger_cls)
            node_io_types_known_a_and_b = (
                node_input_type_a != NodeInputOrOutputType.UNKNOWN
                and node_output_type_a != NodeInputOrOutputType.UNKNOWN
                and node_input_type_b != NodeInputOrOutputType.UNKNOWN
                and node_output_type_b != NodeInputOrOutputType.UNKNOWN)
            if not node_io_types_known_a_and_b:
                print(
                    f'skipping shadow loggers for node_b: {get_target_type_str(node_b, gm_b)}'
                    +
                    f', start_node_a: {get_target_type_str(subgraph_a.start_node, gm_a)}'
                    + ', unknown dtype cast')
                env_c[node_b.name] = graph_c.node_copy(node_b, load_arg)
                continue

            if node_b_is_start_node:

                # if necessary, log the input of node_c
                if should_log_inputs:
                    if isinstance(node_b.args[0], Node):
                        prev_node_c = env_c[node_b.args[0].name]
                        env_c[prev_node_c.name] = _insert_logger_after_node(
                            prev_node_c,
                            gm_b,
                            logger_cls,
                            '_ns_logger_b_inp_',
                            node_b.name,
                            name_b,
                            ref_name,
                            NSSingleResultValuesType.NODE_INPUT.value,
                            index_within_arg=0)
                    elif isinstance(node_b.args[0], list):
                        # first, save the prev_node instances, because they
                        # will be overwritten in the env after the first logger
                        # is added
                        prev_node_c_list = [
                            env_c[arg.name] for arg in node_b.args[0]
                        ]

                        for arg_idx, arg in enumerate(node_b.args[0]):
                            prev_node_c = prev_node_c_list[arg_idx]
                            env_c[
                                prev_node_c.name] = _insert_logger_after_node(
                                    prev_node_c,
                                    gm_b,
                                    logger_cls,
                                    '_ns_logger_b_inp_',
                                    node_b.name,
                                    name_b,
                                    ref_name,
                                    NSSingleResultValuesType.NODE_INPUT.value,
                                    index_within_arg=arg_idx)
                    else:
                        # logging of inputs which are not lists is not supported yet
                        raise AssertionError(
                            f"type {type(node_b.args[0])} is not handled yet")
                # subgraph so far:
                #
                # (prev_node_c)+ -> (logger_c_input)?

            # Note: this if statement is always True, spelling it out to clarify code
            # intent.
            if node_b_is_start_node or node_b_is_end_node:
                # ensure env_c is populated with base node
                env_c[node_b.name] = graph_c.node_copy(node_b, load_arg)
                node_c = env_c[node_b.name]

                # after this point,
                #
                # node_a is the original node from graph_a, with parent module gm_a
                # node_b is the original node from graph_b, with parent module gm_b
                # node_c is the copy of node_b in graph_c
                #
                # subgraph so far:
                #
                # (prev_node_c)+ -> (logger_c_input)? -> node_start_c

            if node_b_is_start_node:

                # cast dtype from the dtype of node_c's input to the dtype of
                # node_a's input (dequant, etc)
                prev_node_c = node_c.args[0]
                if should_log_inputs:
                    # skip the input logger when inserting a dtype cast
                    if isinstance(prev_node_c, Node):
                        prev_node_c = prev_node_c.args[0]
                    elif isinstance(prev_node_c, list):
                        prev_node_c = [arg.args[0] for arg in prev_node_c]
                dtype_cast_node = _insert_dtype_cast_after_node(
                    subgraph_a.start_node, node_c, prev_node_c, gm_a, gm_b,
                    graph_c, node_b.name + '_dtype_cast_', logger_cls)
                # note: not inserting to env_c because all nodes which use the dtype
                #   casts are copied from graph_a
                #
                # subgraph so far:
                #
                #           (dtype_cast_node)+
                #                  /
                # (prev_node_c)+ -> (logger_c_input)? -> node_start_c

                # if input logging is enabled, log the input to the subgraph
                if should_log_inputs:
                    # TODO: explain this
                    ref_node_name = ''
                    if isinstance(dtype_cast_node, Node):
                        dtype_cast_node = _insert_logger_after_node(
                            dtype_cast_node,
                            gm_b,
                            logger_cls,
                            '_ns_logger_a_inp_',
                            ref_node_name,
                            name_a,
                            ref_name,
                            NSSingleResultValuesType.NODE_INPUT.value,
                            index_within_arg=0)
                        input_logger: Union[Node, List[Node]] = dtype_cast_node
                    else:
                        assert isinstance(dtype_cast_node, list)
                        new_loggers = []
                        for dtype_cast_idx, dtype_cast_node_inner in enumerate(
                                dtype_cast_node):
                            dtype_cast_logger = _insert_logger_after_node(
                                dtype_cast_node_inner,
                                gm_b,
                                logger_cls,
                                '_ns_logger_a_inp_',
                                ref_node_name,
                                name_a,
                                ref_name,
                                NSSingleResultValuesType.NODE_INPUT.value,
                                index_within_arg=dtype_cast_idx)
                            new_loggers.append(dtype_cast_logger)
                        dtype_cast_node = new_loggers
                        input_logger = dtype_cast_node
                    # subgraph so far:
                    #
                    #       (dtype_cast_node)+ -> (logger_a_input)?
                    #                  /
                    # prev_node_c -> (logger_c_input)? -> node_start_c

                # hook up the new mod_a copy to be in the graph, receiving the
                # same inputs as mod_b does, with dtype cast to match a
                # Some ops, such as LSTMs, have two non-param inputs. If we have
                # such an op, pass the second param as well. Note: dtype casting
                # for the second param is not implemented yet, it can be added
                # later if there is a use case.
                node_c_second_non_param_arg = None
                num_non_param_args_node_a = get_number_of_non_param_args(
                    subgraph_a.start_node, gm_a)
                if num_non_param_args_node_a == 2:
                    node_c_second_non_param_arg = node_c.args[1]
                node_a_shadows_c = _insert_copy_of_subgraph_a_after_input_node_c(
                    dtype_cast_node, node_c_second_non_param_arg, subgraph_a,
                    gm_a, gm_b, node_c.name + '_shadow_copy_')
                env_c[node_a_shadows_c.name] = node_a_shadows_c
                # subgraph so far:
                #
                #       dtype_cast_node -> (logger_a_input)? -> subgraph_a_copy(args/kwargs not shown)
                #                  /
                # (prev_node_c)+ -> (logger_c_input)? -> node_start_c

                if should_log_inputs:
                    # When we created the input logger, we left the ref_node_name
                    # as an empty string, because the subgraph copy did not exist
                    # yet. Now that the subgraph copy exists, we modify this name
                    # to its true value.
                    # Note: the alternative to this is to create the input logger
                    # after creating the subgraph, which is slightly more
                    # complicated. This is the lesser of two evils.
                    # input_logger = env_c[dtype_cast_node.name]
                    # Find the first node in the subgraph
                    cur_node = node_a_shadows_c
                    while cur_node.args[0] != input_logger:
                        cur_node = cur_node.args[0]  # type: ignore
                    if isinstance(input_logger, Node):
                        input_logger_mod = getattr(gm_b, input_logger.name)
                        input_logger_mod.ref_node_name = cur_node.name
                    else:
                        assert isinstance(input_logger, list)
                        for input_logger_inner in input_logger:
                            input_logger_mod = getattr(gm_b,
                                                       input_logger_inner.name)
                            input_logger_mod.ref_node_name = cur_node.name

                # hook up a logger to the mod_a copy
                env_c[node_a_shadows_c.name] = _insert_logger_after_node(
                    env_c[node_a_shadows_c.name],
                    gm_b,
                    logger_cls,
                    '_ns_logger_a_',
                    node_a_shadows_c.name,
                    name_a,
                    ref_name,
                    NSSingleResultValuesType.NODE_OUTPUT.value,
                    index_within_arg=0)
                # subgraph so far:
                #
                #       dtype_cast_node -> (logger_a_input)? -> subgraph_a_copy -> logger_a
                #                  /
                # (prev_node_c)+ -> (logger_c_input)? -> node_start_c

            if node_b_is_end_node:

                # hook up a logger to the mod_b copy
                env_c[node_b.name] = _insert_logger_after_node(
                    env_c[node_b.name],
                    gm_b,
                    logger_cls,
                    '_ns_logger_b_',
                    node_b.name,
                    name_b,
                    ref_name,
                    NSSingleResultValuesType.NODE_OUTPUT.value,
                    index_within_arg=0)
                # subgraph so far:
                #
                #       dtype_cast_node -> (logger_a_input)? -> subgraph_a_copy -> logger_a
                #                  /
                # (prev_node_c+) -> (logger_c_input)? -> node_start_c -> ... -> node_end_c -> logger_c
                #
                # Note: node_start_c may be the same node as node_end_c, or they
                # may have nodes inbetween.

        else:
            env_c[node_b.name] = graph_c.node_copy(node_b, load_arg)

    gm_c = GraphModule(gm_b, graph_c)
    return gm_c
Esempio n. 7
0
def remove_observers_add_loggers(
    gm: GraphModule,
    node_to_instrument_inputs_to_ref_node_name: Dict[Node, str],
    node_to_instrument_outputs_to_ref_node_name: Dict[Node, str],
    logger_cls: Callable,
    model_name: str,
) -> GraphModule:
    """
    Takes the graph of gm, removes all observers, adds loggers to the output
    of each node in nodes_to_instrument. Returns a GraphModule with the new
    graph.
    """

    new_graph = Graph()
    env: Dict[str, Any] = {}
    modules = dict(gm.named_modules())

    def load_arg(a):
        return map_arg(a, lambda node: env[node.name])

    for node in gm.graph.nodes:
        if node.op == 'output':
            new_graph.output(map_arg(node.args[0], load_arg))
            continue

        if node.op == 'call_module' and is_activation_post_process(modules[node.target]):
            # remove activation post process node
            env[node.name] = env[node.args[0].name]

        elif (
            (node in node_to_instrument_inputs_to_ref_node_name) or
            (node in node_to_instrument_outputs_to_ref_node_name)
        ):

            if node in node_to_instrument_inputs_to_ref_node_name:
                ref_name = node_to_instrument_inputs_to_ref_node_name[node]
                # Ops such add and mul are special because either
                # one or two of the first two arguments can be tensors,
                # and if one argument is a tensor it can be first or
                # second (x + 1 versus 1 + x).
                arg_indices_to_log = get_arg_indices_of_inputs_to_log(node)
                for node_arg_idx in arg_indices_to_log:
                    node_arg = node.args[node_arg_idx]
                    if type(node_arg) == Node:
                        # create a single input logger
                        prev_node = env[node_arg.name]
                        env[node_arg.name] = _insert_logger_after_node(
                            prev_node, gm, logger_cls, '_ns_logger_', node.name,
                            model_name, ref_name,
                            NSSingleResultValuesType.NODE_INPUT.value,
                            index_within_arg=0, index_of_arg=node_arg_idx)
                    elif type(node_arg) == torch.fx.immutable_collections.immutable_list:
                        # create N input loggers, one for each node
                        for arg_idx, arg in enumerate(node_arg):
                            prev_node = env[arg.name]
                            env[prev_node.name] = _insert_logger_after_node(
                                prev_node, gm, logger_cls, '_ns_logger_', node.name,
                                model_name, ref_name,
                                NSSingleResultValuesType.NODE_INPUT.value,
                                index_within_arg=arg_idx, index_of_arg=node_arg_idx)
                    else:
                        pass

            # ensure env is populated with base node
            # Note: runs for both inputs and outputs
            env[node.name] = new_graph.node_copy(node, load_arg)

            if node in node_to_instrument_outputs_to_ref_node_name:
                ref_name = node_to_instrument_outputs_to_ref_node_name[node]
                # add the logger after the base node
                env[node.name] = _insert_logger_after_node(
                    env[node.name], gm, logger_cls, '_ns_logger_', node.name,
                    model_name, ref_name, NSSingleResultValuesType.NODE_OUTPUT.value,
                    index_within_arg=0, index_of_arg=0)

        else:
            env[node.name] = new_graph.node_copy(node, load_arg)

    new_gm = GraphModule(gm, new_graph)
    return new_gm
Esempio n. 8
0
def create_a_shadows_b(
    name_a: str,
    gm_a: GraphModule,
    name_b: str,
    gm_b: GraphModule,
    matched_subgraph_pairs: Dict[str, Tuple[Tuple[Node, Node], Tuple[Node,
                                                                     Node]]],
    logger_cls: Callable,
) -> GraphModule:
    """
    Creates a new GraphModule consisting of the graph of C, with the meaningful
    nodes of A shadowing the corresponding nodes of B.  For example,

    Graph A:
    a0 -> op0_fp32 -> a1 -> op1_fp32 -> a2

    Graph B:
    b0 -> op0_int8 -> b1 -> op1_int8 -> b2

    matched_node_pairs: {'op0': (op0_fp32, op0_int8), 'op1': (op1_fp32, op1_int8)}

    Graph C (A shadows B):

        / dequant0 -> op0_fp32 -> logger_a_0  / dequant_1 -> op1_fp32 -> logger_a_1
       /                                     /
    b0 -------------> op0_int8 -> logger_b_0 --------------> op1_int8 -> logger_b_1

    In a nutshell, this function does the following for each node pair:
    * copies the necessary attributes and modules from gm_a to gm_b,
      keeping names unique
    * adds a dtype cast op (dequant, quant, etc)
    * adds a copy of node_a in gm_b's graph
    * adds loggers to the outputs of node_a and node_b
    """

    # graph_c is the graph created from copying the nodes of graph_b and inserting
    # the shadows with the nodes copied from graph_a
    graph_c = Graph()
    env_c: Dict[str, Any] = {}
    modules = dict(gm_b.named_modules())

    def load_arg(a):
        return map_arg(a, lambda node: env_c[node.name])

    node_b_to_matched_subgraph_a = {}
    for match_name, match in matched_subgraph_pairs.items():
        (node_start_a, node_end_a), (node_start_b, node_end_b) = match
        assert node_start_b is node_end_b, \
            "Shadowing subgraphs of B with multiple nodes is not yet handled."
        node_b_to_matched_subgraph_a[node_end_b] = (node_start_a, node_end_a)

    for node_b in gm_b.graph.nodes:
        if node_b.op == 'output':
            graph_c.output(map_arg(node_b.args[0], load_arg))
            continue

        if node_b.op == 'call_module' and is_activation_post_process(
                modules[node_b.target]):
            # remove activation post process node
            env_c[node_b.name] = env_c[node_b.args[0].name]  # type: ignore

        elif node_b in node_b_to_matched_subgraph_a:
            node_start_a, node_end_a = node_b_to_matched_subgraph_a[node_b]
            if False:
                print('b')
                print_node(node_b)
                print('a')
                print_node(node_start_a)
                print_node(node_end_a)

            # ensure env_c is populated with base node
            env_c[node_b.name] = graph_c.node_copy(node_b, load_arg)
            node_c = env_c[node_b.name]

            # after this point,
            #
            # node_a is the original node from graph_a, with parent module gm_a
            # node_b is the original node from graph_b, with parent module gm_b
            # node_c is the copy of node_b in graph_c
            #
            # subgraph so far:
            #
            # prev_node_c -> node_c

            # cast dtype from the dtype of node_c's input to the dtype of
            # node_a's input (dequant, etc)
            dtype_cast_node = _insert_dtype_cast_after_node(
                node_start_a, node_c, node_c.args[0], gm_a, gm_b, graph_c,
                node_b.name + '_dtype_cast_')
            env_c[dtype_cast_node.name] = dtype_cast_node
            # subgraph so far:
            #
            #       dtype_cast_node
            #      /
            # prev_node_c -> node_c

            # hook up the new mod_a copy to be in the graph, receiving the
            # same inputs as mod_b does, with dtype cast to match a
            node_a_shadows_c = _insert_copy_of_subgraph_a_after_input_node_c(
                env_c[dtype_cast_node.name], node_start_a, node_end_a, gm_a,
                gm_b, node_c.name + '_shadow_copy_')
            env_c[node_a_shadows_c.name] = node_a_shadows_c
            # subgraph so far:
            #
            #       dtype_cast_node --> subgraph_a_copy(args/kwargs not shown)
            #      /
            # prev_node_c -> node_c

            # hook up a logger to the mod_b copy
            env_c[node_b.name] = _insert_logger_after_node(
                env_c[node_b.name], gm_b, logger_cls, '_ns_logger_b_', name_b)
            # subgraph so far:
            #
            #       dtype_cast_node --> subgraph_a_copy
            #      /
            # prev_node_c -> node_c --> logger_c

            # hook up a logger to the mod_a copy
            # Note: we pass node_b.name to this logger, for easy matching later
            env_c[node_a_shadows_c.name] = _insert_logger_after_node(
                env_c[node_a_shadows_c.name], gm_b, logger_cls,
                '_ns_logger_a_', name_a, node_b.name)
            # subgraph so far:
            #
            #       dtype_cast_node --> subgraph_a_copy --> logger_a
            #      /
            # prev_node_c -> node_c --> logger_c

        else:
            env_c[node_b.name] = graph_c.node_copy(node_b, load_arg)

    gm_c = GraphModule(gm_b, graph_c)
    return gm_c
Esempio n. 9
0
def create_a_shadows_b(
    name_a: str,
    gm_a: GraphModule,
    name_b: str,
    gm_b: GraphModule,
    matched_subgraph_pairs: Dict[str, Tuple[Tuple[Node, Node], Tuple[Node,
                                                                     Node]]],
    logger_cls: Callable,
    should_log_inputs: bool,
) -> GraphModule:
    """
    Creates a new GraphModule consisting of the graph of C, with the meaningful
    nodes of A shadowing the corresponding nodes of B.  For example,

    Graph A:
    a0 -> op0_fp32 -> a1 -> op1_fp32 -> a2

    Graph B:
    b0 -> op0_int8 -> b1 -> op1_int8 -> b2

    matched_node_pairs: {'op0': (op0_fp32, op0_int8), 'op1': (op1_fp32, op1_int8)}

    Graph C (A shadows B):

        / dequant0 -> op0_fp32 -> logger_a_0  / dequant_1 -> op1_fp32 -> logger_a_1
       /                                     /
    b0 -------------> op0_int8 -> logger_b_0 --------------> op1_int8 -> logger_b_1

    In a nutshell, this function does the following for each node pair:
    * copies the necessary attributes and modules from gm_a to gm_b,
      keeping names unique
    * adds a dtype cast op (dequant, quant, etc)
    * adds a copy of node_a in gm_b's graph
    * adds loggers to the outputs of node_a and node_b
    """

    # graph_c is the graph created from copying the nodes of graph_b and inserting
    # the shadows with the nodes copied from graph_a
    graph_c = Graph()
    env_c: Dict[str, Any] = {}
    modules = dict(gm_b.named_modules())

    def load_arg(a):
        return map_arg(a, lambda node: env_c[node.name])

    node_b_to_matched_subgraph_a_and_name = {}
    for match_name, match in matched_subgraph_pairs.items():
        (node_start_a, node_end_a), (node_start_b, node_end_b) = match
        assert node_start_b is node_end_b, \
            "Shadowing subgraphs of B with multiple nodes is not yet handled."
        node_b_to_matched_subgraph_a_and_name[node_end_b] = \
            ((node_start_a, node_end_a), match_name)

    for node_b in gm_b.graph.nodes:
        if node_b.op == 'output':
            graph_c.output(map_arg(node_b.args[0], load_arg))
            continue

        if node_b.op == 'call_module' and is_activation_post_process(
                modules[node_b.target]):
            # remove activation post process node
            env_c[node_b.name] = env_c[node_b.args[0].name]  # type: ignore

        elif node_b in node_b_to_matched_subgraph_a_and_name:
            (node_start_a, node_end_a), ref_name = \
                node_b_to_matched_subgraph_a_and_name[node_b]
            if False:
                print('b')
                print_node(node_b)
                print('a')
                print_node(node_start_a)
                print_node(node_end_a)

            # if necessary, log the input of node_c
            if should_log_inputs:
                if isinstance(node_b.args[0], Node):
                    prev_node_c = env_c[node_b.args[0].name]
                    env_c[prev_node_c.name] = _insert_logger_after_node(
                        prev_node_c,
                        gm_b,
                        logger_cls,
                        '_ns_logger_b_inp_',
                        node_b.name,
                        name_b,
                        ref_name,
                        NSSingleResultValuesType.NODE_INPUT.value,
                        index_within_arg=0)
                elif isinstance(node_b.args[0], list):
                    # first, save the prev_node instances, because they
                    # will be overwritten in the env after the first logger
                    # is added
                    prev_node_c_list = [
                        env_c[arg.name] for arg in node_b.args[0]
                    ]

                    for arg_idx, arg in enumerate(node_b.args[0]):
                        prev_node_c = prev_node_c_list[arg_idx]
                        env_c[prev_node_c.name] = _insert_logger_after_node(
                            prev_node_c,
                            gm_b,
                            logger_cls,
                            '_ns_logger_b_inp_',
                            node_b.name,
                            name_b,
                            ref_name,
                            NSSingleResultValuesType.NODE_INPUT.value,
                            index_within_arg=arg_idx)
                else:
                    # logging of inputs which are not lists is not supported yet
                    raise AssertionError(
                        f"type {type(node_b.args[0])} is not handled yet")
            # subgraph so far:
            #
            # (prev_node_c)+ -> (logger_c_input)?

            # ensure env_c is populated with base node
            env_c[node_b.name] = graph_c.node_copy(node_b, load_arg)
            node_c = env_c[node_b.name]

            # after this point,
            #
            # node_a is the original node from graph_a, with parent module gm_a
            # node_b is the original node from graph_b, with parent module gm_b
            # node_c is the copy of node_b in graph_c
            #
            # subgraph so far:
            #
            # (prev_node_c)+ -> (logger_c_input)? -> node_c

            # cast dtype from the dtype of node_c's input to the dtype of
            # node_a's input (dequant, etc)
            dtype_cast_node = _insert_dtype_cast_after_node(
                node_start_a, node_c, node_c.args[0], gm_a, gm_b, graph_c,
                node_b.name + '_dtype_cast_')
            # note: not inserting to env_c because all nodes which use the dtype
            #   casts are copied from graph_a
            #
            # subgraph so far:
            #
            #           (dtype_cast_node)+
            #                  /
            # (prev_node_c)+ -> (logger_c_input)? -> node_c

            # if input logging is enabled, log the input to the subgraph
            if should_log_inputs:
                # TODO: explain this
                ref_node_name = ''
                if isinstance(dtype_cast_node, Node):
                    dtype_cast_node = _insert_logger_after_node(
                        dtype_cast_node,
                        gm_b,
                        logger_cls,
                        '_ns_logger_a_inp_',
                        ref_node_name,
                        name_a,
                        ref_name,
                        NSSingleResultValuesType.NODE_INPUT.value,
                        index_within_arg=0)
                    input_logger: Union[Node, List[Node]] = dtype_cast_node
                else:
                    assert isinstance(dtype_cast_node, list)
                    new_loggers = []
                    for dtype_cast_idx, dtype_cast_node_inner in enumerate(
                            dtype_cast_node):
                        dtype_cast_logger = _insert_logger_after_node(
                            dtype_cast_node_inner,
                            gm_b,
                            logger_cls,
                            '_ns_logger_a_inp_',
                            ref_node_name,
                            name_a,
                            ref_name,
                            NSSingleResultValuesType.NODE_INPUT.value,
                            index_within_arg=dtype_cast_idx)
                        new_loggers.append(dtype_cast_logger)
                    dtype_cast_node = new_loggers
                    input_logger = dtype_cast_node
                # subgraph so far:
                #
                #       (dtype_cast_node)+ -> (logger_a_input)?
                #                  /
                # prev_node_c -> (logger_c_input)? -> node_c

            # hook up the new mod_a copy to be in the graph, receiving the
            # same inputs as mod_b does, with dtype cast to match a
            node_a_shadows_c = _insert_copy_of_subgraph_a_after_input_node_c(
                dtype_cast_node, node_start_a, node_end_a, gm_a, gm_b,
                node_c.name + '_shadow_copy_')
            env_c[node_a_shadows_c.name] = node_a_shadows_c
            # subgraph so far:
            #
            #       dtype_cast_node -> (logger_a_input)? -> subgraph_a_copy(args/kwargs not shown)
            #                  /
            # (prev_node_c)+ -> (logger_c_input)? -> node_c

            if should_log_inputs:
                # When we created the input logger, we left the ref_node_name
                # as an empty string, because the subgraph copy did not exist
                # yet. Now that the subgraph copy exists, we modify this name
                # to its true value.
                # Note: the alternative to this is to create the input logger
                # after creating the subgraph, which is slightly more
                # complicated. This is the lesser of two evils.
                # input_logger = env_c[dtype_cast_node.name]
                # Find the first node in the subgraph
                cur_node = node_a_shadows_c
                while cur_node.args[0] != input_logger:
                    cur_node = cur_node.args[0]  # type: ignore
                if isinstance(input_logger, Node):
                    input_logger_mod = getattr(gm_b, input_logger.name)
                    input_logger_mod.ref_node_name = cur_node.name
                else:
                    assert isinstance(input_logger, list)
                    for input_logger_inner in input_logger:
                        input_logger_mod = getattr(gm_b,
                                                   input_logger_inner.name)
                        input_logger_mod.ref_node_name = cur_node.name

            # hook up a logger to the mod_b copy
            env_c[node_b.name] = _insert_logger_after_node(
                env_c[node_b.name],
                gm_b,
                logger_cls,
                '_ns_logger_b_',
                node_b.name,
                name_b,
                ref_name,
                NSSingleResultValuesType.NODE_OUTPUT.value,
                index_within_arg=0)
            # subgraph so far:
            #
            #       dtype_cast_node -> (logger_a_input)? -> subgraph_a_copy
            #                  /
            # (prev_node_c+) -> (logger_c_input)? -> node_c -> logger_c

            # hook up a logger to the mod_a copy
            # Note: we pass node_b.name to this logger, for easy matching later
            env_c[node_a_shadows_c.name] = _insert_logger_after_node(
                env_c[node_a_shadows_c.name],
                gm_b,
                logger_cls,
                '_ns_logger_a_',
                node_a_shadows_c.name,
                name_a,
                ref_name,
                NSSingleResultValuesType.NODE_OUTPUT.value,
                index_within_arg=0)
            # subgraph so far:
            #
            #       dtype_cast_node -> (logger_a_input)? -> subgraph_a_copy -> logger_a
            #                  /
            # (prev_node_c)+ -> (logger_c_input)? -> node_c -> logger_c

        else:
            env_c[node_b.name] = graph_c.node_copy(node_b, load_arg)

    gm_c = GraphModule(gm_b, graph_c)
    return gm_c