def test_is_attr_legal_verbose(self): ATTR_FUNCTIONS = [ (lambda attr: setattr(attr, "f", 1.0)), (lambda attr: setattr(attr, "i", 1)), (lambda attr: setattr(attr, "s", b"str")), (lambda attr: attr.floats.extend([1.0, 2.0])), (lambda attr: attr.ints.extend([1, 2])), (lambda attr: attr.strings.extend([b"a", b"b"])), (lambda attr: attr.tensors.extend([TensorProto(), TensorProto()])), (lambda attr: attr.graphs.extend([GraphProto(), GraphProto()])), ] # Randomly set one field, and the result should be legal. for i in range(100): attr = AttributeProto() attr.name = "test" random.choice(ATTR_FUNCTIONS)(attr) self.assertTrue(helper.is_attribute_legal(attr)) # Randomly set two fields, and then ensure helper function catches it. for i in range(100): attr = AttributeProto() attr.name = "test" for func in random.sample(ATTR_FUNCTIONS, 2): func(attr) self.assertFalse(helper.is_attribute_legal(attr))
def make_tensor(name, data_type, dims, vals, raw=False): ''' Make a TensorProto with specified arguments. If raw is False, this function will choose the corresponding proto field to store the values based on data_type. If raw is True, use "raw_data" proto field to store the values, and values should be of type bytes in this case. ''' tensor = TensorProto() tensor.data_type = data_type tensor.name = name if data_type == TensorProto.STRING: assert not raw, "Can not use raw_data to store string type" tensor.string_data.extend(vals) if (data_type == TensorProto.COMPLEX64 or data_type == TensorProto.COMPLEX128): vals = split_complex_to_pairs(vals) if raw: tensor.raw_data = vals else: field = mapping.STORAGE_TENSOR_TYPE_TO_FIELD[ mapping.TENSOR_TYPE_TO_STORAGE_TENSOR_TYPE[data_type]] getattr(tensor, field).extend(vals) tensor.dims.extend(dims) return tensor
def make_tensor(name, data_type, dims, vals, raw=False): ''' Make a TensorProto with specified arguments. If raw is False, this function will choose the corresponding proto field to store the values based on data_type. If raw is True, use "raw_data" proto field to store the values, and values should be of type bytes in this case. ''' tensor = TensorProto() tensor.data_type = data_type tensor.name = name if data_type == TensorProto.STRING: assert not raw, "Can not use raw_data to store string type" tensor.string_data.extend(vals) elif data_type in [ TensorProto.UINT8, TensorProto.INT8, TensorProto.UINT16, TensorProto.INT16, TensorProto.INT32, TensorProto.FLOAT16, TensorProto.BOOL, TensorProto.FLOAT ]: if raw: tensor.raw_data = vals else: if data_type == TensorProto.FLOAT: tensor.float_data.extend(vals) elif data_type == TensorProto.INT64: tensor.int64_data.extend(vals) else: tensor.int32_data.extend(vals) else: raise RuntimeError('Unrecognized data_type: {}'.format(data_type)) tensor.dims.extend(dims) return tensor
def from_array(arr, name=None): """Converts a numpy array to a tensor def. Inputs: arr: a numpy array. name: (optional) the name of the tensor. Returns: tensor_def: the converted tensor def. """ tensor = TensorProto() tensor.dims.extend(arr.shape) if name: tensor.name = name if arr.dtype == np.object: # Special care for strings. raise NotImplementedError("Need to properly implement string.") # For numerical types, directly use numpy raw bytes. try: dtype = mapping.NP_TYPE_TO_TENSOR_TYPE[arr.dtype] except KeyError: raise RuntimeError("Numpy data type not understood yet: {}".format( str(arr.dtype))) tensor.data_type = dtype tensor.raw_data = arr.tobytes() # note: tobytes() is only after 1.9. return tensor
def test_attr_repeated_tensor_proto(self): tensors = [TensorProto(), TensorProto()] tensors[0].name = "a" tensors[1].name = "b" attr = helper.make_attribute("tensors", tensors) self.assertEqual(attr.name, "tensors") self.assertEqual(list(attr.tensors), tensors) self.assertTrue(helper.is_attribute_legal(attr))
def test_check_string_tensor(self): tensor = TensorProto() tensor.data_type = TensorProto.STRING tensor.string_data.append('Test'.encode('utf-8')) checker.check_tensor(tensor) del tensor.string_data[:] tensor.raw_data = 'Test'.encode('utf-8') # string data should not be stored in raw_data field self.assertRaises(ValueError, checker.check_tensor, tensor)
def tfwrite_tensor(ox: onnx_pb2.TensorProto, tens: tf.Tensor): data = tens.numpy() shape = data.shape if data.dtype == np.int8: dtype = onnx_pb2.TensorProto.DataType.INT8 elif data.dtype == np.uint8: dtype = onnx_pb2.TensorProto.DataType.UINT8 elif data.dtype == np.int16: dtype = onnx_pb2.TensorProto.DataType.INT16 elif data.dtype == np.uint16: dtype = onnx_pb2.TensorProto.DataType.UINT16 elif data.dtype == np.int32: dtype = onnx_pb2.TensorProto.DataType.INT32 elif data.dtype == np.uint32: dtype = onnx_pb2.TensorProto.DataType.UINT32 elif data.dtype == np.int64: dtype = onnx_pb2.TensorProto.DataType.INT64 elif data.dtype == np.uint64: dtype = onnx_pb2.TensorProto.DataType.UINT64 elif data.dtype == np.single: dtype = onnx_pb2.TensorProto.DataType.FLOAT16 elif data.dtype == np.double: dtype = onnx_pb2.TensorProto.DataType.DOUBLE else: dtype = onnx_pb2.TensorProto.DataType.UNDEFINED ox.dims.extend(shape) ox.data_type = dtype if dtype in (onnx_pb2.TensorProto.DataType.INT8, onnx_pb2.TensorProto.DataType.UINT8, onnx_pb2.TensorProto.DataType.INT16, onnx_pb2.TensorProto.DataType.UINT16, onnx_pb2.TensorProto.DataType.INT32): ox.int32_data.extend(data.flatten()) elif dtype == onnx_pb2.TensorProto.DataType.INT64: ox.int64_data.extend(data.flatten()) elif dtype in (onnx_pb2.TensorProto.DataType.UINT32, onnx_pb2.TensorProto.DataType.UINT64): ox.uint64_data.extend(data.flatten()) elif dtype in (onnx_pb2.TensorProto.DataType.FLOAT16, onnx_pb2.TensorProto.DataType.DOUBLE): ox.float_data.extend(data.flatten()) else: raw = data.flatten() if isinstance(raw[0], bytes) or isinstance(raw[0], str): ox.data_type = onnx_pb2.TensorProto.DataType.STRING ox.string_data.extend(raw) else: assert False, 'failed to serialize unknown datatype {}: {}...'.format( type(raw[0]), raw[0])
def make_tensor(name, data_type, dims, vals): tensor = TensorProto() tensor.data_type = data_type tensor.name = name if data_type == TensorProto.FLOAT: tensor.float_data.extend(vals) elif data_type in [ TensorProto.UINT8, TensorProto.INT8, TensorProto.UINT16, TensorProto.INT16, TensorProto.INT32, TensorProto.FLOAT16, TensorProto.BOOL ]: tensor.int32_data.extend(vals) elif data_type == TensorProto.INT64: tensor.int64_data.extend(vals) elif data_type == TensorProto.STRING: tensor.string_data.extend(vals) tensor.dims.extend(dims) return tensor
def from_array(arr, name=None): """Converts a numpy array to a tensor def. Inputs: arr: a numpy array. name: (optional) the name of the tensor. Returns: tensor_def: the converted tensor def. """ tensor = TensorProto() tensor.dims.extend(arr.shape) if name: tensor.name = name dtype_map = { np.dtype("float32"): TensorProto.FLOAT, np.dtype("uint8"): TensorProto.UINT8, np.dtype("int8"): TensorProto.INT8, np.dtype("uint16"): TensorProto.UINT16, np.dtype("int16"): TensorProto.INT16, np.dtype("int32"): TensorProto.INT32, np.dtype("int64"): TensorProto.INT64, np.dtype("bool"): TensorProto.BOOL, np.dtype("float16"): TensorProto.FLOAT16, } if arr.dtype == np.object: # Special care for strings. raise NotImplementedError("Need to properly implement string.") # For numerical types, directly use numpy raw bytes. try: dtype = dtype_map[arr.dtype] except KeyError: raise RuntimeError( "Numpy data type not understood yet: {}".format(str(arr.dtype))) tensor.data_type = dtype tensor.raw_data = arr.tobytes() # note: tobytes() is only after 1.9. return tensor