Пример #1
0
class Concatenation(Layer):
    def __init__(self, op, op_type, tflite_interpreter):
        Layer.__init__(self, op, op_type, tflite_interpreter)

        self.tflite_concat_parser = ConcatenationOptions()
        self.tflite_concat_parser.Init(self.op.BuiltinOptions().Bytes,
                                       self.op.BuiltinOptions().Pos)

    def generate(self):

        prev_node_names = self.input_nodes_name.copy()

        concat_node_name = self.node_name
        concat_node = onnx.helper.make_node(
            'Concat',
            inputs=prev_node_names,
            outputs=[concat_node_name],
            axis=utils.channel_last_2_channel_first_axis_mapping(
                [self.tflite_concat_parser.Axis()])[0],
            name=concat_node_name)

        # update tables
        self.node_list.append(concat_node)

        return self.node_list, self.value_infos, self.weight_node_list

    def defuse_activation_function(self):
        return defused_activation_node_generator(
            activation_function_type=self.tflite_concat_parser.
            FusedActivationFunction(),
            op=self.op,
            tflite_interpreter=self.tflite_interpreter)
Пример #2
0
    def convert_concatenation(self, op):
        """Convert TFLite concatenation"""
        try:
            from tflite.Operator import Operator
            from tflite.ConcatenationOptions import ConcatenationOptions
            from tflite.BuiltinOptions import BuiltinOptions
            from tflite.ActivationFunctionType import ActivationFunctionType
        except ImportError:
            raise ImportError("The tflite package must be installed")

        assert isinstance(op, Operator)
        input_tensors = self.get_input_tensors(op)
        assert len(input_tensors) >= 1, "input tensors should greater than 1"
        in_exprs = [self.get_expr(input_tensor.tensor_idx) for input_tensor in input_tensors]

        output_tensors = self.get_output_tensors(op)
        assert len(output_tensors) == 1, "output tensors should be 1"

        assert op.BuiltinOptionsType() == BuiltinOptions.ConcatenationOptions
        op_options = op.BuiltinOptions()
        concatenation_options = ConcatenationOptions()
        concatenation_options.Init(op_options.Bytes, op_options.Pos)
        concatenation_axis = concatenation_options.Axis()
        fused_activation_fn = concatenation_options.FusedActivationFunction()

        # with axis in N H W C
        out = _op.concatenate(in_exprs, axis=concatenation_axis)

        # if we have activation fn
        if fused_activation_fn != ActivationFunctionType.NONE:
            out = self.convert_fused_activation_function(out, fused_activation_fn)
        return out
Пример #3
0
    def convert_concatenation(self, op):
        """ convert TFLite concatenation"""
        try:
            from tflite.Operator import Operator
            from tflite.ConcatenationOptions import ConcatenationOptions
            from tflite.BuiltinOptions import BuiltinOptions
            from tflite.ActivationFunctionType import ActivationFunctionType
        except ImportError:
            raise ImportError("The tflite package must be installed")

        assert isinstance(op, Operator)
        input_tensors = self.get_input_tensors(op)
        assert len(input_tensors) >= 1, "input tensors should greater than 1"
        in_exprs = [
            self.get_expr(input_tensor.tensor_idx)
            for input_tensor in input_tensors
        ]

        output_tensors = self.get_output_tensors(op)
        assert len(output_tensors) == 1, "output tensors should be 1"

        assert op.BuiltinOptionsType() == BuiltinOptions.ConcatenationOptions
        op_options = op.BuiltinOptions()
        concatenation_options = ConcatenationOptions()
        concatenation_options.Init(op_options.Bytes, op_options.Pos)
        concatenation_axis = concatenation_options.Axis()
        fused_activation_fn = concatenation_options.FusedActivationFunction()
        input_shape_length = len(input_tensors[0].tensor.ShapeAsNumpy())

        # TFLite is N H W C, our layout is N C H W
        if input_shape_length <= 4:
            axis_convert_map = [0] + list(range(2, input_shape_length)) + [1]
            concatenation_axis = axis_convert_map[concatenation_axis]
        else:
            raise NotImplementedError(
                "Not support input shape length {} of concatenatio : ".format(
                    str(input_shape_length)))

        # with axis in N H W C
        out = _op.concatenate(in_exprs, axis=concatenation_axis)

        # if we have activation fn
        if fused_activation_fn != ActivationFunctionType.NONE:
            out = self.convert_fused_activation_function(
                out, fused_activation_fn)
        return out