示例#1
0
    def convert_squeeze(self, op):
        """Convert TFLite squeeze"""
        try:
            from tflite.BuiltinOptions import BuiltinOptions
            from tflite.Operator import Operator
            from tflite.SqueezeOptions import SqueezeOptions
        except ImportError:
            raise ImportError("The tflite package must be installed")

        assert isinstance(op, Operator)
        input_tensors = self.get_input_tensors(op)
        output_tensors = self.get_output_tensors(op)
        assert len(input_tensors) == 1, "input tensors length should be 1"
        assert len(output_tensors) == 1, "output tensors length should be 1"
        input_tensor = input_tensors[0]
        input_tensor_idx = input_tensor.tensor_idx

        assert op.BuiltinOptionsType() == BuiltinOptions.SqueezeOptions
        op_options = op.BuiltinOptions()
        squeeze_options = SqueezeOptions()
        squeeze_options.Init(op_options.Bytes, op_options.Pos)
        squeeze_axis = squeeze_options.SqueezeDimsAsNumpy()

        in_expr = self.get_expr(input_tensor_idx)
        out = _op.squeeze(in_expr, axis=tuple(squeeze_axis))

        return out
示例#2
0
class Squeeze(Layer):
    def __init__(self, op, op_type, tflite_interpreter):
        Layer.__init__(self, op, op_type, tflite_interpreter)

        self.tflite_squeeze_parser = SqueezeOptions()
        self.tflite_squeeze_parser.Init(op.BuiltinOptions().Bytes,
                                        op.BuiltinOptions().Pos)

    def generate(self):
        squeeze_node_name = self.node_name
        squeeze_node = onnx.helper.make_node(
            'Squeeze',
            inputs=self.input_nodes_name,
            outputs=[squeeze_node_name],
            axes=utils.channel_last_2_channel_first_axis_mapping(
                self.tflite_squeeze_parser.SqueezeDimsAsNumpy().tolist()),
            name=squeeze_node_name)

        # update tables
        self.node_list.append(squeeze_node)

        return self.node_list, self.value_infos, self.weight_node_list
示例#3
0
文件: tflite.py 项目: zhyj3038/tvm
    def convert_squeeze(self, op):
        """Convert TFLite squeeze"""
        try:
            from tflite.BuiltinOptions import BuiltinOptions
            from tflite.Operator import Operator
            from tflite.SqueezeOptions import SqueezeOptions
        except ImportError:
            raise ImportError("The tflite package must be installed")

        assert isinstance(op, Operator)
        input_tensors = self.get_input_tensors(op)
        output_tensors = self.get_output_tensors(op)
        assert len(input_tensors) == 1, "input tensors length should be 1"
        assert len(output_tensors) == 1, "output tensors length should be 1"
        input_tensor = input_tensors[0]
        input_tensor_idx = input_tensor.tensor_idx

        assert op.BuiltinOptionsType() == BuiltinOptions.SqueezeOptions
        op_options = op.BuiltinOptions()
        squeeze_options = SqueezeOptions()
        squeeze_options.Init(op_options.Bytes, op_options.Pos)
        squeeze_axis = squeeze_options.SqueezeDimsAsNumpy()
        input_shape_length = len(input_tensor.tensor.ShapeAsNumpy())
        output_shape_length = len(output_tensors[0].tensor.ShapeAsNumpy())

        in_expr = self.get_expr(input_tensor_idx)

        # TFLite is N H W C, our layout is N C H W
        if input_shape_length in (1, 2):
            # The rule is channel first (after N but before H, W).
            # length of 1 means N*H*W*C, do nothing.
            # length of 2 means N*H*W, C, do nothing.
            pass
        elif input_shape_length == 3:
            # convert N C H*W to N H*W C
            in_expr = _op.transpose(in_expr, axes=(0, 2, 1))
        elif input_shape_length == 4:
            # convert input to N H W C, then reshape to target shape,
            # finally convert back if necessary
            in_expr = _op.transpose(in_expr, axes=(0, 2, 3, 1))
        else:
            msg = 'Input shape length {} for operator Squeeze is not valid.'
            raise tvm.error.OpAttributeInvalid(msg.format(input_shape_length))

        out = _op.squeeze(in_expr, axis=tuple(squeeze_axis))

        # The rule is channel first.
        # 1: N*H*W*C
        # 2: N*H*W, C
        # 3: N H W C, reshape to N H*W C, transpose to N C H*W
        # 4: N H W C, transpose to N C H W
        # add more if we need target shapes in future
        if output_shape_length in (1, 2):
            pass
        elif output_shape_length == 3:
            out = _op.transpose(out, axes=(0, 2, 1))
        elif output_shape_length == 4:
            out = _op.transpose(out, axes=(0, 3, 1, 2))
        else:
            msg = 'Output shape length {} for operator Squeeze is not valid.'
            raise tvm.error.OpAttributeInvalid(msg.format(output_shape_length))

        return out