Exemplo n.º 1
0
def build_graph(graph_attrs, meta_data, nodes, edges):
    """ Build the Graph with specific nodes and edges.
     :param graph_attrs: dictionary with graph attributes
     :param nodes: list of nodes where each node is tuple (node_name, type, attrs)
                  nodes=[
                      ('input', 'Parameter',  {}),
                      ('weights', 'Const', {}),
                      ('conv', 'Convolution', {}),
                      ('output', 'Result', {})
                  ]
     :param edges: list of edges where each edge is tuple (node_out, node_in, attrs)
                  edges=[
                      ('input', 'conv', {'out': 0, 'in': 0}),
                      ('weights', 'conv', {'out': 0, 'in': 1}),
                      ('conv', 'output', {'out': 0, 'in': 0})
                  ]
     :return: generated graph.
    """
    graph = Graph()
    graph.graph = graph_attrs
    graph.meta_data = meta_data

    for node in nodes:
        create_node(graph, node[0], node[1], node[2])

    for edge in edges:
        out_port = edge[2].get('out', 0)
        in_port = edge[2].get('in', 0)
        connect_nodes_by_name(graph, edge[0], out_port, edge[1], in_port)

    graph.clean_up()

    return graph
Exemplo n.º 2
0
def copy_graph_with_ops(graph: Graph) -> Graph:
    """
    Function to copy graph and apply extenders to appropriate nodes
    :param graph: Graph to copy
    :return:Copied graph with applied extenders
    """
    new_graph = Graph()
    new_graph.stage = 'back'
    new_graph.graph = graph.graph

    node_connections = dict()
    mapping_of_old_idx_into_new = dict()

    restore_correct_ports(graph)

    # Nodes preprocessing stage in source graph
    # Firstly propagate values only for Const nodes, because other preprocessings
    # assumes Const nodes are already preprocessed.
    for op in graph.get_op_nodes(type='Const'):
        preprocessing_op_nodes[op.type](op)

    for op in graph.get_op_nodes():
        if op.soft_get('type') != 'Const' and op.soft_get(
                'type') in preprocessing_op_nodes:
            preprocessing_op_nodes[op.type](op)

    # Create a new copy of graph with correct attributes (shape & type infer, backend attrs etc.)
    for op in graph.get_op_nodes():

        # Save input shapes restored from IR
        op['old_input_shapes'] = list()
        for n in op.in_nodes():
            op.old_input_shapes.append(int64_array(op.in_node(n).shape))

        # Apply extenders to nodes in source graph
        if op.type in Extender.registered_ops:
            Extender.get_extender_class_by_name(op.type).extend(op)
        else:
            log.debug(
                'Extender for node {} with type={} not found, please note.'.
                format(op.name, op.type))

        # Add node with necessary type and extended attrs in new graph
        op_type = op.soft_get('type_to_create', op.type)

        if op_type in custom_ops:
            node = custom_ops[op_type](new_graph, op.attrs()).create_node()
        else:
            if op_type not in Op.registered_ops:
                log.warning(
                    'Operation {} is not found in MO operations, please check it! '
                    'Simple shape infer function is used'.format(op_type))
                node = Op(new_graph, op.attrs()).create_node()
                assert 'type' in node, 'Operation {} have no `type` attribute.'.format(
                    node.soft_get('name'))
                node['op'] = node.type
                node['infer'] = Extender.use_shapes_from_ir
                if 'ir_data_attrs' in op:
                    node['IE'] = [('layer', [
                        ('id', lambda node: node.node), 'name', 'type',
                        'version'
                    ], [('data', list(op.ir_data_attrs.keys()), []), '@ports',
                        '@consts'])]

            else:
                node = Op.get_op_class_by_name(op_type)(
                    new_graph, op.attrs()).create_node()

            # Fill out_ports_count attribute
            if 'out_ports_count' not in node and node.soft_get(
                    'type') != 'Result':
                node['out_ports_count'] = len(op.out_edges())

        # This attribute is no longer needed and we can delete it
        if 'ir_data_attrs' in node:
            del node['ir_data_attrs']

        if op.has_and_set('need_copy_input_blobs'):
            copy_input_blobs(op, node)

        # Collect node connections
        mapping_of_old_idx_into_new[op.id] = node.id
        node_connections[op.id] = collect_node_outputs(op)

    # Restore connections in new graph
    for input_node_idx, its_outputs in list(node_connections.items()):
        for out_port_idx, out_port_dest in its_outputs.items():
            for dest_in_port_idx, dest_node_idx in out_port_dest:
                src = Node(new_graph,
                           mapping_of_old_idx_into_new[input_node_idx])
                dst = Node(new_graph,
                           mapping_of_old_idx_into_new[dest_node_idx])
                src.out_port(out_port_idx).connect(
                    dst.in_port(dest_in_port_idx))

    # Nodes postprocessing stage in new graph
    for op in new_graph.get_op_nodes():
        # Call normalize node outputs for restored operations to connect temporary Result operations for disconnected
        # output ports. We need to do that for correct shape inference. These Result operations will be removed during
        # IR emitting. For TopK operation outputs normalizing we should use specific
        # function TopKNormalizer.normalize_outputs.
        if op.soft_get('type') != 'TopK':
            Op.normalize_outputs(op)

        # Set correct_data_type attribute to Const data nodes to correct processing of restored values
        if op.soft_get('type') == 'Const':
            assert len(op.out_nodes()) == 1 and op.out_node(0).soft_get('kind') == 'data',\
                'Const node {} not properly corrected to appropriate data node'.format(op.soft_get('name'))
            op.out_node(0)['correct_data_type'] = True

            if op.has_and_set('rt_info'):
                op.out_node(0)['rt_info'] = op.rt_info

        # operations postprocessing with some special types
        if op.soft_get('type') in postprocessing_op_nodes:
            postprocessing_op_nodes[op.type](op)

        restore_tensor_names(op)

    # clean up graph to shape inference
    new_graph.clean_up()

    return new_graph
Exemplo n.º 3
0
    def replace_pattern(graph, match: dict):
        # Here we will found all parts of TI: condition, inputs/outputs, back edges, body and create TensorIterator Op
        # and make all checks needed for TensorIterator work
        cond_data = match['condition'].out_node(
            0) if not match['condition'].out_port(0).disconnected() else None
        time_data = match['condition'].out_node(1) if len(
            match['condition'].out_nodes()) >= 1 else None
        name = match['condition'].name

        back_edges = []
        inputs = []
        outputs = []

        if cond_data is not None:
            for node in cond_data.out_nodes():
                if node['kind'] == 'op' and node[
                        'op'] == 'TensorIteratorBackEdge':
                    back_edges.append(node.id)
                elif node['kind'] == 'op' and node[
                        'op'] == 'TensorIteratorInput':
                    inputs.append(node.id)
                elif node['kind'] == 'op' and node[
                        'op'] == 'TensorIteratorOutput':
                    outputs.append(node.id)

        if time_data is not None:
            for node in time_data.out_nodes():
                if node['kind'] == 'op' and node['op'] == 'TensorIteratorInput':
                    inputs.append(node.id)
                elif node['kind'] == 'op' and node[
                        'op'] == 'TensorIteratorOutput':
                    outputs.append(node.id)
                else:
                    # something goes wrong here
                    assert False
        condition = match['condition']
        tensor_sequence_length = condition.in_node(0)

        nodes_to_remove = [
            n.id
            for n in (condition, cond_data, time_data, tensor_sequence_length)
            if n is not None
        ]
        graph.remove_nodes_from(nodes_to_remove)

        body_nodes, extra_inputs = get_body(graph, inputs, outputs)

        if cond_data is not None:
            body_nodes = list(set(body_nodes) - set([cond_data]))

        inputs += extra_inputs

        assert all([node in graph.nodes() for node in body_nodes])

        inputs = [Node(graph, node) for node in inputs]
        outputs = [Node(graph, node) for node in outputs]
        back_edges = [Node(graph, node) for node in back_edges]

        external_inputs = [{
            'external_data_id':
            node.in_node(1 if node.has_valid('axis') else 0),
            'internal_data_id':
            node.out_node(0),
            'axis':
            node.axis,
            'start':
            node.start,
            'end':
            node.end,
            'stride':
            node.stride,
            'part_size':
            node.part_size
        } for node in inputs]

        external_outputs = [{
            'external_data_id':
            node.out_node(0),
            'internal_data_id':
            node.in_node(1 if node.has_valid('axis') else 0),
            'axis':
            node.axis,
            'start':
            node.start,
            'end':
            node.end,
            'stride':
            node.stride,
            'part_size':
            node.part_size
        } for node in outputs]

        back_edges_data = [{
            'from_data_id': node.in_node(1),
            'to_data_id': node.out_node(0),
            'init_data_id': node.in_node(0),
        } for node in back_edges]

        body = Graph(name='body')
        body.graph = graph.graph
        body.add_nodes_from([(node, graph.node[node]) for node in body_nodes])
        body.add_edges_from([
            (u, v, k, d) for u, v, k, d in graph.edges(data=True, keys=True)
            if u in body_nodes and v in body_nodes
        ])

        graph.remove_nodes_from(body_nodes + [match['condition'].id] +
                                [inp.id for inp in inputs] +
                                [out.id for out in outputs])
        internal_id_count = 0
        real_back_edges = []
        for edge in back_edges_data:
            assert edge['from_data_id'].id in body.nodes()
            assert edge['to_data_id'].id in body.nodes()
            assert edge['init_data_id'].id in body.nodes()
            edge['from_data_id'] = Node(body, edge['from_data_id'].id)
            edge['to_data_id'] = Node(body, edge['to_data_id'].id)
            edge['init_data_id'] = Node(body, edge['init_data_id'].id)
            add_opoutput(body, edge['from_data_id'].id, 0, False)

            # Assign/reuse ids for the back-edge start; it comes from from_data_id
            assert len(edge['from_data_id'].in_nodes()) == 1
            # layer id
            if not edge['from_data_id'].in_node().has_valid(
                    'internal_layer_id'):
                edge['from_data_id'].in_node(
                )['internal_layer_id'] = internal_id_count
                internal_id_count += 1
            edge['from_layer'] = edge['from_data_id'].in_node(
            )['internal_layer_id']

            # port id
            if 'internal_port_id' not in edge['from_data_id'].in_edge():
                edge['from_data_id'].in_edge(
                )['internal_port_id'] = internal_id_count
                internal_id_count += 1
            edge['from_port'] = edge['from_data_id'].in_edge(
            )['internal_port_id']

            # Look at all consumers for a data that ends a back-edge
            # For each such consumer, there will be a separate back-edge (and input)
            current_real_back_edges = []
            for _, consumer, key, edge_attrs in body.out_edges(
                    edge['to_data_id'].id, data=True, keys=True):

                real_edge = {}
                real_edge.update(
                    edge)  # all real back_edges have the same back-edge start

                consumer = Node(body, consumer)

                if real_edge['to_data_id'].in_node().has_valid(
                        'internal_layer_id'):
                    assert False
                    real_edge['to_data_id'].out_node()['internal_layer_id'] = \
                        real_edge['to_data_id'].in_node().internal_layer_id
                elif not consumer.has_valid('internal_layer_id'):
                    consumer['internal_layer_id'] = internal_id_count
                    internal_id_count += 1
                real_edge['to_layer'] = consumer['internal_layer_id']

                assert 'internal_port_id' not in edge_attrs
                assert len(real_edge['init_data_id'].out_edges()) == 1
                assert not 'internal_port_id' in real_edge[
                    'init_data_id'].out_edge()
                edge_attrs['internal_port_id'] = internal_id_count
                internal_id_count += 1
                real_edge['to_port'] = edge_attrs['internal_port_id']
                real_edge['consumer'] = consumer
                real_edge['consumer_key'] = key

                real_edge['attrs'] = deepcopy(edge_attrs)
                current_real_back_edges.append(real_edge)

            # connect initial data node with each consumer providing actual edge attributes
            body.add_edges_from([
                (real_edge['init_data_id'].id, real_edge['consumer'].id,
                 real_edge['consumer_key'], real_edge['attrs'])
                for real_edge in current_real_back_edges
            ])

            body.remove_nodes_from(
                [edge['to_data_id'].id, edge['to_data_id'].in_node().id])
            real_back_edges += current_real_back_edges

        real_external_inputs = []

        for ext_inp in external_inputs:
            assert ext_inp['external_data_id'].id not in body.nodes()
            assert ext_inp['internal_data_id'].id in body.nodes()
            ext_inp['internal_data_id'] = Node(body,
                                               ext_inp['internal_data_id'].id)

            if ext_inp['axis'] is not None:
                # Insert squeezing resize at input port that has partitioning
                shape = ext_inp['internal_data_id'].shape.copy()
                assert not ext_inp['internal_data_id'].has_valid('value')
                new_input_data = Op._create_data_node(
                    body,
                    ext_inp['internal_data_id'].name + '/UnsqueezedInput',
                    dict(shape=shape_insert(shape, ext_inp['axis'], 1)))

                reshape_op = Squeeze(
                    body,
                    dict(name=ext_inp['internal_data_id'].name +
                         '/InputSqueeze'))
                reshape_dim_data = Const(
                    body, {
                        'name':
                        ext_inp['internal_data_id'].name + '/ReshapeDim',
                        'value': ext_inp['axis']
                    }).create_node_with_data()
                reshape_op.create_node_with_data(
                    [new_input_data, reshape_dim_data],
                    data_nodes=[ext_inp['internal_data_id']])
                ext_inp['internal_data_id'] = new_input_data

            ext_inp['internal_data_id']['is_input'] = True
            assert len(ext_inp['internal_data_id'].in_nodes()) == 0
            ext_inp['external_port_id'] = internal_id_count
            internal_id_count += 1
            for _, consumer, edge_attrs in body.out_edges(
                    ext_inp['internal_data_id'].id, data=True):
                real_ext_inp = {}
                real_ext_inp.update(ext_inp)
                consumer = Node(body, consumer)
                if not consumer.has_valid('internal_layer_id'):
                    consumer['internal_layer_id'] = internal_id_count
                    internal_id_count += 1
                if not 'internal_port_id' in edge_attrs:
                    edge_attrs['internal_port_id'] = internal_id_count
                    internal_id_count += 1
                real_ext_inp['internal_layer_id'] = consumer[
                    'internal_layer_id']
                real_ext_inp['internal_port_id'] = edge_attrs[
                    'internal_port_id']
                real_external_inputs.append(real_ext_inp)

        for ext_out in external_outputs:
            assert ext_out['external_data_id'].id not in body.nodes()
            assert ext_out['internal_data_id'].id in body.nodes()
            ext_out['internal_data_id'] = Node(body,
                                               ext_out['internal_data_id'].id)

            if ext_out['axis'] is not None:
                # Insert unsqueezing resize at output port that has partitioning
                reshape_op = Unsqueeze(
                    body,
                    dict(name=ext_out['internal_data_id'].name +
                         '/OutputUnsqueeze'))
                reshape_dim_data = Const(
                    body, {
                        'name':
                        ext_out['internal_data_id'].name + '/ReshapeDim',
                        'value': ext_out['axis']
                    }).create_node_with_data()
                ext_out['internal_data_id'] = reshape_op.create_node_with_data(
                    [ext_out['internal_data_id'], reshape_dim_data])

            # TODO: add here working with simple outputs

            if not any([
                    out_node.soft_get('op', None) == 'Result'
                    for out_node in ext_out['internal_data_id'].out_nodes()
            ]):
                add_opoutput(body, ext_out['internal_data_id'].id, 0, False)

            # assert len(ext_out['internal_data_id'].out_nodes()) == 0
            assert len(ext_out['internal_data_id'].in_nodes()) == 1
            if not 'internal_layer_id' in ext_out['internal_data_id'].in_node(
            ):
                ext_out['internal_data_id'].in_node(
                )['internal_layer_id'] = internal_id_count
                internal_id_count += 1
            if not 'internal_port_id' in ext_out['internal_data_id'].in_edge():
                ext_out['internal_data_id'].in_edge(
                )['internal_port_id'] = internal_id_count
                internal_id_count += 1
            ext_out['internal_layer_id'] = ext_out['internal_data_id'].in_node(
            )['internal_layer_id']
            ext_out['internal_port_id'] = ext_out['internal_data_id'].in_edge(
            )['internal_port_id']
            ext_out['external_port_id'] = internal_id_count
            internal_id_count += 1

        # create TensorIterator layer with pre-computed components
        ti_op = TensorIterator(
            graph, {
                'name':
                name + '/TensorIterator',
                'body':
                body,
                'in_ports_count':
                len(external_inputs),
                'out_ports_count':
                len(external_outputs),
                'input_port_map': [{
                    field: external_input[field]
                    for field in [
                        'external_port_id', 'internal_layer_id',
                        'internal_port_id', 'axis', 'stride', 'part_size',
                        'start', 'end'
                    ]
                } for external_input in real_external_inputs],
                'output_port_map': [{
                    field: external_output[field]
                    for field in [
                        'external_port_id', 'internal_layer_id',
                        'internal_port_id', 'axis', 'stride', 'part_size',
                        'start', 'end'
                    ]
                } for external_output in external_outputs],
                'back_edges': [{
                    field: edge[field]
                    for field in
                    ['from_layer', 'from_port', 'to_layer', 'to_port']
                } for edge in real_back_edges],
            })

        ti_outs = ti_op.create_node_with_data(
            inputs=[inp['external_data_id'] for inp in external_inputs],
            edge_attrs=[{
                'external_port_id': inp['external_port_id']
            } for inp in external_inputs],
            data_nodes=[out['external_data_id'] for out in external_outputs])

        if not isinstance(ti_outs, list):
            ti_outs = [ti_outs]

        for i, out in enumerate(ti_outs):
            out.in_edge(
            )['external_port_id'] = external_outputs[i]['external_port_id']

        ti = ti_outs[0].in_node()
        TensorIterator.cover_body_input_data_nodes_with_parameter_ops(ti)
        TensorIterator.cover_body_constant_data_nodes_with_const_ops(ti)
        TensorIterator.normalize_internal_ids(ti)
    def replace_pattern(self, graph: Graph, match: dict):
        lstm = match['lstm']

        # Build TensorIterator body first
        body = Graph(name=lstm.name + '/sub_graph')
        body.graph = graph.graph

        # 1. Input squeeze Reshape
        inputs = [Op._create_data_node(body, lstm.name + '/inport/' + str(inp),
                                       {'shape': lstm.in_node(inp).shape.copy(),
                                        'value': lstm.in_node(inp).value.copy()
                                        if lstm.in_node(inp).value is not None and inp in [1, 2] else None})
                  for inp in [0, 4, 5, 1, 2]]  # X, WR, B, h_init, c_init

        inputs[0].shape[lstm.sequence_dim] = 1
        input_squeeze = Squeeze(body, dict(name=lstm.name + '/input_squeeze', internal_layer_id=0))
        squeeze_dim_data = Const(body, {'name': lstm.name + '/input_squeeze_dim',
                                        'value': [lstm.sequence_dim]}).create_node_with_data()
        inputs[0] = input_squeeze.create_node_with_data([inputs[0], squeeze_dim_data],
                                                        edge_attrs=[{'internal_port_id': 0}])

        # 2. Output unsqueeze Reshape
        outputs = [Op._create_data_node(body, lstm.name + '/outport/' + str(out),
                                        {'shape': lstm.out_node(out).shape.copy() if out in lstm.out_nodes()
                                        else lstm.in_node(4).shape.copy()}) for out in [0, 1]]
        for out in outputs:
            add_opoutput(body, out.id, 0, False)

        outputs[0].shape = shape_delete(outputs[0].shape, lstm.sequence_dim)
        output_unsqueeze = Unsqueeze(body, dict(name=lstm.name + 'output_unsqueeze', internal_layer_id=2))
        unsqueeze_dim_data = Const(body, {'name': lstm.name + '/output_unsqueeze_dim',
                                          'value': [lstm.sequence_dim]}).create_node_with_data()

        # 3. LSTMCell
        lstm_cell_op = LSTMCell(body, dict(hidden_size=lstm.hidden_size,
                                           activations=lstm.activations,
                                           activation_alpha=lstm.activation_alpha,
                                           activation_beta=lstm.activation_beta,
                                           clip=lstm.clip,
                                           input_forget=lstm.input_forget,
                                           name=lstm.name + '/LSTMCell',
                                           internal_layer_id=1))
        lstm_cell_node = lstm_cell_op.create_node_with_data(inputs, data_nodes=outputs,
                                                            edge_attrs=[{}, {'internal_port_id': 1},
                                                                        {'internal_port_id': 2}, {'bin': 'weights'},
                                                                        {'bin': 'biases'}])
        lstm_cell_node[0].in_node().out_edge(0)['internal_port_id'] = 4
        lstm_cell_node[0].in_node().out_edge(1)['internal_port_id'] = 5
        lstm_cell_node[0] = output_unsqueeze.create_node_with_data([lstm_cell_node[0], unsqueeze_dim_data])
        lstm_cell_node[0].in_node().out_edge(0)['internal_port_id'] = 3
        add_opoutput(body, lstm_cell_node[0].id, 0, False)

        # 4. TensorIterator layer creating
        assert lstm.direction in ['forward', 'reverse']
        if lstm.direction == 'forward':
            stride = 1
            start = None
            end = None
        else:
            assert lstm.direction == 'reverse'
            stride = -1
            start = -1
            end = 0

        output_port_map = [{
            'external_port_id': 3,
            'internal_layer_id': 2,
            'internal_port_id': 3,

            'axis': lstm.sequence_dim,
            'stride': stride,
            'start': start,
            'end': end,
            'part_size': 1,
        }]

        # Adding h_state, c_state to outputs
        if len(lstm.out_nodes()) == 3:
            output_port_map.extend([{
                'external_port_id': 4,
                'internal_layer_id': 1,
                'internal_port_id': 4,
            }, {
                'external_port_id': 5,
                'internal_layer_id': 1,
                'internal_port_id': 5,
            }])

        ti_op = TensorIterator(graph, {
            'name': lstm.name + '/TensorIterator',
            'body': body,
            'in_ports_count': 3,
            'out_ports_count': len(lstm.out_nodes()),

            'input_port_map': [
                {
                    'external_port_id': 0,
                    'internal_layer_id': 0,
                    'internal_port_id': 0,

                    'axis': lstm.sequence_dim,
                    'stride': stride,
                    'start': start,
                    'end': end,
                    'part_size': 1,
                },
                {
                    'external_port_id': 1,
                    'internal_layer_id': 1,
                    'internal_port_id': 1,
                },
                {
                    'external_port_id': 2,
                    'internal_layer_id': 1,
                    'internal_port_id': 2,
                },
            ],

            'output_port_map': output_port_map,

            'back_edges': [
                {
                    'from_layer': 1,
                    'from_port': 4,
                    'to_layer': 1,
                    'to_port': 1,
                },
                {
                    'from_layer': 1,
                    'from_port': 5,
                    'to_layer': 1,
                    'to_port': 2,
                },
            ]
        })

        assert sorted(lstm.out_nodes().keys()) == list(range(len(lstm.out_nodes()))), \
            "There are gaps in output ports of LSTMSequence operation. Node {}".format(lstm.id)

        outs = ti_op.create_node_with_data([lstm.in_node(i) for i in [0, 4, 5]],  # X, h_init, c_init
                                           data_nodes=[lstm.out_node(i) for i in range(len(lstm.out_nodes()))],
                                           edge_attrs=[{'external_port_id': 0}, {'external_port_id': 1},
                                                       {'external_port_id': 2}])

        if not isinstance(outs, list):
            outs = list([outs])

        graph.remove_node(lstm.id)
        outs[0].in_edge(0)['external_port_id'] = 3
        for i, out in enumerate(outs[1:]):
            external_port_id = 4 + i
            out.in_edge()['external_port_id'] = external_port_id

        ti = outs[0].in_node()
        TensorIterator.cover_body_input_data_nodes_with_parameter_ops(ti)
        TensorIterator.cover_body_constant_data_nodes_with_const_ops(ti)
        TensorIterator.normalize_internal_ids(ti)
    def replace_pattern(self, graph: Graph, match: dict):
        if match['rnn_layer']['op'] == 'LSTM':
            return

        rnn_layer = match['rnn_layer']

        # Build TensorIterator body first
        body = Graph(name=rnn_layer.name + '/sub_graph')
        body.graph = graph.graph

        # 1. Input squeeze Reshape
        inputs = [
            Op._create_data_node(
                body, rnn_layer.name + '/inport/' + str(inp), {
                    'shape':
                    rnn_layer.in_node(inp).shape.copy(),
                    'value':
                    rnn_layer.in_node(inp).value.copy()
                    if rnn_layer.in_node(inp).value is not None
                    and inp in [1, 2] else None
                }) for inp in [0, 4, 1, 2]
        ]  # X, h_init, WR, B

        inputs[0].shape[rnn_layer.sequence_dim] = 1
        input_squeeze = Squeeze(
            body,
            dict(name=rnn_layer.name + '/input_squeeze', internal_layer_id=0))
        input_squeeze_dim = Const(
            body,
            dict(name=rnn_layer.name + '/input_squeeze_dim',
                 value=rnn_layer.sequence_dim)).create_node_with_data()
        inputs[0] = input_squeeze.create_node_with_data(
            [inputs[0], input_squeeze_dim],
            edge_attrs=[{
                'internal_port_id': 0
            }])

        # 2. Output unsqueeze Reshape
        outputs = [
            Op._create_data_node(
                body, rnn_layer.name + '/outport/' + str(out), {
                    'shape':
                    rnn_layer.out_node(out).shape.copy()
                    if out in rnn_layer.out_nodes() else None
                }) for out in [0]
        ]
        for out in outputs:
            add_opoutput(body, out.id, 0, False)

        outputs[0].shape = shape_delete(outputs[0].shape,
                                        rnn_layer.sequence_dim)
        output_unsqueeze_dim = Const(
            body,
            dict(name=rnn_layer.name + '/output_unsqueeze_dim',
                 value=rnn_layer.sequence_dim)).create_node_with_data()
        output_unsqueeze = Unsqueeze(
            body,
            dict(name=rnn_layer.name + '/output_unsqueeze/',
                 internal_layer_id=2))

        additional_attrs = dict(activations=rnn_layer.activations,
                                activation_alpha=rnn_layer.activation_alpha,
                                activation_beta=rnn_layer.activation_beta,
                                clip=rnn_layer.clip)
        if rnn_layer.op == 'GRU':
            additional_attrs[
                'linear_before_reset'] = rnn_layer.linear_before_reset

        # 3. ***Cell
        rnn_cell_op = self.get_rnn_cell(rnn_layer['op'])(
            body,
            dict(hidden_size=rnn_layer.hidden_size,
                 name=rnn_layer.name + '/{}Cell'.format(rnn_layer.op),
                 **additional_attrs,
                 internal_layer_id=1))

        gru_cell = rnn_cell_op.create_node_with_data(inputs,
                                                     data_nodes=outputs,
                                                     edge_attrs=[{}, {
                                                         'internal_port_id':
                                                         1
                                                     }, {
                                                         'internal_port_id':
                                                         2
                                                     }, {
                                                         'bin':
                                                         'weights'
                                                     }, {
                                                         'bin':
                                                         'biases'
                                                     }])

        # internal ports for outputs of cell
        gru_cell.in_node().out_edge(0)['internal_port_id'] = 4  # h_state

        gru_cell = output_unsqueeze.create_node_with_data(
            [gru_cell, output_unsqueeze_dim])
        gru_cell.in_node().out_edge(0)['internal_port_id'] = 3
        add_opoutput(body, gru_cell.id, 0, False)

        # 4. TensorIterator layer creating
        assert rnn_layer.direction in ['forward', 'reverse']
        if rnn_layer.direction == 'forward':
            stride = 1
            start = None
            end = None
        else:
            assert rnn_layer.direction == 'reverse'
            stride = -1
            start = -1
            end = 0

        # stacked h_state
        output_port_map = [{
            'external_port_id': 3,
            'internal_layer_id': 2,
            'internal_port_id': 3,
            'axis': rnn_layer.sequence_dim,
            'stride': stride,
            'start': start,
            'end': end,
            'part_size': 1,
        }]

        # Adding last h_state to outputs
        if len(rnn_layer.out_nodes()) == 2:
            output_port_map.extend([{
                'external_port_id': 4,
                'internal_layer_id': 1,
                'internal_port_id': 4,
            }])

        ti_op = TensorIterator(
            graph,
            {
                'name':
                rnn_layer.name + '/TensorIterator',
                'body':
                body,
                'in_ports_count':
                4,
                'out_ports_count':
                len(rnn_layer.out_nodes()),
                'input_port_map': [
                    {
                        'external_port_id': 0,
                        'internal_layer_id': 0,
                        'internal_port_id': 0,
                        'axis': rnn_layer.sequence_dim,
                        'stride': stride,
                        'start': start,
                        'end': end,
                        'part_size': 1,
                    },
                    {
                        'external_port_id': 1,
                        'internal_layer_id': 1,
                        'internal_port_id': 1,
                    },
                ],
                'output_port_map':
                output_port_map,
                # only for h state
                'back_edges': [
                    {
                        'from_layer': 1,
                        'from_port': 4,
                        'to_layer': 1,
                        'to_port': 1,
                    },
                ]
            })

        assert sorted(rnn_layer.out_nodes().keys()) == list(range(len(rnn_layer.out_nodes()))), \
            "There are gaps in output ports of GRUSequence operation. Node {}".format(rnn_layer.id)

        outs = ti_op.create_node_with_data(
            [rnn_layer.in_node(i) for i in [0, 4]],  # X, h_init
            data_nodes=[
                rnn_layer.out_node(i)
                for i in range(len(rnn_layer.out_nodes()))
            ],
            edge_attrs=[{
                'external_port_id': 0
            }, {
                'external_port_id': 1
            }])

        if not isinstance(outs, list):
            outs = list([outs])

        graph.remove_node(rnn_layer.id)
        outs[0].in_edge(0)['external_port_id'] = 3
        for i, out in enumerate(outs[1:]):
            external_port_id = 4 + i
            out.in_edge()['external_port_id'] = external_port_id

        ti = outs[0].in_node()
        TensorIterator.cover_body_input_data_nodes_with_parameter_ops(ti)
        TensorIterator.cover_body_constant_data_nodes_with_const_ops(ti)
        TensorIterator.normalize_internal_ids(ti)