Esempio n. 1
0
class Concatenation(Layer):

  def __init__(self, op, op_type, tflite_interpreter):
      Layer.__init__(self, op, op_type, tflite_interpreter)

      self.tflite_concat_parser = ConcatenationOptions()
      self.tflite_concat_parser.Init(self.op.BuiltinOptions().Bytes, self.op.BuiltinOptions().Pos)

  def generate(self):

      prev_node_names = self.input_nodes_name.copy()

      concat_node_name = self.node_name
      concat_node = onnx.helper.make_node(
          'Concat',
          inputs=prev_node_names,
          outputs=[concat_node_name],
          axis= tflite_utils.channel_last_2_channel_first_axis_mapping( [self.tflite_concat_parser.Axis()] )[0],
          name=concat_node_name
      )

      # update tables
      self.node_list.append(concat_node)

      return self.node_list, self.value_infos, self.weight_node_list, {}

  def defuse_activation_function(self):
      return defused_activation_node_generator(
          activation_function_type=self.tflite_concat_parser.FusedActivationFunction(),
          op=self.op,
          tflite_interpreter=self.tflite_interpreter)
Esempio n. 2
0
    def convert_concatenation(self, op):
        """Convert TFLite concatenation"""
        try:
            from tflite.Operator import Operator
            from tflite.ConcatenationOptions import ConcatenationOptions
            from tflite.BuiltinOptions import BuiltinOptions
            from tflite.ActivationFunctionType import ActivationFunctionType
        except ImportError:
            raise ImportError("The tflite package must be installed")

        assert isinstance(op, Operator)
        input_tensors = self.get_input_tensors(op)
        assert len(input_tensors) >= 1, "input tensors should greater than 1"
        in_exprs = [self.get_expr(input_tensor.tensor_idx) for input_tensor in input_tensors]

        output_tensors = self.get_output_tensors(op)
        assert len(output_tensors) == 1, "output tensors should be 1"

        assert op.BuiltinOptionsType() == BuiltinOptions.ConcatenationOptions
        op_options = op.BuiltinOptions()
        concatenation_options = ConcatenationOptions()
        concatenation_options.Init(op_options.Bytes, op_options.Pos)
        concatenation_axis = concatenation_options.Axis()
        fused_activation_fn = concatenation_options.FusedActivationFunction()

        # with axis in N H W C
        out = _op.concatenate(in_exprs, axis=concatenation_axis)

        # if we have activation fn
        if fused_activation_fn != ActivationFunctionType.NONE:
            out = self.convert_fused_activation_function(out, fused_activation_fn)
        return out
Esempio n. 3
0
    def convert_concatenation(self, op):
        """ convert TFLite concatenation"""
        try:
            from tflite.Operator import Operator
            from tflite.ConcatenationOptions import ConcatenationOptions
            from tflite.BuiltinOptions import BuiltinOptions
            from tflite.ActivationFunctionType import ActivationFunctionType
        except ImportError:
            raise ImportError("The tflite package must be installed")

        assert isinstance(op, Operator)
        input_tensors = self.get_input_tensors(op)
        assert len(input_tensors) >= 1, "input tensors should greater than 1"
        in_exprs = [
            self.get_expr(input_tensor.tensor_idx)
            for input_tensor in input_tensors
        ]

        output_tensors = self.get_output_tensors(op)
        assert len(output_tensors) == 1, "output tensors should be 1"

        assert op.BuiltinOptionsType() == BuiltinOptions.ConcatenationOptions
        op_options = op.BuiltinOptions()
        concatenation_options = ConcatenationOptions()
        concatenation_options.Init(op_options.Bytes, op_options.Pos)
        concatenation_axis = concatenation_options.Axis()
        fused_activation_fn = concatenation_options.FusedActivationFunction()
        input_shape_length = len(input_tensors[0].tensor.ShapeAsNumpy())

        # TFLite is N H W C, our layout is N C H W
        if input_shape_length <= 4:
            axis_convert_map = [0] + list(range(2, input_shape_length)) + [1]
            concatenation_axis = axis_convert_map[concatenation_axis]
        else:
            raise NotImplementedError(
                "Not support input shape length {} of concatenatio : ".format(
                    str(input_shape_length)))

        # with axis in N H W C
        out = _op.concatenate(in_exprs, axis=concatenation_axis)

        # if we have activation fn
        if fused_activation_fn != ActivationFunctionType.NONE:
            out = self.convert_fused_activation_function(
                out, fused_activation_fn)
        return out