def replace_pattern(graph: Graph, match: dict):
        log.debug('================== SimpleOutputFind ===============')
        assert match['WriteEnter_data'].value is not None

        index = match['TensorArrayWrite'].in_node(1)
        value = match['TensorArrayWrite'].in_node(2)

        # axis == 0 because in TensorArray we ALWAYS iterate over 0 axis, other params will be fill later (with
        # condition)
        output = TensorIteratorOutput(
            graph,
            dict(external_port_id=str(match['WriteEnter_data'].value),
                 internal_layer_id=value.id,
                 name=match['TensorArrayWrite'].name +
                 '/TensorIteratorOutput_'))
        output.create_node_with_data(
            inputs=[value, index], data_nodes=[match['TensorArrayRead_data']])

        # Delete useless nodes
        safe_nodes = ['TensorArrayRead_data', 'Condition_data']
        nodes_for_remove = []
        for node in match.keys():
            if node not in safe_nodes:
                nodes_for_remove.append(match[node].id)
        graph.remove_nodes_from(nodes_for_remove)
Esempio n. 2
0
    def replace_pattern(graph: nx.MultiDiGraph, match: dict):
        log.debug('================== SmartOutputFind ===============')

        assert match['WriteEnter_data'].value is not None
        assert match['start_data']['value'] == 0 and match['delta_data'][
            'value'] == 1

        ta_size = match['TensorArray'].in_node()

        index = match['TensorArrayWrite'].in_node(1)
        value = match['TensorArrayWrite'].in_node(2)

        # axis == 0 because in TensorArray we ALWAYS iterate over 0 axis, other params will be fill later (with
        # condition)
        output = TensorIteratorOutput(
            graph,
            dict(axis=0,
                 start=None,
                 stride=None,
                 part_size=None,
                 external_port_id=str(match['WriteEnter_data'].value),
                 internal_layer_id=value.id,
                 name=match['TensorArrayWrite'].name +
                 '/TensorIteratorOutput_'))
        output.create_node_with_data(
            inputs=[ta_size, value, index],
            data_nodes=[match['TensorArrayGather_data']])

        # Delete useless nodes
        safe_nodes = ['TensorArrayGather_data', 'Condition_data']
        nodes_for_remove = []
        for node in match.keys():
            if node not in safe_nodes:
                nodes_for_remove.append(match[node].id)
        graph.remove_nodes_from(nodes_for_remove)
    def replace_pattern(self, graph: Graph, match: dict):
        log.debug('================== BackEdgeFind ===============')

        nodes_for_remove = []
        from_body_data = match['NextIteration'].in_node()

        # If Exit path is exist -> create TensorIteratorOutput for this
        if 0 in match['Switch_1'].out_nodes():
            Exit = match['Switch_1'].out_node(0).out_node(
                0)  # Switch -> Switch_data -> Exit
            assert Exit.has_valid('op') and Exit.op == 'Exit'
            output_data = Exit.out_node(0)

            nodes_for_remove.append(match['Switch_1'].out_node(0).id)
            nodes_for_remove.append(Exit.id)

            # Creating TensorIteratorOutput without partition
            output = TensorIteratorOutput(graph, dict(external_port_id=None,
                                                      internal_layer_id=None, \
                                                      name=Exit.name + '/TensorIteratorOutput_'
                                                      ))
            output.create_node_with_data(
                inputs=[from_body_data, match['condition_cond_data']],
                data_nodes=[output_data])

        assert match['NextIteration_data'].id != match['Enter_1_data'].id
        backedge = TensorIteratorBackEdge(
            graph,
            dict(name=match['Identity_1'].name + '/TensorIteratorBackEdge_'))
        backedge.create_node_with_data(inputs=[
            match['Enter_1_data'], from_body_data, match['condition_cond_data']
        ],
                                       data_nodes=[match['Identity_1_data']])

        # Delete useless nodes
        safe_nodes = [
            'Identity_1_data', 'condition', 'condition_cond_data',
            'Enter_1_data'
        ]
        for node in match.keys():
            if node not in safe_nodes:
                nodes_for_remove.append(match[node].id)
        graph.remove_nodes_from(nodes_for_remove)