Exemplo n.º 1
0
    def test_insertion_point_setup(self):
        # TODO: Change testing premises when module pre/post-op hooks and input/output nodes
        # are correctly handled
        mock_graph = get_two_branch_mock_model_graph()

        ip_graph = InsertionPointGraph(mock_graph)

        ref_node_len = 3 * len(
            mock_graph.nodes)  # 2 additional nodes per each operator node
        ref_edge_len = 3 * len(mock_graph.edges)

        assert len(ip_graph.nodes) == ref_node_len
        assert len(ip_graph.edges) == ref_edge_len

        for node_key, node in mock_graph.nodes.items():
            ip_graph_op_node = ip_graph.nodes[node_key]
            assert ip_graph_op_node[
                InsertionPointGraph.
                NODE_TYPE_NODE_ATTR] == InsertionPointGraphNodeType.OPERATOR
            preds = list(ip_graph.predecessors(node_key))
            succs = list(ip_graph.successors(node_key))
            assert len(preds) == 1
            assert len(succs) == 1
            pre_hook_ip_node_key = preds[0]
            post_hook_ip_node_key = succs[0]
            pre_hook_ip_node = ip_graph.nodes[preds[0]]
            post_hook_ip_node = ip_graph.nodes[succs[0]]
            pre_hook_ip_node_type = pre_hook_ip_node[
                InsertionPointGraph.NODE_TYPE_NODE_ATTR]
            post_hook_ip_node_type = post_hook_ip_node[
                InsertionPointGraph.NODE_TYPE_NODE_ATTR]
            assert pre_hook_ip_node_type == InsertionPointGraphNodeType.INSERTION_POINT
            assert post_hook_ip_node_type == InsertionPointGraphNodeType.INSERTION_POINT
            ref_associated_ip_node_keys_set = {
                pre_hook_ip_node_key, post_hook_ip_node_key
            }
            assert ref_associated_ip_node_keys_set == ip_graph_op_node[
                InsertionPointGraph.ASSOCIATED_IP_NODE_KEYS_NODE_ATTR]
            original_neighbours = mock_graph.neighbors(node_key)
            for neighbour in original_neighbours:
                # IP node insertion should not disrupt the graph superstructure
                ip_graph_paths = list(
                    nx.all_simple_paths(ip_graph, node_key, neighbour))
                for path in ip_graph_paths:
                    path = path[1:-1]
                    for path_node_key in path:
                        node = ip_graph.nodes[path_node_key]
                        node_type = node[
                            InsertionPointGraph.NODE_TYPE_NODE_ATTR]
                        assert node_type == InsertionPointGraphNodeType.INSERTION_POINT

        for node_key, node in ip_graph.nodes.items():
            preds = list(ip_graph.predecessors(node_key))
            succs = list(ip_graph.successors(node_key))
            assert len(preds) != 0 or len(succs) != 0

        for from_node_key, to_node_key in ip_graph.edges.keys():
            assert from_node_key in ip_graph.nodes
            assert to_node_key in ip_graph.nodes
Exemplo n.º 2
0
    def test_insertion_point_data_in_ip_nodes(self):
        # TODO: extend for modules
        mock_graph = nx.DiGraph()
        ref_op_exec_context = OperationExecutionContext(
            "baz", Scope.from_str("Test/Scope[foo]/bar"), 0, [None])
        node_attrs = {NNCFGraph.OP_EXEC_CONTEXT_NODE_ATTR: ref_op_exec_context}

        node_key = 0
        mock_graph.add_node(node_key, **node_attrs)

        ip_graph = InsertionPointGraph(mock_graph)

        for node_key in mock_graph.nodes.keys():
            preds = list(ip_graph.predecessors(node_key))
            succs = list(ip_graph.successors(node_key))
            pre_hook_ip_node = ip_graph.nodes[preds[0]]
            post_hook_ip_node = ip_graph.nodes[succs[0]]

            pre_hook_ip = pre_hook_ip_node[
                InsertionPointGraph.INSERTION_POINT_DATA_NODE_ATTR]
            post_hook_ip = post_hook_ip_node[
                InsertionPointGraph.INSERTION_POINT_DATA_NODE_ATTR]
            assert pre_hook_ip.insertion_type == InsertionType.OPERATOR_PRE_HOOK
            assert post_hook_ip.insertion_type == InsertionType.OPERATOR_POST_HOOK

            assert pre_hook_ip.ia_op_exec_context == ref_op_exec_context.input_agnostic
            assert post_hook_ip.ia_op_exec_context == ref_op_exec_context.input_agnostic