コード例 #1
0
ファイル: utils.py プロジェクト: xiaming9880/nncf
def get_sources_of_node(nncf_node: NNCFNode, graph: NNCFGraph, sources_types):
    """
    Source is a node of sourse such that there is path from this node to nx_node and on this path
    no node has one of sources_types type.
    :param sources_types: list of sources types
    :param nncf_node: NNCFNode to get sources
    :param graph: NNCF graph to work with
    :return: list of all sources nodes
    """
    visited = {node_id: False for node_id in graph.get_all_node_idxs()}
    partial_traverse_function = partial(traverse_function, nncf_graph=graph, type_check_fn=lambda x: x in sources_types,
                                        visited=visited)
    nncf_nodes = [nncf_node]
    if nncf_node.op_exec_context.operator_name in sources_types:
        nncf_nodes = graph.get_previous_nodes(nncf_node)

    source_nodes = []
    for node in nncf_nodes:
        source_nodes.extend(graph.traverse_graph(node, partial_traverse_function, False))
    return source_nodes
コード例 #2
0
    def _paint_activation_quantizer_node(
            nncf_graph: NNCFGraph, quantizer_id: NonWeightQuantizerId,
            quantizer_info: 'NonWeightQuantizerInfo',
            bits_color_map: Dict[int, str],
            groups_of_adjacent_quantizers: GroupsOfAdjacentQuantizers):
        #pylint:disable=too-many-branches
        affected_insertion_infos_list = quantizer_info.affected_insertions  # type: List[InsertionInfo]

        for insertion_info in affected_insertion_infos_list:
            input_agnostic_op_exec_context = insertion_info.op_exec_context.input_agnostic
            affected_nncf_node_key = nncf_graph.get_node_key_by_iap_context(
                input_agnostic_op_exec_context)
            affected_nx_node = nncf_graph.get_nx_node_by_key(
                affected_nncf_node_key)
            operator_name = affected_nx_node[
                NNCFGraph.OP_EXEC_CONTEXT_NODE_ATTR].operator_name
            node_id = affected_nx_node[NNCFGraph.ID_NODE_ATTR]

            affected_nncf_node = nncf_graph.get_node_by_id(node_id)
            affected_nx_node['label'] = '_#'.join(
                [operator_name, str(node_id)])

            if insertion_info.is_input:
                # Module UpdateInputs pre-op used for activation quantization
                previous_nodes = nncf_graph.get_previous_nodes(
                    affected_nncf_node)

                # Relying on the _quantize_inputs behaviour of only being able to quantize 0-th input

                # previous_nodes are either UpdateWeights, or UpdateWeights + UpdateInputs
                assert len(previous_nodes) == 2 or len(previous_nodes) == 1

                if len(previous_nodes) == 2:
                    if "UpdateInputs" in str(
                            previous_nodes[0].op_exec_context.input_agnostic):
                        target_node = previous_nodes[0]
                    else:
                        target_node = previous_nodes[1]
                else:
                    target_node = previous_nodes[0]
                target_nncf_node_id = target_node.node_id
                target_nncf_node_key = nncf_graph.get_node_key_by_id(
                    target_nncf_node_id)
            else:
                in_port_id = insertion_info.in_port_id

                if in_port_id is None:
                    # Post-hooking used for activation quantization
                    # Currently only a single post-hook can immediately follow an operation
                    succs = list(
                        nncf_graph.get_successors(affected_nncf_node_key))
                    assert len(succs) == 1
                    target_nncf_node_key = succs[0]
                else:
                    # Pre-hooking used for activation quantization
                    previous_nodes = nncf_graph.get_previous_nodes(
                        affected_nncf_node)
                    target_node = None
                    for prev_node in previous_nodes:
                        prev_edge = nncf_graph.get_nx_edge(
                            prev_node, affected_nncf_node)
                        if prev_edge[NNCFGraph.
                                     IN_PORT_NAME_EDGE_ATTR] == in_port_id:
                            target_node = prev_node
                            break

                    assert target_node is not None, "Could not find a pre-hook quantizer node for a specific " \
                                                    "input port!"
                    target_nncf_node_id = target_node.node_id
                    target_nncf_node_key = nncf_graph.get_node_key_by_id(
                        target_nncf_node_id)

            activation_fq_node = nncf_graph.get_nx_node_by_key(
                target_nncf_node_key)
            bits = quantizer_info.quantizer_module_ref.num_bits
            activation_fq_node['color'] = bits_color_map[bits]
            activation_fq_node['style'] = 'filled'
            node_id = activation_fq_node[NNCFGraph.ID_NODE_ATTR]

            activation_fq_node['label'] = 'AFQ_[{}]_#{}'.format(
                quantizer_info.quantizer_module_ref.get_current_config(),
                str(node_id))
            grouped_mode = bool(groups_of_adjacent_quantizers)
            if grouped_mode:
                group_id_str = 'UNDEFINED'
                group_id = groups_of_adjacent_quantizers.get_group_id_for_quantizer(
                    quantizer_id)
                if node_id is None:
                    nncf_logger.error(
                        'No group for activation quantizer: {}'.format(
                            target_nncf_node_key))
                else:
                    group_id_str = str(group_id)
                activation_fq_node['label'] += "_G" + group_id_str