def test_alternating_longest(): g = nx.DiGraph() # b c # a d # b add_nodes(g, ['a', 'b', 'c', 'b', 'd']) g.add_edges_from([(1, 2), (2, 3), (3, 5), (1, 4), (4, 5)]) ex = N('a') + (N('b') | N('b') + N('c')) ex2 = N('a') + (N('b') + N('c') | N('b')) matches = search_all(ex, g) matches2 = search_all(ex2, g) assert matches2 == matches == [[1, 2, 3]]
def get_matching_nncf_graph_pattern_io_list( self, expression: Expression) -> List[NNCFGraphPatternIO]: matched_node_key_sequences = search_all(self._nx_graph, expression) pattern_ios = [ self._get_nncf_graph_pattern_io_list(match) for match in matched_node_key_sequences ] return pattern_ios
def test_graph_branching(): g = nx.DiGraph() add_nodes(g, ['a', 'b', 'a', 'c']) g.add_edges_from([(1, 2), (1, 3), (2, 4), (3, 4)]) ex = N('a') + N('b') matches = search_all(ex, g) assert matches == [[1, 2]]
def test_two_matched(): g = nx.DiGraph() add_nodes(g, ['a', 'b', 'c', 'a', 'b', 'c']) g.add_edges_from([(1, 2), (2, 3), (3, 4), (4, 5), (5, 6)]) ex = N('b') + N('c') matches = search_all(ex, g) assert matches == [[2, 3], [5, 6]]
def test_simple(): g = nx.DiGraph() add_nodes(g, ['a', 'b', 'c', 'a']) g.add_edges_from([(1, 2), (2, 3), (3, 4)]) ex = N('b') + N('c') matches = search_all(ex, g) assert matches == [[2, 3]]
def test_alternating(): g = nx.DiGraph() add_nodes(g, ['a', 'b']) g.add_edges_from([(1, 2)]) ex = N('a') + (N('a') | N('b')) matches = search_all(ex, g) assert matches == [[1, 2]]
def test_branching_expression2(): g = nx.DiGraph() # b # a e d # c add_nodes(g, ['a', 'b', 'c', 'd', 'e']) g.add_edges_from([(1, 2), (1, 3), (2, 4), (3, 4), (1, 5), (5, 4)]) c_ = (N('b') & N('c') & N('e')) n = N('a') node_expression = N('d') ex = n + c_ + node_expression matches = search_all(ex, g) assert matches == [[1, 2, 3, 5, 4]]
def get_merged_original_graph_with_patterns(self, original_graph: NNCFGraph): import nncf.dynamic_graph.patterns as p from nncf.dynamic_graph.graph_matching import search_all pattern = p.LINEAR_OPS + p.ANY_BN_ACT_COMBO | p.LINEAR_OPS + p.ELTWISE_UNIFORM_OPS # pylint: disable=protected-access matches = search_all(original_graph._nx_graph, pattern) merged_graph = deepcopy(original_graph._nx_graph) nx.set_node_attributes(merged_graph, False, self.IS_MERGED_GRAPH_ATTR) for match in matches: if len(match) == 1: continue input_node_key = match[0] output_node_key = match[-1] in_edges = list(merged_graph.in_edges(input_node_key)) out_edges = list(merged_graph.out_edges(output_node_key)) in_edge_copies_dict = {} for in_edge_key in in_edges: in_edge_copies_dict[in_edge_key] = deepcopy( merged_graph.edges[in_edge_key]) out_edge_copies_dict = {} for out_edge_key in out_edges: out_edge_copies_dict[out_edge_key] = deepcopy( merged_graph.edges[out_edge_key]) merged_node_key = "" merged_nodes = [] for node_key in match: merged_node_key += node_key + '\n' # pylint: disable=protected-access merged_nodes.append(original_graph._nx_graph.nodes[node_key]) merged_graph.remove_node(node_key) merged_node_attrs = { NNCFGraph.KEY_NODE_ATTR: merged_node_key, self.NODES_GRAPH_ATTR: merged_nodes, self.IS_MERGED_GRAPH_ATTR: True } merged_graph.add_node(merged_node_key, **merged_node_attrs) for in_edge_key, in_edge_attrs in in_edge_copies_dict.items(): merged_graph.add_edge(in_edge_key[0], merged_node_key, **in_edge_attrs) for out_edge_key, out_edge_attrs in out_edge_copies_dict.items(): merged_graph.add_edge(merged_node_key, out_edge_key[1], **out_edge_attrs) return merged_graph
def test_branching_expression(): g = nx.DiGraph() # b # a d # c add_nodes(g, ['a', 'b', 'c', 'd']) g.add_edges_from([(1, 2), (1, 3), (2, 4), (3, 4)]) c_ = (N('b') & N('c')) n = N('a') node_expression = N('d') ex = n + c_ + node_expression ex = N('a') + (N('b') & N('c')) + N('d') matches = search_all(g, ex) assert matches == [[1, 2, 3, 4]]
def get_insertion_point_nodes_after_pattern(self, expression: Expression) -> List[NNCFNode]: matched_node_key_sequences = search_all(self._nx_graph, expression) ip_node_keys = self._find_insertion_points(matched_node_key_sequences) return [NNCFGraph._nx_node_to_nncf_node(self._nx_graph.nodes[key]) for key in ip_node_keys]
def get_ip_graph_with_merged_hw_optimized_operations( self, hw_config: Optional[HWConfig] = None) -> 'InsertionPointGraph': #pylint:disable=too-many-branches merged_ip_graph = deepcopy(self) pattern = self._get_mergeable_operator_patterns(hw_config) from nncf.dynamic_graph.graph_matching import search_all matches = search_all(self._base_nx_graph, pattern) for match in matches: if len(match) == 1: continue input_node_key = match[0] output_node_key = match[-1] # If a subgraph has output edges in its middle, should skip merging it # Example: # (conv2d) # |------\ # (BN) | # | | # (RELU) | # | | # (cat)----/ # | # ... has_breaking_output_edges = self._base_graph_match_has_breaking_output_edges( match) if has_breaking_output_edges: continue in_edges = list(self.in_edges(input_node_key)) out_edges = list(self.out_edges(output_node_key)) assert len( in_edges ) <= 1 # TODO: change to == 1 when input nodes are handled correctly if in_edges: in_edge_key = in_edges[0] in_edge_copy = deepcopy(self.edges[in_edge_key]) out_edge_copies_dict = {} for out_edge_key in out_edges: out_edge_copies_dict[out_edge_key] = deepcopy( self.edges[out_edge_key]) conserved_edges_list = out_edges if in_edges: conserved_edges_list.append(in_edge_key) merged_node_attrs = deepcopy(self.nodes[input_node_key]) merged_node_attrs[ InsertionPointGraph.ASSOCIATED_IP_NODE_KEYS_NODE_ATTR] = set() merged_node_key = "" for node_key in match: ip_node_keys = self.nodes[node_key][ InsertionPointGraph.ASSOCIATED_IP_NODE_KEYS_NODE_ATTR] for ip_node_key in ip_node_keys: should_keep_ip_node = False for edge_key in conserved_edges_list: if ip_node_key in edge_key: should_keep_ip_node = True break if should_keep_ip_node: merged_node_attrs[ InsertionPointGraph. ASSOCIATED_IP_NODE_KEYS_NODE_ATTR].add(ip_node_key) else: merged_ip_graph.remove_node(ip_node_key) merged_ip_graph.remove_node(node_key) merged_node_key += node_key + '\n' merged_ip_graph.add_node(merged_node_key, **merged_node_attrs) if in_edges: merged_ip_graph.add_edge(in_edge_key[0], merged_node_key, **in_edge_copy) for out_edge_key, out_edge_attrs in out_edge_copies_dict.items(): merged_ip_graph.add_edge(merged_node_key, out_edge_key[1], **out_edge_attrs) return merged_ip_graph