def replace_sub_graph(self, graph: Graph, external_match: dict): loop_node = external_match['while'] body_graph = loop_node['body'] body_pattern = KerasRNNOutputConcatenation.get_body_pattern() internal_matches = find_subgraph_match_to_pattern( body_graph, body_pattern) if len(internal_matches) == 1: internal_match = internal_matches[0] reserve_port_idx = compute_input_port_idx( external_match['reserve'], loop_node) stack_port_idx = external_match['stack'].in_port( 0).get_source().idx # check that back edges connect correct Parameter and Result nodes in the body # check connections between body input ports and external inputs ports of Loop node # check connections between body output ports and external output ports of Loop node if Loop.back_edge_exists(loop_node.back_edges, internal_match['concatenation_result'].internal_layer_id, internal_match['container'].internal_layer_id) and \ Loop.back_edge_exists(loop_node.back_edges, internal_match['increment_iteration_result'].internal_layer_id, internal_match['current_iteration'].internal_layer_id) and \ Loop.inter_edge_exists(loop_node.input_port_map, reserve_port_idx, internal_match['container'].internal_layer_id) and \ Loop.inter_edge_exists(loop_node.output_port_map, stack_port_idx, internal_match['concatenation_result'].internal_layer_id): KerasRNNOutputConcatenation.transform_keras_rnn_output_concatenation( external_match, internal_match)
def replace_sub_graph(self, graph: Graph, external_match: dict): loop_node = external_match['while'] body_graph = loop_node['body'] body_pattern = KerasRNNInputSlicing.get_body_pattern() internal_matches = find_subgraph_match_to_pattern( body_graph, body_pattern) # a case of multiple matches is not handled since it is not clear how to select corresponding match if len(internal_matches) == 1: internal_match = internal_matches[0] loop_node = external_match['while'] unstack_port_idx = compute_input_port_idx( external_match['unstack'], loop_node) # check that back edges connect correct Parameter and Result nodes in the body # check connections between body input ports and external inputs ports of Loop node if Loop.back_edge_exists(loop_node.back_edges, internal_match['increment_iteration_result'].internal_layer_id, internal_match['current_iteration'].internal_layer_id) and \ Loop.inter_edge_exists(loop_node.input_port_map, unstack_port_idx, internal_match['tensor_list'].internal_layer_id): # only if inter-graph match passed it starts to process the sub-graph KerasRNNInputSlicing.transform_keras_rnn_input_slicing( external_match, internal_match)