Example #1
0
def permute_before_and_after(inp: Node, middle: Node, out: Node, order):
    ''' Insert two permutes: before middle node and after middle node.

        The first permute has a given order, the second permute has an
        inversed order.
    '''

    permute = Permute(middle.graph, dict(order=np.array(order)))

    edge_attrs = deepcopy(middle.graph.get_edge_data(inp.id, middle.id)[0])
    middle.graph.remove_edge(inp.id, middle.id)
    new_inp = permute.create_node_with_data([inp],
                                            dict(name=middle.name +
                                                 '/InputPermute'))
    middle.graph.add_edge(new_inp.id, middle.id, **edge_attrs)

    permute = Permute(middle.graph, dict(order=inverse_perm(np.array(order))))

    middle.graph.remove_edge(middle.id, out.id)
    new_out = Op._create_data_node(middle.graph,
                                   name=middle.name + '/WithoutPermute',
                                   attrs={'shape': out.shape[order]})
    middle.graph.add_edge(middle.id, new_out.id, key=0, out=0)
    permute.create_node_with_data([new_out],
                                  dict(name=middle.name + '/OutputPermute'),
                                  data_nodes=out)
Example #2
0
def conv_flatten_concat_action(graph: Graph, match: dict):
    assert graph.graph['layout'] == 'NHWC'
    reshape_node = match['reshape']
    reshape_data_node = match['reshape_data']
    conv_name = match['conv'].name
    conv_data_node = match['conv_data']
    # the pattern should be applied only in case when the reshape operation changes number of dimensions
    if len(reshape_data_node.shape) == len(
            conv_data_node.shape) or reshape_node.has_and_set('nchw_layout'):
        return

    if len(reshape_data_node.out_nodes()) == 1 and reshape_data_node.out_node().has_valid('type') and \
        reshape_data_node.out_node().type == 'FullyConnected' and \
            can_repack_fully_connected_weights_nhwc_to_nchw(reshape_data_node.out_node()):
        log.info(
            'There is a FullyConnected layer after the node "{}" which weights will be repacked. So there is no '
            'need to insert Permute'.format(reshape_node.soft_get('name')))
        return
    graph.remove_edge(conv_data_node.id, reshape_node.id)

    permutation_order = PermuteAttrs.get_nchw_to_nhwc_permutation(
        len(conv_data_node.shape)).perm
    new_permute_op = Permute(graph, {'order': permutation_order})
    permute_data_node = new_permute_op.create_node_with_data(
        [conv_data_node], dict(name=conv_name + '/Permute_'))
    graph.create_edge(permute_data_node, reshape_node)
    # Disable permutation for Reshape and Concat layers attributes
    PermuteAttrs.set_permutation(reshape_node, reshape_data_node, None)
    reshape_node['nchw_layout'] = True
Example #3
0
    def replace_pattern(self, graph: Graph, match: dict):
        if graph.graph['layout'] != "NCHW":
            return

        node = match['op']

        in_node = node.in_node(0)
        out_node = node.out_node(0)
        group = int(node['group'])

        graph.remove_edge(in_node.id, node.id)
        graph.remove_edge(node.id, out_node.id)

        rows = group
        cols = in_node.shape[1] // group

        if rows * cols != in_node.shape[1]:
            raise Error("Group {} should divide input channels number {} without reminder for node {}"
                        "".format(group, in_node.shape[1], node.id))

        reshape_split = Reshape(graph, attrs={'name': node.id + '/Reshape_split_',
                                              'dim': np.array([in_node.shape[0], rows, cols, -1])})
        reshape_split_node = reshape_split.create_node_with_data([in_node])
        transpose = Permute(graph, attrs={'name': node.id + '/Transpose_',
                                          'order': np.array([0, 2, 1, 3])})
        transpose_node = transpose.create_node_with_data([reshape_split_node])
        reshape_concat = Reshape(graph, attrs={'name': node.id + '/Reshape_concat_',
                                               'dim': out_node.shape})
        reshape_concat.create_node_with_data([transpose_node], data_nodes=[out_node])
def permute_before_and_after(inp: Node, middle: Node, out: Node, input_order, output_order):
    """
        Insert two permutes: before middle node and after middle node.

        Both permutes has a given order (input/output).
    """
    # Permute before input
    permute = Permute(middle.graph, dict(order=np.array(input_order)))

    edge_attrs = deepcopy(middle.graph.get_edge_data(inp.id, middle.id)[0])
    middle.graph.remove_edge(inp.id, middle.id)
    new_inp = permute.create_node_with_data([inp], dict(name=middle.name + '/InputPermute'))
    middle.graph.add_edge(new_inp.id, middle.id, **edge_attrs)

    # Permute after output
    permute = Permute(middle.graph, dict(order=output_order))

    middle.graph.remove_edge(middle.id, out.id)
    new_out = Op._create_data_node(middle.graph, name=middle.name + '/WithoutPermute',
                                   attrs={'shape': out.shape[output_order]})
    middle.graph.add_edge(middle.id, new_out.id, key=0, out=0)
    permute.create_node_with_data([new_out], dict(name=middle.name + '/OutputPermute'), data_nodes=out)
    def add_output_reshape(graph: Graph, match: dict):
        """
        Since MXNet Y output shape is [batch_size, seq_len, hidden_size * num_directions] we need to add reshape
        from above common format [batch_size, num_directions, seq_len, hidden_size] to MXNet format.
        """
        lstm = match['rnn_layer']
        input = match['input']
        if not lstm.has_num_directions:
            return
        old_data_node = lstm.out_node(0)
        num_directions = 2 if lstm.direction in ['bidirectional'] else 1
        mxnet_shape = lstm.out_node(0).shape.copy()

        if lstm.batch_dim == 0:
            mo_shape = np.array([
                input.shape[lstm.batch_dim], input.shape[lstm.sequence_dim],
                lstm.hidden_size
            ],
                                dtype=np.int64)
        else:
            mo_shape = np.array([
                input.shape[lstm.sequence_dim], input.shape[lstm.batch_dim],
                lstm.hidden_size
            ],
                                dtype=np.int64)

        if lstm.has_num_directions:
            mo_shape = np.insert(mo_shape, 1, np.int64(num_directions))

        new_data = Op._create_data_node(graph,
                                        name=lstm.name +
                                        '/Data/Reshape_mxnet/',
                                        attrs={'shape': mo_shape})
        graph.remove_edge(lstm.id, old_data_node.id)
        graph.add_edge(lstm.id, new_data.id, key=0, out=0)

        # Add Permute
        permute_order = np.array([0, 2, 1, 3], dtype=np.int64)
        permute = Permute(graph, dict(order=permute_order))
        permute_data = permute.create_node_with_data([new_data],
                                                     dict(name=lstm.name +
                                                          '/Permute_mxnet/'))

        # Add Reshape
        reshape = Reshape(graph, dict(dim=mxnet_shape))
        reshape.create_node_with_data([permute_data],
                                      dict(name=lstm.name + '/Reshape_mxnet/'),
                                      data_nodes=[old_data_node])
Example #6
0
    def replace_pattern(graph: Graph, match: dict):
        reshape = match['reshape']
        assert len(reshape.in_nodes()) > 0
        if graph.graph['layout'] == 'NCHW' or reshape.has_and_set('nchw_layout') or\
                reshape.soft_get('correct_data_layout') is True:
            return

        input_node = reshape.in_node()
        output_node = reshape.out_node()
        input_shape = input_node.shape
        output_shape = output_node.shape

        if len(input_shape) >= 4 and len(output_shape) == 3:
            # Check that we will permute some shapes in this Reshape by our permutation pass
            layout = 'NCHW'
            c_idx = get_features_dim(layout, len(input_shape))
            hw_idx = [
                get_width_dim(layout, len(input_shape)),
                get_height_dim(layout, len(input_shape))
            ]
            if input_shape[c_idx] != 1 and np.any(
                    input_shape[hw_idx] != [1, 1]):
                # then nhwc -> nchw permutation can change shapes significantly
                # We need to wrap up node with NCHW -> NHWC permutes and don't touch it later
                permutation = PermuteAttrs.get_nchw_to_nhwc_permutation(
                    len(input_shape))
                permutation_back = PermuteAttrs.get_nchw_to_nhwc_permutation(
                    len(input_shape))

                # 1. Insert input Permute
                #    This Permute will permute input from original input layout to operation layout
                edge_attrs = graph.get_edge_data(input_node.id, reshape.id)[0]
                graph.remove_edge(input_node.id, reshape.id)

                permute_op = Permute(graph, {
                    'order': permutation.perm,
                    'name': reshape.name + '/Permute_'
                })
                permute_data_node = permute_op.create_node_with_data(
                    [input_node])

                graph.add_edge(permute_data_node.id, reshape.id, **edge_attrs)
Example #7
0
    def replace_pattern(self, graph: nx.MultiDiGraph, match: dict):
        if graph.graph['layout'] != 'NHWC':
            return

        if self.is_reshape_bad(match['reshape_pack'], match['reshape_unpack'],
                               match['strided_slice']):
            log.info("Reshape that pack/unpack several dimensions detected {}".
                     format(match['reshape_pack'].id))
            node_split = match['reshape_split']

            # insert Permute before reshape
            data_node = Op._create_data_node(
                graph, node_split.name + "/Permute_before_data")
            permute_before = Permute(
                graph,
                dict(name=node_split.name + "/Permute_before",
                     order=np.array([0, 2, 3, 1])))
            in_node = node_split.in_node(0)
            attrs = deepcopy(graph.get_edge_data(in_node.id, node_split.id)[0])
            graph.remove_edge(in_node.id, node_split.id)
            permute_before_node = permute_before.create_node_with_data(
                [in_node], permute_before.attrs, data_nodes=[data_node])
            graph.add_edge(permute_before_node.id, node_split.id, **attrs)

            node = match['reshape_pack']
            node['nchw_layout'] = True
            new_reshape_shape = np.concatenate(
                (np.array([node.in_node(0).shape[0]]),
                 np.array([np.prod(node.in_node(0).shape[[1, 2, 3]])]),
                 np.array([node.in_node(0).shape[-1]])))

            node.dim = new_reshape_shape

            # insert Permute after reshape
            data_node = Op._create_data_node(graph,
                                             node.name + "/Permute_after_data",
                                             {'shape': node.dim})
            permute_after = Permute(
                graph,
                dict(name=node.name + "/Permute_after",
                     order=np.array([0, 2, 1])))
            out_node = node.out_node(0)
            out_node.shape = new_reshape_shape[np.array([0, 2, 1])]
            attrs = deepcopy(graph.get_edge_data(node.id, out_node.id)[0])
            graph.remove_edge(node.id, out_node.id)

            permute_after_node = permute_after.create_node_with_data(
                [data_node], permute_after.attrs, data_nodes=[out_node])
            graph.add_edge(node.id, data_node.id, **attrs)

            # update softmax shape
            node_softmax = match['softmax']
            node_softmax.out_node(0).shape = out_node.shape

            # revert strided slice and reshape
            node_ss = match['strided_slice']
            node_unpack = match['reshape_unpack']

            unpack_out = node_unpack.out_node(0).id
            ss_out = node_ss.out_node(0).id

            #gather edge attributes
            soft_reshape_attrs = deepcopy(
                graph.get_edge_data(
                    node_softmax.out_node(0).id, node_unpack.id)[0])
            reshape_data_attrs = deepcopy(
                graph.get_edge_data(node_unpack.id, unpack_out)[0])
            reshape_ss_attrs = deepcopy(
                graph.get_edge_data(unpack_out, node_ss.id)[0])
            ss_out_attrs = deepcopy(graph.get_edge_data(node_ss.id, ss_out)[0])

            #remove all edges in Softmax->Reshape->StridedSlice chain
            graph.remove_edge(node_softmax.out_node(0).id, node_unpack.id)
            graph.remove_edge(node_unpack.id, unpack_out)
            graph.remove_edge(unpack_out, node_ss.id)
            graph.remove_edge(node_ss.id, ss_out)

            #add new edges to get chain Softmax->StridedSlice->Reshape
            graph.add_edge(
                node_softmax.out_node(0).id, node_ss.id, **soft_reshape_attrs)
            graph.add_edge(node_ss.id, unpack_out, **reshape_data_attrs)
            graph.add_edge(unpack_out, node_unpack.id, **reshape_ss_attrs)
            graph.add_edge(node_unpack.id, ss_out, **ss_out_attrs)

            #update output shape and parameters for StridedSlice
            node_ss.out_node(0).shape = np.zeros(3)
            node_ss.out_node(0).shape[0] = out_node.shape[0]
            node_ss.out_node(0).shape[1] = 1
            node_ss.out_node(0).shape[2] = out_node.shape[2]

            old_slices = node_ss.slices.copy()
            node_ss.slices = []
            node_ss.slices.append(old_slices[0])
            node_ss.slices.append(old_slices[-1])
            node_ss.slices.append(slice(0, out_node.shape[2], 1))
            node_ss.shrink_axis_mask = [False, False, False]
            node_ss.new_axis_mask = [False, False, False]

            #update Reshape attribute
            node_unpack.dim = np.delete(node_unpack.dim, 4)
            #prevent permute for reshape because it gives wrong result
            node_unpack['nchw_layout'] = True
            node_unpack.out_node(0)['nchw_layout'] = True
Example #8
0
    def find_and_replace_pattern(self, graph: Graph):
        for node in list(graph.nodes()):
            node = Node(graph, node)
            # Check that node layout mismatch with graph layout
            # For example: NHWC and NCHW or NCDHW and NDHWC
            if node.kind == 'op' and node.has_valid(
                    'layout') and node.layout != indices_mapping[len(
                        node.layout)][graph.graph['layout']]:
                input = node.in_node()
                output = node.out_node()

                # Calculate permutation for further Permute operations
                if graph.graph['layout'] == 'NCHW':
                    # if Node has NCHW and graph has NHWC layout
                    permutation = PermuteAttrs.get_nhwc_to_nchw_permutation(
                        len(node.layout))
                else:
                    # if Node has NHWC and graph has NCHW layout
                    permutation = PermuteAttrs.get_nchw_to_nhwc_permutation(
                        len(node.layout))

                # Schematic representation of transformation below
                #
                #                                           \            NCHW                              NCHW
                #            NHWC                        --  \            |  permutation       permutation  |
                #   data-->Convolution(example)-->data   --  /            |      |       NCHW      |        |
                #                                           /   data->Permute->data->Convolution->data->Permute->data

                # 1. Insert input Permute
                #    This Permute will permute input from original input layout to operation layout
                edge_attrs = graph.get_edge_data(input.id, node.id)[0]
                graph.remove_edge(input.id, node.id)

                input_permute_op = Permute(graph, {'order': permutation.perm})
                input_permute_data_node = input_permute_op.create_node_with_data(
                    [input], dict(name=node.name + '/Permute_'))

                graph.add_edge(input_permute_data_node.id, node.id,
                               **edge_attrs)

                # 2. Insert output Permute
                #    This Permute will permute output from operation layout to original input layout
                edge_attrs = graph.get_edge_data(node.id, output.id)[0]
                graph.remove_edge(node.id, output.id)

                input_data_node = Op.create_data_node(
                    graph, node, {'shape': output.shape[permutation.perm]},
                    edge_attrs)

                output_permute_op = Permute(graph, {'order': permutation.inv})
                output_permute_op.create_node_with_data([input_data_node],
                                                        dict(name=node.name +
                                                             '/Permute_'),
                                                        data_nodes=output)

                # 3. Add permutations for Node
                #    Here we use permutation mechanism where data nodes takes permutation attribute.
                #    And then we call permute_attrs method that permutes node attributes according to permutations on
                #    data nodes.
                node.in_node()['permutation'] = permutation
                node.out_node()['permutation'] = permutation
                node.permute_attrs.permute_attrs(node)

                node.in_node()['permutation'] = None
                node.out_node()['permutation'] = None