def get_sources_of_node(nncf_node: NNCFNode, graph: NNCFGraph, sources_types): """ Source is a node of sourse such that there is path from this node to nx_node and on this path no node has one of sources_types type. :param sources_types: list of sources types :param nncf_node: NNCFNode to get sources :param graph: NNCF graph to work with :return: list of all sources nodes """ visited = {node_id: False for node_id in graph.get_all_node_idxs()} partial_traverse_function = partial(traverse_function, nncf_graph=graph, type_check_fn=lambda x: x in sources_types, visited=visited) nncf_nodes = [nncf_node] if nncf_node.op_exec_context.operator_name in sources_types: nncf_nodes = graph.get_previous_nodes(nncf_node) source_nodes = [] for node in nncf_nodes: source_nodes.extend(graph.traverse_graph(node, partial_traverse_function, False)) return source_nodes
def _paint_activation_quantizer_node( nncf_graph: NNCFGraph, quantizer_id: NonWeightQuantizerId, quantizer_info: 'NonWeightQuantizerInfo', bits_color_map: Dict[int, str], groups_of_adjacent_quantizers: GroupsOfAdjacentQuantizers): #pylint:disable=too-many-branches affected_insertion_infos_list = quantizer_info.affected_insertions # type: List[InsertionInfo] for insertion_info in affected_insertion_infos_list: input_agnostic_op_exec_context = insertion_info.op_exec_context.input_agnostic affected_nncf_node_key = nncf_graph.get_node_key_by_iap_context( input_agnostic_op_exec_context) affected_nx_node = nncf_graph.get_nx_node_by_key( affected_nncf_node_key) operator_name = affected_nx_node[ NNCFGraph.OP_EXEC_CONTEXT_NODE_ATTR].operator_name node_id = affected_nx_node[NNCFGraph.ID_NODE_ATTR] affected_nncf_node = nncf_graph.get_node_by_id(node_id) affected_nx_node['label'] = '_#'.join( [operator_name, str(node_id)]) if insertion_info.is_input: # Module UpdateInputs pre-op used for activation quantization previous_nodes = nncf_graph.get_previous_nodes( affected_nncf_node) # Relying on the _quantize_inputs behaviour of only being able to quantize 0-th input # previous_nodes are either UpdateWeights, or UpdateWeights + UpdateInputs assert len(previous_nodes) == 2 or len(previous_nodes) == 1 if len(previous_nodes) == 2: if "UpdateInputs" in str( previous_nodes[0].op_exec_context.input_agnostic): target_node = previous_nodes[0] else: target_node = previous_nodes[1] else: target_node = previous_nodes[0] target_nncf_node_id = target_node.node_id target_nncf_node_key = nncf_graph.get_node_key_by_id( target_nncf_node_id) else: in_port_id = insertion_info.in_port_id if in_port_id is None: # Post-hooking used for activation quantization # Currently only a single post-hook can immediately follow an operation succs = list( nncf_graph.get_successors(affected_nncf_node_key)) assert len(succs) == 1 target_nncf_node_key = succs[0] else: # Pre-hooking used for activation quantization previous_nodes = nncf_graph.get_previous_nodes( affected_nncf_node) target_node = None for prev_node in previous_nodes: prev_edge = nncf_graph.get_nx_edge( prev_node, affected_nncf_node) if prev_edge[NNCFGraph. IN_PORT_NAME_EDGE_ATTR] == in_port_id: target_node = prev_node break assert target_node is not None, "Could not find a pre-hook quantizer node for a specific " \ "input port!" target_nncf_node_id = target_node.node_id target_nncf_node_key = nncf_graph.get_node_key_by_id( target_nncf_node_id) activation_fq_node = nncf_graph.get_nx_node_by_key( target_nncf_node_key) bits = quantizer_info.quantizer_module_ref.num_bits activation_fq_node['color'] = bits_color_map[bits] activation_fq_node['style'] = 'filled' node_id = activation_fq_node[NNCFGraph.ID_NODE_ATTR] activation_fq_node['label'] = 'AFQ_[{}]_#{}'.format( quantizer_info.quantizer_module_ref.get_current_config(), str(node_id)) grouped_mode = bool(groups_of_adjacent_quantizers) if grouped_mode: group_id_str = 'UNDEFINED' group_id = groups_of_adjacent_quantizers.get_group_id_for_quantizer( quantizer_id) if node_id is None: nncf_logger.error( 'No group for activation quantizer: {}'.format( target_nncf_node_key)) else: group_id_str = str(group_id) activation_fq_node['label'] += "_G" + group_id_str