Ejemplo n.º 1
0
    def __init__(self):
        self.graph = NNCFGraph()

        self._save_context = None
        self._post_hooks = {}
        self._pre_hooks = {}
        self._num_nested_hooks = 0

        self._thread_local = threading.local()

        self._n_instance = 0
        self._cond = threading.Condition()
        self.is_tracing = True
        self._input_comparators_per_scope = []
 def reset_graph(self):
     self.graph = NNCFGraph()
Ejemplo n.º 3
0
def test_graph_pattern_io_building():
    graph = NNCFGraph()
    #   1
    # /   \
    # 2   |
    # |   |
    # 3   |
    # \   /
    #   4
    # / | \
    # 5 6 7
    # |
    # 8

    #pylint:disable=protected-access
    node_keys = ['1', '2', '3', '4', '5', '6', '7', '8']
    for idx, node_key in enumerate(node_keys):
        attrs = {
            NNCFGraph.ID_NODE_ATTR: idx + 1,
            NNCFGraph.KEY_NODE_ATTR: node_key,
            NNCFGraph.OP_EXEC_CONTEXT_NODE_ATTR: None,
        }
        graph._nx_graph.add_node(node_key, **attrs)

    edge_attr = {NNCFGraph.ACTIVATION_SHAPE_EDGE_ATTR: None}
    graph._nx_graph.add_edges_from([('1', '2'), ('1', '4'), ('2', '3'),
                                    ('3', '4'), ('4', '5'), ('4', '6'),
                                    ('4', '7'), ('5', '8')], **edge_attr)
    graph._node_id_to_key_dict.update(
        {k + 1: v
         for k, v in enumerate(node_keys)})

    def make_mock_edge(from_id: int, to_id: int):
        return NNCFGraphEdge(NNCFNode(from_id, None), NNCFNode(to_id, None),
                             None)

    def make_mock_node(id_: int):
        return NNCFNode(id_, None)

    ref_patterns_and_ios = [
        (['1', '2'],
         NNCFGraphPatternIO(
             input_edges=[],
             input_nodes=[make_mock_node(1)],
             output_edges=[make_mock_edge(2, 3),
                           make_mock_edge(1, 4)],
             output_nodes=[])),
        (['3'],
         NNCFGraphPatternIO(input_edges=[make_mock_edge(2, 3)],
                            input_nodes=[],
                            output_edges=[make_mock_edge(3, 4)],
                            output_nodes=[])),
        (['1', '2', '3'],
         NNCFGraphPatternIO(
             input_edges=[],
             input_nodes=[make_mock_node(1)],
             output_edges=[make_mock_edge(3, 4),
                           make_mock_edge(1, 4)],
             output_nodes=[])),
        (['4'],
         NNCFGraphPatternIO(
             input_edges=[make_mock_edge(3, 4),
                          make_mock_edge(1, 4)],
             input_nodes=[],
             output_edges=[
                 make_mock_edge(4, 5),
                 make_mock_edge(4, 6),
                 make_mock_edge(4, 7)
             ],
             output_nodes=[])),
        (['5', '6', '8'],
         NNCFGraphPatternIO(
             input_edges=[make_mock_edge(4, 5),
                          make_mock_edge(4, 6)],
             input_nodes=[],
             output_edges=[],
             output_nodes=[make_mock_node(6),
                           make_mock_node(8)])),
        (['7'],
         NNCFGraphPatternIO(input_edges=[make_mock_edge(4, 7)],
                            input_nodes=[],
                            output_edges=[],
                            output_nodes=[make_mock_node(7)]))
    ]

    for pattern, ref_pattern_io in ref_patterns_and_ios:
        test_pattern_io = graph._get_nncf_graph_pattern_io_list(pattern)
        assert Counter(test_pattern_io.input_edges) == Counter(
            ref_pattern_io.input_edges)
        assert Counter(test_pattern_io.output_edges) == Counter(
            ref_pattern_io.output_edges)
        assert Counter(test_pattern_io.input_nodes) == Counter(
            ref_pattern_io.input_nodes)
        assert Counter(test_pattern_io.output_nodes) == Counter(
            ref_pattern_io.output_nodes)