def replace_pattern(graph: Graph, match: dict):
        """
        Create condition node and delete all useless nodes (like Switch/Merge/Identity) from condition pattern
        """
        log.debug(
            '================== DynamicDecoderConditionFind  =================='
        )
        # Create and connect condition node for dynamic decoder in TF
        loop_condiiton = match['loop_cond_data']
        iterator_data = match['identity_data']

        condition_attrs = dict(name=match['loop_cond'].name +
                               '/TensorIteratorCondition_')
        condition = TensorIteratorCondition(graph, attrs=condition_attrs)
        condition.create_node_with_data(
            inputs=[match['Less_enter'].in_node()],
            data_nodes=[loop_condiiton, iterator_data])

        # Delete useless nodes
        safe_nodes = [
            'loop_cond_data', 'identity_data', 'TensorIteratorOutput',
            'TensorIteratorOutput_1'
        ]
        nodes_for_remove = []
        for node in match.keys():
            if node not in safe_nodes:
                nodes_for_remove.append(match[node].id)
        graph.remove_nodes_from(nodes_for_remove)
    def replace_pattern(graph: Graph, match: dict):
        log.debug('================== SimpleConditionFind ===============')
        # init_1
        init_1 = match['init_1_data'].value
        assert init_1 is not None
        init_1 = int(init_1)

        # step_1
        assert match['add_1_y_data'].value is not None
        step_1 = int(match['add_1_y_data'].value)

        match['loop_cond_data'].value = None

        # Create condition node and delete all useless nodes from condition pattern
        condition_attrs = dict(iter=dict(init=init_1, step=step_1),
                               name=match['loop_cond'].name + '/TensorIteratorCondition_')
        condition = TensorIteratorCondition(graph, attrs=condition_attrs)
        condition.create_node_with_data(inputs=[match['Strided_slice_data']],
                                        data_nodes=[match['loop_cond_data'], match['Identity_1_data']])

        # Delete useless nodes
        safe_nodes = ['loop_cond_data', 'Identity_1_data', 'Strided_slice', 'Strided_slice_data']
        nodes_for_remove = []
        for node in match.keys():
            if node not in safe_nodes:
                nodes_for_remove.append(match[node].id)
        graph.remove_nodes_from(nodes_for_remove)
Example #3
0
    def replace_pattern(graph: nx.MultiDiGraph, match: dict):
        log.debug('================== ConditionFind ===============')
        max_node = match['minimum'].in_node(1).in_node()
        assert max_node['kind'] == 'op' and max_node['op'] == 'Maximum'

        #init_1
        init_1 = match['init_1_data'].value
        assert init_1 is not None
        init_1 = int(init_1)

        #init_2
        init_2 = match['init_2_data'].value
        assert init_2 is not None
        init_2 = int(init_2)

        #step_1
        assert match['add_1_y_data'].value is not None
        step_1 = int(match['add_1_y_data'].value)

        #step_2
        assert match['add_2_y_data'].value is not None
        step_2 = int(match['add_2_y_data'].value)

        match['loop_cond_data'].value = None
        match['Identity_2_data'].value = None

        # Create condition node and delete all useless nodes from condition pattern
        condition_attrs = dict(time=dict(init=init_2, step=step_2), iter=dict(init=init_1, step=step_1), \
                               name=match['loop_cond'].name + '/TensorIteratorCondition_')
        condition = TensorIteratorCondition(graph, attrs=condition_attrs)
        condition.create_node_with_data(
            inputs=[match['Strided_slice_data'], match['minimum_data']],
            data_nodes=[match['loop_cond_data'], match['Identity_2_data']])

        # Delete useless nodes
        safe_nodes = [
            'loop_cond_data', 'Identity_2_data', 'Strided_slice',
            'Strided_slice_data', 'minimum', 'minimum_data'
        ]
        nodes_for_remove = []
        for node in match.keys():
            if node not in safe_nodes:
                nodes_for_remove.append(match[node].id)
        graph.remove_nodes_from(nodes_for_remove)
    def replace_pattern(graph: Graph, match: dict):
        log.debug('================== SimpleConditionFind ===============')
        # init_1
        init_1 = match['init_1_data'].value
        assert init_1 is not None
        init_1 = int(init_1)

        # step_1
        assert match['add_1_y_data'].value is not None
        step_1 = int(match['add_1_y_data'].value)

        match['loop_cond_data'].value = None

        # compute destination (or consumer) ports for time node
        identity_node_name = match['Identity_1'].soft_get(
            'name', match['Identity_1'].id)
        time_dsts = match['Identity_1'].out_port(0).get_destinations()

        # Create condition node and delete all useless nodes from condition pattern
        condition_attrs = dict(iter=dict(init=init_1, step=step_1),
                               name=match['loop_cond'].name +
                               '/TensorIteratorCondition_')
        condition = TensorIteratorCondition(graph, attrs=condition_attrs)
        condition.create_node_with_data(
            inputs=[match['Strided_slice_data']],
            data_nodes=[match['loop_cond_data'], match['Identity_1_data']])

        safe_nodes = [
            'loop_cond_data', 'Identity_1_data', 'Strided_slice',
            'Strided_slice_data'
        ]

        # check if time node has other consumers  different from increment node,
        # input slicing and output concatenation nodes
        other_time_consumers = False
        for time_consumer in time_dsts:
            if time_consumer.node.soft_get('op') not in ['TensorIteratorInput', 'TensorIteratorOutput'] and \
                    time_consumer.node.id != match['add_1'].id:
                other_time_consumers = True
                break
        if other_time_consumers:
            # save time related nodes since they have other consumers different from
            # input slicing and output concatenation nodes
            safe_nodes += [
                'init_1', 'init_1_data', 'Enter_1', 'Enter_1_data', 'Merge_1',
                'Merge_1_data', 'Switch_1', 'Switch_1_data', 'add_1',
                'add_1_y', 'add_1_y_data', 'add_1_data', 'NextIteration_1'
            ]
            switch_node = match['Switch_1']
            new_identity_node = Identity(
                graph, dict(name=identity_node_name)).create_node()
            switch_node.out_port(1).connect(new_identity_node.in_port(0))

            # make the graph consistent to avoid multiple producers by the same input port
            graph.remove_nodes_from([match['Identity_1'].id])
            rename_nodes([(new_identity_node, identity_node_name)])

            for time_consumer in time_dsts:
                if time_consumer.node.soft_get('op') not in [
                        'TensorIteratorInput', 'TensorIteratorOutput'
                ]:
                    time_consumer.get_connection().set_source(
                        new_identity_node.out_port(0))

        # Delete useless nodes
        nodes_for_remove = []
        for node in match.keys():
            if node not in safe_nodes:
                nodes_for_remove.append(match[node].id)
        graph.remove_nodes_from(nodes_for_remove)
    def replace_pattern(self, graph: Graph, match: dict):
        log.debug('================== ConditionFind ===============')
        # init_1
        init_1 = match['init_1_data'].value
        assert init_1 is not None
        init_1 = int(init_1)

        # init_2
        init_2 = match['init_2_data'].value
        assert init_2 is not None
        init_2 = int(init_2)

        # step_1
        assert match['add_1_y_data'].value is not None
        step_1 = int(match['add_1_y_data'].value)

        # step_2
        assert match['add_2_y_data'].value is not None
        step_2 = int(match['add_2_y_data'].value)

        dynamic_seq_len = self.check_dynamic_seq_len(graph, match)

        # Create condition node and delete all useless nodes from condition pattern
        loop_condition = match['loop_cond_data']
        iterator_data = self.looking_for_iteration_counter(graph, match)

        condition_attrs = dict(time=dict(init=init_2, step=step_2),
                               iter=dict(init=init_1, step=step_1),
                               name=match['loop_cond'].name +
                               '/TensorIteratorCondition_')
        condition = TensorIteratorCondition(graph, attrs=condition_attrs)
        condition_data = condition.create_node_with_data(
            inputs=[match['Strided_slice_data'], match['minimum_data']],
            data_nodes=[loop_condition, iterator_data])

        safe_nodes = [
            'loop_cond_data', 'Identity_1_data', 'Identity_2_data',
            'Strided_slice', 'Strided_slice_data', 'minimum', 'minimum_data'
        ]

        identity_ops = [n.op for n in iterator_data.out_nodes()]
        if 'GreaterEqual' in identity_ops:
            greater_equal_id = [
                n.id for n in iterator_data.out_nodes()
                if n.op == 'GreaterEqual'
            ][0]

            if dynamic_seq_len:
                # Add BackEdge for time iterator node
                backedge = TensorIteratorBackEdge(
                    graph, dict(name='/TimeIterator/TensorIteratorBackEdge_'))
                backedge_data = backedge.create_node_with_data(inputs=[
                    match['init_2_data'], match['add_2_data'],
                    condition_data[0]
                ], )

                graph.remove_edge(match['add_2'].in_node(0).id,
                                  match['add_2'].id)
                graph.add_edge(backedge_data.id, match['add_2'].id,
                               **{'in': 0})

                graph.remove_edge(iterator_data.id, greater_equal_id)
                graph.add_edge(backedge_data.id, greater_equal_id, **{'in': 0})

                # nodes for time iterator
                safe_nodes += [
                    'init_2_data', 'init_2', 'Identity_2_data', 'add_2_data',
                    'add_2', 'add_2_y', 'add_2_y_data'
                ]

                # Manually reshape all iterator nodes (for time) from 0D to 1D
                iterator_data_nodes = [
                    backedge_data, match['add_2_data'], match['add_2_y_data'],
                    match['add_2_y'], match['init_2_data'], match['init_2']
                ]
                make_nodes_1D(iterator_data_nodes)
            else:
                # Delete Selects from this cycle to make it not dynamic:
                greater_equal_idxs = [
                    n.id for n in iterator_data.out_nodes()
                    if n.op == 'GreaterEqual'
                ]
                delete_selects_from(graph, greater_equal_idxs)

        # Delete useless nodes
        nodes_for_remove = []
        for node in match.keys():
            if node not in safe_nodes:
                nodes_for_remove.append(match[node].id)
        graph.remove_nodes_from(nodes_for_remove)