Ejemplo n.º 1
0
    def replace_sub_graph(self, graph: Graph, external_match: dict):
        loop_node = external_match['while']
        body_graph = loop_node['body']
        body_pattern = KerasRNNOutputConcatenation.get_body_pattern()

        internal_matches = find_subgraph_match_to_pattern(
            body_graph, body_pattern)

        if len(internal_matches) == 1:
            internal_match = internal_matches[0]
            reserve_port_idx = compute_input_port_idx(
                external_match['reserve'], loop_node)
            stack_port_idx = external_match['stack'].in_port(
                0).get_source().idx
            # check that back edges connect correct Parameter and Result nodes in the body
            # check connections between body input ports and external inputs ports of Loop node
            # check connections between body output ports and external output ports of Loop node
            if Loop.back_edge_exists(loop_node.back_edges, internal_match['concatenation_result'].internal_layer_id,
                                     internal_match['container'].internal_layer_id) and \
                    Loop.back_edge_exists(loop_node.back_edges,
                                          internal_match['increment_iteration_result'].internal_layer_id,
                                          internal_match['current_iteration'].internal_layer_id) and \
                    Loop.inter_edge_exists(loop_node.input_port_map, reserve_port_idx,
                                           internal_match['container'].internal_layer_id) and \
                    Loop.inter_edge_exists(loop_node.output_port_map, stack_port_idx,
                                           internal_match['concatenation_result'].internal_layer_id):
                KerasRNNOutputConcatenation.transform_keras_rnn_output_concatenation(
                    external_match, internal_match)
Ejemplo n.º 2
0
    def find_and_replace_pattern(self, graph: Graph):
        for loop_node in graph.get_op_nodes(op='Loop'):
            loop_name = loop_node.soft_get('name', loop_node.id)
            body_graph = loop_node['body']
            body_pattern = MapFNOutputConcatenation.get_body_pattern()
            internal_matches = find_subgraph_match_to_pattern(
                body_graph, body_pattern)

            for internal_match in internal_matches:
                # check if TensorListReserve from the main graph is connected with Parameter node from the body graph
                # that is assigned for storing intermediate output results of While Loop. If yes, the transformation
                # detects intermediate outputs concatentation by this port and can use Loop axis attribute
                reserve_node = Loop.get_external_nodes_by_internal_id(
                    loop_node, internal_match['container'].internal_layer_id)
                reserve_node = reserve_node[0] if (
                    len(reserve_node) == 1
                    and reserve_node[0].op == 'TensorListReserve') else None
                if reserve_node is None:
                    log.info(
                        "A sub-graph around the loop node {} does not match "
                        "TensorFlow 2 MapFN pattern for intermediate outputs concatenation"
                        .format(loop_name))
                    continue
                stack_node = Loop.get_external_nodes_by_internal_id(
                    loop_node,
                    internal_match['concatenation_result'].internal_layer_id)
                stack_node = stack_node[0] if len(stack_node) == 1 else None

                if stack_node is None:
                    log.info(
                        "A sub-graph around the loop node {} does not match "
                        "TensorFlow 2 MapFN pattern for intermediate outputs concatenation"
                        .format(loop_name))
                    continue

                # skip StopGradient node if it exists between While loop output port and TensorListStack operation
                stack_node = skip_nodes_by_condition(
                    stack_node, lambda x: x.has_and_set('identity'), True)
                stack_node = stack_node if stack_node.op == 'TensorListStack' else None
                if stack_node is None:
                    log.info(
                        "A sub-graph around the loop node {} does not match "
                        "TensorFlow 2 MapFN pattern for intermediate outputs concatenation"
                        .format(loop_name))
                    continue

                external_match = {
                    'while': loop_node,
                    'reserve': reserve_node,
                    'stack': stack_node
                }
                # check that back edges connect Parameter node (or container with intermediate output results)
                # and concatenation result produced by TensorListSetItem node
                if Loop.back_edge_exists(loop_node.back_edges, internal_match['concatenation_result'].internal_layer_id,
                                         internal_match['container'].internal_layer_id) and \
                        Loop.back_edge_exists(loop_node.back_edges,
                                              internal_match['increment_iteration_result'].internal_layer_id,
                                              internal_match['current_iteration'].internal_layer_id):
                    MapFNOutputConcatenation.transform_map_fn_output_concatenation(
                        external_match, internal_match)
Ejemplo n.º 3
0
    def find_and_replace_pattern(self, graph: Graph):
        for loop_node in graph.get_op_nodes(op='Loop'):
            loop_name = loop_node.soft_get('name', loop_node.id)
            body_graph = loop_node['body']
            body_pattern = MapFNInputSlicing.get_body_pattern()
            internal_matches = find_subgraph_match_to_pattern(
                body_graph, body_pattern)

            for internal_match in internal_matches:
                # check if TensorListGetItem from the body graph is connected with TensorListFromTensor
                # from the main graph. If yes, the transformation detects input slicing by this port
                # and can use Loop axis attribute
                unstack_node = Loop.get_external_nodes_by_internal_id(
                    loop_node, internal_match['tensor_list'].internal_layer_id)
                unstack_node = unstack_node[0] if (
                    len(unstack_node) == 1
                    and unstack_node[0].op == 'TensorListFromTensor') else None
                if unstack_node is None:
                    log.info(
                        "A sub-graph around the loop node {} does not match "
                        "TensorFlow 2 MapFN pattern for input slicing".format(
                            loop_name))
                    continue

                external_match = {'while': loop_node, 'unstack': unstack_node}
                # check that back edges connect correct Parameter and Result nodes in the body
                # check connections between body input ports and external inputs ports of Loop node
                if Loop.back_edge_exists(
                        loop_node.back_edges,
                        internal_match['increment_iteration_result'].
                        internal_layer_id,
                        internal_match['current_iteration'].internal_layer_id):
                    MapFNInputSlicing.transform_map_fn_input_slicing(
                        external_match, internal_match)
Ejemplo n.º 4
0
    def replace_sub_graph(self, graph: Graph, external_match: dict):
        loop_node = external_match['while']
        body_graph = loop_node['body']
        body_pattern = KerasRNNInputSlicing.get_body_pattern()
        internal_matches = find_subgraph_match_to_pattern(
            body_graph, body_pattern)

        # a case of multiple matches is not handled since it is not clear how to select corresponding match
        if len(internal_matches) == 1:
            internal_match = internal_matches[0]
            loop_node = external_match['while']
            unstack_port_idx = compute_input_port_idx(
                external_match['unstack'], loop_node)
            # check that back edges connect correct Parameter and Result nodes in the body
            # check connections between body input ports and external inputs ports of Loop node
            if Loop.back_edge_exists(loop_node.back_edges,
                                     internal_match['increment_iteration_result'].internal_layer_id,
                                     internal_match['current_iteration'].internal_layer_id) and \
                    Loop.inter_edge_exists(loop_node.input_port_map, unstack_port_idx,
                                           internal_match['tensor_list'].internal_layer_id):
                # only if inter-graph match passed it starts to process the sub-graph
                KerasRNNInputSlicing.transform_keras_rnn_input_slicing(
                    external_match, internal_match)