Ejemplo n.º 1
0
    def add_shape_tensor_from_axis_arg(self, op):
        list_value_arg = ConverterUtil.get_arg(op, MaceKeyword.mace_axis_str)
        mace_check(
            list_value_arg.ints is not None, op.name + ': ' +
            MaceKeyword.mace_axis_str + ' value ints should not be None')
        axes = list_value_arg.ints
        for producer in self._model.op:
            if producer.output[0] == op.input[0]:
                input_tensor_shape = producer.output_shape[0].dims
                break

        shape_tensor = self._model.tensors.add()
        shape_tensor.name = op.name + '/' + MaceKeyword.mace_axis_str + ':0'
        shape_tensor.data_type = mace_pb2.DT_INT32
        shape_tensor.dims.extend([len(input_tensor_shape) - len(axes)])
        shape_tensor.int32_data.extend(input_tensor_shape)
        for axis in sorted(axes, reverse=True):
            del shape_tensor.int32_data[axis]
        op.input.extend([shape_tensor.name])
        ConverterUtil.del_arg(op, MaceKeyword.mace_axis_str)
Ejemplo n.º 2
0
    def ensure_binary_input(self):
        for _op in self._model.op:
            if _op.type != MaceOp.Eltwise.name:
                continue
            if len(_op.input) != 1:
                continue
            eltwise_type = ConverterUtil.get_arg(
                _op, MaceKeyword.mace_element_type_str).i
            if eltwise_type != EltwiseType.SUM.value and \
               eltwise_type != EltwiseType.PROD.value:
                continue

            float_value_arg = ConverterUtil.get_arg(
                _op, MaceKeyword.mace_scalar_input_str)
            mace_check(
                float_value_arg.f is not None,
                _op.name + ': ' + MaceKeyword.mace_scalar_input_str +
                ' value float should not be None')
            scalar = float_value_arg.f
            const_tensor = self._model.tensors.add()
            const_tensor.name = _op.name + '/' + \
                MaceKeyword.mace_scalar_input_str + ':0'
            const_tensor.dims.extend([1])
            if _op.output_type[0] == mace_pb2.DT_UINT8 or \
                    _op.output_type[0] == mace_pb2.DT_INT16:
                const_tensor.data_type = _op.output_type[0]
                const_tensor.scale = scalar
                const_tensor.zero_point = 0
                const_tensor.quantized = True
                const_tensor.int32_data.extend([1])
            elif _op.output_type[0] == mace_pb2.DT_FLOAT:
                const_tensor.data_type = mace_pb2.DT_FLOAT
                const_tensor.float_data.extend([scalar])
            _op.input.extend([const_tensor.name])
            ConverterUtil.del_arg(_op, MaceKeyword.mace_scalar_input_str)
            ConverterUtil.del_arg(_op, MaceKeyword.mace_scalar_input_index_str)