示例#1
0
    def find_and_replace_pattern(self, graph: Graph):
        for loop_node in graph.get_op_nodes(op='Loop'):
            loop_name = loop_node.soft_get('name', loop_node.id)
            body_graph = loop_node['body']
            body_pattern = MapFNOutputConcatenation.get_body_pattern()
            internal_matches = find_subgraph_match_to_pattern(
                body_graph, body_pattern)

            for internal_match in internal_matches:
                # check if TensorListReserve from the main graph is connected with Parameter node from the body graph
                # that is assigned for storing intermediate output results of While Loop. If yes, the transformation
                # detects intermediate outputs concatentation by this port and can use Loop axis attribute
                reserve_node = Loop.get_external_nodes_by_internal_id(
                    loop_node, internal_match['container'].internal_layer_id)
                reserve_node = reserve_node[0] if (
                    len(reserve_node) == 1
                    and reserve_node[0].op == 'TensorListReserve') else None
                if reserve_node is None:
                    log.info(
                        "A sub-graph around the loop node {} does not match "
                        "TensorFlow 2 MapFN pattern for intermediate outputs concatenation"
                        .format(loop_name))
                    continue
                stack_node = Loop.get_external_nodes_by_internal_id(
                    loop_node,
                    internal_match['concatenation_result'].internal_layer_id)
                stack_node = stack_node[0] if len(stack_node) == 1 else None

                if stack_node is None:
                    log.info(
                        "A sub-graph around the loop node {} does not match "
                        "TensorFlow 2 MapFN pattern for intermediate outputs concatenation"
                        .format(loop_name))
                    continue

                # skip StopGradient node if it exists between While loop output port and TensorListStack operation
                stack_node = skip_nodes_by_condition(
                    stack_node, lambda x: x.has_and_set('identity'), True)
                stack_node = stack_node if stack_node.op == 'TensorListStack' else None
                if stack_node is None:
                    log.info(
                        "A sub-graph around the loop node {} does not match "
                        "TensorFlow 2 MapFN pattern for intermediate outputs concatenation"
                        .format(loop_name))
                    continue

                external_match = {
                    'while': loop_node,
                    'reserve': reserve_node,
                    'stack': stack_node
                }
                # check that back edges connect Parameter node (or container with intermediate output results)
                # and concatenation result produced by TensorListSetItem node
                if Loop.back_edge_exists(loop_node.back_edges, internal_match['concatenation_result'].internal_layer_id,
                                         internal_match['container'].internal_layer_id) and \
                        Loop.back_edge_exists(loop_node.back_edges,
                                              internal_match['increment_iteration_result'].internal_layer_id,
                                              internal_match['current_iteration'].internal_layer_id):
                    MapFNOutputConcatenation.transform_map_fn_output_concatenation(
                        external_match, internal_match)
示例#2
0
    def find_and_replace_pattern(self, graph: Graph):
        for loop_node in graph.get_op_nodes(op='Loop'):
            loop_name = loop_node.soft_get('name', loop_node.id)
            body_graph = loop_node['body']
            body_pattern = MapFNInputSlicing.get_body_pattern()
            internal_matches = find_subgraph_match_to_pattern(body_graph, body_pattern)

            for internal_match in internal_matches:
                # check if TensorListGetItem from the body graph is connected with TensorListFromTensor
                # from the main graph. If yes, the transformation detects input slicing by this port
                # and can use Loop axis attribute
                unstack_node = Loop.get_external_nodes_by_internal_id(loop_node,
                                                                      internal_match['tensor_list'].internal_layer_id)
                unstack_node = unstack_node[0] if (len(unstack_node) == 1
                                                   and unstack_node[0].op == 'TensorListFromTensor') else None
                if unstack_node is None:
                    log.info("A sub-graph around the loop node {} does not match "
                             "TensorFlow 2 MapFN pattern for input slicing".format(loop_name))
                    continue

                external_match = {'while': loop_node,
                                  'unstack': unstack_node}
                # check that back edges connect correct Parameter and Result nodes in the body
                # check connections between body input ports and external inputs ports of Loop node
                if Loop.back_edge_exists(loop_node.back_edges,
                                         internal_match['increment_iteration_result'].internal_layer_id,
                                         internal_match['current_iteration'].internal_layer_id):
                    MapFNInputSlicing.transform_map_fn_input_slicing(external_match, internal_match)
    def find_and_replace_pattern(self, graph: Graph):
        for loop_node in graph.get_op_nodes(op='Loop'):
            loop_name = loop_node.soft_get('name', loop_node.id)
            body_graph = loop_node['body']
            body_pattern = TensorListOutputConcatenation.get_body_pattern()
            internal_matches = find_subgraph_match_to_pattern(
                body_graph, body_pattern)

            for internal_match in internal_matches:
                # check if EmptyTensorList from the main graph is connected with Parameter node from the body graph
                # that is assigned for storing intermediate output results of While Loop. If yes, the transformation
                # detects intermediate outputs concatenation by this port and can use Loop axis attribute
                reserve_node = Loop.get_external_nodes_by_internal_id(
                    loop_node, internal_match['container'].internal_layer_id)
                reserve_node = reserve_node[0] if (
                    len(reserve_node) == 1
                    and reserve_node[0].op == 'EmptyTensorList') else None
                if reserve_node is None:
                    log.info(
                        "A sub-graph around the loop node {} does not match "
                        "TensorFlow 2 EmptyTensorList->TensorListPushBack pattern for intermediate "
                        "outputs concatenation".format(loop_name))
                    continue

                external_match = {'while': loop_node, 'reserve': reserve_node}
                # check that back edges connect Parameter node (or container with intermediate output results)
                # and concatenation result produced by TensorListPushBack node
                if Loop.back_edge_exists(
                        loop_node.back_edges,
                        internal_match['concatenation_result'].
                        internal_layer_id,
                        internal_match['container'].internal_layer_id):
                    TensorListOutputConcatenation.transform_tensor_list_output_concatenation(
                        external_match, internal_match)