Ejemplo n.º 1
0
def conv_flatten_concat_action(graph: Graph, match: dict):
    assert graph.graph['layout'] == 'NHWC'
    reshape_node = match['reshape']
    reshape_data_node = match['reshape_data']
    conv_name = match['conv'].name
    conv_data_node = match['conv_data']
    # the pattern should be applied only in case when the reshape operation changes number of dimensions
    if len(reshape_data_node.shape) == len(
            conv_data_node.shape) or reshape_node.has_and_set('nchw_layout'):
        return

    if len(reshape_data_node.out_nodes()) == 1 and reshape_data_node.out_node().has_valid('type') and \
        reshape_data_node.out_node().type == 'FullyConnected' and \
            can_repack_fully_connected_weights_nhwc_to_nchw(reshape_data_node.out_node()):
        log.info(
            'There is a FullyConnected layer after the node "{}" which weights will be repacked. So there is no '
            'need to insert Permute'.format(reshape_node.soft_get('name')))
        return
    graph.remove_edge(conv_data_node.id, reshape_node.id)

    permutation_order = PermuteAttrs.get_nchw_to_nhwc_permutation(
        len(conv_data_node.shape)).perm
    new_permute_op = Permute(graph, {'order': permutation_order})
    permute_data_node = new_permute_op.create_node_with_data(
        [conv_data_node], dict(name=conv_name + '/Permute_'))
    graph.create_edge(permute_data_node, reshape_node)
    # Disable permutation for Reshape and Concat layers attributes
    PermuteAttrs.set_permutation(reshape_node, reshape_data_node, None)
    reshape_node['nchw_layout'] = True
Ejemplo n.º 2
0
    def replace_pattern(graph: Graph, match: dict):
        reshape = match['reshape']
        assert len(reshape.in_nodes()) > 0
        if graph.graph['layout'] == 'NCHW' or reshape.has_and_set('nchw_layout') or\
                reshape.soft_get('correct_data_layout') is True:
            return

        input_node = reshape.in_node()
        output_node = reshape.out_node()
        input_shape = input_node.shape
        output_shape = output_node.shape

        if len(input_shape) >= 4 and len(output_shape) == 3:
            # Check that we will permute some shapes in this Reshape by our permutation pass
            layout = 'NCHW'
            c_idx = get_features_dim(layout, len(input_shape))
            hw_idx = [
                get_width_dim(layout, len(input_shape)),
                get_height_dim(layout, len(input_shape))
            ]
            if input_shape[c_idx] != 1 and np.any(
                    input_shape[hw_idx] != [1, 1]):
                # then nhwc -> nchw permutation can change shapes significantly
                # We need to wrap up node with NCHW -> NHWC permutes and don't touch it later
                permutation = PermuteAttrs.get_nchw_to_nhwc_permutation(
                    len(input_shape))
                permutation_back = PermuteAttrs.get_nchw_to_nhwc_permutation(
                    len(input_shape))

                # 1. Insert input Permute
                #    This Permute will permute input from original input layout to operation layout
                edge_attrs = graph.get_edge_data(input_node.id, reshape.id)[0]
                graph.remove_edge(input_node.id, reshape.id)

                permute_op = Permute(graph, {
                    'order': permutation.perm,
                    'name': reshape.name + '/Permute_'
                })
                permute_data_node = permute_op.create_node_with_data(
                    [input_node])

                graph.add_edge(permute_data_node.id, reshape.id, **edge_attrs)
    def find_and_replace_pattern(self, graph: Graph):
        for node in list(graph.nodes()):
            node = Node(graph, node)
            # Check that node layout mismatch with graph layout
            # For example: NHWC and NCHW or NCDHW and NDHWC
            if node.kind == 'op' and node.has_valid(
                    'layout') and node.layout != indices_mapping[len(
                        node.layout)][graph.graph['layout']]:
                input = node.in_node()
                output = node.out_node()

                # Calculate permutation for further Transpose operations
                if graph.graph['layout'] == 'NCHW':
                    # if Node has NCHW and graph has NHWC layout
                    permutation = PermuteAttrs.get_nhwc_to_nchw_permutation(
                        len(node.layout))
                else:
                    # if Node has NHWC and graph has NCHW layout
                    permutation = PermuteAttrs.get_nchw_to_nhwc_permutation(
                        len(node.layout))

                # Schematic representation of transformation below
                #
                #                                           \            NCHW                              NCHW
                #            NHWC                        --  \            |  permutation       permutation  |
                #   data-->Convolution(example)-->data   --  /            |      |       NCHW      |        |
                #                                           /   data->Transpose->data->Convolution->data->Transpose->data

                # 1. Insert input Transpose
                #    This Transpose will permute input from original input layout to operation layout
                edge_attrs = graph.get_edge_data(input.id, node.id)[0]
                graph.remove_edge(input.id, node.id)

                input_order_const = Const(graph, {
                    'value': permutation.perm
                }).create_node_with_data()
                input_permute_op = Transpose(
                    graph, dict(name=node.name + '/Transpose_'))
                input_permute_data_node = input_permute_op.create_node_with_data(
                    [input, input_order_const])

                graph.add_edge(input_permute_data_node.id, node.id,
                               **edge_attrs)

                # 2. Insert output Transpose
                #    This Transpose will permute output from operation layout to original input layout
                edge_attrs = graph.get_edge_data(node.id, output.id)[0]
                graph.remove_edge(node.id, output.id)

                input_data_node = Op.create_data_node(
                    graph, node, {'shape': output.shape[permutation.perm]},
                    edge_attrs)

                output_order_const = Const(graph, {
                    'value': permutation.inv
                }).create_node_with_data()
                output_permute_op = Transpose(
                    graph, dict(name=node.name +
                                '/Transpose_')).create_node_with_data(
                                    [input_data_node, output_order_const],
                                    data_nodes=output)

                # 3. Add permutations for Node
                #    Here we use permutation mechanism where data nodes takes permutation attribute.
                #    And then we call permute_attrs method that permutes node attributes according to permutations on
                #    data nodes.
                node.in_node()['permutation'] = permutation
                node.out_node()['permutation'] = permutation
                node.permute_attrs.permute_attrs(node)

                node.in_node()['permutation'] = None
                node.out_node()['permutation'] = None
Ejemplo n.º 4
0
def permute_nchw_to_nhwc(shape):
    perm = PermuteAttrs.get_nchw_to_nhwc_permutation(len(shape)).perm
    new_shape = np.array(shape)[perm]
    return new_shape
Ejemplo n.º 5
0
def permute_nchw_to_nhwc(shape, use_new_frontend=False):
    if use_new_frontend:
        return shape
    perm = PermuteAttrs.get_nchw_to_nhwc_permutation(len(shape)).perm
    new_shape = np.array(shape)[perm]
    return new_shape