def replace_pattern(graph: Graph, match: dict):
        node = match['op']
        if not node.has_port('in', 2) or node.in_port(
                2).disconnected() or not node.has_and_set('shape_input'):
            return

        if node.has_valid('layout') and not node.layout.startswith(
                'NC') and graph.graph['layout'] == 'NCHW':
            input_shape_rank = len(node.in_port(0).data.get_shape())
            permutation = PermuteAttrs.get_nhwc_to_nchw_permutation(
                input_shape_rank)

            data_node = node.in_node(2)

            name = node.soft_get('name', node.id) + '/ShapeGather'
            const = Const(
                graph, {
                    'value': permutation.perm,
                    'name': name + '/Const',
                    'need_shape_inference': True
                }).create_node_with_data()
            axis_const = Const(graph, {
                'value': int64_array(0),
                'name': name + '/Axis'
            }).create_node_with_data()
            gather = Gather(graph, {
                'name': name,
                'need_shape_inference': True
            }).create_node_with_data([data_node, const, axis_const])
            attrs = graph.get_edge_data(data_node.id, node.id, key=0).copy()

            graph.add_edge(gather.id, node.id, **attrs)
            graph.remove_edge(data_node.id, node.id)
示例#2
0
    def infer(node: Node):
        tf_strided_slice_infer(node)

        if node.graph.graph['layout'] == 'NHWC' and node.out_port(
                0).data.get_value() is None:
            PermuteAttrs.create_permute_attrs(
                node,
                attrs=[
                    ('shrink_axis_mask', 'input:0', permute_masks),
                    ('new_axis_mask', 'input:0', permute_masks),
                    ('ellipsis_mask', 'input:0', permute_masks),
                    ('begin_mask', 'input:0', permute_masks),
                    ('end_mask', 'input:0', permute_masks),
                ])
            for i in range(1, len(node.in_nodes())):
                if node.in_node(
                        i).value is not None and node.in_node(i).shape[0] > 3:
                    perm = PermuteAttrs.get_nhwc_to_nchw_permutation(
                        len(node.in_node(0).shape))
                    node.in_node(i).value = permute_array_with_ellipsis(
                        node, perm,
                        node.in_node(i).value, 0)

            # due to permutation from nhwc to nchw we will extend all masks and inputs
            idx = np.nonzero(node.ellipsis_mask)
            node.ellipsis_mask[idx] = 0
示例#3
0
def permute_array_with_ellipsis(node: Node,
                                permutation: PermuteAttrs.Permutation,
                                array: np.array, ins_value: int):
    """
    This function permutes masks according to permutation parameter. Several cases should be processed:
    * Some dimensions can be omitted in mask according to ellipsis mask
    * Mask length can be less than length of output dimensions plus shrinked dimensions
    * Mask have the same or more length than output
    """
    attr_mask_extended = list(array)

    # If input and output have length of shape 3 and less, no need to permute
    if len(node.in_node().shape) < 4 and len(node.out_node().shape) < 4:
        return attr_mask_extended

    # Length of mask is less than length of output ()plus shrinked dimensions then we should extend it before permutation
    if len(attr_mask_extended) < len(
            node.out_node(0).shape) + np.count_nonzero(node.shrink_axis_mask):
        # ellipsis is set, add dimensions in right place otherwise insert in the end
        if np.any(node.ellipsis_mask):
            idx = np.nonzero(node.ellipsis_mask)
            assert len(idx[0]) == 1
            id = idx[0][0]
        else:
            id = len(attr_mask_extended) - 1

        ellips_ext = len(node.out_node(0).shape) + np.count_nonzero(
            node.shrink_axis_mask) - len(attr_mask_extended)
        for i in range(0, ellips_ext):
            attr_mask_extended.insert(id + i + 1, ins_value)
        # permute extended mask
        perm = PermuteAttrs.get_nhwc_to_nchw_permutation(
            len(attr_mask_extended))
        attr_mask_extended = np.array(attr_mask_extended)[perm.perm]
        return attr_mask_extended
    else:
        perm_len = len(node.out_node(0).shape) + np.count_nonzero(
            node.shrink_axis_mask)
        perm = PermuteAttrs.get_nhwc_to_nchw_permutation(perm_len)
        perm_list = list(perm.perm)
        # if mask length is more than output, just add tail that will not be permuted to avoid error
        for i in range(perm_len, len(attr_mask_extended)):
            perm_list.append(i)
        return np.array(attr_mask_extended,
                        dtype=np.int64)[np.array(perm_list)]
示例#4
0
def permute_array(node: Node, array: np.array):
    """
    This function permutes masks according to permutation parameter. Mask have the same or more length than output
    """
    attr_mask_extended = list(array)

    # If input and output have length of shape 3 and less, no need to permute
    if len(node.in_port(0).data.get_shape()) < 4 and len(node.out_port(0).data.get_shape()) < 4:
        return attr_mask_extended

    perm_len = len(node.out_port(0).data.get_shape()) + np.count_nonzero(node.shrink_axis_mask)
    perm = PermuteAttrs.get_nhwc_to_nchw_permutation(perm_len)
    perm_list = list(perm.perm)
    # if mask length is more than output, just add tail that will not be permuted to avoid error
    for i in range(perm_len, len(attr_mask_extended)):
        perm_list.append(i)
    return int64_array(attr_mask_extended)[int64_array(perm_list)]
示例#5
0
def strided_slice(op_node: Node, port_info: str, input_port: int):
    """
    StridedSLice must be permuted even if input or output tensors have rank lesser than 4
    e.g. input_shape = (1, 10, 10), out = input[:, 0:10, :, new_axis], input_rank < 4
    input_shape = (1, 10, 10, 3), out = input[:, 0:5, 0:4, 0], output_rank < 4
    in both examples slice_rank is >= 4
    slice_rank is defined by length of begin, end, strides (they all are of the same length)
    """
    permutation_data_node = get_node_with_permutation(op_node, port_info)
    assert permutation_data_node.has_and_set('permutation'), 'Data node "{}" does not have permutation for node {}, ' \
                                                             'port_info "{}".'.format(permutation_data_node.id,
                                                                                      op_node.id, port_info)
    permute_indices_for_gather = permutation_data_node.permutation.perm
    if len(permute_indices_for_gather) == 0:
        return
    from mo.ops.op import PermuteAttrs

    slice_rank = op_node.in_port(input_port).data.get_shape()[0]  # length of begin, end or strides
    permute_indices_for_gather = PermuteAttrs.get_nhwc_to_nchw_permutation(slice_rank).perm
    reorder_inputs_for_shape_or_slice(op_node, input_port, permute_indices_for_gather)
    def find_and_replace_pattern(self, graph: Graph):
        for node in list(graph.nodes()):
            node = Node(graph, node)
            # Check that node layout mismatch with graph layout
            # For example: NHWC and NCHW or NCDHW and NDHWC
            if node.kind == 'op' and node.has_valid(
                    'layout') and node.layout != indices_mapping[len(
                        node.layout)][graph.graph['layout']]:
                input = node.in_node()
                output = node.out_node()

                # Calculate permutation for further Transpose operations
                if graph.graph['layout'] == 'NCHW':
                    # if Node has NCHW and graph has NHWC layout
                    permutation = PermuteAttrs.get_nhwc_to_nchw_permutation(
                        len(node.layout))
                else:
                    # if Node has NHWC and graph has NCHW layout
                    permutation = PermuteAttrs.get_nchw_to_nhwc_permutation(
                        len(node.layout))

                # Schematic representation of transformation below
                #
                #                                           \            NCHW                              NCHW
                #            NHWC                        --  \            |  permutation       permutation  |
                #   data-->Convolution(example)-->data   --  /            |      |       NCHW      |        |
                #                                           /   data->Transpose->data->Convolution->data->Transpose->data

                # 1. Insert input Transpose
                #    This Transpose will permute input from original input layout to operation layout
                edge_attrs = graph.get_edge_data(input.id, node.id)[0]
                graph.remove_edge(input.id, node.id)

                input_order_const = Const(graph, {
                    'value': permutation.perm
                }).create_node_with_data()
                input_permute_op = Transpose(
                    graph, dict(name=node.name + '/Transpose_'))
                input_permute_data_node = input_permute_op.create_node_with_data(
                    [input, input_order_const])

                graph.add_edge(input_permute_data_node.id, node.id,
                               **edge_attrs)

                # 2. Insert output Transpose
                #    This Transpose will permute output from operation layout to original input layout
                edge_attrs = graph.get_edge_data(node.id, output.id)[0]
                graph.remove_edge(node.id, output.id)

                input_data_node = Op.create_data_node(
                    graph, node, {'shape': output.shape[permutation.perm]},
                    edge_attrs)

                output_order_const = Const(graph, {
                    'value': permutation.inv
                }).create_node_with_data()
                output_permute_op = Transpose(
                    graph, dict(name=node.name +
                                '/Transpose_')).create_node_with_data(
                                    [input_data_node, output_order_const],
                                    data_nodes=output)

                # 3. Add permutations for Node
                #    Here we use permutation mechanism where data nodes takes permutation attribute.
                #    And then we call permute_attrs method that permutes node attributes according to permutations on
                #    data nodes.
                node.in_node()['permutation'] = permutation
                node.out_node()['permutation'] = permutation
                node.permute_attrs.permute_attrs(node)

                node.in_node()['permutation'] = None
                node.out_node()['permutation'] = None
示例#7
0
    def create_topK_net(shape, k, ir_version, use_new_frontend):
        """
            Tensorflow net:

                          |-> Values
            Input -> TopK |
                          |-> Indices


            IR net:

                          |-> Values
            Input -> TopK |
                          |-> Indices

        """

        #
        #   Create Tensorflow model
        #

        import tensorflow as tf

        tf.compat.v1.reset_default_graph()

        # Create the graph and model
        with tf.compat.v1.Session() as sess:
            shape_net = permute_nchw_to_nhwc(shape)

            input_tensor = tf.compat.v1.placeholder(tf.int32,
                                                    shape=shape_net,
                                                    name='Input')
            values, indices = tf.nn.top_k(input_tensor,
                                          k=k,
                                          sorted=True,
                                          name='Operation')

            tf.compat.v1.global_variables_initializer()
            tf_net = sess.graph_def

        #
        #   Create reference IR net
        #
        topk_output_shape = shape.copy()
        inverse_nhwc_nchw = PermuteAttrs.get_nhwc_to_nchw_permutation(
            len(topk_output_shape)).inv
        topk_axis = permute_axis(
            len(topk_output_shape) - 1,
            inverse_nhwc_nchw)  # we need to permute axis attribute
        topk_output_shape[topk_axis] = k

        ref_net = None

        if check_ir_version(10, None, ir_version) and not use_new_frontend:
            nodes_attributes = {
                'input': {
                    'kind': 'op',
                    'type': 'Parameter'
                },
                'input_data': {
                    'shape': shape,
                    'kind': 'data'
                },
                'Const_k_input_data': {
                    'shape': [],
                    'kind': 'data'
                },
                'Const_k': {
                    'kind': 'op',
                    'type': 'Const'
                },
                'Const_k_data': {
                    'shape': [],
                    'kind': 'data'
                },
                'TopK': {
                    'kind': 'op',
                    'type': 'TopK',
                    'axis': topk_axis,
                    'mode': 'max',
                    'sort': 'value'
                },
                'TopK_data_1': {
                    'shape': topk_output_shape,
                    'kind': 'data'
                },
                'TopK_data_2': {
                    'shape': topk_output_shape,
                    'kind': 'data'
                },
                'result_1': {
                    'kind': 'op',
                    'type': 'Result'
                },
                'result_2': {
                    'kind': 'op',
                    'type': 'Result'
                },
            }

            ref_net = build_graph(nodes_attributes, [
                ('input', 'input_data'),
                ('input_data', 'TopK', {
                    'in': 0
                }),
                ('Const_k_input_data', 'Const_k'),
                ('Const_k', 'Const_k_data'),
                ('Const_k_data', 'TopK', {
                    'in': 1
                }),
                ('TopK', 'TopK_data_1', {
                    'out': 0
                }),
                ('TopK', 'TopK_data_2', {
                    'out': 1
                }),
                ('TopK_data_1', 'result_1'),
                ('TopK_data_2', 'result_2'),
            ])

        return tf_net, ref_net
示例#8
0
def permute_nhwc_to_nchw(shape):
    perm = PermuteAttrs.get_nhwc_to_nchw_permutation(len(shape)).perm
    new_shape = np.array(shape)[perm]
    return new_shape
示例#9
0
def permute_nhwc_to_nchw(shape, use_new_frontend=False):
    if use_new_frontend:
        return shape
    perm = PermuteAttrs.get_nhwc_to_nchw_permutation(len(shape)).perm
    new_shape = np.array(shape)[perm]
    return new_shape