Exemplo n.º 1
0
def _guess_type_proto_str(data_type, dims):
    # This could be moved to onnxconverter_common.
    if data_type == "tensor(float)":
        return FloatTensorType(dims)
    if data_type == "tensor(double)":
        return DoubleTensorType(dims)
    if data_type == "tensor(string)":
        return StringTensorType(dims)
    if data_type == "tensor(int64)":
        return Int64TensorType(dims)
    if data_type == "tensor(int32)":
        return Int32TensorType(dims)
    if data_type == "tensor(bool)":
        return BooleanTensorType(dims)
    if data_type == "tensor(int8)":
        return Int8TensorType(dims)
    if data_type == "tensor(uint8)":
        return UInt8TensorType(dims)
    if Complex64TensorType is not None:
        if data_type == "tensor(complex64)":
            return Complex64TensorType(dims)
        if data_type == "tensor(complex128)":
            return Complex128TensorType(dims)
    raise NotImplementedError(
        "Unsupported data_type '{}'. You may raise an issue "
        "at https://github.com/onnx/sklearn-onnx/issues."
        "".format(data_type))
Exemplo n.º 2
0
def _guess_type_proto(data_type, dims):
    # This could be moved to onnxconverter_common.
    for d in dims:
        if d == 0:
            raise RuntimeError("Dimension should not be null: {}.".format(
                list(dims)))
    if data_type == onnx_proto.TensorProto.FLOAT:
        return FloatTensorType(dims)
    if data_type == onnx_proto.TensorProto.DOUBLE:
        return DoubleTensorType(dims)
    if data_type == onnx_proto.TensorProto.STRING:
        return StringTensorType(dims)
    if data_type == onnx_proto.TensorProto.INT64:
        return Int64TensorType(dims)
    if data_type == onnx_proto.TensorProto.INT32:
        return Int32TensorType(dims)
    if data_type == onnx_proto.TensorProto.BOOL:
        return BooleanTensorType(dims)
    if data_type == onnx_proto.TensorProto.INT8:
        return Int8TensorType(dims)
    if data_type == onnx_proto.TensorProto.UINT8:
        return UInt8TensorType(dims)
    if Complex64TensorType is not None:
        if data_type == onnx_proto.TensorProto.COMPLEX64:
            return Complex64TensorType(dims)
        if data_type == onnx_proto.TensorProto.COMPLEX128:
            return Complex128TensorType(dims)
    raise NotImplementedError(
        "Unsupported data_type '{}'. You may raise an issue "
        "at https://github.com/onnx/sklearn-onnx/issues."
        "".format(data_type))
Exemplo n.º 3
0
def _guess_type_proto(data_type, dims):
    # This could be moved to onnxconverter_common.
    if data_type == onnx_proto.TensorProto.FLOAT:
        return FloatTensorType(dims)
    elif data_type == onnx_proto.TensorProto.DOUBLE:
        return DoubleTensorType(dims)
    elif data_type == onnx_proto.TensorProto.STRING:
        return StringTensorType(dims)
    elif data_type == onnx_proto.TensorProto.INT64:
        return Int64TensorType(dims)
    elif data_type == onnx_proto.TensorProto.INT32:
        return Int32TensorType(dims)
    elif data_type == onnx_proto.TensorProto.BOOL:
        return BooleanTensorType(dims)
    else:
        raise NotImplementedError(
            "Unsupported data_type '{}'. You may raise an issue "
            "at https://github.com/onnx/sklearn-onnx/issues."
            "".format(data_type))
Exemplo n.º 4
0
def _guess_numpy_type(data_type, dims):
    # This could be moved to onnxconverter_common.
    if data_type == np.float32:
        return FloatTensorType(dims)
    elif data_type in (np.str, str,
                       object) or str(data_type) in ('<U1', ):  # noqa
        return StringTensorType(dims)
    elif data_type in (np.int64, np.uint64) or str(data_type) == '<U6':
        return Int64TensorType(dims)
    elif data_type in (np.int32,
                       np.uint32) or str(data_type) in ('<U4', ):  # noqa
        return Int32TensorType(dims)
    elif data_type == np.bool:
        return BooleanTensorType(dims)
    else:
        raise NotImplementedError(
            "Unsupported data_type '{}'. You may raise an issue "
            "at https://github.com/onnx/sklearn-onnx/issues."
            "".format(data_type))
Exemplo n.º 5
0
def _guess_numpy_type(data_type, dims):
    # This could be moved to onnxconverter_common.
    if data_type == np.float32:
        return FloatTensorType(dims)
    if data_type == np.float64:
        return DoubleTensorType(dims)
    if data_type in (np.str_, str, object) or str(data_type) in ('<U1', ) or (
            hasattr(data_type, 'type') and data_type.type is np.str_):  # noqa
        return StringTensorType(dims)
    if data_type in (np.int64, ) or str(data_type) == '<U6':
        return Int64TensorType(dims)
    if data_type in (np.int32, ) or str(data_type) in ('<U4', ):  # noqa
        return Int32TensorType(dims)
    if data_type == np.uint8:
        return UInt8TensorType(dims)
    if data_type in (np.bool_, bool):
        return BooleanTensorType(dims)
    if data_type in (np.str_, str):
        return StringTensorType(dims)
    if data_type == np.int8:
        return Int8TensorType(dims)
    if data_type == np.int16:
        return Int16TensorType(dims)
    if data_type == np.uint64:
        return UInt64TensorType(dims)
    if data_type == np.uint32:
        return UInt32TensorType(dims)
    if data_type == np.uint16:
        return UInt16TensorType(dims)
    if data_type == np.float16:
        return Float16TensorType(dims)
    if Complex64TensorType is not None:
        if data_type == np.complex64:
            return Complex64TensorType(dims)
        if data_type == np.complex128:
            return Complex128TensorType(dims)
    raise NotImplementedError(
        "Unsupported data_type %r (type=%r). You may raise an issue "
        "at https://github.com/onnx/sklearn-onnx/issues."
        "" % (data_type, type(data_type)))
Exemplo n.º 6
0
    def from_pb(obj):
        """
        Creates a data type from a protobuf object.
        """
        def get_shape(tt):
            return [
                tt.shape.dim[i].dim_value for i in range(len(tt.shape.dim))
            ]

        if hasattr(obj, 'extend'):
            return [Variable.from_pb(o) for o in obj]
        name = obj.name
        if obj.type.tensor_type:
            tt = obj.type.tensor_type
            elem = tt.elem_type
            shape = get_shape(tt)
            if elem == onnx_proto.TensorProto.FLOAT:
                ty = FloatTensorType(shape)
            elif elem == onnx_proto.TensorProto.BOOL:
                ty = BooleanTensorType(shape)
            elif elem == onnx_proto.TensorProto.DOUBLE:
                ty = DoubleTensorType(shape)
            elif elem == onnx_proto.TensorProto.STRING:
                ty = StringTensorType(shape)
            elif elem == onnx_proto.TensorProto.INT64:
                ty = Int64TensorType(shape)
            elif elem == onnx_proto.TensorProto.INT32:
                ty = Int32TensorType(shape)
            else:
                raise NotImplementedError("Unsupported type '{}' "
                                          "(elem_type={}).".format(
                                              type(obj.type.tensor_type),
                                              elem))
        else:
            raise NotImplementedError("Unsupported type '{}' as "
                                      "a string ({}).".format(type(obj), obj))

        return Variable(name, name, None, ty)