def replace_pattern(graph: Graph, match: dict): log.debug('================== SimpleOutputFind ===============') assert match['WriteEnter_data'].value is not None index = match['TensorArrayWrite'].in_node(1) value = match['TensorArrayWrite'].in_node(2) # axis == 0 because in TensorArray we ALWAYS iterate over 0 axis, other params will be fill later (with # condition) output = TensorIteratorOutput( graph, dict(external_port_id=str(match['WriteEnter_data'].value), internal_layer_id=value.id, name=match['TensorArrayWrite'].name + '/TensorIteratorOutput_')) output.create_node_with_data( inputs=[value, index], data_nodes=[match['TensorArrayRead_data']]) # Delete useless nodes safe_nodes = ['TensorArrayRead_data', 'Condition_data'] nodes_for_remove = [] for node in match.keys(): if node not in safe_nodes: nodes_for_remove.append(match[node].id) graph.remove_nodes_from(nodes_for_remove)
def replace_pattern(graph: nx.MultiDiGraph, match: dict): log.debug('================== SmartOutputFind ===============') assert match['WriteEnter_data'].value is not None assert match['start_data']['value'] == 0 and match['delta_data'][ 'value'] == 1 ta_size = match['TensorArray'].in_node() index = match['TensorArrayWrite'].in_node(1) value = match['TensorArrayWrite'].in_node(2) # axis == 0 because in TensorArray we ALWAYS iterate over 0 axis, other params will be fill later (with # condition) output = TensorIteratorOutput( graph, dict(axis=0, start=None, stride=None, part_size=None, external_port_id=str(match['WriteEnter_data'].value), internal_layer_id=value.id, name=match['TensorArrayWrite'].name + '/TensorIteratorOutput_')) output.create_node_with_data( inputs=[ta_size, value, index], data_nodes=[match['TensorArrayGather_data']]) # Delete useless nodes safe_nodes = ['TensorArrayGather_data', 'Condition_data'] nodes_for_remove = [] for node in match.keys(): if node not in safe_nodes: nodes_for_remove.append(match[node].id) graph.remove_nodes_from(nodes_for_remove)
def replace_pattern(self, graph: Graph, match: dict): log.debug('================== BackEdgeFind ===============') nodes_for_remove = [] from_body_data = match['NextIteration'].in_node() # If Exit path is exist -> create TensorIteratorOutput for this if 0 in match['Switch_1'].out_nodes(): Exit = match['Switch_1'].out_node(0).out_node( 0) # Switch -> Switch_data -> Exit assert Exit.has_valid('op') and Exit.op == 'Exit' output_data = Exit.out_node(0) nodes_for_remove.append(match['Switch_1'].out_node(0).id) nodes_for_remove.append(Exit.id) # Creating TensorIteratorOutput without partition output = TensorIteratorOutput(graph, dict(external_port_id=None, internal_layer_id=None, \ name=Exit.name + '/TensorIteratorOutput_' )) output.create_node_with_data( inputs=[from_body_data, match['condition_cond_data']], data_nodes=[output_data]) assert match['NextIteration_data'].id != match['Enter_1_data'].id backedge = TensorIteratorBackEdge( graph, dict(name=match['Identity_1'].name + '/TensorIteratorBackEdge_')) backedge.create_node_with_data(inputs=[ match['Enter_1_data'], from_body_data, match['condition_cond_data'] ], data_nodes=[match['Identity_1_data']]) # Delete useless nodes safe_nodes = [ 'Identity_1_data', 'condition', 'condition_cond_data', 'Enter_1_data' ] for node in match.keys(): if node not in safe_nodes: nodes_for_remove.append(match[node].id) graph.remove_nodes_from(nodes_for_remove)