Esempio n. 1
0
 def extract(cls, node):
     attrs = {
         'op': __class__.op,
         'order': node.module.order,
     }
     Transpose.update_node_stat(node, attrs)
     return cls.enabled
Esempio n. 2
0
    def test_transpose_infer_1(self, order):
        graph = self._create_graph_with_transpose(order)
        transpose_node = Node(graph, 'transpose')

        Transpose.infer(transpose_node)

        ref = [transpose_node.in_node().shape[i] for i in order]
        self.assertTrue(
            np.array_equal(transpose_node.out_node().shape, np.array(ref)))
Esempio n. 3
0
 def extract(cls, node):
     # In case of undefined 'perm' attribute, Transpose operation in ONNX reverse the dimensions
     order = onnx_attr(node, 'perm', 'ints', default=None)
     attrs = {
         'order': int64_array(order) if order is not None else None,
         'reverse_order': order is None
     }
     Transpose.update_node_stat(node, attrs)
     return cls.enabled
Esempio n. 4
0
    def test_transpose_infer_2(self):
        order = None
        graph = self._create_graph_with_transpose(order)
        transpose_node = Node(graph, 'transpose')
        transpose_node['reverse_order'] = True
        Transpose.infer(transpose_node)

        ref = np.array([x for x in reversed(transpose_node.in_node().shape)])
        self.assertTrue(
            np.array_equal(transpose_node.out_node().shape, ref),
            "Shapes are not the same: {} and {}".format(
                transpose_node.out_node().shape, ref))
Esempio n. 5
0
    def replace_pattern(self, graph: Graph, match: [str, Node]):
        swapaxis = match['op']
        assert len(swapaxis.in_ports()) == 1
        assert swapaxis.has_and_set('order')
        order = swapaxis.order

        swapaxis.add_input_port(1)
        const = Const(graph, {'value': order, 'name': swapaxis.soft_get('name', swapaxis.id) + '/Order'}).create_node()
        const.out_port(0).connect(swapaxis.in_port(1))

        Transpose.update_node_stat(swapaxis, {'need_shape_inference': True})

        del swapaxis['order']
Esempio n. 6
0
    def add_output_reshape(graph: Graph, match: dict):
        """
        Since MXNet Y output shape is [batch_size, seq_len, hidden_size * num_directions] we need to add reshape
        from above common format [batch_size, num_directions, seq_len, hidden_size] to MXNet format.
        """
        lstm = match['rnn_layer']
        input = match['input']
        if not lstm.has_num_directions:
            return
        old_data_node = lstm.out_node(0)
        num_directions = 2 if lstm.direction in ['bidirectional'] else 1
        mxnet_shape = lstm.out_node(0).shape.copy()

        if lstm.batch_dim == 0:
            mo_shape = shape_array([
                input.shape[lstm.batch_dim], input.shape[lstm.sequence_dim],
                lstm.hidden_size
            ])
        else:
            mo_shape = shape_array([
                input.shape[lstm.sequence_dim], input.shape[lstm.batch_dim],
                lstm.hidden_size
            ])

        if lstm.has_num_directions:
            mo_shape = shape_insert(mo_shape, 1, np.int64(num_directions))

        lstm_name = lstm.soft_get('name', lstm.id)

        new_data = Op._create_data_node(graph,
                                        name=lstm_name +
                                        '/Data/Reshape_mxnet/',
                                        attrs={'shape': mo_shape})
        graph.remove_edge(lstm.id, old_data_node.id)
        graph.add_edge(lstm.id, new_data.id, key=0, out=0)

        # Add Transpose
        permute_order = Const(
            graph, {
                'name': lstm_name + '/Transpose_mxnet_order',
                'value': int64_array([0, 2, 1, 3])
            }).create_node_with_data()
        permute_data = Transpose(graph, {
            'name': lstm_name + '/Transpose_mxnet/'
        }).create_node_with_data([new_data, permute_order])

        # Add Reshape
        reshape = Reshape(graph, {'name': lstm_name + '/Reshape_mxnet/'})
        reshape_dim_data = Const(
            graph, {
                'name': lstm_name + '/Reshape_mxnet_dim',
                'value': int64_array(unmask_shape(mxnet_shape))
            }).create_node_with_data()

        reshape.create_node_with_data([permute_data, reshape_dim_data],
                                      dict(),
                                      data_nodes=[old_data_node])
Esempio n. 7
0
    def replace_op(self, graph: Graph, node: Node):
        pb = node.parameters
        weights_size = read_binary_integer32_token(pb)
        weights = read_blob(pb, weights_size, dtype=np.int32) - 1

        node_name = node.soft_get('name', node.id)
        const_attrs = {
            'name': node_name + '/indexes',
            'value': mo_array(weights),
            'shape': [weights_size],
            'data_type': np.int32
        }
        indexes_node = Const(graph).create_node(attrs=const_attrs)

        perm_in_1 = Const(graph, {
            'value': int64_array([1, 0]),
            'name': node_name + '/order'
        }).create_node()
        perm1_node = Transpose(graph, {
            'name': node_name + '/input_permute'
        }).create_node([node.in_node(0)])
        perm1_node.in_port(0).connect(node.in_port(0).get_source())
        perm1_node.in_port(1).connect(perm_in_1.out_port(0))

        gather_node = create_op_with_const_inputs(
            graph, Gather, {2: int64_array(0)},
            {'name': node_name + '/gather'})
        gather_node.in_port(0).connect(perm1_node.out_port(0))
        gather_node.in_port(1).connect(indexes_node.out_port(0))

        perm2_node = Transpose(graph, {
            'name': node_name + '/output_permute'
        }).create_node()
        perm2_node.in_port(0).connect(gather_node.out_port(0))
        perm2_node.in_port(1).connect(perm_in_1.out_port(0))

        return [perm2_node.id]
Esempio n. 8
0
 def extract(cls, node):
     attrs = get_mxnet_layer_attrs(node.symbol_dict)
     order = list(attrs.tuple("axes", int, None))
     Transpose.update_node_stat(node, {'order': np.array(order, dtype=np.int32)})
     return cls.enabled
Esempio n. 9
0
 def extract(cls, node):
     order = node.pb.permute_param.order
     Transpose.update_node_stat(node,
                                {'order': mo_array(order, dtype=np.int32)})
     return cls.enabled
Esempio n. 10
0
 def extract(cls, node):
     Transpose.update_node_stat(node, {'order': None})
     return cls.enabled
Esempio n. 11
0
    def find_and_replace_pattern(self, graph: Graph):
        for node in list(graph.nodes()):
            node = Node(graph, node)
            node_name = node.soft_get('name', node.id)
            # Check that node layout mismatch with graph layout
            # For example: NHWC and NCHW or NCDHW and NDHWC
            if node.kind == 'op' and node.has_valid(
                    'layout') and node.layout != indices_mapping[len(
                        node.layout)][graph.graph['layout']]:
                input = node.in_node()
                output = node.out_node()

                # Calculate permutation for further Transpose operations
                if graph.graph['layout'] == 'NCHW':
                    # if Node has NCHW and graph has NHWC layout
                    permutation = PermuteAttrs.get_nhwc_to_nchw_permutation(
                        len(node.layout))
                else:
                    # if Node has NHWC and graph has NCHW layout
                    permutation = PermuteAttrs.get_nchw_to_nhwc_permutation(
                        len(node.layout))

                # Schematic representation of transformation below
                #
                #                                           \            NCHW                              NCHW
                #            NHWC                        --  \            |  permutation       permutation  |
                #   data-->Convolution(example)-->data   --  /            |      |       NCHW      |        |
                #                                           /   data->Transpose->data->Convolution->data->Transpose->data

                # 1. Insert input Transpose
                #    This Transpose will permute input from original input layout to operation layout
                edge_attrs = graph.get_edge_data(input.id, node.id)[0]
                graph.remove_edge(input.id, node.id)

                input_permute_name = node_name + '/input_transpose'
                input_order_const = Const(
                    graph, {
                        'name': input_permute_name + '/order',
                        'value': permutation.perm
                    }).create_node_with_data()
                input_permute_op = Transpose(graph,
                                             {'name': input_permute_name})
                input_permute_data_node = input_permute_op.create_node_with_data(
                    [input, input_order_const])

                graph.add_edge(input_permute_data_node.id, node.id,
                               **edge_attrs)

                # 2. Insert output Transpose
                #    This Transpose will permute output from operation layout to original input layout
                edge_attrs = graph.get_edge_data(node.id, output.id)[0]
                graph.remove_edge(node.id, output.id)

                input_data_node = Op.create_data_node(
                    graph, node, {'shape': output.shape[permutation.perm]},
                    edge_attrs)

                output_permute_name = node_name + '/output_transpose'
                output_order_const = Const(
                    graph, {
                        'name': output_permute_name + '/order',
                        'value': permutation.inv
                    }).create_node_with_data()
                output_permute_op = Transpose(graph, {
                    'name': output_permute_name
                }).create_node_with_data([input_data_node, output_order_const],
                                         data_nodes=output)

                # 3. Add permutations for Node
                #    Here we use permutation mechanism where data nodes takes permutation attribute.
                #    And then we call permute_attrs method that permutes node attributes according to permutations on
                #    data nodes.
                node.in_node()['permutation'] = permutation
                node.out_node()['permutation'] = permutation
                node.permute_attrs.permute_attrs(node)

                node.in_node()['permutation'] = None
                node.out_node()['permutation'] = None