Beispiel #1
0
class Add(Layer):

  def __init__(self, op, op_type, tflite_interpreter):
      Layer.__init__(self, op, op_type, tflite_interpreter)

      self.tflite_add_parser = AddOptions()
      self.tflite_add_parser.Init(self.op.BuiltinOptions().Bytes, self.op.BuiltinOptions().Pos)

  def generate(self):
      node_output_detail = self.tflite_interpreter._get_tensor_details(self.op.Outputs(0))
      prev_node_names = self.input_nodes_name.copy()
      
      for input_idx in range(self.op.InputsLength()):

          node_input_detail = self.tflite_interpreter._get_tensor_details(self.op.Inputs(input_idx))
          if node_input_detail['shape'].size == 4:
              # it's a op node, do nothing
              continue

          prev_node_name = node_input_detail['name']

          # it's a parameter node, create constant node
          if node_input_detail['shape'].size == 0:
              # add all
              tensor_node = make_onnx_constant_number(self.tflite_interpreter, node_input_detail)
              self.weight_node_list.append(tensor_node)
          elif node_input_detail['shape'].size == 1:
              # channelwise add
              tensor_node = make_onnx_channelwise_constant_number(self.tflite_interpreter, node_input_detail, shape=self.node_output_shape)
              self.weight_node_list.append(tensor_node)

          if prev_node_name not in prev_node_names:
              prev_node_names.append(prev_node_name)
      

      add_node_name = self.node_name
      add_node = onnx.helper.make_node(
          'Add',
          inputs=prev_node_names,
          outputs=[add_node_name],
          name=add_node_name
      )

      # update tables
      self.node_list.append(add_node)

      # change output node's input_name
      for o_n in self.output_nodes:
          for idx, o_n_i_n in enumerate(o_n.input_nodes_name):
              if o_n_i_n == self.node_name:
                  o_n.input_nodes_name[idx] = self.node_list[-1].name

      return self.node_list, self.value_infos, self.weight_node_list, {}

  def defuse_activation_function(self):
      return defused_activation_node_generator(
          activation_function_type=self.tflite_add_parser.FusedActivationFunction(),
          op=self.op,
          tflite_interpreter=self.tflite_interpreter)
Beispiel #2
0
    def _convert_elemwise(self, relay_op, op):
        """Generic method to Convert TFLite elemwise"""
        try:
            from tflite.Operator import Operator
            from tflite.AddOptions import AddOptions
            from tflite.SubOptions import SubOptions
            from tflite.MulOptions import MulOptions
            from tflite.DivOptions import DivOptions
            from tflite.BuiltinOptions import BuiltinOptions
            from tflite.ActivationFunctionType import ActivationFunctionType
        except ImportError:
            raise ImportError("The tflite package must be installed")

        assert isinstance(op, Operator)
        input_tensors = self.get_input_tensors(op)
        assert len(input_tensors) == 2, "input tensors length should be 2"

        lhs_tensor = input_tensors[0]
        lhs_expr = self.get_expr(lhs_tensor.tensor_idx)

        rhs_tensor = input_tensors[1]
        if self.has_expr(rhs_tensor.tensor_idx):
            # In most cases, we can assume that TOCO fuses elemwise operators
            # with constants - it means both will be tensors.
            rhs_expr = self.get_expr(rhs_tensor.tensor_idx)
        else:
            # However, in some corner cases, the elemwise operator is not fused,
            # we can receive as constant.
            rhs_type_str = self.get_tensor_type_str(rhs_tensor.tensor.Type())
            rhs_expr = self.exp_tab.new_const(
                self.get_tensor_value(rhs_tensor), dtype=rhs_type_str)
        out = relay_op(lhs_expr, rhs_expr)

        # Options (fused_activation_function)
        options = None
        if op.BuiltinOptionsType() == BuiltinOptions.AddOptions:
            options = AddOptions()
        elif op.BuiltinOptionsType() == BuiltinOptions.SubOptions:
            options = SubOptions()
        elif op.BuiltinOptionsType() == BuiltinOptions.MulOptions:
            options = MulOptions()
        elif op.BuiltinOptionsType() == BuiltinOptions.DivOptions:
            options = DivOptions()

        if options is not None:
            op_options = op.BuiltinOptions()
            options.Init(op_options.Bytes, op_options.Pos)
            fused_activation_fn = options.FusedActivationFunction()
            # if we have activation fn
            if fused_activation_fn != ActivationFunctionType.NONE:
                out = self.convert_fused_activation_function(
                    out, fused_activation_fn)

        return out