def find_and_replace_pattern(self, graph: Graph): for loop_node in graph.get_op_nodes(op='Loop'): loop_name = loop_node.soft_get('name', loop_node.id) body_graph = loop_node['body'] body_pattern = MapFNOutputConcatenation.get_body_pattern() internal_matches = find_subgraph_match_to_pattern( body_graph, body_pattern) for internal_match in internal_matches: # check if TensorListReserve from the main graph is connected with Parameter node from the body graph # that is assigned for storing intermediate output results of While Loop. If yes, the transformation # detects intermediate outputs concatentation by this port and can use Loop axis attribute reserve_node = Loop.get_external_nodes_by_internal_id( loop_node, internal_match['container'].internal_layer_id) reserve_node = reserve_node[0] if ( len(reserve_node) == 1 and reserve_node[0].op == 'TensorListReserve') else None if reserve_node is None: log.info( "A sub-graph around the loop node {} does not match " "TensorFlow 2 MapFN pattern for intermediate outputs concatenation" .format(loop_name)) continue stack_node = Loop.get_external_nodes_by_internal_id( loop_node, internal_match['concatenation_result'].internal_layer_id) stack_node = stack_node[0] if len(stack_node) == 1 else None if stack_node is None: log.info( "A sub-graph around the loop node {} does not match " "TensorFlow 2 MapFN pattern for intermediate outputs concatenation" .format(loop_name)) continue # skip StopGradient node if it exists between While loop output port and TensorListStack operation stack_node = skip_nodes_by_condition( stack_node, lambda x: x.has_and_set('identity'), True) stack_node = stack_node if stack_node.op == 'TensorListStack' else None if stack_node is None: log.info( "A sub-graph around the loop node {} does not match " "TensorFlow 2 MapFN pattern for intermediate outputs concatenation" .format(loop_name)) continue external_match = { 'while': loop_node, 'reserve': reserve_node, 'stack': stack_node } # check that back edges connect Parameter node (or container with intermediate output results) # and concatenation result produced by TensorListSetItem node if Loop.back_edge_exists(loop_node.back_edges, internal_match['concatenation_result'].internal_layer_id, internal_match['container'].internal_layer_id) and \ Loop.back_edge_exists(loop_node.back_edges, internal_match['increment_iteration_result'].internal_layer_id, internal_match['current_iteration'].internal_layer_id): MapFNOutputConcatenation.transform_map_fn_output_concatenation( external_match, internal_match)
def find_and_replace_pattern(self, graph: Graph): for loop_node in graph.get_op_nodes(op='Loop'): loop_name = loop_node.soft_get('name', loop_node.id) body_graph = loop_node['body'] body_pattern = MapFNInputSlicing.get_body_pattern() internal_matches = find_subgraph_match_to_pattern(body_graph, body_pattern) for internal_match in internal_matches: # check if TensorListGetItem from the body graph is connected with TensorListFromTensor # from the main graph. If yes, the transformation detects input slicing by this port # and can use Loop axis attribute unstack_node = Loop.get_external_nodes_by_internal_id(loop_node, internal_match['tensor_list'].internal_layer_id) unstack_node = unstack_node[0] if (len(unstack_node) == 1 and unstack_node[0].op == 'TensorListFromTensor') else None if unstack_node is None: log.info("A sub-graph around the loop node {} does not match " "TensorFlow 2 MapFN pattern for input slicing".format(loop_name)) continue external_match = {'while': loop_node, 'unstack': unstack_node} # check that back edges connect correct Parameter and Result nodes in the body # check connections between body input ports and external inputs ports of Loop node if Loop.back_edge_exists(loop_node.back_edges, internal_match['increment_iteration_result'].internal_layer_id, internal_match['current_iteration'].internal_layer_id): MapFNInputSlicing.transform_map_fn_input_slicing(external_match, internal_match)
def find_and_replace_pattern(self, graph: Graph): for loop_node in graph.get_op_nodes(op='Loop'): loop_name = loop_node.soft_get('name', loop_node.id) body_graph = loop_node['body'] body_pattern = TensorListOutputConcatenation.get_body_pattern() internal_matches = find_subgraph_match_to_pattern( body_graph, body_pattern) for internal_match in internal_matches: # check if EmptyTensorList from the main graph is connected with Parameter node from the body graph # that is assigned for storing intermediate output results of While Loop. If yes, the transformation # detects intermediate outputs concatenation by this port and can use Loop axis attribute reserve_node = Loop.get_external_nodes_by_internal_id( loop_node, internal_match['container'].internal_layer_id) reserve_node = reserve_node[0] if ( len(reserve_node) == 1 and reserve_node[0].op == 'EmptyTensorList') else None if reserve_node is None: log.info( "A sub-graph around the loop node {} does not match " "TensorFlow 2 EmptyTensorList->TensorListPushBack pattern for intermediate " "outputs concatenation".format(loop_name)) continue external_match = {'while': loop_node, 'reserve': reserve_node} # check that back edges connect Parameter node (or container with intermediate output results) # and concatenation result produced by TensorListPushBack node if Loop.back_edge_exists( loop_node.back_edges, internal_match['concatenation_result']. internal_layer_id, internal_match['container'].internal_layer_id): TensorListOutputConcatenation.transform_tensor_list_output_concatenation( external_match, internal_match)