Esempio n. 1
0
    def use_shapes_from_ir(node: Node):
        # This function used instead of operation shape inference function to set all output shapes the same as
        # restored from IR. Firstly, check equality of old (restored from IR) and
        # new (calculated while shape inference) input shapes
        node['new_input_shapes'] = list()
        for n in node.in_ports():
            if not node.in_port(n).disconnected(
            ):  # We use such condition to handle optional inputs
                node.new_input_shapes.append(node.in_port(n).data.get_shape())
        assert len(node.new_input_shapes) == len(node.old_input_shapes), \
            'Something wrong happened while {} node with type {} copy shape inference!'.format(node.name, node.type)
        for new_input_shape, old_input_shape in zip(node.new_input_shapes,
                                                    node.old_input_shapes):
            assert np.array_equal(new_input_shape, old_input_shape), \
                'Something wrong happened while {} node with type {} copy shape inference!'.format(node.name, node.type)

        # We need to use number of connected input ports to avoid errors with numbering
        # in node.ports dictionary, where used numbers of input nodes
        connected_input_ports = []
        for n in node.in_ports():
            if not node.in_port(n).disconnected():
                connected_input_ports.append(node.in_port(n))
        i = len(connected_input_ports)

        # Set all output shapes the same as restored from IR
        for num in node.out_ports():
            if i in node.ports:
                node.out_port(num).data.set_shape(int64_array(
                    node.ports[i][0]))
            else:
                assert node.out_port(num).data.get_shape(
                ) is not None, "Newly added port does not have set shape"
            i += 1
Esempio n. 2
0
    def extend(op: Node):
        def normalize_port_map(port_map: dict):
            for port in port_map:
                for elem in [
                        'axis', 'stride', 'part_size', 'start', 'end',
                        'purpose'
                ]:
                    if port.get(elem) is None:
                        port[elem] = None

        assert op.has(
            'body'), 'There is no "body" attribute in the Loop op {}.'.format(
                op.name)

        # Now op.body is an IREngine, we need to replace it with IREngine.graph
        op.body.graph.graph['cmd_params'] = op.graph.graph['cmd_params']
        op.body.graph.graph['ir_version'] = op.graph.graph['ir_version']
        op.body.graph.name = op.name + '/body'

        for node in op.body.graph.get_op_nodes():
            node['internal_layer_id'] = int(node.id)

        op.body = copy_graph_with_ops(op.body.graph)

        normalize_port_map(op.input_port_map)
        normalize_port_map(op.output_port_map)

        # the 'external_port_id' uses end-to-end numbering of ports, but at this moment it is separate for input and
        # output ports so we need to decrease the output por_id with a number of input ports
        for record in op.output_port_map:
            if record['external_port_id'] != -1:
                record['external_port_id'] -= len(op.in_ports())

        for edge in op.back_edges:
            edge['from_layer'] = edge['from-layer']
            edge['to_layer'] = edge['to-layer']

            edge['to_port'] = 0
            edge['from_port'] = 0

            del (edge['from-layer'])
            del (edge['to-layer'])