示例#1
0
    def re_numerate_input_ports(loop_node: Node):
        """
        Update input ports ids to be consecutive from 0 to num_input_ports - 1 and update the port_map values of the
        Loop node.

        :param loop_node: the Loop node
        :return: None
        """
        def re_number_input_port(loop_node: Node, old_port_id: int,
                                 new_port_id: int):
            loop_node.add_input_port(new_port_id, skip_if_exist=True)
            loop_node.in_port(old_port_id).get_connection().set_destination(
                loop_node.in_port(new_port_id))
            Loop.update_port_map_value(loop_node.input_port_map,
                                       'external_port_id', old_port_id,
                                       new_port_id)

        if len(loop_node.in_ports()) > 0:
            max_port_id = sorted(loop_node.in_ports().keys())[-1]
            new_port_id = 0
            for port_id in range(max_port_id + 1):
                if port_id in loop_node.in_ports():
                    if port_id != new_port_id:
                        re_number_input_port(loop_node, port_id, new_port_id)
                    new_port_id += 1

            for port_idx_to_remove in reversed(
                    range(new_port_id, max_port_id + 1)):
                if port_idx_to_remove in loop_node.in_ports().keys():
                    loop_node.delete_input_port(port_idx_to_remove)
示例#2
0
    def pull_constant_inputs_into_body(loop_node: Node):
        for port_idx, in_port in reversed(loop_node.in_ports().items()):
            if port_idx > 1 and not in_port.disconnected(
            ) and in_port.get_source().node.soft_get('type') == 'Const':
                body_parameter = Loop.external_port_id_to_body_node(
                    loop_node, port_idx, loop_node.input_port_map)
                # if there is a back edge into a body Parameter then we cannot replace it with a Const if the value
                # is updated during each iteration. So we need to check that the tensor is passed to the next iteration
                # unchanged
                if not Loop.parameter_unchanged_after_iteration(
                        loop_node, body_parameter):
                    continue

                original_const_node = in_port.get_source().node
                new_const_node = Const(
                    loop_node.body, original_const_node.attrs()).create_node()

                body_parameter.out_port(0).get_connection().set_source(
                    new_const_node.out_port(0))
                loop_node.body.remove_nodes_from([body_parameter.id])
                loop_node.delete_input_port(port_idx)
示例#3
0
    def remove_unused_ops_from_port_map(loop_node: Node,
                                        port_map: dict,
                                        port_map_attr: str,
                                        dir: [None, str] = None):
        """
        Find unused operations in the Loop body referenced in the port_map and removes Loop ports connected to it.
        Loop input port with index 0 and 1 are mandatory so cannot be removed.
        Output ports of the Loop may not be connected so check for that case also and remove such an ops from the
        port_map. The only exception is the "execution_condition" output which is a mandatory.

        :param loop_node: the Loop node to update
        :param port_map: the port_map (input, output or back edges)
        :param port_map_attr: the port_map attribute containing the `internal_layer_id`
        :param dir: the direction of the port_map meaning 'in' or 'out' port of the Loop
        :return:
        """
        record_ids_to_remove = []
        for record_id, record in enumerate(port_map):
            if len(loop_node.body.get_op_nodes(internal_layer_id=record[port_map_attr])) == 0 or \
                    (dir == 'out' and record.get('purpose', "") != 'execution_condition' and
                     record['external_port_id'] not in loop_node.out_ports()):
                record_ids_to_remove.append(record_id)
        for record_id_to_remove in reversed(record_ids_to_remove):
            if dir in ['in', 'out']:
                port_to_remove = port_map[record_id_to_remove][
                    'external_port_id']
                if port_to_remove != -1:
                    if dir == 'in':
                        # input port 0 and 1 are mandatory for the Loop node
                        if port_to_remove not in [
                                0, 1
                        ] and port_to_remove in loop_node.in_ports().keys():
                            loop_node.delete_input_port(port_to_remove)
                    elif dir == 'out' and port_to_remove in loop_node.out_ports(
                    ):
                        loop_node.delete_output_port(port_to_remove)
            del port_map[record_id_to_remove]