def add_shape_tensor_from_axis_arg(self, op): list_value_arg = ConverterUtil.get_arg(op, MaceKeyword.mace_axis_str) mace_check( list_value_arg.ints is not None, op.name + ': ' + MaceKeyword.mace_axis_str + ' value ints should not be None') axes = list_value_arg.ints for producer in self._model.op: if producer.output[0] == op.input[0]: input_tensor_shape = producer.output_shape[0].dims break shape_tensor = self._model.tensors.add() shape_tensor.name = op.name + '/' + MaceKeyword.mace_axis_str + ':0' shape_tensor.data_type = mace_pb2.DT_INT32 shape_tensor.dims.extend([len(input_tensor_shape) - len(axes)]) shape_tensor.int32_data.extend(input_tensor_shape) for axis in sorted(axes, reverse=True): del shape_tensor.int32_data[axis] op.input.extend([shape_tensor.name]) ConverterUtil.del_arg(op, MaceKeyword.mace_axis_str)
def ensure_binary_input(self): for _op in self._model.op: if _op.type != MaceOp.Eltwise.name: continue if len(_op.input) != 1: continue eltwise_type = ConverterUtil.get_arg( _op, MaceKeyword.mace_element_type_str).i if eltwise_type != EltwiseType.SUM.value and \ eltwise_type != EltwiseType.PROD.value: continue float_value_arg = ConverterUtil.get_arg( _op, MaceKeyword.mace_scalar_input_str) mace_check( float_value_arg.f is not None, _op.name + ': ' + MaceKeyword.mace_scalar_input_str + ' value float should not be None') scalar = float_value_arg.f const_tensor = self._model.tensors.add() const_tensor.name = _op.name + '/' + \ MaceKeyword.mace_scalar_input_str + ':0' const_tensor.dims.extend([1]) if _op.output_type[0] == mace_pb2.DT_UINT8 or \ _op.output_type[0] == mace_pb2.DT_INT16: const_tensor.data_type = _op.output_type[0] const_tensor.scale = scalar const_tensor.zero_point = 0 const_tensor.quantized = True const_tensor.int32_data.extend([1]) elif _op.output_type[0] == mace_pb2.DT_FLOAT: const_tensor.data_type = mace_pb2.DT_FLOAT const_tensor.float_data.extend([scalar]) _op.input.extend([const_tensor.name]) ConverterUtil.del_arg(_op, MaceKeyword.mace_scalar_input_str) ConverterUtil.del_arg(_op, MaceKeyword.mace_scalar_input_index_str)