def re_numerate_input_ports(loop_node: Node): """ Update input ports ids to be consecutive from 0 to num_input_ports - 1 and update the port_map values of the Loop node. :param loop_node: the Loop node :return: None """ def re_number_input_port(loop_node: Node, old_port_id: int, new_port_id: int): loop_node.add_input_port(new_port_id, skip_if_exist=True) loop_node.in_port(old_port_id).get_connection().set_destination( loop_node.in_port(new_port_id)) Loop.update_port_map_value(loop_node.input_port_map, 'external_port_id', old_port_id, new_port_id) if len(loop_node.in_ports()) > 0: max_port_id = sorted(loop_node.in_ports().keys())[-1] new_port_id = 0 for port_id in range(max_port_id + 1): if port_id in loop_node.in_ports(): if port_id != new_port_id: re_number_input_port(loop_node, port_id, new_port_id) new_port_id += 1 for port_idx_to_remove in reversed( range(new_port_id, max_port_id + 1)): if port_idx_to_remove in loop_node.in_ports().keys(): loop_node.delete_input_port(port_idx_to_remove)
def pull_constant_inputs_into_body(loop_node: Node): for port_idx, in_port in reversed(loop_node.in_ports().items()): if port_idx > 1 and not in_port.disconnected( ) and in_port.get_source().node.soft_get('type') == 'Const': body_parameter = Loop.external_port_id_to_body_node( loop_node, port_idx, loop_node.input_port_map) # if there is a back edge into a body Parameter then we cannot replace it with a Const if the value # is updated during each iteration. So we need to check that the tensor is passed to the next iteration # unchanged if not Loop.parameter_unchanged_after_iteration( loop_node, body_parameter): continue original_const_node = in_port.get_source().node new_const_node = Const( loop_node.body, original_const_node.attrs()).create_node() body_parameter.out_port(0).get_connection().set_source( new_const_node.out_port(0)) loop_node.body.remove_nodes_from([body_parameter.id]) loop_node.delete_input_port(port_idx)
def remove_unused_ops_from_port_map(loop_node: Node, port_map: dict, port_map_attr: str, dir: [None, str] = None): """ Find unused operations in the Loop body referenced in the port_map and removes Loop ports connected to it. Loop input port with index 0 and 1 are mandatory so cannot be removed. Output ports of the Loop may not be connected so check for that case also and remove such an ops from the port_map. The only exception is the "execution_condition" output which is a mandatory. :param loop_node: the Loop node to update :param port_map: the port_map (input, output or back edges) :param port_map_attr: the port_map attribute containing the `internal_layer_id` :param dir: the direction of the port_map meaning 'in' or 'out' port of the Loop :return: """ record_ids_to_remove = [] for record_id, record in enumerate(port_map): if len(loop_node.body.get_op_nodes(internal_layer_id=record[port_map_attr])) == 0 or \ (dir == 'out' and record.get('purpose', "") != 'execution_condition' and record['external_port_id'] not in loop_node.out_ports()): record_ids_to_remove.append(record_id) for record_id_to_remove in reversed(record_ids_to_remove): if dir in ['in', 'out']: port_to_remove = port_map[record_id_to_remove][ 'external_port_id'] if port_to_remove != -1: if dir == 'in': # input port 0 and 1 are mandatory for the Loop node if port_to_remove not in [ 0, 1 ] and port_to_remove in loop_node.in_ports().keys(): loop_node.delete_input_port(port_to_remove) elif dir == 'out' and port_to_remove in loop_node.out_ports( ): loop_node.delete_output_port(port_to_remove) del port_map[record_id_to_remove]