Esempio n. 1
0
 def generate_port_map(node: Node, src_port_map, dir: str):
     result_list = []
     for record in src_port_map:
         # do not update ids for not-connected output which is used in the Loop operation only
         if record['external_port_id'] != -1:
             if dir == 'out':  # increase the output port id by the number of input ports
                 # update the port id for proper generation of a "ports" section
                 record['external_port_id'] += len(node.in_ports())
         record['internal_layer_id'] = TensorIterator.find_internal_layer_id(node.body, record['internal_layer_id'])
         result_list.append(record)
     return result_list
Esempio n. 2
0
    def find_and_replace_pattern(self, graph: Graph):
        for ti in graph.get_op_nodes(type='TensorIterator'):
            self.external_nodes_normalization(ti)

            if len([record for record in ti.input_port_map if record.get('axis') is not None]) == 0:
                for record in ti.output_port_map:
                    if record.get('axis') is not None:
                        record['start'] = 0
                        real_output_port = TensorIterator.special_port_to_real_port(ti, record['external_port_id'], 'out')
                        output_shape = ti.out_port(real_output_port).data.get_shape()
                        assert output_shape is not None
                        record['end'] = output_shape[record['axis']]
def ti_set_output_port_shape(cycle_node, internal_id, port_num,
                             iterations_count, axis):
    int_node_name = TensorIterator.find_internal_layer_id(
        cycle_node.body, internal_id)
    int_node = Node(cycle_node.body, int_node_name)
    assert int_node.op == 'Result'
    out_shape = int_node.in_port(0).data.get_shape().copy()
    # inside cycle node Unsqueeze was added to have the first dimension for concatenating results along it
    assert len(out_shape) >= 1
    if axis is not None:
        out_shape[axis] = iterations_count

    assert port_num in cycle_node.out_ports()
    cycle_node.out_port(port_num).data.set_shape(out_shape)
def ti_infer(step_node, port_num):
    out_port_map = step_node.output_port_map
    port_num = port_num + len(step_node.in_ports())
    # find out which internal layer maps to port_num
    found_rec = None
    for record in out_port_map:
        if record['external_port_id'] == port_num:
            found_rec = record
            break
    assert found_rec is not None, \
        "External port {} is not connected with body in node {}".format(port_num,
                                                                        step_node.soft_get('name', step_node.id))

    port_num = port_num - len(step_node.in_ports())

    # find out iterations count for TensorIterator to set output shape correctly

    iterations_count = get_iterations_count_from_output_record(found_rec)
    if iterations_count is dynamic_dimension_value:
        iterations_count = TensorIterator.find_iterations_count_for_output(
            step_node)

    ti_set_output_port_shape(step_node, found_rec['internal_layer_id'],
                             port_num, iterations_count, found_rec['axis'])
Esempio n. 5
0
    def external_nodes_normalization(ti):
        """
        TensorIterator external ports may have several internal layer connections.

        Current transformation does the following:
            - normalizes port maps (eliminating duplicated records)
            - replicates external input/output port for each internal Parameter/Result it is connected to
            - updates input and output port maps according to previous step replications
        """
        def update_external_port_id(ti, port_type, old_external_port_id,
                                    new_external_port_id, internal_layer_id):
            assert port_type in ['in', 'out']

            port_map = ti.input_port_map if port_type == 'in' else ti.output_port_map
            for record in port_map:
                if record['external_port_id'] == old_external_port_id and \
                        record['internal_layer_id'] == internal_layer_id:
                    record['external_port_id'] = new_external_port_id

        NormalizeTI.maps_uniqueization(ti)

        body = ti.body

        external_input_ports = defaultdict(list)
        for record in ti.input_port_map:
            assert 'external_port_id' in record
            external_input_ports[record['external_port_id']].append(record)

        for external_port_id, record_list in external_input_ports.items():
            if len(record_list) == 1:
                continue

            real_external_port_id = TensorIterator.special_port_to_real_port(
                ti, external_port_id, 'in')
            source = ti.in_port(real_external_port_id).get_source()

            for record in record_list[1:]:
                assert 'internal_layer_id' in record

                new_real_input_port_id = max(map(int,
                                                 ti.in_ports().keys())) + 1
                new_external_port_id = max([
                    int(d['external_port_id'])
                    for d in list(ti.in_edges().values()) +
                    list(ti.out_edges().values())
                ]) + 1

                ti.add_input_port(new_real_input_port_id)
                source.connect(ti.in_port(new_real_input_port_id))

                ti.in_edge(new_real_input_port_id
                           )['external_port_id'] = new_external_port_id
                update_external_port_id(ti, 'in', external_port_id,
                                        new_external_port_id,
                                        record['internal_layer_id'])

        external_output_ports = defaultdict(list)
        for record in ti.output_port_map:
            assert 'external_port_id' in record
            external_output_ports[record['external_port_id']].append(record)

        for external_port_id, record_list in external_output_ports.items():
            if len(record_list) == 1:
                continue

            real_external_port_id = TensorIterator.special_port_to_real_port(
                ti, external_port_id, 'out')
            dsts = ti.out_port(real_external_port_id).get_destinations()

            for record in record_list[1:]:
                assert 'internal_layer_id' in record

                new_real_output_port_id = max(map(int,
                                                  ti.out_ports().keys())) + 1
                new_external_port_id = max([
                    int(d['external_port_id'])
                    for d in list(ti.in_edges().values()) +
                    list(ti.out_edges().values())
                ]) + 1

                ti.add_output_port(new_real_output_port_id)
                for dst in dsts:
                    ti.out_port(new_real_output_port_id).connect(dst)

                update_external_port_id(ti, 'out', external_port_id,
                                        new_external_port_id,
                                        record['internal_layer_id'])

        body.clean_up()
Esempio n. 6
0
    def replace_pattern(self, graph: Graph, match: dict):
        if match['rnn_layer']['op'] == 'LSTM':
            return

        rnn_layer = match['rnn_layer']

        # Build TensorIterator body first
        body = Graph(name=rnn_layer.name + '/sub_graph')
        body.graph = graph.graph

        # 1. Input squeeze Reshape
        inputs = [
            Op._create_data_node(
                body, rnn_layer.name + '/inport/' + str(inp), {
                    'shape':
                    rnn_layer.in_node(inp).shape.copy(),
                    'value':
                    rnn_layer.in_node(inp).value.copy()
                    if rnn_layer.in_node(inp).value is not None
                    and inp in [1, 2] else None
                }) for inp in [0, 4, 1, 2]
        ]  # X, h_init, WR, B

        inputs[0].shape[rnn_layer.sequence_dim] = 1
        input_squeeze = Squeeze(
            body,
            dict(name=rnn_layer.name + '/input_squeeze', internal_layer_id=0))
        input_squeeze_dim = Const(
            body,
            dict(name=rnn_layer.name + '/input_squeeze_dim',
                 value=rnn_layer.sequence_dim)).create_node_with_data()
        inputs[0] = input_squeeze.create_node_with_data(
            [inputs[0], input_squeeze_dim],
            edge_attrs=[{
                'internal_port_id': 0
            }])

        # 2. Output unsqueeze Reshape
        outputs = [
            Op._create_data_node(
                body, rnn_layer.name + '/outport/' + str(out), {
                    'shape':
                    rnn_layer.out_node(out).shape.copy()
                    if out in rnn_layer.out_nodes() else None
                }) for out in [0]
        ]
        for out in outputs:
            add_opoutput(body, out.id, 0, False)

        outputs[0].shape = np.delete(outputs[0].shape.copy(),
                                     rnn_layer.sequence_dim)
        output_unsqueeze_dim = Const(
            body,
            dict(name=rnn_layer.name + '/output_unsqueeze_dim',
                 value=rnn_layer.sequence_dim)).create_node_with_data()
        output_unsqueeze = Unsqueeze(
            body,
            dict(name=rnn_layer.name + '/output_unsqueeze/',
                 internal_layer_id=2))

        additional_attrs = dict(activations=rnn_layer.activations,
                                activation_alpha=rnn_layer.activation_alpha,
                                activation_beta=rnn_layer.activation_beta,
                                clip=rnn_layer.clip)
        if rnn_layer.op == 'GRU':
            additional_attrs[
                'linear_before_reset'] = rnn_layer.linear_before_reset

        # 3. ***Cell
        rnn_cell_op = self.get_rnn_cell(rnn_layer['op'])(
            body,
            dict(hidden_size=rnn_layer.hidden_size,
                 name=rnn_layer.name + '/{}Cell'.format(rnn_layer.op),
                 **additional_attrs,
                 internal_layer_id=1))

        gru_cell = rnn_cell_op.create_node_with_data(inputs,
                                                     data_nodes=outputs,
                                                     edge_attrs=[{}, {
                                                         'internal_port_id':
                                                         1
                                                     }, {
                                                         'internal_port_id':
                                                         2
                                                     }, {
                                                         'bin':
                                                         'weights'
                                                     }, {
                                                         'bin':
                                                         'biases'
                                                     }])

        # internal ports for outputs of cell
        gru_cell.in_node().out_edge(0)['internal_port_id'] = 4  # h_state

        gru_cell = output_unsqueeze.create_node_with_data(
            [gru_cell, output_unsqueeze_dim])
        gru_cell.in_node().out_edge(0)['internal_port_id'] = 3
        add_opoutput(body, gru_cell.id, 0, False)

        # 4. TensorIterator layer creating
        assert rnn_layer.direction in ['forward', 'reverse']
        if rnn_layer.direction == 'forward':
            stride = 1
            start = None
            end = None
        else:
            assert rnn_layer.direction == 'reverse'
            stride = -1
            start = -1
            end = 0

        # stacked h_state
        output_port_map = [{
            'external_port_id': 3,
            'internal_layer_id': 2,
            'internal_port_id': 3,
            'axis': rnn_layer.sequence_dim,
            'stride': stride,
            'start': start,
            'end': end,
            'part_size': 1,
        }]

        # Adding last h_state to outputs
        if len(rnn_layer.out_nodes()) == 2:
            output_port_map.extend([{
                'external_port_id': 4,
                'internal_layer_id': 1,
                'internal_port_id': 4,
            }])

        ti_op = TensorIterator(
            graph,
            {
                'name':
                rnn_layer.name + '/TensorIterator',
                'body':
                body,
                'in_ports_count':
                4,
                'out_ports_count':
                len(rnn_layer.out_nodes()),
                'input_port_map': [
                    {
                        'external_port_id': 0,
                        'internal_layer_id': 0,
                        'internal_port_id': 0,
                        'axis': rnn_layer.sequence_dim,
                        'stride': stride,
                        'start': start,
                        'end': end,
                        'part_size': 1,
                    },
                    {
                        'external_port_id': 1,
                        'internal_layer_id': 1,
                        'internal_port_id': 1,
                    },
                ],
                'output_port_map':
                output_port_map,
                # only for h state
                'back_edges': [
                    {
                        'from_layer': 1,
                        'from_port': 4,
                        'to_layer': 1,
                        'to_port': 1,
                    },
                ]
            })

        assert sorted(rnn_layer.out_nodes().keys()) == list(range(len(rnn_layer.out_nodes()))), \
            "There are gaps in output ports of GRUSequence operation. Node {}".format(rnn_layer.id)

        outs = ti_op.create_node_with_data(
            [rnn_layer.in_node(i) for i in [0, 4]],  # X, h_init
            data_nodes=[
                rnn_layer.out_node(i)
                for i in range(len(rnn_layer.out_nodes()))
            ],
            edge_attrs=[{
                'external_port_id': 0
            }, {
                'external_port_id': 1
            }])

        if not isinstance(outs, list):
            outs = list([outs])

        graph.remove_node(rnn_layer.id)
        outs[0].in_edge(0)['external_port_id'] = 3
        for i, out in enumerate(outs[1:]):
            external_port_id = 4 + i
            out.in_edge()['external_port_id'] = external_port_id

        ti = outs[0].in_node()
        TensorIterator.cover_body_input_data_nodes_with_parameter_ops(ti)
        TensorIterator.cover_body_constant_data_nodes_with_const_ops(ti)
        TensorIterator.normalize_internal_ids(ti)
Esempio n. 7
0
    def replace_pattern(self, graph: Graph, match: dict):
        lstm = match['lstm']

        # Build TensorIterator body first
        body = Graph(name=lstm.name + '/sub_graph')
        body.graph = graph.graph

        # 1. Input squeeze Reshape
        inputs = [
            Op._create_data_node(
                body, lstm.name + '/inport/' + str(inp), {
                    'shape':
                    lstm.in_node(inp).shape.copy(),
                    'value':
                    lstm.in_node(inp).value.copy() if lstm.in_node(inp).value
                    is not None and inp in [1, 2] else None
                }) for inp in [0, 4, 5, 1, 2]
        ]  # X, WR, B, h_init, c_init

        inputs[0].shape[lstm.sequence_dim] = 1
        reshape_dim = inputs[0].shape.copy()
        reshape_dim[lstm.batch_dim] = -1
        reshape_dim = np.delete(reshape_dim, lstm.sequence_dim)
        input_squeeze = Reshape(
            body, dict(name=lstm.name + '/input_squeeze', internal_layer_id=0))
        squeeze_dim_data = Const(body, {
            'name': lstm.name + '/input_squeeze_dim',
            'value': reshape_dim
        }).create_node_with_data()
        inputs[0] = input_squeeze.create_node_with_data(
            [inputs[0], squeeze_dim_data],
            edge_attrs=[{
                'internal_port_id': 0
            }])

        # 2. Output unsqueeze Reshape
        outputs = [
            Op._create_data_node(
                body, lstm.name + '/outport/' + str(out), {
                    'shape':
                    lstm.out_node(out).shape.copy() if out in lstm.out_nodes()
                    else lstm.in_node(4).shape.copy()
                }) for out in [0, 1]
        ]
        for out in outputs:
            add_opoutput(body, out.id, 0, False)

        unsqueezed_output_shape = outputs[0].shape.copy()
        unsqueezed_output_shape[lstm.sequence_dim] = 1
        squeezed_output_shape = np.delete(unsqueezed_output_shape,
                                          lstm.sequence_dim)
        outputs[0].shape = squeezed_output_shape
        unsqueezed_output_shape[lstm.batch_dim] = -1
        output_unsqueeze = Reshape(
            body, dict(name=lstm.name + 'output_unsqueeze',
                       internal_layer_id=2))
        unsqueeze_dim_data = Const(
            body, {
                'name': lstm.name + '/output_unsqueeze_dim',
                'value': unsqueezed_output_shape
            }).create_node_with_data()

        # 3. LSTMCell
        lstm_cell_op = LSTMCell(
            body,
            dict(hidden_size=lstm.hidden_size,
                 activations=lstm.activations,
                 activation_alpha=lstm.activation_alpha,
                 activation_beta=lstm.activation_beta,
                 clip=lstm.clip,
                 input_forget=lstm.input_forget,
                 name=lstm.name + '/LSTMCell',
                 internal_layer_id=1))
        lstm_cell_node = lstm_cell_op.create_node_with_data(
            inputs,
            data_nodes=outputs,
            edge_attrs=[{}, {
                'internal_port_id': 1
            }, {
                'internal_port_id': 2
            }, {
                'bin': 'weights'
            }, {
                'bin': 'biases'
            }])
        lstm_cell_node[0].in_node().out_edge(0)['internal_port_id'] = 4
        lstm_cell_node[0].in_node().out_edge(1)['internal_port_id'] = 5
        lstm_cell_node[0] = output_unsqueeze.create_node_with_data(
            [lstm_cell_node[0], unsqueeze_dim_data])
        lstm_cell_node[0].in_node().out_edge(0)['internal_port_id'] = 3
        add_opoutput(body, lstm_cell_node[0].id, 0, False)

        # 4. TensorIterator layer creating
        assert lstm.direction in ['forward', 'reverse']
        if lstm.direction == 'forward':
            stride = 1
            start = None
            end = None
        else:
            assert lstm.direction == 'reverse'
            stride = -1
            start = -1
            end = 0

        output_port_map = [{
            'external_port_id': 3,
            'internal_layer_id': 2,
            'internal_port_id': 3,
            'axis': lstm.sequence_dim,
            'stride': stride,
            'start': start,
            'end': end,
            'part_size': 1,
        }]

        # Adding h_state, c_state to outputs
        if len(lstm.out_nodes()) == 3:
            output_port_map.extend([{
                'external_port_id': 4,
                'internal_layer_id': 1,
                'internal_port_id': 4,
            }, {
                'external_port_id': 5,
                'internal_layer_id': 1,
                'internal_port_id': 5,
            }])

        ti_op = TensorIterator(
            graph, {
                'name':
                lstm.name + '/TensorIterator',
                'body':
                body,
                'in_ports_count':
                3,
                'out_ports_count':
                len(lstm.out_nodes()),
                'input_port_map': [
                    {
                        'external_port_id': 0,
                        'internal_layer_id': 0,
                        'internal_port_id': 0,
                        'axis': lstm.sequence_dim,
                        'stride': stride,
                        'start': start,
                        'end': end,
                        'part_size': 1,
                    },
                    {
                        'external_port_id': 1,
                        'internal_layer_id': 1,
                        'internal_port_id': 1,
                    },
                    {
                        'external_port_id': 2,
                        'internal_layer_id': 1,
                        'internal_port_id': 2,
                    },
                ],
                'output_port_map':
                output_port_map,
                'back_edges': [
                    {
                        'from_layer': 1,
                        'from_port': 4,
                        'to_layer': 1,
                        'to_port': 1,
                    },
                    {
                        'from_layer': 1,
                        'from_port': 5,
                        'to_layer': 1,
                        'to_port': 2,
                    },
                ]
            })

        assert sorted(lstm.out_nodes().keys()) == list(range(len(lstm.out_nodes()))), \
            "There are gaps in output ports of LSTMSequence operation. Node {}".format(lstm.id)

        outs = ti_op.create_node_with_data(
            [lstm.in_node(i) for i in [0, 4, 5]],  # X, h_init, c_init
            data_nodes=[
                lstm.out_node(i) for i in range(len(lstm.out_nodes()))
            ],
            edge_attrs=[{
                'external_port_id': 0
            }, {
                'external_port_id': 1
            }, {
                'external_port_id': 2
            }])

        if not isinstance(outs, list):
            outs = list([outs])

        graph.remove_node(lstm.id)
        outs[0].in_edge(0)['external_port_id'] = 3
        for i, out in enumerate(outs[1:]):
            external_port_id = 4 + i
            out.in_edge()['external_port_id'] = external_port_id
Esempio n. 8
0
    def normalize_ti(ti):
        assert ti.has_valid('input_port_map')
        assert ti.has_valid('output_port_map')
        assert ti.has_valid('back_edges')

        body = ti.body

        for record in ti.input_port_map:
            assert 'internal_layer_id' in record
            assert 'internal_port_id' not in record
            assert 'external_port_id' in record
            internal_layer_id = copy(record['internal_layer_id'])
            parameter = get_internal_node_by_layer_id(ti, internal_layer_id)

            dst = parameter.out_port(0).get_destination()
            in_port_idx = dst.idx
            internal_node = dst.node
            internal_port_id = internal_node.in_edge(
                in_port_idx)['internal_port_id']

            record['internal_layer_id'] = internal_node.internal_layer_id
            record['internal_port_id'] = internal_port_id
            TensorIterator.update_back_edge_map(
                ti=ti,
                direction='to',
                old_layer_id=internal_layer_id,
                old_port_id=None,
                new_layer_id=internal_node.internal_layer_id,
                new_port_id=internal_port_id)

        for record in ti.output_port_map:
            assert 'internal_layer_id' in record
            assert 'internal_port_id' not in record
            assert 'external_port_id' in record
            internal_layer_id = copy(record['internal_layer_id'])
            result = get_internal_node_by_layer_id(ti, internal_layer_id)

            out_port_idx = result.in_port(0).get_source().idx
            internal_node = result.in_port(0).get_source().node
            internal_port_id = internal_node.out_edge(
                out_port_idx)['internal_port_id']

            record['internal_layer_id'] = internal_node.internal_layer_id
            record['internal_port_id'] = internal_port_id
            TensorIterator.update_back_edge_map(
                ti=ti,
                direction='from',
                old_layer_id=internal_layer_id,
                old_port_id=None,
                new_layer_id=internal_node.internal_layer_id,
                new_port_id=internal_port_id)

        for record in ti.back_edges:
            assert 'from_layer' in record
            assert 'to_layer' in record

            internal_layer_id = record['from_layer']
            result = get_internal_node_by_layer_id(ti, internal_layer_id)

            if result.soft_get('type') == 'Result':
                assert 'from_port' not in record

                out_port_idx = result.in_port(0).get_source().idx
                internal_node = result.in_port(0).get_source().node
                internal_port_id = internal_node.out_edge(
                    out_port_idx)['internal_port_id']

                TensorIterator.update_back_edge_map(
                    ti=ti,
                    direction='from',
                    old_layer_id=internal_layer_id,
                    old_port_id=None,
                    new_layer_id=internal_node.internal_layer_id,
                    new_port_id=internal_port_id)

        body.remove_nodes_from([n.id for n in body.get_op_nodes(type='Input')])
        body.remove_nodes_from(
            [n.id for n in body.get_op_nodes(type='Parameter')])
Esempio n. 9
0
    def replace_pattern(graph, match: dict):
        # Here we will found all parts of TI: condition, inputs/outputs, back edges, body and create TensorIterator Op
        # and make all checks needed for TensorIterator work
        cond_data = match['condition'].out_node(
            0) if not match['condition'].out_port(0).disconnected() else None
        time_data = match['condition'].out_node(1) if len(
            match['condition'].out_nodes()) >= 1 else None
        name = match['condition'].name

        back_edges = []
        inputs = []
        outputs = []

        if cond_data is not None:
            for node in cond_data.out_nodes():
                if node['kind'] == 'op' and node[
                        'op'] == 'TensorIteratorBackEdge':
                    back_edges.append(node.id)
                elif node['kind'] == 'op' and node[
                        'op'] == 'TensorIteratorInput':
                    inputs.append(node.id)
                elif node['kind'] == 'op' and node[
                        'op'] == 'TensorIteratorOutput':
                    outputs.append(node.id)

        if time_data is not None:
            for node in time_data.out_nodes():
                if node['kind'] == 'op' and node['op'] == 'TensorIteratorInput':
                    inputs.append(node.id)
                elif node['kind'] == 'op' and node[
                        'op'] == 'TensorIteratorOutput':
                    outputs.append(node.id)
                else:
                    # something goes wrong here
                    assert False
        condition = match['condition']
        tensor_sequence_length = condition.in_node(0)

        nodes_to_remove = [
            n.id
            for n in (condition, cond_data, time_data, tensor_sequence_length)
            if n is not None
        ]
        graph.remove_nodes_from(nodes_to_remove)

        body_nodes, extra_inputs = get_body(graph, inputs, outputs)

        if cond_data is not None:
            body_nodes = list(set(body_nodes) - set([cond_data]))

        inputs += extra_inputs

        assert all([node in graph.nodes() for node in body_nodes])

        inputs = [Node(graph, node) for node in inputs]
        outputs = [Node(graph, node) for node in outputs]
        back_edges = [Node(graph, node) for node in back_edges]

        external_inputs = [{
            'external_data_id':
            node.in_node(1 if node.has_valid('axis') else 0),
            'internal_data_id':
            node.out_node(0),
            'axis':
            node.axis,
            'start':
            node.start,
            'end':
            node.end,
            'stride':
            node.stride,
            'part_size':
            node.part_size
        } for node in inputs]

        external_outputs = [{
            'external_data_id':
            node.out_node(0),
            'internal_data_id':
            node.in_node(1 if node.has_valid('axis') else 0),
            'axis':
            node.axis,
            'start':
            node.start,
            'end':
            node.end,
            'stride':
            node.stride,
            'part_size':
            node.part_size
        } for node in outputs]

        back_edges_data = [{
            'from_data_id': node.in_node(1),
            'to_data_id': node.out_node(0),
            'init_data_id': node.in_node(0),
        } for node in back_edges]

        body = Graph(name='body')
        body.graph = graph.graph
        body.add_nodes_from([(node, graph.node[node]) for node in body_nodes])
        body.add_edges_from([
            (u, v, k, d) for u, v, k, d in graph.edges(data=True, keys=True)
            if u in body_nodes and v in body_nodes
        ])

        graph.remove_nodes_from(body_nodes + [match['condition'].id] +
                                [inp.id for inp in inputs] +
                                [out.id for out in outputs])
        internal_id_count = 0
        real_back_edges = []
        for edge in back_edges_data:
            assert edge['from_data_id'].id in body.nodes()
            assert edge['to_data_id'].id in body.nodes()
            assert edge['init_data_id'].id in body.nodes()
            edge['from_data_id'] = Node(body, edge['from_data_id'].id)
            edge['to_data_id'] = Node(body, edge['to_data_id'].id)
            edge['init_data_id'] = Node(body, edge['init_data_id'].id)
            add_opoutput(body, edge['from_data_id'].id, 0, False)

            # Assign/reuse ids for the back-edge start; it comes from from_data_id
            assert len(edge['from_data_id'].in_nodes()) == 1
            # layer id
            if not edge['from_data_id'].in_node().has_valid(
                    'internal_layer_id'):
                edge['from_data_id'].in_node(
                )['internal_layer_id'] = internal_id_count
                internal_id_count += 1
            edge['from_layer'] = edge['from_data_id'].in_node(
            )['internal_layer_id']

            # port id
            if 'internal_port_id' not in edge['from_data_id'].in_edge():
                edge['from_data_id'].in_edge(
                )['internal_port_id'] = internal_id_count
                internal_id_count += 1
            edge['from_port'] = edge['from_data_id'].in_edge(
            )['internal_port_id']

            # Look at all consumers for a data that ends a back-edge
            # For each such consumer, there will be a separate back-edge (and input)
            current_real_back_edges = []
            for _, consumer, key, edge_attrs in body.out_edges(
                    edge['to_data_id'].id, data=True, keys=True):

                real_edge = {}
                real_edge.update(
                    edge)  # all real back_edges have the same back-edge start

                consumer = Node(body, consumer)

                if real_edge['to_data_id'].in_node().has_valid(
                        'internal_layer_id'):
                    assert False
                    real_edge['to_data_id'].out_node()['internal_layer_id'] = \
                        real_edge['to_data_id'].in_node().internal_layer_id
                elif not consumer.has_valid('internal_layer_id'):
                    consumer['internal_layer_id'] = internal_id_count
                    internal_id_count += 1
                real_edge['to_layer'] = consumer['internal_layer_id']

                assert 'internal_port_id' not in edge_attrs
                assert len(real_edge['init_data_id'].out_edges()) == 1
                assert not 'internal_port_id' in real_edge[
                    'init_data_id'].out_edge()
                edge_attrs['internal_port_id'] = internal_id_count
                internal_id_count += 1
                real_edge['to_port'] = edge_attrs['internal_port_id']
                real_edge['consumer'] = consumer
                real_edge['consumer_key'] = key

                real_edge['attrs'] = deepcopy(edge_attrs)
                current_real_back_edges.append(real_edge)

            # connect initial data node with each consumer providing actual edge attributes
            body.add_edges_from([
                (real_edge['init_data_id'].id, real_edge['consumer'].id,
                 real_edge['consumer_key'], real_edge['attrs'])
                for real_edge in current_real_back_edges
            ])

            body.remove_nodes_from(
                [edge['to_data_id'].id, edge['to_data_id'].in_node().id])
            real_back_edges += current_real_back_edges

        real_external_inputs = []

        for ext_inp in external_inputs:
            assert ext_inp['external_data_id'].id not in body.nodes()
            assert ext_inp['internal_data_id'].id in body.nodes()
            ext_inp['internal_data_id'] = Node(body,
                                               ext_inp['internal_data_id'].id)

            if ext_inp['axis'] is not None:
                # Insert squeezing resize at input port that has partitioning
                shape = ext_inp['internal_data_id'].shape.copy()
                assert not ext_inp['internal_data_id'].has_valid('value')
                new_input_data = Op._create_data_node(
                    body,
                    ext_inp['internal_data_id'].name + '/UnsqueezedInput',
                    dict(shape=shape_insert(shape, ext_inp['axis'], 1)))

                reshape_op = Squeeze(
                    body,
                    dict(name=ext_inp['internal_data_id'].name +
                         '/InputSqueeze'))
                reshape_dim_data = Const(
                    body, {
                        'name':
                        ext_inp['internal_data_id'].name + '/ReshapeDim',
                        'value': ext_inp['axis']
                    }).create_node_with_data()
                reshape_op.create_node_with_data(
                    [new_input_data, reshape_dim_data],
                    data_nodes=[ext_inp['internal_data_id']])
                ext_inp['internal_data_id'] = new_input_data

            ext_inp['internal_data_id']['is_input'] = True
            assert len(ext_inp['internal_data_id'].in_nodes()) == 0
            ext_inp['external_port_id'] = internal_id_count
            internal_id_count += 1
            for _, consumer, edge_attrs in body.out_edges(
                    ext_inp['internal_data_id'].id, data=True):
                real_ext_inp = {}
                real_ext_inp.update(ext_inp)
                consumer = Node(body, consumer)
                if not consumer.has_valid('internal_layer_id'):
                    consumer['internal_layer_id'] = internal_id_count
                    internal_id_count += 1
                if not 'internal_port_id' in edge_attrs:
                    edge_attrs['internal_port_id'] = internal_id_count
                    internal_id_count += 1
                real_ext_inp['internal_layer_id'] = consumer[
                    'internal_layer_id']
                real_ext_inp['internal_port_id'] = edge_attrs[
                    'internal_port_id']
                real_external_inputs.append(real_ext_inp)

        for ext_out in external_outputs:
            assert ext_out['external_data_id'].id not in body.nodes()
            assert ext_out['internal_data_id'].id in body.nodes()
            ext_out['internal_data_id'] = Node(body,
                                               ext_out['internal_data_id'].id)

            if ext_out['axis'] is not None:
                # Insert unsqueezing resize at output port that has partitioning
                reshape_op = Unsqueeze(
                    body,
                    dict(name=ext_out['internal_data_id'].name +
                         '/OutputUnsqueeze'))
                reshape_dim_data = Const(
                    body, {
                        'name':
                        ext_out['internal_data_id'].name + '/ReshapeDim',
                        'value': ext_out['axis']
                    }).create_node_with_data()
                ext_out['internal_data_id'] = reshape_op.create_node_with_data(
                    [ext_out['internal_data_id'], reshape_dim_data])

            # TODO: add here working with simple outputs

            if not any([
                    out_node.soft_get('op', None) == 'Result'
                    for out_node in ext_out['internal_data_id'].out_nodes()
            ]):
                add_opoutput(body, ext_out['internal_data_id'].id, 0, False)

            # assert len(ext_out['internal_data_id'].out_nodes()) == 0
            assert len(ext_out['internal_data_id'].in_nodes()) == 1
            if not 'internal_layer_id' in ext_out['internal_data_id'].in_node(
            ):
                ext_out['internal_data_id'].in_node(
                )['internal_layer_id'] = internal_id_count
                internal_id_count += 1
            if not 'internal_port_id' in ext_out['internal_data_id'].in_edge():
                ext_out['internal_data_id'].in_edge(
                )['internal_port_id'] = internal_id_count
                internal_id_count += 1
            ext_out['internal_layer_id'] = ext_out['internal_data_id'].in_node(
            )['internal_layer_id']
            ext_out['internal_port_id'] = ext_out['internal_data_id'].in_edge(
            )['internal_port_id']
            ext_out['external_port_id'] = internal_id_count
            internal_id_count += 1

        # create TensorIterator layer with pre-computed components
        ti_op = TensorIterator(
            graph, {
                'name':
                name + '/TensorIterator',
                'body':
                body,
                'in_ports_count':
                len(external_inputs),
                'out_ports_count':
                len(external_outputs),
                'input_port_map': [{
                    field: external_input[field]
                    for field in [
                        'external_port_id', 'internal_layer_id',
                        'internal_port_id', 'axis', 'stride', 'part_size',
                        'start', 'end'
                    ]
                } for external_input in real_external_inputs],
                'output_port_map': [{
                    field: external_output[field]
                    for field in [
                        'external_port_id', 'internal_layer_id',
                        'internal_port_id', 'axis', 'stride', 'part_size',
                        'start', 'end'
                    ]
                } for external_output in external_outputs],
                'back_edges': [{
                    field: edge[field]
                    for field in
                    ['from_layer', 'from_port', 'to_layer', 'to_port']
                } for edge in real_back_edges],
            })

        ti_outs = ti_op.create_node_with_data(
            inputs=[inp['external_data_id'] for inp in external_inputs],
            edge_attrs=[{
                'external_port_id': inp['external_port_id']
            } for inp in external_inputs],
            data_nodes=[out['external_data_id'] for out in external_outputs])

        if not isinstance(ti_outs, list):
            ti_outs = [ti_outs]

        for i, out in enumerate(ti_outs):
            out.in_edge(
            )['external_port_id'] = external_outputs[i]['external_port_id']

        ti = ti_outs[0].in_node()
        TensorIterator.cover_body_input_data_nodes_with_parameter_ops(ti)
        TensorIterator.cover_body_constant_data_nodes_with_const_ops(ti)
        TensorIterator.normalize_internal_ids(ti)
Esempio n. 10
0
    def replace_pattern(self, graph: nx.MultiDiGraph, match: dict):
        lstm = match['lstm']

        # Build TensorIterator body first
        body = nx.MultiDiGraph(name=lstm.name + '/sub_graph',
                               layout=graph.graph['layout'])
        inputs = [
            Op._create_data_node(
                body, lstm.name + '/inport/' + str(inp), {
                    'shape':
                    lstm.in_node(inp).shape.copy(),
                    'value':
                    lstm.in_node(inp).value.copy() if lstm.in_node(inp).value
                    is not None and inp in [1, 2] else None
                }) for inp in [0, 3, 4, 1, 2]
        ]
        inputs[0].shape[lstm.sequence_dim] = 1
        reshape_dim = inputs[0].shape.copy()
        reshape_dim[lstm.batch_dim] = -1
        reshape_dim = np.delete(reshape_dim, lstm.sequence_dim)
        input_squeeze = Reshape(
            body,
            dict(name=lstm.name + '/input_squeeze',
                 internal_layer_id=0,
                 dim=reshape_dim))
        inputs[0] = input_squeeze.create_node_with_data([inputs[0]],
                                                        edge_attrs=[{
                                                            'internal_port_id':
                                                            0
                                                        }])
        lstm_cell_op = LSTMCell(
            body,
            dict(hidden_size=match['lstm'].hidden_size,
                 name=lstm.name + '/LSTMCell',
                 internal_layer_id=1))
        outputs = [
            Op._create_data_node(
                body, lstm.name + '/outport/' + str(out), {
                    'shape':
                    lstm.out_node(out).shape.copy() if out in lstm.out_nodes()
                    else lstm.in_node(3).shape.copy(),
                    'is_output':
                    True
                }) for out in [0, 1]
        ]
        unsqueezed_output_shape = outputs[0].shape.copy()
        unsqueezed_output_shape[lstm.sequence_dim] = 1
        squeezed_output_shape = np.delete(unsqueezed_output_shape,
                                          lstm.sequence_dim)
        outputs[0].shape = squeezed_output_shape
        unsqueezed_output_shape[lstm.batch_dim] = -1
        output_unsqueeze = Reshape(
            body,
            dict(name=lstm.name + 'output_unsqueeze',
                 dim=unsqueezed_output_shape,
                 internal_layer_id=2))
        # TODO edge attributes should be assigned by the op itself
        lstm_cell_node = lstm_cell_op.create_node_with_data(
            inputs,
            data_nodes=outputs,
            edge_attrs=[{}, {
                'internal_port_id': 1
            }, {
                'internal_port_id': 2
            }, {
                'bin': 'weights'
            }, {
                'bin': 'biases'
            }])
        lstm_cell_node[0].in_node().out_edge(0)['internal_port_id'] = 4
        lstm_cell_node[0].in_node().out_edge(1)['internal_port_id'] = 5
        lstm_cell_node[0] = output_unsqueeze.create_node_with_data(
            [lstm_cell_node[0]])
        lstm_cell_node[0].in_node().out_edge(0)['internal_port_id'] = 3
        lstm_cell_node[0]['is_output'] = True

        assert lstm.direction in ['forward', 'reverse']
        if lstm.direction == 'forward':
            stride = 1
            start = None
            end = None
        else:
            assert lstm.direction == 'reverse'
            stride = -1
            start = -1
            end = 0

        output_port_map = [{
            'external_port_id': 3,
            'internal_layer_id': 2,
            'internal_port_id': 3,
            'axis': lstm.sequence_dim,
            'stride': stride,
            'start': start,
            'end': end,
            'part_size': 1,
        }]

        if len(lstm.out_nodes()) == 3:
            output_port_map.extend([{
                'external_port_id': 4,
                'internal_layer_id': 1,
                'internal_port_id': 4,
            }, {
                'external_port_id': 5,
                'internal_layer_id': 1,
                'internal_port_id': 5,
            }])

        ti_op = TensorIterator(
            graph, {
                'name':
                lstm.name + '/TensorIterator',
                'body':
                body,
                'input_port_map': [
                    {
                        'external_port_id': 0,
                        'internal_layer_id': 0,
                        'internal_port_id': 0,
                        'axis': lstm.sequence_dim,
                        'stride': stride,
                        'start': start,
                        'end': end,
                        'part_size': 1,
                    },
                    {
                        'external_port_id': 1,
                        'internal_layer_id': 1,
                        'internal_port_id': 1,
                    },
                    {
                        'external_port_id': 2,
                        'internal_layer_id': 1,
                        'internal_port_id': 2,
                    },
                ],
                'output_port_map':
                output_port_map,
                'back_edges': [
                    {
                        'from_layer': 1,
                        'from_port': 4,
                        'to_layer': 1,
                        'to_port': 1,
                    },
                    {
                        'from_layer': 1,
                        'from_port': 5,
                        'to_layer': 1,
                        'to_port': 2,
                    },
                ]
            })

        assert sorted(lstm.out_nodes().keys()) == list(range(len(lstm.out_nodes()))), \
            "There are gaps in output ports of LSTMSequence operation. Node {}".format(lstm.id)
        outs = ti_op.create_node_with_data(
            [lstm.in_node(i) for i in [0, 3, 4]],
            data_nodes=[
                lstm.out_node(i) for i in range(len(lstm.out_nodes()))
            ],
            edge_attrs=[{
                'external_port_id': 0
            }, {
                'external_port_id': 1
            }, {
                'external_port_id': 2
            }])

        if not isinstance(outs, list):
            outs = list([outs])

        graph.remove_node(lstm.id)
        outs[0].in_edge(0)['external_port_id'] = 3
        for i, out in enumerate(outs[1:]):
            external_port_id = 4 + i
            out.in_edge()['external_port_id'] = external_port_id
Esempio n. 11
0
    def replace_pattern(graph, match: dict):
        # Here we will found all parts of TI: condition, inputs/outputs, back edges, body and create TensorIterator Op
        # and make all checks needed for TensorIteator work
        cond_data = match['condition'].out_node(0)
        time_data = match['condition'].out_node(1) if len(
            match['condition'].out_nodes()) > 1 else None
        name = match['condition'].name

        assert match['condition'].in_node(0).has_valid('value')

        back_edges = []
        inputs = []
        outputs = []

        for node in cond_data.out_nodes():
            if node['kind'] == 'op' and node['op'] == 'TensorIteratorBackEdge':
                back_edges.append(node.id)
            elif node['kind'] == 'op' and node['op'] == 'TensorIteratorInput':
                inputs.append(node.id)
            elif node['kind'] == 'op' and node['op'] == 'TensorIteratorOutput':
                outputs.append(node.id)

        if time_data is not None:
            for node in time_data.out_nodes():
                if node['kind'] == 'op' and node['op'] == 'TensorIteratorInput':
                    inputs.append(node.id)
                elif node['kind'] == 'op' and node[
                        'op'] == 'TensorIteratorOutput':
                    outputs.append(node.id)
                else:
                    # something goes wrong here
                    assert False

        condition = match['condition']
        tensor_sequence_length = condition.in_node(0)
        graph.remove_nodes_from(
            [condition.id, cond_data.id, tensor_sequence_length.id])
        if time_data is not None:
            graph.remove_nodes_from([time_data.id])

        body_nodes, extra_inputs = get_body(graph, inputs, outputs)
        body_nodes = list(set(body_nodes) - set([cond_data]))

        inputs += extra_inputs

        assert all([node in graph.nodes() for node in body_nodes])

        inputs = [Node(graph, node) for node in inputs]
        outputs = [Node(graph, node) for node in outputs]
        back_edges = [Node(graph, node) for node in back_edges]

        external_inputs = [{
            'external_data_id':
            node.in_node(1 if node.has_valid('axis') else 0),
            'internal_data_id':
            node.out_node(0),
            'axis':
            node.axis,
            'start':
            node.start,
            'end':
            node.end,
            'stride':
            node.stride,
            'part_size':
            node.part_size
        } for node in inputs]

        external_outputs = [{
            'external_data_id':
            node.out_node(0),
            'internal_data_id':
            node.in_node(1 if node.has_valid('axis') else 0),
            'axis':
            node.axis,
            'start':
            node.start,
            'end':
            node.end,
            'stride':
            node.stride,
            'part_size':
            node.part_size
        } for node in outputs]

        back_edges_data = [{
            'from_data_id': node.in_node(1),
            'to_data_id': node.out_node(0),
            'init_data_id': node.in_node(0),
        } for node in back_edges]

        body = nx.MultiDiGraph(name='body')
        body.graph['layout'] = graph.graph['layout']
        body.add_nodes_from([(node, graph.node[node]) for node in body_nodes])
        body.add_edges_from([
            (u, v, k, d) for u, v, k, d in graph.edges(data=True, keys=True)
            if u in body_nodes and v in body_nodes
        ])

        graph.remove_nodes_from(body_nodes + [match['condition'].id] +
                                [inp.id for inp in inputs] +
                                [out.id for out in outputs])
        internal_id_count = 0
        real_back_edges = []
        for edge in back_edges_data:
            assert edge['from_data_id'].id in body.nodes()
            assert edge['to_data_id'].id in body.nodes()
            assert edge['init_data_id'].id in body.nodes()
            edge['from_data_id'] = Node(body, edge['from_data_id'].id)
            edge['to_data_id'] = Node(body, edge['to_data_id'].id)
            edge['init_data_id'] = Node(body, edge['init_data_id'].id)
            edge['from_data_id']['is_output'] = True

            # Assign/reuse ids for the back-edge start; it comes from from_data_id
            assert len(edge['from_data_id'].in_nodes()) == 1
            # layer id
            if not edge['from_data_id'].in_node().has_valid(
                    'internal_layer_id'):
                edge['from_data_id'].in_node(
                )['internal_layer_id'] = internal_id_count
                internal_id_count += 1
            edge['from_layer'] = edge['from_data_id'].in_node(
            )['internal_layer_id']

            # port id
            if 'internal_port_id' not in edge['from_data_id'].in_edge():
                edge['from_data_id'].in_edge(
                )['internal_port_id'] = internal_id_count
                internal_id_count += 1
            edge['from_port'] = edge['from_data_id'].in_edge(
            )['internal_port_id']

            # Look at all consumers for a data that ends a back-edge
            # For each such consumer, there will be a separate back-edge (and input)
            current_real_back_edges = []
            for _, consumer, key, edge_attrs in body.out_edges(
                    edge['to_data_id'].id, data=True, keys=True):

                real_edge = {}
                real_edge.update(
                    edge)  # all real back_edges have the same back-edge start

                consumer = Node(body, consumer)

                if real_edge['to_data_id'].in_node().has_valid(
                        'internal_layer_id'):
                    assert False
                    real_edge['to_data_id'].out_node(
                    )['internal_layer_id'] = real_edge['to_data_id'].in_node(
                    ).internal_layer_id
                elif not consumer.has_valid('internal_layer_id'):
                    consumer['internal_layer_id'] = internal_id_count
                    internal_id_count += 1
                real_edge['to_layer'] = consumer['internal_layer_id']

                assert 'internal_port_id' not in edge_attrs
                assert len(real_edge['init_data_id'].out_edges()) == 1
                assert not 'internal_port_id' in real_edge[
                    'init_data_id'].out_edge()
                edge_attrs['internal_port_id'] = internal_id_count
                internal_id_count += 1
                real_edge['to_port'] = edge_attrs['internal_port_id']
                real_edge['consumer'] = consumer
                real_edge['consumer_key'] = key

                real_edge['attrs'] = deepcopy(edge_attrs)
                current_real_back_edges.append(real_edge)

            # connect initial data node with each consumer providing actual edge attributes
            body.add_edges_from([
                (real_edge['init_data_id'].id, real_edge['consumer'].id,
                 real_edge['consumer_key'], real_edge['attrs'])
                for real_edge in current_real_back_edges
            ])

            body.remove_nodes_from(
                [edge['to_data_id'].id, edge['to_data_id'].in_node().id])
            real_back_edges += current_real_back_edges

        real_external_inputs = []

        for ext_inp in external_inputs:
            assert ext_inp['external_data_id'].id not in body.nodes()
            assert ext_inp['internal_data_id'].id in body.nodes()
            ext_inp['internal_data_id'] = Node(body,
                                               ext_inp['internal_data_id'].id)

            if ext_inp['axis'] is not None:
                # Insert squeezing resize at input port that has partitioning
                shape = ext_inp['internal_data_id'].shape.copy()
                assert not ext_inp['internal_data_id'].has_valid('value')
                new_input_data = Op._create_data_node(
                    body,
                    ext_inp['internal_data_id'].name + '/UnsqueezedInput',
                    dict(shape=np.insert(shape, ext_inp['axis'], 1)))
                dim = shape.copy()
                # try to do it dynamically reshapable along one of the axis
                # it is practically useful to reshape along batch dimension, but here we cannot detect where it is
                # so, we are guessing based onother transflormaions that it is the major dimension
                dim[0] = -1
                reshape_op = Reshape(
                    body,
                    dict(name=ext_inp['internal_data_id'].name +
                         '/InputSqueeze',
                         dim=dim))
                reshape_op.create_node_with_data(
                    [new_input_data], data_nodes=[ext_inp['internal_data_id']])
                ext_inp['internal_data_id'] = new_input_data

            ext_inp['internal_data_id']['is_input'] = True
            assert len(ext_inp['internal_data_id'].in_nodes()) == 0
            ext_inp['external_port_id'] = internal_id_count
            internal_id_count += 1
            for _, consumer, edge_attrs in body.out_edges(
                    ext_inp['internal_data_id'].id, data=True):
                real_ext_inp = {}
                real_ext_inp.update(ext_inp)
                consumer = Node(body, consumer)
                if not consumer.has_valid('internal_layer_id'):
                    consumer['internal_layer_id'] = internal_id_count
                    internal_id_count += 1
                if not 'internal_port_id' in edge_attrs:
                    edge_attrs['internal_port_id'] = internal_id_count
                    internal_id_count += 1
                real_ext_inp['internal_layer_id'] = consumer[
                    'internal_layer_id']
                real_ext_inp['internal_port_id'] = edge_attrs[
                    'internal_port_id']
                real_external_inputs.append(real_ext_inp)

        for ext_out in external_outputs:
            assert ext_out['external_data_id'].id not in body.nodes()
            assert ext_out['internal_data_id'].id in body.nodes()
            ext_out['internal_data_id'] = Node(body,
                                               ext_out['internal_data_id'].id)

            if ext_out['axis'] is not None:
                # Insert unsqueezing resize at output port that has partitioning
                dim = ext_out['internal_data_id'].shape.copy()
                # trying to make it dynamically reshapable (see related comment above for the first Reshape)
                dim[0] = -1
                assert not ext_out['internal_data_id'].has_valid('value')
                reshape_op = Reshape(
                    body,
                    dict(name=ext_out['internal_data_id'].name +
                         '/OutputUnsqueeze',
                         dim=np.insert(dim, ext_out['axis'], 1)))
                ext_out['internal_data_id'] = reshape_op.create_node_with_data(
                    [ext_out['internal_data_id']])

            # TODO: add here working with simple outputs

            ext_out['internal_data_id']['is_output'] = True
            #assert len(ext_out['internal_data_id'].out_nodes()) == 0
            assert len(ext_out['internal_data_id'].in_nodes()) == 1
            if not 'internal_layer_id' in ext_out['internal_data_id'].in_node(
            ):
                ext_out['internal_data_id'].in_node(
                )['internal_layer_id'] = internal_id_count
                internal_id_count += 1
            if not 'internal_port_id' in ext_out['internal_data_id'].in_edge():
                ext_out['internal_data_id'].in_edge(
                )['internal_port_id'] = internal_id_count
                internal_id_count += 1
            ext_out['internal_layer_id'] = ext_out['internal_data_id'].in_node(
            )['internal_layer_id']
            ext_out['internal_port_id'] = ext_out['internal_data_id'].in_edge(
            )['internal_port_id']
            ext_out['external_port_id'] = internal_id_count
            internal_id_count += 1

        ti_op = TensorIterator(
            graph, {
                'name':
                name + '/TensorIterator',
                'body':
                body,
                'input_port_map': [{
                    field: external_input[field]
                    for field in [
                        'external_port_id', 'internal_layer_id',
                        'internal_port_id', 'axis', 'stride', 'part_size',
                        'start', 'end'
                    ]
                } for external_input in real_external_inputs],
                'output_port_map': [{
                    field: external_output[field]
                    for field in [
                        'external_port_id', 'internal_layer_id',
                        'internal_port_id', 'axis', 'stride', 'part_size',
                        'start', 'end'
                    ]
                } for external_output in external_outputs],
                'back_edges': [{
                    field: edge[field]
                    for field in
                    ['from_layer', 'from_port', 'to_layer', 'to_port']
                } for edge in real_back_edges],
            })

        ti_outs = ti_op.create_node_with_data(
            inputs=[inp['external_data_id'] for inp in external_inputs],
            edge_attrs=[{
                'external_port_id': inp['external_port_id']
            } for inp in external_inputs],
            data_nodes=[out['external_data_id'] for out in external_outputs])

        if not isinstance(ti_outs, list):
            ti_outs = [ti_outs]

        for i, out in enumerate(ti_outs):
            out.in_edge(
            )['external_port_id'] = external_outputs[i]['external_port_id']