Пример #1
0
    def replace_pattern(self, graph: Graph, match: dict):
        node = match['op']

        N, H, W, C = match['in_data'].shape
        block_size = node['block_size']

        graph.remove_edge(match['in_data'].id, node.id)
        graph.remove_edge(node.id, match['out_data'].id)

        dim_6D = int64_array([0, block_size, block_size, int(C / (block_size ** 2)), H, W])
        order_6D = int64_array([0, 3, 4, 1, 5, 2])
        dim_4D = int64_array([0, int(H * block_size), int(W * block_size), int(C / (block_size ** 2))])

        reshape_6_op = Reshape(graph, dict(name=node.id + '/Reshape_to_6D'))
        reshape_6_const_data = Const(graph, dict(value=dim_6D)).create_node_with_data()
        reshape_6_data_node = reshape_6_op.create_node_with_data([match['in_data'], reshape_6_const_data])
        mark_as_correct_data_layout(reshape_6_data_node.in_node(0))

        order_const_data = Const(graph, dict(value=order_6D)).create_node_with_data()

        transpose_op = Transpose(graph, dict(name=node.id + '/Transpose'))
        transpose_data_node = transpose_op.create_node_with_data([reshape_6_data_node, order_const_data])
        mark_as_correct_data_layout(transpose_data_node.in_node(0))

        reshape_4_op = Reshape(graph, dict(name=node.id + '/Reshape_to_4D'))
        reshape_4_const_data = Const(graph, dict(value=dim_4D)).create_node_with_data()
        reshape_4_data_node = reshape_4_op.create_node_with_data([transpose_data_node, reshape_4_const_data],
                                                                 data_nodes=[match['out_data']])
        mark_input_as_in_correct_layout(reshape_4_data_node.in_node(0), 0)
        mark_output_as_in_correct_layout(reshape_4_data_node.in_node(0), 0)
Пример #2
0
    def find_and_replace_pattern(self, graph: Graph):
        visited = set()
        marked_nodes = set()
        condition_forward = lambda n: not InsertLayoutPropagationTranspose.is_nhwc_to_nchw_transpose_needed(
            n)
        condition_backward = lambda n: not InsertLayoutPropagationTranspose.is_nchw_to_nhwc_transpose_needed(
            n)
        for node_condition in self.op_conditions:
            for node in graph.get_op_nodes():
                if node_condition(node):
                    log.debug(
                        'Detected node "{}" as a node which should be executed in the original layout'
                        ''.format(node.soft_get('name', node.id)))
                    forward_visited_nodes = self.bfs([node], visited,
                                                     condition_forward, True)
                    backward_visited_nodes = self.bfs([node], visited,
                                                      condition_backward,
                                                      False)

                    # find "reinterp_shape" like ops which change rank of input to 4D or 5D from smaller dimensions
                    for back_node in backward_visited_nodes:
                        for input_node in self.get_input_nodes(back_node):
                            if input_node not in backward_visited_nodes and not condition_forward(
                                    input_node):
                                marked_nodes.add(input_node)

                    # find "reinterp_shape" like ops which change rank of input from 4D or 5D to smaller dimensions
                    for forward_node in forward_visited_nodes:
                        for output_node in self.get_output_nodes(forward_node):
                            if output_node not in forward_visited_nodes and not condition_backward(
                                    output_node):
                                marked_nodes.add(output_node)

                    marked_nodes.update(forward_visited_nodes +
                                        backward_visited_nodes)

        if len(marked_nodes):
            log.debug(
                'The following nodes will be executed in the original layout: {}'
                ''.format([n.soft_get('name', n.id) for n in marked_nodes]))

            # mark all matched nodes as in correct layout and disable attributes permutation for them
            for visited_node in marked_nodes:
                mark_as_correct_data_layout(visited_node)
                visited_node['nchw_layout'] = True