def transform_map_fn_input_slicing(external_match: dict, internal_match: dict): """ Transforms TensorFlow 2 input slicing into use of axis attribute for input port of Loop node :param external_match: a match used for handling a part of the main graph responsible for input slicing :param internal_match: a match used for handling a part of the body graph responsible for input slicing """ loop_node = external_match['while'] unstack_node = external_match['unstack'] body_graph = loop_node['body'] tensor_list_get_item_node = internal_match['slicing'] unstack_placeholder = internal_match['tensor_list'] tensor_list_get_item_node_name = tensor_list_get_item_node.soft_get('name', tensor_list_get_item_node.id) # 1. process the body graph to avoid unsupported operations: TensorListGetItem and TensorListSetItem # replace TensorListGetItem with Squeeze node and iterate through slices using axis for input port squeeze_list_element = create_op_with_const_inputs(body_graph, Squeeze, {1: int64_array(0)}, {'name': 'TensorListGetItemSqueeze'}) tensor_list_get_item_node.in_port(0).get_connection().set_destination(squeeze_list_element.in_port(0)) tensor_list_get_item_node.out_port(0).get_connection().set_source(squeeze_list_element.out_port(0)) rename_nodes([(tensor_list_get_item_node, tensor_list_get_item_node_name + '/AbandonedName'), (squeeze_list_element, tensor_list_get_item_node_name)]) unstack_placeholder_layer_id = unstack_placeholder.internal_layer_id Loop.update_port_map_value_ext(loop_node.input_port_map, 'internal_layer_id', unstack_placeholder_layer_id, 'axis', 0) # 2. process locality of Loop node in the main graph to avoid unsupported operations: # TensorListFromTensor, TensorListReserve, and TensorListStack # remove TensorListFromTensor and pass a tensor to Loop as is unstack_node.out_port(0).get_connection().set_source(unstack_node.in_port(0).get_connection().get_source())
def transform_map_fn_output_concatenation(external_match: dict, internal_match: dict): """ Transforms TensorFlow 2 output concatenation into use of axis attribute for output port of Loop node :param external_match: a match used for handling a part of the main graph responsible for output concatenation :param internal_match: a match used for handling a part of the body graph responsible for output concatenation """ loop_node = external_match['while'] stack_node = external_match['stack'] list_reserve_node = external_match['reserve'] body_graph = loop_node['body'] tensor_list_set_item_node = internal_match['concatenation'] tensor_list_set_item_node_name = tensor_list_set_item_node.soft_get( 'name', tensor_list_set_item_node.id) list_result_node = internal_match['concatenation_result'] # replace TensorListSetItem with Unsqueeze and use axis attribute for corresponding Result node # to concatenate results from different iterations unsqueeze_list_element = create_op_with_const_inputs( body_graph, Unsqueeze, {1: int64_array(0)}, {'name': 'TensorListSetItemUnsqueeze'}) tensor_list_set_item_node.in_port(2).get_connection().set_destination( unsqueeze_list_element.in_port(0)) tensor_list_set_item_node.out_port(0).get_connection().set_source( unsqueeze_list_element.out_port(0)) rename_nodes([(tensor_list_set_item_node, tensor_list_set_item_node_name + '/AbandonedName'), (unsqueeze_list_element, tensor_list_set_item_node_name) ]) list_result_node_layer_id = list_result_node.internal_layer_id Loop.update_port_map_value_ext(loop_node.output_port_map, 'internal_layer_id', list_result_node_layer_id, 'axis', 0) # remove TensorListStack to by-pass the node since the result from the Loop node is already concatenated stack_node.out_port(0).get_connection().set_source( stack_node.in_port(0).get_connection().get_source()) # disconnect ListReserve node because it is no longer needed for Loop list_reserve_node.out_port(0).disconnect() # connect a number of iterations with trip count that can be received from the second input of ListReserve # create a constant network with True value for execution_condition so that IE can ignore execution condition # and perform trip_counts iterations. This approach with known trip count value allows to avoid dynamism. loop_node.in_port(1).disconnect() list_reserve_node.in_port(1).get_source().connect(loop_node.in_port(1)) for record in loop_node.output_port_map: if 'purpose' in record and record[ 'purpose'] == 'execution_condition': exec_cond_layer_id = record['internal_layer_id'] exec_cond_node = Loop.get_body_node_by_internal_id( loop_node, exec_cond_layer_id) const_true = Const(body_graph, { 'value': np.array(True, dtype=np.bool) }).create_node() exec_cond_node.in_port(0).get_connection().set_source( const_true.out_port(0))
def transform_tensor_list_output_concatenation(external_match: dict, internal_match: dict): """ Transforms TensorFlow 2 output concatenation into use of axis attribute for output port of Loop node :param external_match: a match used for handling a part of the main graph responsible for output concatenation :param internal_match: a match used for handling a part of the body graph responsible for output concatenation """ loop_node = external_match['while'] empty_tensor_list_node = external_match['reserve'] body_graph = loop_node['body'] tensor_list_push_back_node = internal_match['concatenation'] tensor_list_push_back_node_name = tensor_list_push_back_node.soft_get( 'name', tensor_list_push_back_node.id) list_result_node = internal_match['concatenation_result'] # replace TensorListPushBack with Unsqueeze and use axis attribute for corresponding Result node # to concatenate results from different iterations unsqueeze_list_element = create_op_with_const_inputs( body_graph, Unsqueeze, {1: int64_array(0)}, { 'name': tensor_list_push_back_node_name + '/TensorListPushBackUnsqueeze' }) tensor_list_push_back_node.in_port(1).get_connection().set_destination( unsqueeze_list_element.in_port(0)) tensor_list_push_back_node.out_port(0).get_connection().set_source( unsqueeze_list_element.out_port(0)) rename_nodes([(tensor_list_push_back_node, tensor_list_push_back_node_name + '/AbandonedName'), (unsqueeze_list_element, tensor_list_push_back_node_name) ]) list_result_node_layer_id = list_result_node.internal_layer_id Loop.update_port_map_value_ext(loop_node.output_port_map, 'internal_layer_id', list_result_node_layer_id, 'axis', 0) # disconnect EmptyTensorList node because it is no longer needed for Loop empty_tensor_list_node.out_port(0).disconnect() loop_node.in_port(1).disconnect() empty_tensor_list_node.in_port(1).get_source().connect( loop_node.in_port(1)) # remove back edge for record in loop_node.back_edges: if 'from_layer' in record and record[ 'from_layer'] == list_result_node_layer_id: loop_node.back_edges.remove(record)