コード例 #1
0
ファイル: tfgen_dataset.py プロジェクト: mingkaic/onnxds
def tfwrite_tensor(ox: onnx_pb2.TensorProto, tens: tf.Tensor):
    data = tens.numpy()
    shape = data.shape

    if data.dtype == np.int8:
        dtype = onnx_pb2.TensorProto.DataType.INT8
    elif data.dtype == np.uint8:
        dtype = onnx_pb2.TensorProto.DataType.UINT8
    elif data.dtype == np.int16:
        dtype = onnx_pb2.TensorProto.DataType.INT16
    elif data.dtype == np.uint16:
        dtype = onnx_pb2.TensorProto.DataType.UINT16
    elif data.dtype == np.int32:
        dtype = onnx_pb2.TensorProto.DataType.INT32
    elif data.dtype == np.uint32:
        dtype = onnx_pb2.TensorProto.DataType.UINT32
    elif data.dtype == np.int64:
        dtype = onnx_pb2.TensorProto.DataType.INT64
    elif data.dtype == np.uint64:
        dtype = onnx_pb2.TensorProto.DataType.UINT64
    elif data.dtype == np.single:
        dtype = onnx_pb2.TensorProto.DataType.FLOAT16
    elif data.dtype == np.double:
        dtype = onnx_pb2.TensorProto.DataType.DOUBLE
    else:
        dtype = onnx_pb2.TensorProto.DataType.UNDEFINED
    ox.dims.extend(shape)
    ox.data_type = dtype
    if dtype in (onnx_pb2.TensorProto.DataType.INT8,
                 onnx_pb2.TensorProto.DataType.UINT8,
                 onnx_pb2.TensorProto.DataType.INT16,
                 onnx_pb2.TensorProto.DataType.UINT16,
                 onnx_pb2.TensorProto.DataType.INT32):
        ox.int32_data.extend(data.flatten())
    elif dtype == onnx_pb2.TensorProto.DataType.INT64:
        ox.int64_data.extend(data.flatten())
    elif dtype in (onnx_pb2.TensorProto.DataType.UINT32,
                   onnx_pb2.TensorProto.DataType.UINT64):
        ox.uint64_data.extend(data.flatten())
    elif dtype in (onnx_pb2.TensorProto.DataType.FLOAT16,
                   onnx_pb2.TensorProto.DataType.DOUBLE):
        ox.float_data.extend(data.flatten())
    else:
        raw = data.flatten()
        if isinstance(raw[0], bytes) or isinstance(raw[0], str):
            ox.data_type = onnx_pb2.TensorProto.DataType.STRING
            ox.string_data.extend(raw)
        else:
            assert False, 'failed to serialize unknown datatype {}: {}...'.format(
                type(raw[0]), raw[0])
コード例 #2
0
ファイル: helper.py プロジェクト: SeanHsieh/onnx
def make_tensor(name, data_type, dims, vals, raw=False):
    '''
    Make a TensorProto with specified arguments.  If raw is False, this
    function will choose the corresponding proto field to store the
    values based on data_type. If raw is True, use "raw_data" proto
    field to store the values, and values should be of type bytes in
    this case.
    '''
    tensor = TensorProto()
    tensor.data_type = data_type
    tensor.name = name

    if data_type == TensorProto.STRING:
        assert not raw, "Can not use raw_data to store string type"
        tensor.string_data.extend(vals)

    if (data_type == TensorProto.COMPLEX64
            or data_type == TensorProto.COMPLEX128):
        vals = split_complex_to_pairs(vals)
    if raw:
        tensor.raw_data = vals
    else:
        field = mapping.STORAGE_TENSOR_TYPE_TO_FIELD[
            mapping.TENSOR_TYPE_TO_STORAGE_TENSOR_TYPE[data_type]]
        getattr(tensor, field).extend(vals)

    tensor.dims.extend(dims)
    return tensor
コード例 #3
0
def make_tensor(name, data_type, dims, vals, raw=False):
    '''
    Make a TensorProto with specified arguments.  If raw is False, this
    function will choose the corresponding proto field to store the
    values based on data_type. If raw is True, use "raw_data" proto
    field to store the values, and values should be of type bytes in
    this case.
    '''
    tensor = TensorProto()
    tensor.data_type = data_type
    tensor.name = name

    if data_type == TensorProto.STRING:
        assert not raw, "Can not use raw_data to store string type"
        tensor.string_data.extend(vals)
    elif data_type in [
            TensorProto.UINT8, TensorProto.INT8, TensorProto.UINT16,
            TensorProto.INT16, TensorProto.INT32, TensorProto.FLOAT16,
            TensorProto.BOOL, TensorProto.FLOAT
    ]:
        if raw:
            tensor.raw_data = vals
        else:
            if data_type == TensorProto.FLOAT:
                tensor.float_data.extend(vals)
            elif data_type == TensorProto.INT64:
                tensor.int64_data.extend(vals)
            else:
                tensor.int32_data.extend(vals)
    else:
        raise RuntimeError('Unrecognized data_type: {}'.format(data_type))
    tensor.dims.extend(dims)
    return tensor
コード例 #4
0
def from_array(arr, name=None):
    """Converts a numpy array to a tensor def.

    Inputs:
        arr: a numpy array.
        name: (optional) the name of the tensor.
    Returns:
        tensor_def: the converted tensor def.
    """
    tensor = TensorProto()
    tensor.dims.extend(arr.shape)
    if name:
        tensor.name = name

    if arr.dtype == np.object:
        # Special care for strings.
        raise NotImplementedError("Need to properly implement string.")
    # For numerical types, directly use numpy raw bytes.
    try:
        dtype = mapping.NP_TYPE_TO_TENSOR_TYPE[arr.dtype]
    except KeyError:
        raise RuntimeError("Numpy data type not understood yet: {}".format(
            str(arr.dtype)))
    tensor.data_type = dtype
    tensor.raw_data = arr.tobytes()  # note: tobytes() is only after 1.9.

    return tensor
コード例 #5
0
    def test_check_string_tensor(self):
        tensor = TensorProto()
        tensor.data_type = TensorProto.STRING
        tensor.string_data.append('Test'.encode('utf-8'))
        checker.check_tensor(tensor)

        del tensor.string_data[:]
        tensor.raw_data = 'Test'.encode('utf-8')
        # string data should not be stored in raw_data field
        self.assertRaises(ValueError, checker.check_tensor, tensor)
コード例 #6
0
ファイル: helper.py プロジェクト: zgsxwsdxg/onnx
def make_tensor(name, data_type, dims, vals):
    tensor = TensorProto()
    tensor.data_type = data_type
    tensor.name = name
    if data_type == TensorProto.FLOAT:
        tensor.float_data.extend(vals)
    elif data_type in [
            TensorProto.UINT8, TensorProto.INT8, TensorProto.UINT16,
            TensorProto.INT16, TensorProto.INT32, TensorProto.FLOAT16,
            TensorProto.BOOL
    ]:
        tensor.int32_data.extend(vals)
    elif data_type == TensorProto.INT64:
        tensor.int64_data.extend(vals)
    elif data_type == TensorProto.STRING:
        tensor.string_data.extend(vals)
    tensor.dims.extend(dims)
    return tensor
コード例 #7
0
ファイル: numpy_helper.py プロジェクト: zgsxwsdxg/onnx
def from_array(arr, name=None):
    """Converts a numpy array to a tensor def.

    Inputs:
        arr: a numpy array.
        name: (optional) the name of the tensor.
    Returns:
        tensor_def: the converted tensor def.
    """
    tensor = TensorProto()
    tensor.dims.extend(arr.shape)
    if name:
        tensor.name = name

    dtype_map = {
        np.dtype("float32"): TensorProto.FLOAT,
        np.dtype("uint8"): TensorProto.UINT8,
        np.dtype("int8"): TensorProto.INT8,
        np.dtype("uint16"): TensorProto.UINT16,
        np.dtype("int16"): TensorProto.INT16,
        np.dtype("int32"): TensorProto.INT32,
        np.dtype("int64"): TensorProto.INT64,
        np.dtype("bool"): TensorProto.BOOL,
        np.dtype("float16"): TensorProto.FLOAT16,
    }

    if arr.dtype == np.object:
        # Special care for strings.
        raise NotImplementedError("Need to properly implement string.")
    # For numerical types, directly use numpy raw bytes.
    try:
        dtype = dtype_map[arr.dtype]
    except KeyError:
        raise RuntimeError(
            "Numpy data type not understood yet: {}".format(str(arr.dtype)))
    tensor.data_type = dtype
    tensor.raw_data = arr.tobytes()  # note: tobytes() is only after 1.9.
    
    return tensor