Exemplo n.º 1
0
    def replace_pattern(self, graph: Graph, match: dict):
        node = match['op']

        N, H, W, C = match['in_data'].shape
        block_size = node['block_size']

        graph.remove_edge(match['in_data'].id, node.id)
        graph.remove_edge(node.id, match['out_data'].id)

        dim_6D = int64_array([0, block_size, block_size, int(C / (block_size ** 2)), H, W])
        order_6D = int64_array([0, 3, 4, 1, 5, 2])
        dim_4D = int64_array([0, int(H * block_size), int(W * block_size), int(C / (block_size ** 2))])

        reshape_6_op = Reshape(graph, dict(name=node.id + '/Reshape_to_6D'))
        reshape_6_const_data = Const(graph, dict(value=dim_6D)).create_node_with_data()
        reshape_6_data_node = reshape_6_op.create_node_with_data([match['in_data'], reshape_6_const_data])
        mark_as_correct_data_layout(reshape_6_data_node.in_node(0))

        order_const_data = Const(graph, dict(value=order_6D)).create_node_with_data()

        transpose_op = Transpose(graph, dict(name=node.id + '/Transpose'))
        transpose_data_node = transpose_op.create_node_with_data([reshape_6_data_node, order_const_data])
        mark_as_correct_data_layout(transpose_data_node.in_node(0))

        reshape_4_op = Reshape(graph, dict(name=node.id + '/Reshape_to_4D'))
        reshape_4_const_data = Const(graph, dict(value=dim_4D)).create_node_with_data()
        reshape_4_data_node = reshape_4_op.create_node_with_data([transpose_data_node, reshape_4_const_data],
                                                                 data_nodes=[match['out_data']])
        mark_input_as_in_correct_layout(reshape_4_data_node.in_node(0), 0)
        mark_output_as_in_correct_layout(reshape_4_data_node.in_node(0), 0)
    def find_and_replace_pattern(self, graph: Graph):
        for node in list(graph.nodes()):
            node = Node(graph, node)
            # Check that node layout mismatch with graph layout
            # For example: NHWC and NCHW or NCDHW and NDHWC
            if node.kind == 'op' and node.has_valid(
                    'layout') and node.layout != indices_mapping[len(
                        node.layout)][graph.graph['layout']]:
                input = node.in_node()
                output = node.out_node()

                # Calculate permutation for further Transpose operations
                if graph.graph['layout'] == 'NCHW':
                    # if Node has NCHW and graph has NHWC layout
                    permutation = PermuteAttrs.get_nhwc_to_nchw_permutation(
                        len(node.layout))
                else:
                    # if Node has NHWC and graph has NCHW layout
                    permutation = PermuteAttrs.get_nchw_to_nhwc_permutation(
                        len(node.layout))

                # Schematic representation of transformation below
                #
                #                                           \            NCHW                              NCHW
                #            NHWC                        --  \            |  permutation       permutation  |
                #   data-->Convolution(example)-->data   --  /            |      |       NCHW      |        |
                #                                           /   data->Transpose->data->Convolution->data->Transpose->data

                # 1. Insert input Transpose
                #    This Transpose will permute input from original input layout to operation layout
                edge_attrs = graph.get_edge_data(input.id, node.id)[0]
                graph.remove_edge(input.id, node.id)

                input_order_const = Const(graph, {
                    'value': permutation.perm
                }).create_node_with_data()
                input_permute_op = Transpose(
                    graph, dict(name=node.name + '/Transpose_'))
                input_permute_data_node = input_permute_op.create_node_with_data(
                    [input, input_order_const])

                graph.add_edge(input_permute_data_node.id, node.id,
                               **edge_attrs)

                # 2. Insert output Transpose
                #    This Transpose will permute output from operation layout to original input layout
                edge_attrs = graph.get_edge_data(node.id, output.id)[0]
                graph.remove_edge(node.id, output.id)

                input_data_node = Op.create_data_node(
                    graph, node, {'shape': output.shape[permutation.perm]},
                    edge_attrs)

                output_order_const = Const(graph, {
                    'value': permutation.inv
                }).create_node_with_data()
                output_permute_op = Transpose(
                    graph, dict(name=node.name +
                                '/Transpose_')).create_node_with_data(
                                    [input_data_node, output_order_const],
                                    data_nodes=output)

                # 3. Add permutations for Node
                #    Here we use permutation mechanism where data nodes takes permutation attribute.
                #    And then we call permute_attrs method that permutes node attributes according to permutations on
                #    data nodes.
                node.in_node()['permutation'] = permutation
                node.out_node()['permutation'] = permutation
                node.permute_attrs.permute_attrs(node)

                node.in_node()['permutation'] = None
                node.out_node()['permutation'] = None