Esempio n. 1
0
    def find_and_replace_pattern(self, graph: Graph):
        for ti in graph.get_op_nodes(type='TensorIterator'):
            self.external_nodes_normalization(ti)

            if len([record for record in ti.input_port_map if record.get('axis') is not None]) == 0:
                for record in ti.output_port_map:
                    if record.get('axis') is not None:
                        record['start'] = 0
                        real_output_port = TensorIterator.special_port_to_real_port(ti, record['external_port_id'], 'out')
                        output_shape = ti.out_port(real_output_port).data.get_shape()
                        assert output_shape is not None
                        record['end'] = output_shape[record['axis']]
Esempio n. 2
0
    def external_nodes_normalization(ti):
        """
        TensorIterator external ports may have several internal layer connections.

        Current transformation does the following:
            - normalizes port maps (eliminating duplicated records)
            - replicates external input/output port for each internal Parameter/Result it is connected to
            - updates input and output port maps according to previous step replications
        """
        def update_external_port_id(ti, port_type, old_external_port_id,
                                    new_external_port_id, internal_layer_id):
            assert port_type in ['in', 'out']

            port_map = ti.input_port_map if port_type == 'in' else ti.output_port_map
            for record in port_map:
                if record['external_port_id'] == old_external_port_id and \
                        record['internal_layer_id'] == internal_layer_id:
                    record['external_port_id'] = new_external_port_id

        NormalizeTI.maps_uniqueization(ti)

        body = ti.body

        external_input_ports = defaultdict(list)
        for record in ti.input_port_map:
            assert 'external_port_id' in record
            external_input_ports[record['external_port_id']].append(record)

        for external_port_id, record_list in external_input_ports.items():
            if len(record_list) == 1:
                continue

            real_external_port_id = TensorIterator.special_port_to_real_port(
                ti, external_port_id, 'in')
            source = ti.in_port(real_external_port_id).get_source()

            for record in record_list[1:]:
                assert 'internal_layer_id' in record

                new_real_input_port_id = max(map(int,
                                                 ti.in_ports().keys())) + 1
                new_external_port_id = max([
                    int(d['external_port_id'])
                    for d in list(ti.in_edges().values()) +
                    list(ti.out_edges().values())
                ]) + 1

                ti.add_input_port(new_real_input_port_id)
                source.connect(ti.in_port(new_real_input_port_id))

                ti.in_edge(new_real_input_port_id
                           )['external_port_id'] = new_external_port_id
                update_external_port_id(ti, 'in', external_port_id,
                                        new_external_port_id,
                                        record['internal_layer_id'])

        external_output_ports = defaultdict(list)
        for record in ti.output_port_map:
            assert 'external_port_id' in record
            external_output_ports[record['external_port_id']].append(record)

        for external_port_id, record_list in external_output_ports.items():
            if len(record_list) == 1:
                continue

            real_external_port_id = TensorIterator.special_port_to_real_port(
                ti, external_port_id, 'out')
            dsts = ti.out_port(real_external_port_id).get_destinations()

            for record in record_list[1:]:
                assert 'internal_layer_id' in record

                new_real_output_port_id = max(map(int,
                                                  ti.out_ports().keys())) + 1
                new_external_port_id = max([
                    int(d['external_port_id'])
                    for d in list(ti.in_edges().values()) +
                    list(ti.out_edges().values())
                ]) + 1

                ti.add_output_port(new_real_output_port_id)
                for dst in dsts:
                    ti.out_port(new_real_output_port_id).connect(dst)

                update_external_port_id(ti, 'out', external_port_id,
                                        new_external_port_id,
                                        record['internal_layer_id'])

        body.clean_up()