def test_insertion_point_setup(self): # TODO: Change testing premises when module pre/post-op hooks and input/output nodes # are correctly handled mock_graph = get_two_branch_mock_model_graph() ip_graph = InsertionPointGraph(mock_graph) ref_node_len = 3 * len( mock_graph.nodes) # 2 additional nodes per each operator node ref_edge_len = 3 * len(mock_graph.edges) assert len(ip_graph.nodes) == ref_node_len assert len(ip_graph.edges) == ref_edge_len for node_key, node in mock_graph.nodes.items(): ip_graph_op_node = ip_graph.nodes[node_key] assert ip_graph_op_node[ InsertionPointGraph. NODE_TYPE_NODE_ATTR] == InsertionPointGraphNodeType.OPERATOR preds = list(ip_graph.predecessors(node_key)) succs = list(ip_graph.successors(node_key)) assert len(preds) == 1 assert len(succs) == 1 pre_hook_ip_node_key = preds[0] post_hook_ip_node_key = succs[0] pre_hook_ip_node = ip_graph.nodes[preds[0]] post_hook_ip_node = ip_graph.nodes[succs[0]] pre_hook_ip_node_type = pre_hook_ip_node[ InsertionPointGraph.NODE_TYPE_NODE_ATTR] post_hook_ip_node_type = post_hook_ip_node[ InsertionPointGraph.NODE_TYPE_NODE_ATTR] assert pre_hook_ip_node_type == InsertionPointGraphNodeType.INSERTION_POINT assert post_hook_ip_node_type == InsertionPointGraphNodeType.INSERTION_POINT ref_associated_ip_node_keys_set = { pre_hook_ip_node_key, post_hook_ip_node_key } assert ref_associated_ip_node_keys_set == ip_graph_op_node[ InsertionPointGraph.ASSOCIATED_IP_NODE_KEYS_NODE_ATTR] original_neighbours = mock_graph.neighbors(node_key) for neighbour in original_neighbours: # IP node insertion should not disrupt the graph superstructure ip_graph_paths = list( nx.all_simple_paths(ip_graph, node_key, neighbour)) for path in ip_graph_paths: path = path[1:-1] for path_node_key in path: node = ip_graph.nodes[path_node_key] node_type = node[ InsertionPointGraph.NODE_TYPE_NODE_ATTR] assert node_type == InsertionPointGraphNodeType.INSERTION_POINT for node_key, node in ip_graph.nodes.items(): preds = list(ip_graph.predecessors(node_key)) succs = list(ip_graph.successors(node_key)) assert len(preds) != 0 or len(succs) != 0 for from_node_key, to_node_key in ip_graph.edges.keys(): assert from_node_key in ip_graph.nodes assert to_node_key in ip_graph.nodes
def test_insertion_point_data_in_ip_nodes(self): # TODO: extend for modules mock_graph = nx.DiGraph() ref_op_exec_context = OperationExecutionContext( "baz", Scope.from_str("Test/Scope[foo]/bar"), 0, [None]) node_attrs = {NNCFGraph.OP_EXEC_CONTEXT_NODE_ATTR: ref_op_exec_context} node_key = 0 mock_graph.add_node(node_key, **node_attrs) ip_graph = InsertionPointGraph(mock_graph) for node_key in mock_graph.nodes.keys(): preds = list(ip_graph.predecessors(node_key)) succs = list(ip_graph.successors(node_key)) pre_hook_ip_node = ip_graph.nodes[preds[0]] post_hook_ip_node = ip_graph.nodes[succs[0]] pre_hook_ip = pre_hook_ip_node[ InsertionPointGraph.INSERTION_POINT_DATA_NODE_ATTR] post_hook_ip = post_hook_ip_node[ InsertionPointGraph.INSERTION_POINT_DATA_NODE_ATTR] assert pre_hook_ip.insertion_type == InsertionType.OPERATOR_PRE_HOOK assert post_hook_ip.insertion_type == InsertionType.OPERATOR_POST_HOOK assert pre_hook_ip.ia_op_exec_context == ref_op_exec_context.input_agnostic assert post_hook_ip.ia_op_exec_context == ref_op_exec_context.input_agnostic