Ejemplo n.º 1
0
    def convert_reshape(self, op):
        """Convert TFLite reshape"""
        try:
            from tflite.BuiltinOptions import BuiltinOptions
            from tflite.Operator import Operator
            from tflite.ReshapeOptions import ReshapeOptions
        except ImportError:
            raise ImportError("The tflite package must be installed")

        assert isinstance(op, Operator)
        input_tensors = self.get_input_tensors(op)
        assert len(input_tensors) == 2, "input tensors length should be 2"
        input_tensor = input_tensors[0]
        input_tensor_idx = input_tensor.tensor_idx

        assert op.BuiltinOptionsType() == BuiltinOptions.ReshapeOptions
        op_options = op.BuiltinOptions()
        reshape_options = ReshapeOptions()
        reshape_options.Init(op_options.Bytes, op_options.Pos)
        target_shape = reshape_options.NewShapeAsNumpy()

        in_expr = self.get_expr(input_tensor_idx)
        out = _op.reshape(in_expr, newshape=tuple(target_shape))

        return out
Ejemplo n.º 2
0
    def convert_reshape(self, op):
        """Convert TFLite reshape"""
        try:
            from tflite.BuiltinOptions import BuiltinOptions
            from tflite.Operator import Operator
            from tflite.ReshapeOptions import ReshapeOptions
        except ImportError:
            raise ImportError("The tflite package must be installed")

        assert isinstance(op, Operator)
        input_tensors = self.get_input_tensors(op)
        assert len(input_tensors) == 2, "input tensors length should be 2"
        input_tensor = input_tensors[0]
        input_tensor_idx = input_tensor.tensor_idx

        assert op.BuiltinOptionsType() == BuiltinOptions.ReshapeOptions
        op_options = op.BuiltinOptions()
        reshape_options = ReshapeOptions()
        reshape_options.Init(op_options.Bytes, op_options.Pos)
        target_shape = reshape_options.NewShapeAsNumpy()
        input_shape_length = len(input_tensor.tensor.ShapeAsNumpy())

        in_expr = self.get_expr(input_tensor_idx)

        if input_shape_length in (1, 2):
            # The rule is channel first (after N but before H, W).
            # length of 1 means N*H*W*C, do nothing.
            # length of 2 means N*H*W, C, do nothing.
            pass
        elif input_shape_length == 3:
            # convert N C H*W to N H*W C
            in_expr = _op.transpose(in_expr, axes=(0, 2, 1))
        elif input_shape_length == 4:
            # convert input to N H W C, then reshape to target shape,
            # finally convert back if necessary
            in_expr = _op.transpose(in_expr, axes=(0, 2, 3, 1))
        else:
            msg = 'Input shape length {} for operator Reshape is not valid.'
            raise tvm.error.OpAttributeInvalid(msg.format(input_shape_length))

        out = _op.reshape(in_expr, newshape=tuple(target_shape))

        # The rule is channel first.
        # 1: N*H*W*C
        # 2: N*H*W, C
        # 3: N H W C, reshape to N H*W C, transpose to N C H*W
        # 4: N H W C, transpose to N C H W
        # add more if we need target shapes in future
        if len(target_shape) == 1 or len(target_shape) == 2:
            pass
        elif len(target_shape) == 3:
            out = _op.transpose(out, axes=(0, 2, 1))
        elif len(target_shape) == 4:
            out = _op.transpose(out, axes=(0, 3, 1, 2))
        else:
            raise tvm.error.OpAttributeInvalid(
                'Length of target shape must be between 1 and 5 for operator Reshape.'
            )

        return out
Ejemplo n.º 3
0
class Reshape(Layer):
    def __init__(self, op, op_type, tflite_interpreter):
        Layer.__init__(self, op, op_type, tflite_interpreter)

        if op.BuiltinOptions() is None:
            self.tflite_reshape_parser = None
        else:
            self.tflite_reshape_parser = ReshapeOptions()
            self.tflite_reshape_parser.Init(op.BuiltinOptions().Bytes,
                                            op.BuiltinOptions().Pos)

    def generate(self):
        node_output_detail = self.tflite_interpreter._get_tensor_details(
            self.op.Outputs(0))
        node_input_detail = self.tflite_interpreter._get_tensor_details(
            self.op.Inputs(0))

        out_dim = node_output_detail['shape']
        in_dim = node_input_detail['shape']

        dims = list(range(len(in_dim)))
        dims = dims[:1] + dims[2:] + dims[1:2]

        # add transpose
        transpose_before_node_name = 'transpose_node_before_reshape_' + self.node_name
        transpose_before_node = onnx.helper.make_node(
            'Transpose',
            inputs=self.input_nodes_name,
            outputs=[transpose_before_node_name],
            perm=dims,
            name=transpose_before_node_name)
        # update tables
        self.node_list.append(transpose_before_node)

        reshape_node_name = self.node_name
        shape_tensor_name = 'shape_tensor_' + self.node_name
        shape_node_name = 'shape_const_' + self.node_name

        # no attribute 'new_shape', should be op 'squeeze'
        if self.tflite_reshape_parser is None:
            new_shape = np.array(out_dim, dtype='int64')
        elif not self.tflite_reshape_parser.NewShapeIsNone():
            new_shape = np.array(self.tflite_reshape_parser.NewShapeAsNumpy(),
                                 dtype='int64')

        shape_tensor = onnx.helper.make_tensor(shape_tensor_name,
                                               TensorProto.INT64,
                                               new_shape.shape, new_shape)
        shape_node = helper.make_node("Constant", [], [shape_node_name],
                                      name=shape_node_name,
                                      value=shape_tensor)

        reshape_node = onnx.helper.make_node(
            'Reshape',
            inputs=[transpose_before_node_name, shape_node_name],
            outputs=[reshape_node_name],
            name=reshape_node_name)
        # update tables
        self.node_list.append(shape_node)
        self.node_list.append(reshape_node)

        # no attribute 'new_shape',
        if self.tflite_reshape_parser is None:
            pass
        elif self.tflite_reshape_parser.NewShapeIsNone:
            dims = list(range(len(out_dim)))
            dims = dims[:1] + dims[-1:] + dims[1:-1]
            # add transpose
            transpose_after_node_name = 'transpose_node_after_reshape_' + self.node_name
            transpose_after_node = onnx.helper.make_node(
                'Transpose',
                inputs=[reshape_node_name],
                outputs=[transpose_after_node_name],
                perm=dims,
                name=transpose_after_node_name)

            # update tables
            self.node_list.append(transpose_after_node)

            # change output node's input_name
            for o_n in self.output_nodes:
                for idx, o_n_i_n in enumerate(o_n.input_nodes_name):
                    if o_n_i_n == self.node_name:
                        o_n.input_nodes_name[idx] = transpose_after_node_name

        return self.node_list, self.value_infos, self.weight_node_list