def replace_sub_graph(self, graph: Graph, external_match: dict): loop_node = external_match['while'] body_graph = loop_node['body'] body_pattern = KerasRNNOutputConcatenation.get_body_pattern() internal_matches = find_subgraph_match_to_pattern( body_graph, body_pattern) if len(internal_matches) == 1: internal_match = internal_matches[0] reserve_port_idx = compute_input_port_idx( external_match['reserve'], loop_node) stack_port_idx = external_match['stack'].in_port( 0).get_source().idx # check that back edges connect correct Parameter and Result nodes in the body # check connections between body input ports and external inputs ports of Loop node # check connections between body output ports and external output ports of Loop node if Loop.back_edge_exists(loop_node.back_edges, internal_match['concatenation_result'].internal_layer_id, internal_match['container'].internal_layer_id) and \ Loop.back_edge_exists(loop_node.back_edges, internal_match['increment_iteration_result'].internal_layer_id, internal_match['current_iteration'].internal_layer_id) and \ Loop.inter_edge_exists(loop_node.input_port_map, reserve_port_idx, internal_match['container'].internal_layer_id) and \ Loop.inter_edge_exists(loop_node.output_port_map, stack_port_idx, internal_match['concatenation_result'].internal_layer_id): KerasRNNOutputConcatenation.transform_keras_rnn_output_concatenation( external_match, internal_match)
def find_and_replace_pattern(self, graph: Graph): for loop_node in graph.get_op_nodes(op='Loop'): loop_name = loop_node.soft_get('name', loop_node.id) body_graph = loop_node['body'] body_pattern = MapFNOutputConcatenation.get_body_pattern() internal_matches = find_subgraph_match_to_pattern( body_graph, body_pattern) for internal_match in internal_matches: # check if TensorListReserve from the main graph is connected with Parameter node from the body graph # that is assigned for storing intermediate output results of While Loop. If yes, the transformation # detects intermediate outputs concatentation by this port and can use Loop axis attribute reserve_node = Loop.get_external_nodes_by_internal_id( loop_node, internal_match['container'].internal_layer_id) reserve_node = reserve_node[0] if ( len(reserve_node) == 1 and reserve_node[0].op == 'TensorListReserve') else None if reserve_node is None: log.info( "A sub-graph around the loop node {} does not match " "TensorFlow 2 MapFN pattern for intermediate outputs concatenation" .format(loop_name)) continue stack_node = Loop.get_external_nodes_by_internal_id( loop_node, internal_match['concatenation_result'].internal_layer_id) stack_node = stack_node[0] if len(stack_node) == 1 else None if stack_node is None: log.info( "A sub-graph around the loop node {} does not match " "TensorFlow 2 MapFN pattern for intermediate outputs concatenation" .format(loop_name)) continue # skip StopGradient node if it exists between While loop output port and TensorListStack operation stack_node = skip_nodes_by_condition( stack_node, lambda x: x.has_and_set('identity'), True) stack_node = stack_node if stack_node.op == 'TensorListStack' else None if stack_node is None: log.info( "A sub-graph around the loop node {} does not match " "TensorFlow 2 MapFN pattern for intermediate outputs concatenation" .format(loop_name)) continue external_match = { 'while': loop_node, 'reserve': reserve_node, 'stack': stack_node } # check that back edges connect Parameter node (or container with intermediate output results) # and concatenation result produced by TensorListSetItem node if Loop.back_edge_exists(loop_node.back_edges, internal_match['concatenation_result'].internal_layer_id, internal_match['container'].internal_layer_id) and \ Loop.back_edge_exists(loop_node.back_edges, internal_match['increment_iteration_result'].internal_layer_id, internal_match['current_iteration'].internal_layer_id): MapFNOutputConcatenation.transform_map_fn_output_concatenation( external_match, internal_match)
def find_and_replace_pattern(self, graph: Graph): for loop_node in graph.get_op_nodes(op='Loop'): loop_name = loop_node.soft_get('name', loop_node.id) body_graph = loop_node['body'] body_pattern = MapFNInputSlicing.get_body_pattern() internal_matches = find_subgraph_match_to_pattern( body_graph, body_pattern) for internal_match in internal_matches: # check if TensorListGetItem from the body graph is connected with TensorListFromTensor # from the main graph. If yes, the transformation detects input slicing by this port # and can use Loop axis attribute unstack_node = Loop.get_external_nodes_by_internal_id( loop_node, internal_match['tensor_list'].internal_layer_id) unstack_node = unstack_node[0] if ( len(unstack_node) == 1 and unstack_node[0].op == 'TensorListFromTensor') else None if unstack_node is None: log.info( "A sub-graph around the loop node {} does not match " "TensorFlow 2 MapFN pattern for input slicing".format( loop_name)) continue external_match = {'while': loop_node, 'unstack': unstack_node} # check that back edges connect correct Parameter and Result nodes in the body # check connections between body input ports and external inputs ports of Loop node if Loop.back_edge_exists( loop_node.back_edges, internal_match['increment_iteration_result']. internal_layer_id, internal_match['current_iteration'].internal_layer_id): MapFNInputSlicing.transform_map_fn_input_slicing( external_match, internal_match)
def replace_sub_graph(self, graph: Graph, external_match: dict): loop_node = external_match['while'] body_graph = loop_node['body'] body_pattern = KerasRNNInputSlicing.get_body_pattern() internal_matches = find_subgraph_match_to_pattern( body_graph, body_pattern) # a case of multiple matches is not handled since it is not clear how to select corresponding match if len(internal_matches) == 1: internal_match = internal_matches[0] loop_node = external_match['while'] unstack_port_idx = compute_input_port_idx( external_match['unstack'], loop_node) # check that back edges connect correct Parameter and Result nodes in the body # check connections between body input ports and external inputs ports of Loop node if Loop.back_edge_exists(loop_node.back_edges, internal_match['increment_iteration_result'].internal_layer_id, internal_match['current_iteration'].internal_layer_id) and \ Loop.inter_edge_exists(loop_node.input_port_map, unstack_port_idx, internal_match['tensor_list'].internal_layer_id): # only if inter-graph match passed it starts to process the sub-graph KerasRNNInputSlicing.transform_keras_rnn_input_slicing( external_match, internal_match)