예제 #1
0
class Mul(Layer):

  def __init__(self, op, op_type, tflite_interpreter):
      Layer.__init__(self, op, op_type, tflite_interpreter)

      self.tflite_mul_parser = MulOptions()
      self.tflite_mul_parser.Init(self.op.BuiltinOptions().Bytes, self.op.BuiltinOptions().Pos)

  def generate(self):
      prev_node_names = self.input_nodes_name.copy()
      for input_idx in range(self.op.InputsLength()):
          node_input_detail = self.tflite_interpreter._get_tensor_details(self.op.Inputs(input_idx))

          if node_input_detail['shape'].size == 4:
              # it's a op node, do nothing
              continue

          prev_node_name = node_input_detail['name']

          # create constant node
          if node_input_detail['shape'].size == 0:
              # mul all
              tensor_node = make_onnx_constant_number(self.tflite_interpreter, node_input_detail)
              self.weight_node_list.append(tensor_node)
          elif node_input_detail['shape'].size == 1:
              # channelwise mul
              tensor_node = make_onnx_channelwise_constant_number(self.tflite_interpreter, node_input_detail, shape=self.node_output_shape)
              self.weight_node_list.append(tensor_node)

          if prev_node_name not in prev_node_names:
              prev_node_names.append(prev_node_name)

      mul_node_name = self.node_name
      mul_node = onnx.helper.make_node(
          'Mul',
          inputs=prev_node_names,
          outputs=[mul_node_name],
          name=mul_node_name
      )

      # update tables
      self.node_list.append(mul_node)

      return self.node_list, self.value_infos, self.weight_node_list, {}

  def defuse_activation_function(self):
      return defused_activation_node_generator(
          activation_function_type=self.tflite_mul_parser.FusedActivationFunction(),
          op=self.op,
          tflite_interpreter=self.tflite_interpreter)
예제 #2
0
    def _convert_elemwise(self, relay_op, op):
        """Generic method to Convert TFLite elemwise"""
        try:
            from tflite.Operator import Operator
            from tflite.AddOptions import AddOptions
            from tflite.SubOptions import SubOptions
            from tflite.MulOptions import MulOptions
            from tflite.DivOptions import DivOptions
            from tflite.BuiltinOptions import BuiltinOptions
            from tflite.ActivationFunctionType import ActivationFunctionType
        except ImportError:
            raise ImportError("The tflite package must be installed")

        assert isinstance(op, Operator)
        input_tensors = self.get_input_tensors(op)
        assert len(input_tensors) == 2, "input tensors length should be 2"

        lhs_tensor = input_tensors[0]
        lhs_expr = self.get_expr(lhs_tensor.tensor_idx)

        rhs_tensor = input_tensors[1]
        if self.has_expr(rhs_tensor.tensor_idx):
            # In most cases, we can assume that TOCO fuses elemwise operators
            # with constants - it means both will be tensors.
            rhs_expr = self.get_expr(rhs_tensor.tensor_idx)
        else:
            # However, in some corner cases, the elemwise operator is not fused,
            # we can receive as constant.
            rhs_type_str = self.get_tensor_type_str(rhs_tensor.tensor.Type())
            rhs_expr = self.exp_tab.new_const(
                self.get_tensor_value(rhs_tensor), dtype=rhs_type_str)
        out = relay_op(lhs_expr, rhs_expr)

        # Options (fused_activation_function)
        options = None
        if op.BuiltinOptionsType() == BuiltinOptions.AddOptions:
            options = AddOptions()
        elif op.BuiltinOptionsType() == BuiltinOptions.SubOptions:
            options = SubOptions()
        elif op.BuiltinOptionsType() == BuiltinOptions.MulOptions:
            options = MulOptions()
        elif op.BuiltinOptionsType() == BuiltinOptions.DivOptions:
            options = DivOptions()

        if options is not None:
            op_options = op.BuiltinOptions()
            options.Init(op_options.Bytes, op_options.Pos)
            fused_activation_fn = options.FusedActivationFunction()
            # if we have activation fn
            if fused_activation_fn != ActivationFunctionType.NONE:
                out = self.convert_fused_activation_function(
                    out, fused_activation_fn)

        return out
예제 #3
0
    def __init__(self, op, op_type, tflite_interpreter):
        Layer.__init__(self, op, op_type, tflite_interpreter)

        self.tflite_mul_parser = MulOptions()
        self.tflite_mul_parser.Init(self.op.BuiltinOptions().Bytes,
                                    self.op.BuiltinOptions().Pos)