Exemplo n.º 1
0
    def test_send_op_names_info(self):
        graph = Graph()
        graph.add_nodes_from(['node1'])
        graph.op_names_statistic = Counter(['a', 'a', 'a', 'b', 'b'])

        sub_graph1 = Graph()
        sub_graph1.add_nodes_from(['node2'])
        sub_graph1.op_names_statistic = Counter(['a', 'c', 'c'])

        sub_graph2 = Graph()
        sub_graph2.op_names_statistic = Counter(['a', 'd'])

        node1 = Node(graph, 'node1')
        node1['sub_graphs'] = ['sub_graph1']
        node1['sub_graph1'] = sub_graph1

        node2 = Node(sub_graph1, 'node2')
        node2['sub_graphs'] = ['sub_graph2']
        node2['sub_graph2'] = sub_graph2

        self.init_telemetry_mocks()

        send_op_names_info('framework', graph)
        tm.Telemetry.send_event.assert_any_call('mo', 'op_count',
                                                'framework_a', 5)
        tm.Telemetry.send_event.assert_any_call('mo', 'op_count',
                                                'framework_b', 2)
        tm.Telemetry.send_event.assert_any_call('mo', 'op_count',
                                                'framework_c', 2)
        tm.Telemetry.send_event.assert_any_call('mo', 'op_count',
                                                'framework_d', 1)
Exemplo n.º 2
0
 def test_sub_graph_between_nodes_branches_included(self):
     """
     Check that the function works correctly for tree like structures.
     1 -> 2 -> 3 -> 4
          \
          5 -> 6
         / \
     9 ->   -> 7 -> 8
     """
     graph = Graph()
     node_names = list(range(1, 10))
     graph.add_nodes_from(node_names)
     graph.add_edges_from([(1, 2), (2, 3), (3, 4), (2, 5), (5, 6), (5, 7),
                           (7, 8), (9, 5)])
     self.assertListEqual(sorted(sub_graph_between_nodes(graph, [1], [4])),
                          node_names)
     self.assertListEqual(sorted(sub_graph_between_nodes(graph, [1], [6])),
                          node_names)
     self.assertListEqual(sorted(sub_graph_between_nodes(graph, [1], [8])),
                          node_names)
     # all nodes except 4 because it is a child of end node
     self.assertListEqual(sorted(sub_graph_between_nodes(graph, [1], [3])),
                          [n for n in node_names if n != 4])
     # all nodes except 1 because it is a parent node child of start node. The nodes 3 and 4 must be added because
     # after merging node 2 into sub-graph the node 2 will be removed and it is not known how to calculate the tensor
     # between node 2 and 3.
     self.assertListEqual(sorted(sub_graph_between_nodes(graph, [2], [8])),
                          [n for n in node_names if n != 1])
Exemplo n.º 3
0
    def test_simple_dfs(self):
        graph = Graph()
        graph.add_nodes_from(list(range(1, 5)))
        graph.add_edges_from([(1, 2), (1, 3), (3, 4)])

        visited = set()
        order = graph.dfs(1, visited)
        self.assertTrue(order == [4, 3, 2, 1] or order == [2, 4, 3, 1])
Exemplo n.º 4
0
 def test_is_connected_component_connected(self):
     """
     Check that if the sub-graph is connected.
     """
     graph = Graph()
     node_names = list(range(1, 8))
     graph.add_nodes_from(node_names)
     graph.add_edges_from([(1, 2), (2, 3), (4, 5), (5, 6), (1, 7), (7, 4)])
     self.assertTrue(is_connected_component(graph, list(range(1, 8))))
Exemplo n.º 5
0
    def test_bfs_search_default_start_nodes(self):
        """
        Check that BFS automatically determines input nodes and start searching from them.
        """
        graph = Graph()
        graph.add_nodes_from(list(range(1, 6)))
        graph.add_edges_from([(1, 3), (2, 3), (3, 4), (4, 5)])

        order = bfs_search(graph)
        self.assertTrue(order == [1, 2, 3, 4, 5] or order == [2, 1, 3, 4, 5])
Exemplo n.º 6
0
    def test_bfs_search_specific_start_nodes(self):
        """
        Check that BFS stars from the user defined nodes and doesn't go in backward edge direction.
        """
        graph = Graph()
        graph.add_nodes_from(list(range(1, 7)))
        graph.add_edges_from([(1, 3), (2, 3), (3, 4), (4, 5), (6, 1)])

        order = bfs_search(graph, [1])
        self.assertTrue(order == [1, 3, 4, 5])
Exemplo n.º 7
0
 def test_is_connected_component_two_separate_sub_graphs(self):
     """
     Check that if there are two separate sub-graphs the function returns False.
     """
     graph = Graph()
     graph.add_nodes_from(list(range(1, 7)))
     graph.add_edges_from([(1, 2), (2, 3), (4, 5), (5, 6)])
     self.assertFalse(is_connected_component(graph, list(range(1, 7))))
     self.assertFalse(is_connected_component(graph, [1, 3]))
     self.assertFalse(is_connected_component(graph, [6, 4]))
     self.assertFalse(is_connected_component(graph, [2, 5]))
Exemplo n.º 8
0
 def test_is_connected_component_two_separate_sub_graphs_divided_by_ignored_node(
         self):
     """
     Check that if there are two separate sub-graphs the function connected by an edge going through the ignored node
     then the function returns False.
     """
     graph = Graph()
     node_names = list(range(1, 8))
     graph.add_nodes_from(node_names)
     graph.add_edges_from([(1, 2), (2, 3), (4, 5), (5, 6), (1, 7), (7, 4)])
     self.assertFalse(is_connected_component(graph, list(range(1, 7))))
Exemplo n.º 9
0
 def test_is_connected_component_edges_direction_is_ignored(self):
     """
     Check that edges direction is ignored when checking for the connectivity.
     """
     graph = Graph()
     node_names = list(range(1, 5))
     graph.add_nodes_from(node_names)
     graph.add_edges_from([(2, 1), (2, 3), (4, 3)])
     self.assertTrue(is_connected_component(graph, node_names))
     self.assertTrue(is_connected_component(graph, [2, 1]))
     self.assertTrue(is_connected_component(graph, [4, 2, 3]))
Exemplo n.º 10
0
 def test_sub_graph_between_nodes_placeholder_included(self):
     """
     Check that the function doesn't allow to add Placeholders to the sub-graph. 5 is the Placeholder op.
       5->
          \
     1 -> 2 -> 3 -> 4
     """
     graph = Graph()
     graph.add_nodes_from(list(range(1, 6)))
     graph.node[5]['op'] = 'Parameter'
     graph.add_edges_from([(1, 2), (2, 3), (3, 4), (5, 2)])
     self.assertRaises(Error, sub_graph_between_nodes, graph, [1], [4])
Exemplo n.º 11
0
def build_matcher(graph: Graph, nodes: list, edges: list, node_attrs: list = None,
                  edge_attrs: list = None):
    if node_attrs is not None or edge_attrs is not None:
        log.warning('\'edge_attrs\' or `\'node_attrs\'` parameter was passed to function \'find_pattern_matches\', '
                    'but they are not used anymore. Pattern matching proceeds according to \'nodes\' and \'edges\' '
                    'parameters. Please avoid passing \'edge_attrs\' and \'node_attrs\' parameters to any pattern '
                    'matching function like \'find_pattern_matches\', \'apply_pattern\' and \'pattern\' because it '
                    'will be deprecated in the next release.')

    subgraph = Graph(name='pattern')
    subgraph.add_nodes_from(nodes)
    subgraph.add_edges_from(edges)
    return ism.MultiDiGraphMatcher(graph, subgraph, node_match, edge_match)
Exemplo n.º 12
0
 def test_sub_graph_between_nodes_multiple_inputs(self):
     """
     Check that the function works correctly when multiple inputs specified.
       5->
          \
     1 -> 2 -> 3 -> 4
     """
     graph = Graph()
     graph.add_nodes_from(list(range(1, 6)))
     graph.add_edges_from([(1, 2), (2, 3), (3, 4), (5, 2)])
     sub_graph_nodes = sub_graph_between_nodes(graph, [2, 5], [4])
     self.assertIsNotNone(sub_graph_nodes)
     self.assertListEqual(sorted(sub_graph_nodes), sorted([2, 3, 4, 5]))
Exemplo n.º 13
0
 def test_is_connected_component_edges_direction_is_ignored_not_connected(
         self):
     """
     Check that edges direction is ignored when checking for the connectivity. In this case the graph is not
     connected.
     """
     graph = Graph()
     graph.add_nodes_from(list(range(1, 5)))
     graph.add_edges_from([(2, 1), (2, 3), (4, 3)])
     self.assertFalse(is_connected_component(graph, [1, 2, 4]))
     self.assertFalse(is_connected_component(graph, [1, 4]))
     self.assertFalse(is_connected_component(graph, [2, 4]))
     self.assertFalse(is_connected_component(graph, [3, 4, 1]))
Exemplo n.º 14
0
 def test_sub_graph_between_nodes_do_not_include_incoming_edges_for_input_nodes(
         self):
     """
     Check that the function doesn't add input nodes for the start nodes of the sub-graph. For example, we do not
     need to add node 5 in the case below if we find match from node 1 till node 4.
       5->
          \
     1 -> 2 -> 3 -> 4
     """
     graph = Graph()
     graph.add_nodes_from(list(range(1, 6)))
     graph.add_edges_from([(1, 2), (2, 3), (3, 4), (5, 2)])
     sub_graph_nodes = sub_graph_between_nodes(graph, [2], [4])
     self.assertIsNotNone(sub_graph_nodes)
     self.assertListEqual(sorted(sub_graph_nodes), [2, 3, 4])
Exemplo n.º 15
0
 def test_sub_graph_between_nodes_placeholder_excluded(self):
     """
     Check that the function do not check that node is Placeholders for the nodes not included into the sub-graph.
     For example, node 5 is Placeholder but it is not included into the sub-graph, so this attribute is ignored.
       5->
          \
     1 -> 2 -> 3 -> 4
     """
     graph = Graph()
     graph.add_nodes_from(list(range(1, 6)))
     graph.node[5]['op'] = 'Parameter'
     graph.add_edges_from([(1, 2), (2, 3), (3, 4), (5, 2)])
     sub_graph_nodes = sub_graph_between_nodes(graph, [2], [4])
     self.assertIsNotNone(sub_graph_nodes)
     self.assertListEqual(sorted(sub_graph_nodes), [2, 3, 4])
Exemplo n.º 16
0
 def test_sub_graph_between_nodes_control_flow_not_included_forward(self):
     """
     Check that the function works correctly for case when control flow edges should not be traversed (edge 3 -> 5).
        1 -> 2 -> 3 -> 4
                   \
                    -> 5 -> 6
     """
     graph = Graph()
     graph.add_nodes_from(list(range(1, 7)))
     graph.add_edges_from([(1, 2), (2, 3), (3, 4),
                           (3, 5, {
                               'control_flow_edge': True
                           }), (5, 6)])
     sub_graph_nodes = sub_graph_between_nodes(graph, [1], [4],
                                               include_control_flow=False)
     self.assertIsNotNone(sub_graph_nodes)
     self.assertListEqual(sorted(sub_graph_nodes), sorted([1, 2, 3, 4]))
Exemplo n.º 17
0
 def test_sub_graph_between_nodes_control_flow_included(self):
     """
     Check that the function works correctly for case when control flow edges must be traversed (edge 5 -> 2).
     6 -> 5->
             \
        1 -> 2 -> 3 -> 4
     """
     graph = Graph()
     graph.add_nodes_from(list(range(1, 7)))
     graph.add_edges_from([(1, 2), (2, 3), (3, 4),
                           (5, 2, {
                               'control_flow_edge': True
                           }), (6, 5)])
     sub_graph_nodes = sub_graph_between_nodes(graph, [1], [4],
                                               include_control_flow=True)
     self.assertIsNotNone(sub_graph_nodes)
     self.assertListEqual(sorted(sub_graph_nodes),
                          sorted([1, 2, 3, 4, 5, 6]))
Exemplo n.º 18
0
    def test_sub_graph_between_nodes_include_incoming_edges_for_internal_nodes(
            self):
        """
        Check that the function adds input nodes for the internal nodes of the graph. For example, we need to add node 5
        and 6 in the case below if we find match from node 1 till node 4.
        6 -> 5 ->
                 \
            1 -> 2 -> 3 -> 4
        :return:
        """
        graph = Graph()
        graph.add_nodes_from(list(range(1, 7)))
        graph.add_edges_from([(1, 2), (2, 3), (3, 4), (5, 2), (6, 5)])
        sub_graph_nodes = sub_graph_between_nodes(graph, [1], [4])
        self.assertIsNotNone(sub_graph_nodes)
        self.assertListEqual(sorted(sub_graph_nodes), list(range(1, 7)))

        sub_graph_nodes = sub_graph_between_nodes(graph, [1], [2])
        self.assertIsNotNone(sub_graph_nodes)
        self.assertListEqual(sorted(sub_graph_nodes), [1, 2, 5, 6])
Exemplo n.º 19
0
    def replace_pattern(graph, match: dict):
        # Here we will found all parts of TI: condition, inputs/outputs, back edges, body and create TensorIterator Op
        # and make all checks needed for TensorIterator work
        cond_data = match['condition'].out_node(
            0) if not match['condition'].out_port(0).disconnected() else None
        time_data = match['condition'].out_node(1) if len(
            match['condition'].out_nodes()) >= 1 else None
        name = match['condition'].name

        back_edges = []
        inputs = []
        outputs = []

        if cond_data is not None:
            for node in cond_data.out_nodes():
                if node['kind'] == 'op' and node[
                        'op'] == 'TensorIteratorBackEdge':
                    back_edges.append(node.id)
                elif node['kind'] == 'op' and node[
                        'op'] == 'TensorIteratorInput':
                    inputs.append(node.id)
                elif node['kind'] == 'op' and node[
                        'op'] == 'TensorIteratorOutput':
                    outputs.append(node.id)

        if time_data is not None:
            for node in time_data.out_nodes():
                if node['kind'] == 'op' and node['op'] == 'TensorIteratorInput':
                    inputs.append(node.id)
                elif node['kind'] == 'op' and node[
                        'op'] == 'TensorIteratorOutput':
                    outputs.append(node.id)
                else:
                    # something goes wrong here
                    assert False
        condition = match['condition']
        tensor_sequence_length = condition.in_node(0)

        nodes_to_remove = [
            n.id
            for n in (condition, cond_data, time_data, tensor_sequence_length)
            if n is not None
        ]
        graph.remove_nodes_from(nodes_to_remove)

        body_nodes, extra_inputs = get_body(graph, inputs, outputs)

        if cond_data is not None:
            body_nodes = list(set(body_nodes) - set([cond_data]))

        inputs += extra_inputs

        assert all([node in graph.nodes() for node in body_nodes])

        inputs = [Node(graph, node) for node in inputs]
        outputs = [Node(graph, node) for node in outputs]
        back_edges = [Node(graph, node) for node in back_edges]

        external_inputs = [{
            'external_data_id':
            node.in_node(1 if node.has_valid('axis') else 0),
            'internal_data_id':
            node.out_node(0),
            'axis':
            node.axis,
            'start':
            node.start,
            'end':
            node.end,
            'stride':
            node.stride,
            'part_size':
            node.part_size
        } for node in inputs]

        external_outputs = [{
            'external_data_id':
            node.out_node(0),
            'internal_data_id':
            node.in_node(1 if node.has_valid('axis') else 0),
            'axis':
            node.axis,
            'start':
            node.start,
            'end':
            node.end,
            'stride':
            node.stride,
            'part_size':
            node.part_size
        } for node in outputs]

        back_edges_data = [{
            'from_data_id': node.in_node(1),
            'to_data_id': node.out_node(0),
            'init_data_id': node.in_node(0),
        } for node in back_edges]

        body = Graph(name='body')
        body.graph = graph.graph
        body.add_nodes_from([(node, graph.node[node]) for node in body_nodes])
        body.add_edges_from([
            (u, v, k, d) for u, v, k, d in graph.edges(data=True, keys=True)
            if u in body_nodes and v in body_nodes
        ])

        graph.remove_nodes_from(body_nodes + [match['condition'].id] +
                                [inp.id for inp in inputs] +
                                [out.id for out in outputs])
        internal_id_count = 0
        real_back_edges = []
        for edge in back_edges_data:
            assert edge['from_data_id'].id in body.nodes()
            assert edge['to_data_id'].id in body.nodes()
            assert edge['init_data_id'].id in body.nodes()
            edge['from_data_id'] = Node(body, edge['from_data_id'].id)
            edge['to_data_id'] = Node(body, edge['to_data_id'].id)
            edge['init_data_id'] = Node(body, edge['init_data_id'].id)
            add_opoutput(body, edge['from_data_id'].id, 0, False)

            # Assign/reuse ids for the back-edge start; it comes from from_data_id
            assert len(edge['from_data_id'].in_nodes()) == 1
            # layer id
            if not edge['from_data_id'].in_node().has_valid(
                    'internal_layer_id'):
                edge['from_data_id'].in_node(
                )['internal_layer_id'] = internal_id_count
                internal_id_count += 1
            edge['from_layer'] = edge['from_data_id'].in_node(
            )['internal_layer_id']

            # port id
            if 'internal_port_id' not in edge['from_data_id'].in_edge():
                edge['from_data_id'].in_edge(
                )['internal_port_id'] = internal_id_count
                internal_id_count += 1
            edge['from_port'] = edge['from_data_id'].in_edge(
            )['internal_port_id']

            # Look at all consumers for a data that ends a back-edge
            # For each such consumer, there will be a separate back-edge (and input)
            current_real_back_edges = []
            for _, consumer, key, edge_attrs in body.out_edges(
                    edge['to_data_id'].id, data=True, keys=True):

                real_edge = {}
                real_edge.update(
                    edge)  # all real back_edges have the same back-edge start

                consumer = Node(body, consumer)

                if real_edge['to_data_id'].in_node().has_valid(
                        'internal_layer_id'):
                    assert False
                    real_edge['to_data_id'].out_node()['internal_layer_id'] = \
                        real_edge['to_data_id'].in_node().internal_layer_id
                elif not consumer.has_valid('internal_layer_id'):
                    consumer['internal_layer_id'] = internal_id_count
                    internal_id_count += 1
                real_edge['to_layer'] = consumer['internal_layer_id']

                assert 'internal_port_id' not in edge_attrs
                assert len(real_edge['init_data_id'].out_edges()) == 1
                assert not 'internal_port_id' in real_edge[
                    'init_data_id'].out_edge()
                edge_attrs['internal_port_id'] = internal_id_count
                internal_id_count += 1
                real_edge['to_port'] = edge_attrs['internal_port_id']
                real_edge['consumer'] = consumer
                real_edge['consumer_key'] = key

                real_edge['attrs'] = deepcopy(edge_attrs)
                current_real_back_edges.append(real_edge)

            # connect initial data node with each consumer providing actual edge attributes
            body.add_edges_from([
                (real_edge['init_data_id'].id, real_edge['consumer'].id,
                 real_edge['consumer_key'], real_edge['attrs'])
                for real_edge in current_real_back_edges
            ])

            body.remove_nodes_from(
                [edge['to_data_id'].id, edge['to_data_id'].in_node().id])
            real_back_edges += current_real_back_edges

        real_external_inputs = []

        for ext_inp in external_inputs:
            assert ext_inp['external_data_id'].id not in body.nodes()
            assert ext_inp['internal_data_id'].id in body.nodes()
            ext_inp['internal_data_id'] = Node(body,
                                               ext_inp['internal_data_id'].id)

            if ext_inp['axis'] is not None:
                # Insert squeezing resize at input port that has partitioning
                shape = ext_inp['internal_data_id'].shape.copy()
                assert not ext_inp['internal_data_id'].has_valid('value')
                new_input_data = Op._create_data_node(
                    body,
                    ext_inp['internal_data_id'].name + '/UnsqueezedInput',
                    dict(shape=shape_insert(shape, ext_inp['axis'], 1)))

                reshape_op = Squeeze(
                    body,
                    dict(name=ext_inp['internal_data_id'].name +
                         '/InputSqueeze'))
                reshape_dim_data = Const(
                    body, {
                        'name':
                        ext_inp['internal_data_id'].name + '/ReshapeDim',
                        'value': ext_inp['axis']
                    }).create_node_with_data()
                reshape_op.create_node_with_data(
                    [new_input_data, reshape_dim_data],
                    data_nodes=[ext_inp['internal_data_id']])
                ext_inp['internal_data_id'] = new_input_data

            ext_inp['internal_data_id']['is_input'] = True
            assert len(ext_inp['internal_data_id'].in_nodes()) == 0
            ext_inp['external_port_id'] = internal_id_count
            internal_id_count += 1
            for _, consumer, edge_attrs in body.out_edges(
                    ext_inp['internal_data_id'].id, data=True):
                real_ext_inp = {}
                real_ext_inp.update(ext_inp)
                consumer = Node(body, consumer)
                if not consumer.has_valid('internal_layer_id'):
                    consumer['internal_layer_id'] = internal_id_count
                    internal_id_count += 1
                if not 'internal_port_id' in edge_attrs:
                    edge_attrs['internal_port_id'] = internal_id_count
                    internal_id_count += 1
                real_ext_inp['internal_layer_id'] = consumer[
                    'internal_layer_id']
                real_ext_inp['internal_port_id'] = edge_attrs[
                    'internal_port_id']
                real_external_inputs.append(real_ext_inp)

        for ext_out in external_outputs:
            assert ext_out['external_data_id'].id not in body.nodes()
            assert ext_out['internal_data_id'].id in body.nodes()
            ext_out['internal_data_id'] = Node(body,
                                               ext_out['internal_data_id'].id)

            if ext_out['axis'] is not None:
                # Insert unsqueezing resize at output port that has partitioning
                reshape_op = Unsqueeze(
                    body,
                    dict(name=ext_out['internal_data_id'].name +
                         '/OutputUnsqueeze'))
                reshape_dim_data = Const(
                    body, {
                        'name':
                        ext_out['internal_data_id'].name + '/ReshapeDim',
                        'value': ext_out['axis']
                    }).create_node_with_data()
                ext_out['internal_data_id'] = reshape_op.create_node_with_data(
                    [ext_out['internal_data_id'], reshape_dim_data])

            # TODO: add here working with simple outputs

            if not any([
                    out_node.soft_get('op', None) == 'Result'
                    for out_node in ext_out['internal_data_id'].out_nodes()
            ]):
                add_opoutput(body, ext_out['internal_data_id'].id, 0, False)

            # assert len(ext_out['internal_data_id'].out_nodes()) == 0
            assert len(ext_out['internal_data_id'].in_nodes()) == 1
            if not 'internal_layer_id' in ext_out['internal_data_id'].in_node(
            ):
                ext_out['internal_data_id'].in_node(
                )['internal_layer_id'] = internal_id_count
                internal_id_count += 1
            if not 'internal_port_id' in ext_out['internal_data_id'].in_edge():
                ext_out['internal_data_id'].in_edge(
                )['internal_port_id'] = internal_id_count
                internal_id_count += 1
            ext_out['internal_layer_id'] = ext_out['internal_data_id'].in_node(
            )['internal_layer_id']
            ext_out['internal_port_id'] = ext_out['internal_data_id'].in_edge(
            )['internal_port_id']
            ext_out['external_port_id'] = internal_id_count
            internal_id_count += 1

        # create TensorIterator layer with pre-computed components
        ti_op = TensorIterator(
            graph, {
                'name':
                name + '/TensorIterator',
                'body':
                body,
                'in_ports_count':
                len(external_inputs),
                'out_ports_count':
                len(external_outputs),
                'input_port_map': [{
                    field: external_input[field]
                    for field in [
                        'external_port_id', 'internal_layer_id',
                        'internal_port_id', 'axis', 'stride', 'part_size',
                        'start', 'end'
                    ]
                } for external_input in real_external_inputs],
                'output_port_map': [{
                    field: external_output[field]
                    for field in [
                        'external_port_id', 'internal_layer_id',
                        'internal_port_id', 'axis', 'stride', 'part_size',
                        'start', 'end'
                    ]
                } for external_output in external_outputs],
                'back_edges': [{
                    field: edge[field]
                    for field in
                    ['from_layer', 'from_port', 'to_layer', 'to_port']
                } for edge in real_back_edges],
            })

        ti_outs = ti_op.create_node_with_data(
            inputs=[inp['external_data_id'] for inp in external_inputs],
            edge_attrs=[{
                'external_port_id': inp['external_port_id']
            } for inp in external_inputs],
            data_nodes=[out['external_data_id'] for out in external_outputs])

        if not isinstance(ti_outs, list):
            ti_outs = [ti_outs]

        for i, out in enumerate(ti_outs):
            out.in_edge(
            )['external_port_id'] = external_outputs[i]['external_port_id']

        ti = ti_outs[0].in_node()
        TensorIterator.cover_body_input_data_nodes_with_parameter_ops(ti)
        TensorIterator.cover_body_constant_data_nodes_with_const_ops(ti)
        TensorIterator.normalize_internal_ids(ti)
Exemplo n.º 20
0
def load_parallel_component(file_descr, graph: Graph, prev_layer_id):
    """
    Load ParallelComponent of the Kaldi model.
    ParallelComponent contains parallel nested networks.
    VariadicSplit is inserted before nested networks.
    Outputs of nested networks concatenate with layer Concat.

    :param file_descr: descriptor of the model file
    :param graph: graph with the topology.
    :param prev_layer_id: id of the input layers for parallel component layer
    :return: id of the concat layer - last layer of the parallel component layers
    """
    nnet_count = read_token_value(file_descr, b'<NestedNnetCount>')
    log.debug(
        'Model contains parallel component with {} nested networks'.format(
            nnet_count))

    split_points = []
    outputs = []
    inputs = []

    for i in range(nnet_count):
        read_token_value(file_descr, b'<NestedNnet>')
        collect_until_token(file_descr, b'<Nnet>')
        g = Graph()
        load_kalid_nnet1_model(g, file_descr, 'Nested_net_{}'.format(i))

        # input to nnet1 models is of a rank 1 but we also insert batch_size to 0th axis
        # 1st axis contains input_size of the nested subnetwork
        # we split input from the main network to subnetworks
        input_node = Node(g, 'Parameter')
        split_points.append(input_node['shape'][1])
        g.remove_node(input_node.id)

        mapping = {
            node: graph.unique_id(node)
            for node in g.nodes(data=False) if node in graph
        }
        g = nx.relabel_nodes(g, mapping)
        for val in mapping.values():
            g.node[val]['name'] = val
        graph.add_nodes_from(g.nodes(data=True))
        graph.add_edges_from(g.edges(data=True))
        sorted_nodes = tuple(nx.topological_sort(g))

        outputs.append(Node(graph, sorted_nodes[-1]))
        inputs.append(Node(graph, sorted_nodes[0]))

    split_id = graph.unique_id(prefix='NestedNets/VariadicSplit')
    attrs = {
        'out_ports_count': nnet_count,
        'size_splits': split_points,
        'axis': 1,
        'name': split_id
    }
    variadic_split_node = AttributedVariadicSplit(graph, attrs).create_node()
    prev_layer_node = Node(graph, prev_layer_id)
    prev_layer_node.add_output_port(0)
    graph.create_edge(
        prev_layer_node, variadic_split_node, 0, 0,
        create_edge_attrs(prev_layer_id, variadic_split_node.id,
                          prev_layer_id))

    concat_id = graph.unique_id(prefix='Concat')
    graph.add_node(concat_id, parameters=None, op='concat', kind='op')
    concat_node = Node(graph, concat_id)

    # Connect each output of variadic_split_node to each subnetwork's inputs in ParallelComponent
    # and each subnetwork's output to concat_node
    for i, (input_node, output_node) in enumerate(zip(inputs, outputs)):
        output_node.add_output_port(0)
        concat_node.add_input_port(i)
        graph.create_edge(
            output_node, concat_node, 0, i,
            create_edge_attrs(output_node.id, concat_id, output_node.id, i, 0))
        graph.create_edge(
            variadic_split_node, input_node, i, 0,
            create_edge_attrs(variadic_split_node.id, input_node.id,
                              variadic_split_node.id, 0, i))
    return concat_id