def transform_single(file): tf_filename = os.path.join(tf_dir, file) filenames = [tf_filename] raw_dataset = tf.data.TFRecordDataset(filenames) image_feature_description = { "image/encoded": tf.io.FixedLenFeature([], tf.string), # tf.string means bytestring "image/class/label": tf.io.FixedLenFeature([], tf.int64), # shape [] means single element } def _parse_image_function(example_proto): # parse the input tf.Example proto using the dictionary above. return tf.io.parse_single_example(example_proto, image_feature_description) parsed_image_dataset = raw_dataset.map(_parse_image_function) def _bytes_feature(value): if isinstance(value, type(tf.constant(0))): value = value.numpy() # BytesList won't unpack a string from an EagerTensor. return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) def _float_feature(value): return tf.train.Feature(float_list=tf.train.FloatList(value=[value])) def _int64_feature(value): return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) def image_example(image_string, label): label = label.numpy().astype(int) feature = { 'label': _int64_feature(label), 'image': _bytes_feature(image_string), } return tf.train.Example(features=tf.train.Features(feature=feature)) record_file = os.path.join(ms_dir, file + str('-tf')) ms_filename = os.path.join(ms_dir, file) with tf.io.TFRecordWriter(record_file) as writer: for image_features in parsed_image_dataset: label = image_features['image/class/label'] image = image_features['image/encoded'] tf_example = image_example(image, label) writer.write(tf_example.SerializeToString()) feature_dict = {"image": tf.io.FixedLenFeature([], tf.string), "label": tf.io.FixedLenFeature([], tf.int64) } tfrecord_transformer = TFRecordToMR(record_file, ms_filename, feature_dict, ["image"]) tfrecord_transformer.transform() os.remove(record_file)
def test_tfrecord_to_mindrecord_list_without_bytes_type(): """test transform tfrecord to mindrecord.""" if not tf or tf.__version__ < SupportedTensorFlowVersion: # skip the test logger.warning("Module tensorflow is not found or version wrong, \ please use pip install it / reinstall version >= {}.".format( SupportedTensorFlowVersion)) return generate_tfrecord() assert os.path.exists(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME)) feature_dict = { "file_name": tf.io.FixedLenFeature([], tf.string), "image_bytes": tf.io.FixedLenFeature([], tf.string), "int64_scalar": tf.io.FixedLenFeature([1], tf.int64), "float_scalar": tf.io.FixedLenFeature([1], tf.float32), "int64_list": tf.io.FixedLenFeature([6], tf.int64), "float_list": tf.io.FixedLenFeature([7], tf.float32), } if os.path.exists(MINDRECORD_FILE_NAME): os.remove(MINDRECORD_FILE_NAME) if os.path.exists(MINDRECORD_FILE_NAME + ".db"): os.remove(MINDRECORD_FILE_NAME + ".db") tfrecord_transformer = TFRecordToMR( os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME), MINDRECORD_FILE_NAME, feature_dict) tfrecord_transformer.transform() assert os.path.exists(MINDRECORD_FILE_NAME) assert os.path.exists(MINDRECORD_FILE_NAME + ".db") fr_mindrecord = FileReader(MINDRECORD_FILE_NAME) verify_data(tfrecord_transformer, fr_mindrecord) os.remove(MINDRECORD_FILE_NAME) os.remove(MINDRECORD_FILE_NAME + ".db") os.remove(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME))
def test_tfrecord_to_mindrecord_scalar_with_1_list_small_len_exception(): """test transform tfrecord to mindrecord.""" if not tf or tf.__version__ < SupportedTensorFlowVersion: # skip the test logger.warning("Module tensorflow is not found or version wrong, \ please use pip install it / reinstall version >= {}.".format( SupportedTensorFlowVersion)) return generate_tfrecord() assert os.path.exists(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME)) feature_dict = { "file_name": tf.io.FixedLenFeature([], tf.string), "image_bytes": tf.io.FixedLenFeature([], tf.string), "int64_scalar": tf.io.FixedLenFeature([1], tf.int64), "float_scalar": tf.io.FixedLenFeature([1], tf.float32), "int64_list": tf.io.FixedLenFeature([6], tf.int64), "float_list": tf.io.FixedLenFeature([2], tf.float32), } if os.path.exists(MINDRECORD_FILE_NAME): os.remove(MINDRECORD_FILE_NAME) if os.path.exists(MINDRECORD_FILE_NAME + ".db"): os.remove(MINDRECORD_FILE_NAME + ".db") with pytest.raises(ValueError): tfrecord_transformer = TFRecordToMR( os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME), MINDRECORD_FILE_NAME, feature_dict, ["image_bytes"]) tfrecord_transformer.transform() if os.path.exists(MINDRECORD_FILE_NAME): os.remove(MINDRECORD_FILE_NAME) if os.path.exists(MINDRECORD_FILE_NAME + ".db"): os.remove(MINDRECORD_FILE_NAME + ".db") os.remove(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME))
def test_tfrecord_to_mindrecord_with_special_field_name(): """test transform tfrecord to mindrecord.""" if not tf or tf.__version__ < SupportedTensorFlowVersion: # skip the test logger.warning("Module tensorflow is not found or version wrong, \ please use pip install it / reinstall version >= {}.".format( SupportedTensorFlowVersion)) return generate_tfrecord_with_special_field_name() assert os.path.exists(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME)) feature_dict = { "image/class/label": tf.io.FixedLenFeature([], tf.int64), "image/encoded": tf.io.FixedLenFeature([], tf.string), } if os.path.exists(MINDRECORD_FILE_NAME): os.remove(MINDRECORD_FILE_NAME) if os.path.exists(MINDRECORD_FILE_NAME + ".db"): os.remove(MINDRECORD_FILE_NAME + ".db") tfrecord_transformer = TFRecordToMR( os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME), MINDRECORD_FILE_NAME, feature_dict, ["image/encoded"]) tfrecord_transformer.transform() assert os.path.exists(MINDRECORD_FILE_NAME) assert os.path.exists(MINDRECORD_FILE_NAME + ".db") fr_mindrecord = FileReader(MINDRECORD_FILE_NAME) verify_data(tfrecord_transformer, fr_mindrecord) os.remove(MINDRECORD_FILE_NAME) os.remove(MINDRECORD_FILE_NAME + ".db") os.remove(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME))