def replace_pattern(graph: Graph, match: dict): """ Create condition node and delete all useless nodes (like Switch/Merge/Identity) from condition pattern """ log.debug( '================== DynamicDecoderConditionFind ==================' ) # Create and connect condition node for dynamic decoder in TF loop_condiiton = match['loop_cond_data'] iterator_data = match['identity_data'] condition_attrs = dict(name=match['loop_cond'].name + '/TensorIteratorCondition_') condition = TensorIteratorCondition(graph, attrs=condition_attrs) condition.create_node_with_data( inputs=[match['Less_enter'].in_node()], data_nodes=[loop_condiiton, iterator_data]) # Delete useless nodes safe_nodes = [ 'loop_cond_data', 'identity_data', 'TensorIteratorOutput', 'TensorIteratorOutput_1' ] nodes_for_remove = [] for node in match.keys(): if node not in safe_nodes: nodes_for_remove.append(match[node].id) graph.remove_nodes_from(nodes_for_remove)
def replace_pattern(graph: Graph, match: dict): log.debug('================== SimpleConditionFind ===============') # init_1 init_1 = match['init_1_data'].value assert init_1 is not None init_1 = int(init_1) # step_1 assert match['add_1_y_data'].value is not None step_1 = int(match['add_1_y_data'].value) match['loop_cond_data'].value = None # Create condition node and delete all useless nodes from condition pattern condition_attrs = dict(iter=dict(init=init_1, step=step_1), name=match['loop_cond'].name + '/TensorIteratorCondition_') condition = TensorIteratorCondition(graph, attrs=condition_attrs) condition.create_node_with_data(inputs=[match['Strided_slice_data']], data_nodes=[match['loop_cond_data'], match['Identity_1_data']]) # Delete useless nodes safe_nodes = ['loop_cond_data', 'Identity_1_data', 'Strided_slice', 'Strided_slice_data'] nodes_for_remove = [] for node in match.keys(): if node not in safe_nodes: nodes_for_remove.append(match[node].id) graph.remove_nodes_from(nodes_for_remove)
def replace_pattern(graph: nx.MultiDiGraph, match: dict): log.debug('================== ConditionFind ===============') max_node = match['minimum'].in_node(1).in_node() assert max_node['kind'] == 'op' and max_node['op'] == 'Maximum' #init_1 init_1 = match['init_1_data'].value assert init_1 is not None init_1 = int(init_1) #init_2 init_2 = match['init_2_data'].value assert init_2 is not None init_2 = int(init_2) #step_1 assert match['add_1_y_data'].value is not None step_1 = int(match['add_1_y_data'].value) #step_2 assert match['add_2_y_data'].value is not None step_2 = int(match['add_2_y_data'].value) match['loop_cond_data'].value = None match['Identity_2_data'].value = None # Create condition node and delete all useless nodes from condition pattern condition_attrs = dict(time=dict(init=init_2, step=step_2), iter=dict(init=init_1, step=step_1), \ name=match['loop_cond'].name + '/TensorIteratorCondition_') condition = TensorIteratorCondition(graph, attrs=condition_attrs) condition.create_node_with_data( inputs=[match['Strided_slice_data'], match['minimum_data']], data_nodes=[match['loop_cond_data'], match['Identity_2_data']]) # Delete useless nodes safe_nodes = [ 'loop_cond_data', 'Identity_2_data', 'Strided_slice', 'Strided_slice_data', 'minimum', 'minimum_data' ] nodes_for_remove = [] for node in match.keys(): if node not in safe_nodes: nodes_for_remove.append(match[node].id) graph.remove_nodes_from(nodes_for_remove)
def replace_pattern(graph: Graph, match: dict): log.debug('================== SimpleConditionFind ===============') # init_1 init_1 = match['init_1_data'].value assert init_1 is not None init_1 = int(init_1) # step_1 assert match['add_1_y_data'].value is not None step_1 = int(match['add_1_y_data'].value) match['loop_cond_data'].value = None # compute destination (or consumer) ports for time node identity_node_name = match['Identity_1'].soft_get( 'name', match['Identity_1'].id) time_dsts = match['Identity_1'].out_port(0).get_destinations() # Create condition node and delete all useless nodes from condition pattern condition_attrs = dict(iter=dict(init=init_1, step=step_1), name=match['loop_cond'].name + '/TensorIteratorCondition_') condition = TensorIteratorCondition(graph, attrs=condition_attrs) condition.create_node_with_data( inputs=[match['Strided_slice_data']], data_nodes=[match['loop_cond_data'], match['Identity_1_data']]) safe_nodes = [ 'loop_cond_data', 'Identity_1_data', 'Strided_slice', 'Strided_slice_data' ] # check if time node has other consumers different from increment node, # input slicing and output concatenation nodes other_time_consumers = False for time_consumer in time_dsts: if time_consumer.node.soft_get('op') not in ['TensorIteratorInput', 'TensorIteratorOutput'] and \ time_consumer.node.id != match['add_1'].id: other_time_consumers = True break if other_time_consumers: # save time related nodes since they have other consumers different from # input slicing and output concatenation nodes safe_nodes += [ 'init_1', 'init_1_data', 'Enter_1', 'Enter_1_data', 'Merge_1', 'Merge_1_data', 'Switch_1', 'Switch_1_data', 'add_1', 'add_1_y', 'add_1_y_data', 'add_1_data', 'NextIteration_1' ] switch_node = match['Switch_1'] new_identity_node = Identity( graph, dict(name=identity_node_name)).create_node() switch_node.out_port(1).connect(new_identity_node.in_port(0)) # make the graph consistent to avoid multiple producers by the same input port graph.remove_nodes_from([match['Identity_1'].id]) rename_nodes([(new_identity_node, identity_node_name)]) for time_consumer in time_dsts: if time_consumer.node.soft_get('op') not in [ 'TensorIteratorInput', 'TensorIteratorOutput' ]: time_consumer.get_connection().set_source( new_identity_node.out_port(0)) # Delete useless nodes nodes_for_remove = [] for node in match.keys(): if node not in safe_nodes: nodes_for_remove.append(match[node].id) graph.remove_nodes_from(nodes_for_remove)
def replace_pattern(self, graph: Graph, match: dict): log.debug('================== ConditionFind ===============') # init_1 init_1 = match['init_1_data'].value assert init_1 is not None init_1 = int(init_1) # init_2 init_2 = match['init_2_data'].value assert init_2 is not None init_2 = int(init_2) # step_1 assert match['add_1_y_data'].value is not None step_1 = int(match['add_1_y_data'].value) # step_2 assert match['add_2_y_data'].value is not None step_2 = int(match['add_2_y_data'].value) dynamic_seq_len = self.check_dynamic_seq_len(graph, match) # Create condition node and delete all useless nodes from condition pattern loop_condition = match['loop_cond_data'] iterator_data = self.looking_for_iteration_counter(graph, match) condition_attrs = dict(time=dict(init=init_2, step=step_2), iter=dict(init=init_1, step=step_1), name=match['loop_cond'].name + '/TensorIteratorCondition_') condition = TensorIteratorCondition(graph, attrs=condition_attrs) condition_data = condition.create_node_with_data( inputs=[match['Strided_slice_data'], match['minimum_data']], data_nodes=[loop_condition, iterator_data]) safe_nodes = [ 'loop_cond_data', 'Identity_1_data', 'Identity_2_data', 'Strided_slice', 'Strided_slice_data', 'minimum', 'minimum_data' ] identity_ops = [n.op for n in iterator_data.out_nodes()] if 'GreaterEqual' in identity_ops: greater_equal_id = [ n.id for n in iterator_data.out_nodes() if n.op == 'GreaterEqual' ][0] if dynamic_seq_len: # Add BackEdge for time iterator node backedge = TensorIteratorBackEdge( graph, dict(name='/TimeIterator/TensorIteratorBackEdge_')) backedge_data = backedge.create_node_with_data(inputs=[ match['init_2_data'], match['add_2_data'], condition_data[0] ], ) graph.remove_edge(match['add_2'].in_node(0).id, match['add_2'].id) graph.add_edge(backedge_data.id, match['add_2'].id, **{'in': 0}) graph.remove_edge(iterator_data.id, greater_equal_id) graph.add_edge(backedge_data.id, greater_equal_id, **{'in': 0}) # nodes for time iterator safe_nodes += [ 'init_2_data', 'init_2', 'Identity_2_data', 'add_2_data', 'add_2', 'add_2_y', 'add_2_y_data' ] # Manually reshape all iterator nodes (for time) from 0D to 1D iterator_data_nodes = [ backedge_data, match['add_2_data'], match['add_2_y_data'], match['add_2_y'], match['init_2_data'], match['init_2'] ] make_nodes_1D(iterator_data_nodes) else: # Delete Selects from this cycle to make it not dynamic: greater_equal_idxs = [ n.id for n in iterator_data.out_nodes() if n.op == 'GreaterEqual' ] delete_selects_from(graph, greater_equal_idxs) # Delete useless nodes nodes_for_remove = [] for node in match.keys(): if node not in safe_nodes: nodes_for_remove.append(match[node].id) graph.remove_nodes_from(nodes_for_remove)