Ejemplo n.º 1
0
    def replace_pattern(graph: Graph, match: dict):
        split = match['split']

        # Check that we need to add fake output (case when input.shape[axis] != sum(outputs.shape[axis])
        axis = split.axis
        input_shape = split.in_port(0).data.get_shape()[axis]

        output_shape = sum([split.out_node(port).shape[axis] for port in split.out_nodes()])

        # In such case we don't need to do anything
        if input_shape == output_shape:
            return

        # Adding fake outputs
        n_parts = int(input_shape/split.size_splits[0])
        part_shape = split.in_port(0).data.get_shape().copy()
        part_shape[axis] = split.size_splits[0]

        out_ports = split.out_ports()
        for i in range(n_parts):
            if i in out_ports and not split.out_port(i).disconnected():
                continue

            if i not in out_ports:
                split.add_output_port(i)

            output = Result(graph).create_node(attrs={'name': split.name + '/Fake_output_{}/'.format(i)})

            split.out_port(i).connect(output.in_port(0))
            output.in_port(0).data.set_shape(part_shape)
Ejemplo n.º 2
0
    def replace_pattern(self, graph: Graph, match: dict):
        node = match['pooling']
        node.type = 'MaxPool'
        del node['pool_method']
        if 'exclude_pad' in node:
            del node['exclude_pad']

        # adding missed outputs for MaxPool node
        if node.out_port(0).disconnected():
            output = Result(
                node.graph, {
                    'name': node.name + '/Result_port_0/',
                    'keep_output_port':
                    node.has_and_set('remove_values_output')
                }).create_node()
            node.out_port(0).get_connection().set_destination(
                output.in_port(0))

        if node.out_port(1).disconnected():
            output = Result(
                node.graph, {
                    'name': node.name + '/Result_port_1/',
                    'keep_output_port':
                    node.has_and_set('remove_values_output')
                }).create_node()
            node.out_port(1).get_connection().set_destination(
                output.in_port(0))
    def replace_pattern(graph: Graph, match: dict):
        node = match['op']
        pair_node = Node(graph, node.pair_name)

        if node.t >= 0:
            raise Error('Does not support IfDefined with t > 0')

        if node.in_port(0).get_source() is not None:
            input_port = node.in_port(0).get_source()
            op_output_id = node.out_port(0).get_destination().node.id
            out_port = pair_node.out_port(0)
            node_name = node.name
            pair_name = pair_node.name
        else:
            input_port = pair_node.in_port(0).get_source()
            op_output_id = pair_node.out_port(0).get_destination().node.id
            out_port = node.out_port(0)
            node_name = pair_node.name
            pair_name = node.name

        in_shape = input_port.data.get_shape()
        node_t = abs(node.t)

        init_value_memory_out = create_zero_value_with_batch_from_input(input_port, in_shape[1]*node_t)
        memory_out = ReadValue(graph, {'name': pair_name, 'variable_id': node_name+pair_name}).create_node()
        init_value_memory_out.out_port(0).connect(memory_out.in_port(0))

        if node_t > 1:
            crop_concat = Crop(graph, {'name': 'Memory_crop', 'dim': np.array([in_shape[1]*(node_t-1)]),
                                       'offset': np.array([in_shape[1]]), 'axis': np.array([1])}).create_node()
            memory_out.out_port(0).connect(crop_concat.in_port(0))
            concat = Concat(graph, {'name': 'Memory_concat'}).create_node()
            concat.add_sequence_of_ports('in', range(2))
            crop_concat.out_port(0).connect(concat.in_port(0))
            concat.in_port(1).connect(input_port)

            memory_in = Assign(graph, {'name': node_name, 'variable_id': node_name + pair_name}).create_node()
            concat.out_port(0).connect(memory_in.in_port(0))
            out = Result(graph, {'name': 'Memory_output'}).create_node()
            memory_in.out_port(0).connect(out.in_port(0))

            crop_out = Crop(graph, {'name': 'Memory_crop_out', 'dim': np.array([in_shape[1]]),
                                    'offset': np.array([0]), 'axis': np.array([1])}).create_node()
            memory_out.out_port(0).connect(crop_out.in_port(0))
            out_port.get_connection().set_source(crop_out.out_port(0))
        else:
            memory_in = Assign(graph, {'name': node_name, 'variable_id': node_name + pair_name}).create_node()
            memory_in.in_port(0).connect(input_port)
            out = Result(graph, {'name': 'Memory_output'}).create_node()
            memory_in.out_port(0).connect(out.in_port(0))
            out_port.get_connection().set_source(memory_out.out_port(0))

        graph.remove_node(op_output_id)
        graph.remove_node(node.id)
        graph.remove_node(pair_node.id)
Ejemplo n.º 4
0
    def find_and_replace_pattern(self, graph: Graph):
        for ctc_greedy_decoder_tf in graph.get_op_nodes(op='CTCGreedyDecoderSeqLen', output_sparse_format=True):
            ctc_greedy_decoder_tf_name = ctc_greedy_decoder_tf.soft_get('name', ctc_greedy_decoder_tf.id)

            # TF CTCGreedyDecoder have 4 output tensors. If any of them connected to not Result operation then
            # transformation in not applicable
            for port_num in ctc_greedy_decoder_tf.out_ports():
                if not ctc_greedy_decoder_tf.out_port(port_num).disconnected()\
                        and ctc_greedy_decoder_tf.out_port(port_num).get_destination().node.soft_get('op') != 'Result':
                    return

            # If the first and second output are not connected to Result operations -
            # create Result operation and connect it to appropriate output
            if ctc_greedy_decoder_tf.out_port(0).disconnected():
                first_result = Result(graph,
                                       {'name': ctc_greedy_decoder_tf_name + '/decoded_classes'}
                                       ).create_node()
                ctc_greedy_decoder_tf.out_port(0).connect(first_result.in_port(0))

            if ctc_greedy_decoder_tf.out_port(1).disconnected():
                second_result = Result(graph,
                                       {'name': ctc_greedy_decoder_tf_name + '/seq_lengths_output'}
                                       ).create_node()
                ctc_greedy_decoder_tf.out_port(1).connect(second_result.in_port(0))


            # For normalizing input channel needs to transpose input data from [T, N, C] to [N, T, C]
            # which supported CTCGreedyDecoderSeqLen op.
            log.warning('Found TF CTCGreedyDecoder operation at the end of network. '
                        'PLEASE NOTE, appropriate network output operation CTCGreedyDecoderSeqLen {} '
                        'will have dense format, not sparse format!'.format(ctc_greedy_decoder_tf_name))
            ctc_data_permute = create_op_with_const_inputs(graph, Transpose, {1: int64_array([1, 0, 2])},
                                                           {'name': ctc_greedy_decoder_tf_name + '/ctc_data_permute'})

            assert ctc_greedy_decoder_tf.has_valid('merge_repeated'), \
                'The CTCGreedyDecoderSeqLen node "{}" misses "merge_repeated" attribute'.format(
                    ctc_greedy_decoder_tf_name)

            ctc_greedy_decoder_tf.in_port(0).get_source().connect(ctc_data_permute.in_port(0))
            ctc_greedy_decoder_tf.in_port(0).disconnect()
            ctc_data_permute.out_port(0).connect(ctc_greedy_decoder_tf.in_port(0))

            del ctc_greedy_decoder_tf['output_sparse_format']

            for port_num in [2, 3]:  # MO CTCGreedyDecoderSeqLen may have 2 outputs
                if port_num in ctc_greedy_decoder_tf.out_ports():
                    if not ctc_greedy_decoder_tf.out_port(port_num).disconnected():
                        ctc_greedy_decoder_tf.out_port(port_num).disconnect()
Ejemplo n.º 5
0
    def find_and_replace_pattern(self, graph: Graph):
        for offset_node in graph.get_op_nodes(op='MemoryOffset',
                                              splitted=False):
            paired_node = MemoryOffset(
                graph, {
                    'name': offset_node.pair_name,
                    'splitted': True,
                    'pair_name': offset_node.id,
                    't': offset_node.t,
                    'has_default': offset_node.has_default
                }).create_node()
            offset_node['splitted'] = True
            offset_node.out_port(0).get_connection().set_source(
                paired_node.out_port(0))
            res_node = Result(graph, {
                'name': offset_node.id + "_output"
            }).create_node()
            offset_node.out_port(0).connect(res_node.in_port(0))

            # If 'element_size' is previously copied from Parameter of from node with defined dim
            if offset_node.has_valid('element_size'):
                paired_node['element_size'] = offset_node['element_size']
            # Copy shape from previous node. Typically (but not always) for TDNN blocks this is the case
            else:
                paired_node['element_size'] = offset_node.in_port(
                    0).data.get_shape()[1]
Ejemplo n.º 6
0
def add_output_in_body(node,
                       port_num,
                       cur_graph,
                       cur_max_layer_id,
                       tracks,
                       track_index,
                       add_unsqueeze=True):
    port = node.out_port(port_num)
    if add_unsqueeze:
        unsq_name = port.node.soft_get('name', port.node.id) + "/Unsqueeze"
        unsq_node = create_op_node_with_second_input(cur_graph, Unsqueeze,
                                                     int64_array([0]),
                                                     {'name': unsq_name})
        port.connect(unsq_node.in_port(0))
        unsq_node['internal_layer_id'] = cur_max_layer_id + 1
        cur_max_layer_id += 1
        tracks.insert(track_index, {'node': unsq_node, 'graph': cur_graph})
        port = unsq_node.out_port(0)

    out_name = port.node.soft_get('name', port.node.id) + ":" + str(port_num)
    res_node = Result(cur_graph, {'name': out_name}).create_node()
    port.connect(res_node.in_port(0))
    res_node['internal_layer_id'] = cur_max_layer_id + 1
    cur_max_layer_id += 1
    tracks.insert(track_index, {'node': res_node, 'graph': cur_graph})

    return res_node
Ejemplo n.º 7
0
 def replace_pattern(graph: Graph, match: dict):
     offset_node = match['mem_offset']
     paired_node = MemoryOffset(graph, {'name': offset_node.pair_name, 'splitted': True, 'pair_name': offset_node.id,
                                        't': offset_node.t, 'has_default': offset_node.has_default}).create_node()
     offset_node['splitted'] = True
     offset_node.out_port(0).get_connection().set_source(paired_node.out_port(0))
     res_node = Result(graph, {'name': offset_node.id+"_output"}).create_node()
     offset_node.out_port(0).connect(res_node.in_port(0))
Ejemplo n.º 8
0
 def normalize_outputs(node: Node):
     """
     This function adds missed outputs for TopK node.
     """
     if node.out_port(0).disconnected():
         output = Result(
             node.graph, {
                 'name': node.name + '/Result_port_0/',
                 'remove_from_xml': node.has_and_set('remove_values_output')
             }).create_node()
         node.out_port(0).get_connection().set_destination(
             output.in_port(0))
     if node.out_port(1).disconnected():
         output = Result(node.graph, {
             'name': node.name + '/Result_port_1/'
         }).create_node()
         node.out_port(1).get_connection().set_destination(
             output.in_port(0))
Ejemplo n.º 9
0
def assign_add_output_result(op: Node):
    """
    Function adds necessary output result node for Assign node
    :param op:
    :return:
    """
    assert op.soft_get('type') == 'Assign', 'Wrong operation type, {} instead of Assign!' \
                                            ''.format(op.soft_get('type'))
    tmp_result = Result(op.graph, {'name': op.soft_get('name', op.id) + '/Result'}).create_node()
    op.out_port(0).connect(tmp_result.in_port(0))
Ejemplo n.º 10
0
 def split_normalize_outputs(node: Node):
     if node.has_valid('out_ports_count') and len(
             node.out_edges()) < node.out_ports_count:
         for p in range(node.out_ports_count):
             if p not in node.out_ports():
                 node.add_output_port(p)
             if node.out_port(p).disconnected():
                 res_node = Result(
                     node.graph, {
                         'name': node.name + '/Fake_output_{}/'.format(p),
                         'keep_output_port': True
                     }).create_node()
                 node.out_port(p).connect(res_node.in_port(0))
Ejemplo n.º 11
0
    def replace_pattern(graph: Graph, match: dict):
        node = match['op']

        if node.has_valid('out_ports_count') and len(
                node.out_edges()) < node.out_ports_count:
            for p in range(node.out_ports_count):
                if p not in node.out_ports():
                    node.add_output_port(p)
                if node.out_port(p).disconnected():
                    res_node = Result(
                        graph, {
                            'name': node.name + '/Fake_output_{}/'.format(p)
                        }).create_node()
                    node.out_port(p).connect(res_node.in_port(0))
Ejemplo n.º 12
0
 def normalize_outputs(node: Node):
     if node.has_valid('out_ports_count') and len(
             node.out_edges()) < node.out_ports_count:
         from mo.ops.result import Result  # Import is here to avoid circular import error
         for p in range(node.out_ports_count):
             if p not in node.out_ports():
                 node.add_output_port(p)
             if node.out_port(p).disconnected():
                 res_node = Result(
                     node.graph, {
                         'name': node.name + '/Fake_output_{}/'.format(p),
                         'keep_output_port': True
                     }).create_node()
                 node.out_port(p).connect(res_node.in_port(0))
Ejemplo n.º 13
0
    def replace_pattern(graph: Graph, match: dict):
        node = match['result']
        is_scalar = graph.graph['cmd_params'].generate_experimental_IR_V10

        reshape = create_op_node_with_second_input(
            graph, Reshape,
            int64_array([]) if is_scalar else int64_array([1]),
            {'override_output_shape': True})
        node.in_port(1).get_connection().insert_node(reshape)

        if node.out_port(0).disconnected():
            output = Result(
                graph, {
                    'name': node.name + '/Result_port_0/',
                    'remove_from_xml': node.has_and_set('remove_values_output')
                }).create_node()
            node.out_port(0).get_connection().set_destination(
                output.in_port(0))
        if node.out_port(1).disconnected():
            output = Result(graph, {
                'name': node.name + '/Result_port_1/'
            }).create_node()
            node.out_port(1).get_connection().set_destination(
                output.in_port(0))
Ejemplo n.º 14
0
    def replace_pattern(graph: Graph, match: dict):
        node = match['result']
        k = node.in_port(1).data.get_value()
        if not isinstance(k, np.ndarray):
            node.in_port(1).data.set_value(int64_array([k]))
        elif k.ndim == 0:
            node.in_port(1).data.set_value(int64_array([k.item()]))
        else:
            log.debug(
                'The "k" input to the TopK layer "{}" is already 1D'.format(
                    node.soft_get('name')))

        if node.out_port(0).disconnected():
            output = Result(graph, {
                'name': node.name + '/Result_port_0/'
            }).create_node()
            node.out_port(0).get_connection().set_destination(
                output.in_port(0))
        if node.out_port(1).disconnected():
            output = Result(graph, {
                'name': node.name + '/Result_port_1/'
            }).create_node()
            node.out_port(1).get_connection().set_destination(
                output.in_port(0))
Ejemplo n.º 15
0
 def split_offset(offset_node: Node):
     paired_node = MemoryOffset(
         offset_node.graph, {
             'name': offset_node.pair_name,
             'splitted': True,
             'pair_name': offset_node.id,
             'element_size': offset_node['element_size'],
             't': offset_node.t,
             'has_default': offset_node.has_default
         }).create_node()
     offset_node['splitted'] = True
     offset_node.out_port(0).get_connection().set_source(
         paired_node.out_port(0))
     res_node = Result(offset_node.graph, {
         'name': offset_node.id + '_output'
     }).create_node()
     offset_node.out_port(0).connect(res_node.in_port(0))
Ejemplo n.º 16
0
    def test_leaky_relu_mul_multiple_consumers(self):
        # multiple consumers of Mul operation
        graph = build_graph_with_edge_attrs(nodes, edges, {})
        additional_result = Result(graph, {'name': 'result_2'}).create_node()
        Node(graph, 'mul').out_port(0).connect(additional_result.in_port(0))

        ref_nodes = {
            **regular_op_with_shaped_data('input', shape, {
                'type': 'Parameter',
                'op': 'Parameter'
            }),
            **regular_op_with_shaped_data('mul', shape, {
                'type': 'Multiply',
                'name': 'mul'
            }),
            **regular_op_with_shaped_data('max', shape, {
                'type': 'Maximum',
                'name': 'final_max'
            }),
            **valued_const_with_data('const', float_array([0.5])),
            **regular_op_with_shaped_data('leaky_relu', shape, {
                'type': 'LeakyReLU',
                'name': 'max_final',
                'negative_slope': None
            }),
            **result('result'),
            **result('result_2')
        }
        ref_edges = [
            *connect('input:0', '0:mul'), *connect('const', '1:mul'),
            *connect('max:0', 'result'), *connect('mul:0', 'result_2'),
            *connect_data('input', 'leaky_relu'),
            *connect('leaky_relu', 'result')
        ]
        graph_ref = build_graph_with_edge_attrs(ref_nodes, ref_edges)

        LeakyReLUFusion().find_and_replace_pattern(graph)
        graph.clean_up()

        (flag, resp) = compare_graphs(graph, graph_ref, 'result')
        self.assertTrue(flag, resp)

        (flag, resp) = compare_graphs(graph, graph_ref, 'result_2')
        self.assertTrue(flag, resp)
    def replace_pattern(graph: Graph, match: dict):
        node = match['op']

        if node.name == 'iteration_number_out':
            return

        # calculate length of context when state of inference becomes meaningful
        inputs = []
        for n in graph.get_op_nodes(**{'op': 'Parameter'}):
            inputs.append(n)

        in_nodes = []
        for inp in inputs:
            for ins in inp.out_port(0).get_destinations():
                in_nodes.append(ins.node.name)

        context_len = 1
        try:
            subgraph = invert_sub_graph_between_nodes(
                graph, [node.in_port(0).get_source().node.name], in_nodes)
        except Error:
            return

        for n in subgraph:
            n_node = Node(graph, n)
            if n_node.kind == 'op' and n_node.op == 'Splice':
                context_len += len(n_node.context) - 1

        if context_len == 1:
            return

        in_node_port = node.in_port(0).get_source()
        in_node_shape = node.in_port(0).data.get_shape()
        node.in_port(0).disconnect()

        # add Select before saving state to avoid saving garbage
        select_node = Select(graph, {
            'name': 'select_' + node.name
        }).create_node()
        zero_else = Const(graph, {
            'name': 'zero_else',
            'value': np.zeros(in_node_shape)
        }).create_node()
        select_node.in_port(1).connect(in_node_port)
        select_node.in_port(2).connect(zero_else.out_port(0))

        # check if we have already appropriate iteration counter
        existing_counters = find_pattern_matches(
            graph,
            nodes=[('mem_in', dict(op='ReadValue')),
                   ('mem_in_data', dict(shape=int64_array([context_len]))),
                   ('crop_mem_in',
                    dict(op='Crop',
                         axis=int64_array([1]),
                         offset=int64_array([1]),
                         dim=int64_array([context_len - 1]))),
                   ('crop_mem_in_data', dict()),
                   ('concat', dict(op='Concat', axis=1)),
                   ('concat_data', dict()), ('const_1', dict(op='Const')),
                   ('const_1_data', dict()), ('mem_out', dict(op='Assign')),
                   ('crop_out',
                    dict(op='Crop',
                         axis=int64_array([1]),
                         offset=int64_array([0]),
                         dim=int64_array([1]))), ('crop_out_data', dict()),
                   ('select', dict(op='Select'))],
            edges=[('mem_in', 'mem_in_data'), ('mem_in_data', 'crop_mem_in'),
                   ('crop_mem_in', 'crop_mem_in_data'),
                   ('crop_mem_in_data', 'concat', {
                       'in': 0
                   }), ('const_1', 'const_1_data'),
                   ('const_1_data', 'concat', {
                       'in': 1
                   }), ('concat', 'concat_data'), ('concat_data', 'mem_out'),
                   ('concat_data', 'crop_out'), ('crop_out', 'crop_out_data'),
                   ('crop_out_data', 'select')])
        counter_match = next(existing_counters, None)
        if counter_match is not None:
            ones = Node(graph, inverse_dict(counter_match)['const_1'])
            input_port = Node(
                graph,
                inverse_dict(counter_match)['crop_out']).out_port(0)
        else:
            init_value_mem_out = create_zero_value_with_batch_from_input(
                in_node_port, context_len, np.int32)
            mem_out = ReadValue(
                graph, {
                    'name': 'iteration_number',
                    'variable_id': 'iteration_' + node.name
                }).create_node()
            mem_out.in_port(0).connect(init_value_mem_out.out_port(0))
            cut_first = Crop(
                graph, {
                    'name': 'cut_first',
                    'axis': int64_array([1]),
                    'offset': int64_array([1]),
                    'dim': int64_array([context_len - 1])
                }).create_node()
            cut_first.in_port(0).connect(mem_out.out_port(0))
            ones = Const(graph, {
                'name': 'ones',
                'value': np.ones([1, 1], dtype=np.int32)
            }).create_node()
            concat = Concat(graph, {
                'name': 'concat_ones',
                'in_ports_count': 2,
                'axis': 1
            }).create_node()
            concat.in_port(0).connect(cut_first.out_port(0))
            concat.in_port(1).connect(ones.out_port(0))
            mem_in = Assign(
                graph, {
                    'name': 'iteration_number_out',
                    'variable_id': 'iteration_' + node.name
                }).create_node()
            mem_in.in_port(0).connect(concat.out_port(0))
            res = Result(graph, {}).create_node()
            mem_in.out_port(0).connect(res.in_port(0))
            cut_last = Crop(
                graph, {
                    'name': 'cut_last',
                    'axis': int64_array([1]),
                    'offset': int64_array([0]),
                    'dim': int64_array([1])
                }).create_node()
            cut_last.in_port(0).connect(concat.out_port(0))
            input_port = cut_last.out_port(0)

        # Check if data from memory is 1
        # if it is True, we have correct data and should proceed with saving it to memory
        # else we have not gathered context and have garbage here, shouldn't change initial state of memory
        cast_in = Equal(graph, {
            'name': input_port.node.name + '/cast_to_bool'
        }).create_node()
        cast_in.in_port(0).connect(ones.out_port(0))
        cast_in.in_port(1).connect(input_port)
        select_node.in_port(0).connect(cast_in.out_port(0))
        select_node.out_port(0).connect(node.in_port(0))
        select_node.out_port(0).data.set_shape(in_node_shape)
Ejemplo n.º 18
0
    def insert_select(graph: Graph, node: Node):
        context_len = node.frame_time + 1

        if context_len == 1:
            return

        in_node_port = node.in_port(0).get_source()
        in_node_shape = node.in_port(0).data.get_shape()
        node.in_port(0).disconnect()

        # add Select before saving state to avoid saving garbage
        select_node = Select(graph, {'name': 'select_' + node.name}).create_node()
        zero_else = create_const_with_batch_from_input(in_node_port, in_node_shape[1])
        select_node.in_port(1).connect(in_node_port)
        select_node.in_port(2).connect(zero_else.out_port(0))

        # check if we have already appropriate iteration counter
        existing_counters = find_pattern_matches(graph, nodes=[('mem_in', dict(op='ReadValue')),
                                                               ('mem_in_data', dict(shape=int64_array([context_len]))),
                                                               ('crop_mem_in', dict(op='Crop', axis=int64_array([1]),
                                                                                    offset=int64_array([1]),
                                                                                    dim=int64_array([context_len - 1]))),
                                                               ('crop_mem_in_data', dict()),
                                                               ('concat', dict(op='Concat', axis=1)),
                                                               ('concat_data', dict()),
                                                               ('const_1', dict(op='Const')),
                                                               ('const_1_data', dict()),
                                                               ('mem_out', dict(op='Assign')),
                                                               ('crop_out', dict(op='Crop', axis=int64_array([1]),
                                                                                 offset=int64_array([0]),
                                                                                 dim=int64_array([1]))),
                                                               ('crop_out_data', dict()),
                                                               ('select', dict(op='Select'))
                                                               ],
                                                 edges=[('mem_in', 'mem_in_data'), ('mem_in_data', 'crop_mem_in'),
                                                        ('crop_mem_in', 'crop_mem_in_data'),
                                                        ('crop_mem_in_data', 'concat', {'in': 0}),
                                                        ('const_1', 'const_1_data'),
                                                        ('const_1_data', 'concat', {'in': 1}),
                                                        ('concat', 'concat_data'), ('concat_data', 'mem_out'),
                                                        ('concat_data', 'crop_out'), ('crop_out', 'crop_out_data'),
                                                        ('crop_out_data', 'select')])
        counter_match = next(existing_counters, None)
        if counter_match is not None:
            ones = Node(graph, inverse_dict(counter_match)['const_1'])
            input_port = Node(graph, inverse_dict(counter_match)['crop_out']).out_port(0)
        else:
            init_value_mem_out = create_const_with_batch_from_input(in_node_port, context_len, precision=np.int32)
            mem_out = ReadValue(graph, {'name': 'iteration_number',
                                        'variable_id': 'iteration_' + node.name}).create_node()
            mem_out.in_port(0).connect(init_value_mem_out.out_port(0))
            cut_first = Crop(graph, {'name': 'cut_first', 'axis': int64_array([1]),
                                     'offset': int64_array([1]), 'dim': int64_array([context_len - 1])}).create_node()
            cut_first.in_port(0).connect(mem_out.out_port(0))
            ones = create_const_with_batch_from_input(in_node_port, 1, 1, np.int32)
            concat = Concat(graph, {'name': 'concat_ones', 'in_ports_count': 2, 'axis': 1}).create_node()
            concat.in_port(0).connect(cut_first.out_port(0))
            concat.in_port(1).connect(ones.out_port(0))
            mem_in = Assign(graph, {'name': 'iteration_number_out',
                                    'variable_id': 'iteration_' + node.name}).create_node()
            mem_in.in_port(0).connect(concat.out_port(0))
            res = Result(graph, {}).create_node()
            mem_in.out_port(0).connect(res.in_port(0))
            cut_last = Crop(graph, {'name': 'cut_last', 'axis': int64_array([1]),
                                    'offset': int64_array([0]), 'dim': int64_array([1])}).create_node()
            cut_last.in_port(0).connect(concat.out_port(0))
            input_port = cut_last.out_port(0)

        # Check if data from memory is 1
        # if it is True, we have correct data and should proceed with saving it to memory
        # else we have not gathered context and have garbage here, shouldn't change initial state of memory
        cast_in = Equal(graph, {'name': input_port.node.name + '/cast_to_bool'}).create_node()
        cast_in.in_port(0).connect(ones.out_port(0))
        cast_in.in_port(1).connect(input_port)
        select_node.in_port(0).connect(cast_in.out_port(0))
        select_node.out_port(0).connect(node.in_port(0))
        select_node.out_port(0).data.set_shape(in_node_shape)
Ejemplo n.º 19
0
    def add_output_for_path(graphs_nodes_path):
        # add output to nodes according to path
        step_node = graphs_nodes_path[-1]['node']
        cur_graph = graphs_nodes_path[-1]['graph']

        ports_to_add_nodes = []
        for o_p in step_node.out_ports():
            ports_to_add_nodes.append(o_p)

        # update internal_layer_id for new Results
        for i in range(len(graphs_nodes_path) - 1, 0, -1):
            cur_max_layer_id = max_internal_layer_id(cur_graph) + 1
            cur_loop_node = graphs_nodes_path[i - 1]['node']
            new_out_ports = []
            if cur_loop_node.op is not 'If':
                # add Unsqueeze and Result for TensorIterator and Loop and update output_port_map
                for p_num in ports_to_add_nodes:
                    res_node = add_output_in_body(step_node, p_num, cur_graph,
                                                  cur_max_layer_id,
                                                  graphs_nodes_path, i)

                    # IR reader fix output port map for Loop, but have not change for TensorIterator
                    new_port_id = len(cur_loop_node.out_ports())
                    if cur_loop_node.op == 'TensorIterator':
                        new_port_id = new_port_id + len(
                            cur_loop_node.in_ports())
                    cur_loop_node.output_port_map.append({
                        'axis':
                        0,
                        'stride':
                        1,
                        'part_size':
                        1,
                        'start':
                        0,
                        'end':
                        -1,
                        'external_port_id':
                        new_port_id,
                        'internal_layer_id':
                        res_node['internal_layer_id']
                    })
                    port_id = new_port_id
                    if cur_loop_node.op == 'TensorIterator':
                        port_id = port_id - len(cur_loop_node.in_ports())

                    new_out_ports.append(port_id)
                    cur_loop_node.add_output_port(port_id)
            else:
                # add Result nodes for If and update output_id
                for p_num in ports_to_add_nodes:
                    res_node = add_output_in_body(step_node,
                                                  p_num,
                                                  cur_graph,
                                                  cur_max_layer_id,
                                                  graphs_nodes_path,
                                                  i,
                                                  add_unsqueeze=False)

                    if cur_loop_node.then_graph == cur_graph:
                        new_port_id = len(cur_loop_node.out_ports())
                        res_node['output_id'] = new_port_id
                        cur_loop_node.add_output_port(new_port_id)
                        new_out_ports.append(new_port_id)
                    else:
                        res_node['output_id'] = list(
                            cur_loop_node.out_ports().keys())[-1]
            ports_to_add_nodes = new_out_ports
            step_node = cur_loop_node
            cur_graph = graphs_nodes_path[i - 1]['graph']

        i = 0
        for p_num in ports_to_add_nodes:
            port = step_node.out_port(p_num)
            out_name = step_node.soft_get('name',
                                          step_node.id) + "." + str(p_num)
            res_node = Result(cur_graph, {'name': out_name}).create_node()
            port.connect(res_node.in_port(0))
            # add name of Result to fw_tensor_debug_info to avoid renaming
            if step_node.out_nodes()[p_num].has_and_set(
                    'fw_tensor_debug_info'):
                step_node.out_nodes()[p_num]['fw_tensor_debug_info'].append(
                    out_name)
            else:
                step_node.out_nodes()[p_num]['fw_tensor_debug_info'] = [[
                    out_name, out_name
                ]]
            if step_node.op == 'TensorIterator':
                step_node.out_edges()[len(step_node.out_edges())-1]['external_port_id'] = p_num + \
                                                                                          len(step_node.in_ports())
            graphs_nodes_path.insert(0, {'node': res_node, 'graph': cur_graph})
            i += 1
        return graphs_nodes_path
    def replace_pattern(graph: Graph, match: dict):
        node = match['op']
        pair_node = Node(graph, node.pair_name)

        if node.t >= 0:
            raise Error('Does not support IfDefined with t > 0')

        if node.in_port(0).get_source() is not None:
            input_port = node.in_port(0).get_source()
            op_output_id = node.out_port(0).get_destination().node.id
            out_port = pair_node.out_port(0)
            node_name = node.name
            pair_name = pair_node.name
        else:
            input_port = pair_node.in_port(0).get_source()
            op_output_id = pair_node.out_port(0).get_destination().node.id
            out_port = node.out_port(0)
            node_name = pair_node.name
            pair_name = node.name

        in_shape = input_port.data.get_shape()
        node_t = abs(node.t)

        init_value_memory_out = create_zero_value_with_batch_from_input(
            input_port, in_shape[1] * node_t)
        memory_out = ReadValue(graph, {
            'name': pair_name,
            'variable_id': node_name + pair_name
        }).create_node()
        init_value_memory_out.out_port(0).connect(memory_out.in_port(0))

        if node_t > 1:
            crop_concat = Crop(
                graph, {
                    'name': 'Memory_crop',
                    'dim': np.array([in_shape[1] * (node_t - 1)]),
                    'offset': np.array([in_shape[1]]),
                    'axis': np.array([1])
                }).create_node()
            memory_out.out_port(0).connect(crop_concat.in_port(0))
            concat = Concat(graph, {'name': 'Memory_concat'}).create_node()
            concat.add_sequence_of_ports('in', range(2))
            crop_concat.out_port(0).connect(concat.in_port(0))
            concat.in_port(1).connect(input_port)

            memory_in = Assign(graph, {
                'name': node_name,
                'variable_id': node_name + pair_name
            }).create_node()
            concat.out_port(0).connect(memory_in.in_port(0))
            out = Result(graph, {'name': 'Memory_output'}).create_node()
            memory_in.out_port(0).connect(out.in_port(0))

            crop_out = Crop(
                graph, {
                    'name': 'Memory_crop_out',
                    'dim': np.array([in_shape[1]]),
                    'offset': np.array([0]),
                    'axis': np.array([1])
                }).create_node()
            memory_out.out_port(0).connect(crop_out.in_port(0))
            out_port.get_connection().set_source(crop_out.out_port(0))
        else:
            memory_in = Assign(graph, {
                'name': node_name,
                'variable_id': node_name + pair_name
            }).create_node()
            memory_in.in_port(0).connect(input_port)
            out = Result(graph, {'name': 'Memory_output'}).create_node()
            memory_in.out_port(0).connect(out.in_port(0))
            out_port.get_connection().set_source(memory_out.out_port(0))

        if not graph.graph['cmd_params'].static_shape:
            log.error(
                "Model can not be translated in a reshape-able way.\n"
                "Model Optimizer key static_shape was turned on to prevent related errors.\n"
                "There will be no success changing input shapes of the model with the help of "
                "InferenceEngine reshape method",
                extra={'is_warning': True})
            graph.graph['cmd_params'].static_shape = True

        graph.remove_node(op_output_id)
        graph.remove_node(node.id)
        graph.remove_node(pair_node.id)
Ejemplo n.º 21
0
    def replace_pattern(graph: Graph, match: dict):
        node = match['op']

        if node.name == 'iteration_number_out':
            return

        # calculate length of context when state of inference becomes meaningful
        inputs = []
        for n in graph.get_op_nodes(**{'op': 'Parameter'}):
            inputs.append(n)

        in_nodes = []
        for inp in inputs:
            for ins in inp.out_port(0).get_destinations():
                in_nodes.append(ins.node.name)

        context_len = 1
        try:
            subgraph = invert_sub_graph_between_nodes(
                graph, [node.in_port(0).get_source().node.name], in_nodes)
        except Error:
            return

        for n in subgraph:
            n_node = Node(graph, n)
            if n_node.kind == 'op' and n_node.op == 'Splice':
                context_len += len(n_node.context) - 1

        if context_len == 1:
            return

        in_node_port = node.in_port(0).get_source()
        in_node_shape = node.in_port(0).data.get_shape()
        node.in_port(0).disconnect()

        # add Select before saving state to avoid saving garbage
        select_node = Select(graph, {
            'name': 'select_' + node.name
        }).create_node()
        zero_else = Const(graph, {
            'name': 'zero_else',
            'value': np.zeros(in_node_shape)
        }).create_node()
        select_node.in_port(1).connect(in_node_port)
        select_node.in_port(2).connect(zero_else.out_port(0))

        # check if we have already appropriate iteration counter
        existing_counters = find_pattern_matches(
            graph,
            nodes=[('mem_in',
                    dict(op='Memory',
                         index=1,
                         shape=int64_array([context_len]))),
                   ('mem_in_data', dict()),
                   ('crop_mem_in',
                    dict(op='Crop',
                         axis=int64_array([1]),
                         offset=int64_array([1]),
                         dim=int64_array([context_len - 1]))),
                   ('crop_mem_in_data', dict()),
                   ('concat', dict(op='Concat', axis=1)),
                   ('concat_data', dict()), ('const_1', dict(op='Const')),
                   ('const_1_data', dict()),
                   ('mem_out',
                    dict(op='Memory',
                         index=0,
                         shape=int64_array([context_len]))),
                   ('crop_out',
                    dict(op='Crop',
                         axis=int64_array([1]),
                         offset=int64_array([0]),
                         dim=int64_array([1]))), ('crop_out_data', dict()),
                   ('select', dict(op='Select'))],
            edges=[('mem_in', 'mem_in_data'), ('mem_in_data', 'crop_mem_in'),
                   ('crop_mem_in', 'crop_mem_in_data'),
                   ('crop_mem_in_data', 'concat', {
                       'in': 0
                   }), ('const_1', 'const_1_data'),
                   ('const_1_data', 'concat', {
                       'in': 1
                   }), ('concat', 'concat_data'), ('concat_data', 'mem_out'),
                   ('concat_data', 'crop_out'), ('crop_out', 'crop_out_data'),
                   ('crop_out_data', 'select')])
        counter_match = next(existing_counters, None)
        if counter_match is not None:
            input_port = Node(
                graph,
                inverse_dict(counter_match)['crop_out']).out_port(0)
        else:
            mem_out = Memory(
                graph, {
                    'name': 'iteration_number',
                    'size': 2,
                    'index': 1,
                    'id': 'iteration_' + node.name,
                    'shape': int64_array([context_len]),
                    'dst_type': np.int32
                }).create_node()
            cut_first = Crop(
                graph, {
                    'name': 'cut_first',
                    'axis': int64_array([1]),
                    'offset': int64_array([1]),
                    'dim': int64_array([context_len - 1])
                }).create_node()
            cut_first.in_port(0).connect(mem_out.out_port(0))
            ones = Const(graph, {
                'name': 'ones',
                'value': np.ones([1, 1], dtype=np.int32)
            }).create_node()
            concat = Concat(graph, {
                'name': 'concat_ones',
                'in_ports_count': 2,
                'axis': 1
            }).create_node()
            concat.in_port(0).connect(cut_first.out_port(0))
            concat.in_port(1).connect(ones.out_port(0))
            mem_in = Memory(
                graph, {
                    'name': 'iteration_number_out',
                    'size': 2,
                    'index': 0,
                    'id': 'iteration_' + node.name,
                    'shape': int64_array([context_len])
                }).create_node()
            mem_in.in_port(0).connect(concat.out_port(0))
            res = Result(graph, {}).create_node()
            mem_in.out_port(0).connect(res.in_port(0))
            cut_last = Crop(
                graph, {
                    'name': 'cut_last',
                    'axis': int64_array([1]),
                    'offset': int64_array([0]),
                    'dim': int64_array([1])
                }).create_node()
            cut_last.in_port(0).connect(concat.out_port(0))
            input_port = cut_last.out_port(0)

        select_node.in_port(0).connect(input_port)
        select_node.out_port(0).connect(node.in_port(0))
        select_node.out_port(0).data.set_shape(in_node_shape)
Ejemplo n.º 22
0
 def insert_result(node, child_node, name):
     res_op = Result(node.graph, {
         'name': f'Result_{name}_{node.name}'
     }).create_node()
     child_node.out_port(0).connect(res_op.in_port(0))
Ejemplo n.º 23
0
    def replace_op(self, graph: Graph, node: Node):
        input_out_port = node.in_port(0).get_source()

        memory_pair_input = unique_id('id')
        memory_pair_output = unique_id('id')

        # Input -> FullyConnected
        fc_layer_after_input_attrs = {
            'name': 'input_fullyconnected',
            'out-size': node.gifo_x_weights_shape[0],
            'transpose_weights': True,
            'bias_term': True,
        }

        fc_layer_after_input = FullyConnected(
            graph, fc_layer_after_input_attrs).create_node()
        fc_layer_after_input.in_port(0).connect(input_out_port)
        input_as_const(fc_layer_after_input, fc_layer_after_input_attrs, 1,
                       'weights', node.gifo_x_weights)
        input_as_const(fc_layer_after_input, fc_layer_after_input_attrs, 2,
                       'biases', node.gifo_biases)

        init_value_prev_lstm_output = create_zero_value_with_batch_from_input(
            input_out_port, node.gifo_r_weights_shape[1])
        prev_lstm_output = ReadValue(graph, {
            'name': 'prev_memory_output',
            'variable_id': memory_pair_input
        }).create_node()
        prev_lstm_output.in_port(0).connect(
            init_value_prev_lstm_output.out_port(0))

        # *Memory(output) -> FullyConnected
        fc_layer_from_prev_state_attrs = {
            'name': 'prev_memory_output_fullyconnected',
            'out-size': node.gifo_r_weights_shape[0],
            'transpose_weights': True,
            'bias_term': False,
        }

        fc_layer_from_prev_state = FullyConnected(
            graph, fc_layer_from_prev_state_attrs).create_node()
        fc_layer_from_prev_state.in_port(0).connect(
            prev_lstm_output.out_port(0))
        input_as_const(fc_layer_from_prev_state,
                       fc_layer_from_prev_state_attrs, 1, 'weights',
                       node.gifo_r_weights)

        # Memory -> FullyConnected  \
        #                           *Eltwise(sum)
        # Input -> FullyConnected   /
        join_input_prev_state_sum = Add(graph, {
            'name': 'join_input_eltwise'
        }).create_node()
        join_input_prev_state_sum.in_port(0).connect(
            fc_layer_from_prev_state.out_port(0))
        join_input_prev_state_sum.in_port(1).connect(
            fc_layer_after_input.out_port(0))

        # *Eltwise(sum) -> Split
        # it is split into 4 nodes: Act, Eltw*3
        # the following order is mandatory
        #       ___Tanh
        #      /
        # Split ---(2)Eltwise(sum)
        #     |\
        #     | \__(3)Eltwise(sum)
        #     |____(4)Eltwise(sum)
        split_joined_input_axis = Const(graph, {
            'value': np.int64(1)
        }).create_node()
        split_joined_input = Split(graph, {
            'name': 'join_input_split',
            'num_splits': 4,
            'out_ports_count': 4
        }).create_node()
        split_joined_input.in_port(0).connect(
            join_input_prev_state_sum.out_port(0))
        split_joined_input.in_port(1).connect(
            split_joined_input_axis.out_port(0))

        # prev_lstm_state = Memory(graph, {'name': 'prev_memory_state',
        #                                 'id': memory_pair_output,
        #                                 'index': 1,
        #                                 'size': 2,
        #                                 'shape': np.array([node.input_gate_weights.shape[0]], dtype=np.int64)
        #                                 }).create_node()
        init_value_prev_lstm_state = create_zero_value_with_batch_from_input(
            split_joined_input.out_port(0), node.input_gate_weights.shape[0])
        prev_lstm_state = ReadValue(graph, {
            'name': 'prev_memory_state',
            'variable_id': memory_pair_output
        }).create_node()
        prev_lstm_state.in_port(0).connect(
            init_value_prev_lstm_state.out_port(0))

        # *Memory(state) -> *ScaleShift(input)
        state_input_scaleshift_attrs = {
            'name': 'input_scaleshift',
            'bias_term': False
        }
        state_input_scaleshift = ScaleShiftOp(
            graph, state_input_scaleshift_attrs).create_node()
        state_input_scaleshift.in_port(0).connect(prev_lstm_state.out_port(0))
        input_as_const(state_input_scaleshift, state_input_scaleshift_attrs, 1,
                       'weights', node.input_gate_weights)

        # *Memory(state) -> *ScaleShift(forget)
        state_forget_scaleshift_attrs = {
            'name': 'forget_scaleshift',
            'bias_term': False
        }
        state_forget_scaleshift = ScaleShiftOp(
            graph, state_forget_scaleshift_attrs).create_node()
        state_forget_scaleshift.in_port(0).connect(prev_lstm_state.out_port(0))
        input_as_const(state_forget_scaleshift, state_forget_scaleshift_attrs,
                       1, 'weights', node.forget_gate_weights)

        # Split                                 \
        #                                       (2)Eltwise(sum)
        # Memory(state) -> *ScaleShift(input)  /
        join_prev_lstm_input_joined_input_sum = Add(
            graph, {
                'name': 'join_prev_lstm_input_joined_input_eltwise'
            }).create_node()
        join_prev_lstm_input_joined_input_sum.in_port(0).connect(
            split_joined_input.out_port(1))
        join_prev_lstm_input_joined_input_sum.in_port(1).connect(
            state_input_scaleshift.out_port(0))
        # Split                                 \
        #                                       (3)Eltwise(sum)
        # Memory(state) -> *ScaleShift(forget)  /
        join_prev_lstm_input_joined_forget_sum = Add(
            graph, {
                'name': 'join_prev_lstm_input_joined_forget_sum',
            }).create_node()
        join_prev_lstm_input_joined_forget_sum.in_port(0).connect(
            split_joined_input.out_port(2))
        join_prev_lstm_input_joined_forget_sum.in_port(1).connect(
            state_forget_scaleshift.out_port(0))

        # Split -> Tanh
        remember_tahn = Tanh(graph, {'name': 'remember_tahnv'}).create_node()
        remember_tahn.in_port(0).connect(split_joined_input.out_port(0))

        # Split -> (2)Eltwise(sum) -> *Sigmoid
        remember_sigmoid = Sigmoid(graph, {
            'name': 'remember_sigmoid'
        }).create_node()
        remember_sigmoid.in_port(0).connect(
            join_prev_lstm_input_joined_input_sum.out_port(0))

        # Split -> (3)Eltwise(sum) -> **Sigmoid
        forget_sigmoid = Sigmoid(graph, {
            'name': 'forget_sigmoid'
        }).create_node()
        forget_sigmoid.in_port(0).connect(
            join_prev_lstm_input_joined_forget_sum.out_port(0))

        # *Memory(state)                        \
        #                                       (6)Eltwise(mul)
        # Split -> (3)Eltwise(sum) -> **Sigmoid /
        join_forget_prev_state_mul = Mul(graph, {
            'name': 'join_forget_prev_state_mul'
        }).create_node()
        join_forget_prev_state_mul.in_port(0).connect(
            forget_sigmoid.out_port(0))
        join_forget_prev_state_mul.in_port(1).connect(
            prev_lstm_state.out_port(0))

        # Split -> Tahn                         \
        #                                       (5)Eltwise(mul)
        # Split -> (2)Eltwise(sum) -> *Sigmoid   /
        join_remember_candidates_mul = Mul(
            graph, {
                'name': 'join_remember_candidates_mul'
            }).create_node()
        join_remember_candidates_mul.in_port(0).connect(
            remember_tahn.out_port(0))
        join_remember_candidates_mul.in_port(1).connect(
            remember_sigmoid.out_port(0))

        # (5)Eltwise(mul)  \
        #               (7)Eltwise(sum)
        # (6)Eltwise(mul)   /
        join_forget_remember_sum = Add(graph, {
            'name': 'join_forget_remember_sum'
        }).create_node()
        join_forget_remember_sum.in_port(0).connect(
            join_forget_prev_state_mul.out_port(0))
        join_forget_remember_sum.in_port(1).connect(
            join_remember_candidates_mul.out_port(0))

        # (7)Eltwise(sum) -> Clamp
        join_forget_clamp = create_op_with_const_inputs(
            graph, Clamp, {
                1: np.array(-node.clip_value, dtype=np.float32),
                2: np.array(node.clip_value, dtype=np.float32)
            }, {'name': 'join_forget_clamp'}, join_forget_remember_sum)
        #
        # Clamp -> (2)Memory(state)
        next_lstm_state = Assign(graph, {
            'name': 'next_lstm_state',
            'variable_id': memory_pair_output
        }).create_node()
        next_lstm_state.in_port(0).connect(join_forget_clamp.out_port(0))

        res_node = Result(graph, {'name': 'next_lstm_state_out'}).create_node()
        res_node.in_port(0).connect(next_lstm_state.out_port(0))

        # Clamp -> (2)Tahn
        state_filtered_tahn = Tanh(graph, {
            'name': 'state_filtered_tahn'
        }).create_node()
        state_filtered_tahn.in_port(0).connect(join_forget_clamp.out_port(0))

        # Clamp -> (2)ScaleShift
        clamp_scaleshift_attrs = {
            'name': 'clamp_scaleshift',
            'bias_term': False
        }
        clamp_scaleshift = ScaleShiftOp(graph,
                                        clamp_scaleshift_attrs).create_node()
        clamp_scaleshift.in_port(0).connect(join_forget_clamp.out_port(0))
        input_as_const(clamp_scaleshift, clamp_scaleshift_attrs, 1, 'weights',
                       node.output_gate_weights)

        # Split                 \
        #                       (4)Eltwise(sum)
        # Clamp -> (2)ScaleShift /
        join_next_lstm_input_joined_input_sum = Add(
            graph, {
                'name': 'join_next_lstm_input_joined_input_sum',
            }).create_node()
        join_next_lstm_input_joined_input_sum.in_port(0).connect(
            split_joined_input.out_port(3))
        join_next_lstm_input_joined_input_sum.in_port(1).connect(
            clamp_scaleshift.out_port(0))

        # (4)Eltwise(sum) -> (3)Sigmoid
        output_sigmoid = Sigmoid(graph, {
            'name': 'output_sigmoid'
        }).create_node()
        output_sigmoid.in_port(0).connect(
            join_next_lstm_input_joined_input_sum.out_port(0))

        # (4)Eltwise(sum) -> (3)Sigmoid         \
        #                                       (5)Eltwise(mul)
        # Clamp -> (2)Tahn                      /
        joined_output_mul = Mul(graph, {
            'name': 'joined_output_mul'
        }).create_node()
        joined_output_mul.in_port(0).connect(state_filtered_tahn.out_port(0))
        joined_output_mul.in_port(1).connect(output_sigmoid.out_port(0))

        # (5)Eltwise(mul) -> (3)FullyConnected
        fc_output_attrs = {
            'name': 'FullyConnected',
            'out-size': node.projection_weights_shape[0],
            'transpose_weights': True,
            'bias_term': False
        }
        fc_output = FullyConnected(graph, fc_output_attrs).create_node()
        fc_output.in_port(0).connect(joined_output_mul.out_port(0))
        input_as_const(fc_output, fc_output_attrs, 1, 'weights',
                       node.projection_weights)

        #                   / (2)Memory(output)
        # (3)FullyConnected
        #                   \ Output (any next node) (edge created automatically after replacement)
        next_lstm_output = Assign(graph, {
            'name': 'next_lstm_output',
            'variable_id': memory_pair_input
        }).create_node()
        next_lstm_output.in_port(0).connect(fc_output.out_port(0))

        res_node_lstm_output = Result(graph, {
            'name': 'next_lstm_output_out'
        }).create_node()
        res_node_lstm_output.in_port(0).connect(next_lstm_output.out_port(0))

        return [fc_output.id]