Ejemplo n.º 1
0
def transform_single(file):
    tf_filename = os.path.join(tf_dir, file)
    filenames = [tf_filename]
    raw_dataset = tf.data.TFRecordDataset(filenames)
    image_feature_description = {
        "image/encoded": tf.io.FixedLenFeature([], tf.string), # tf.string means bytestring
        "image/class/label": tf.io.FixedLenFeature([], tf.int64),  # shape [] means single element
    }

    def _parse_image_function(example_proto):
    # parse the input tf.Example proto using the dictionary above.
        return tf.io.parse_single_example(example_proto, image_feature_description)

    parsed_image_dataset = raw_dataset.map(_parse_image_function)

    def _bytes_feature(value):
        if isinstance(value, type(tf.constant(0))):
            value = value.numpy() # BytesList won't unpack a string from an EagerTensor.
        return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

    def _float_feature(value):
        return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))

    def _int64_feature(value):
        return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

    def image_example(image_string, label):
        label = label.numpy().astype(int)
        feature = {
        'label': _int64_feature(label),
        'image': _bytes_feature(image_string),
        }

        return tf.train.Example(features=tf.train.Features(feature=feature))

    record_file = os.path.join(ms_dir, file + str('-tf'))
    ms_filename = os.path.join(ms_dir, file)
    with tf.io.TFRecordWriter(record_file) as writer:
        for image_features in parsed_image_dataset:
            label = image_features['image/class/label']
            image = image_features['image/encoded']
            tf_example = image_example(image, label)
            writer.write(tf_example.SerializeToString())
        

    feature_dict = {"image": tf.io.FixedLenFeature([], tf.string),
                    "label": tf.io.FixedLenFeature([], tf.int64)
                }

    tfrecord_transformer = TFRecordToMR(record_file, ms_filename, feature_dict, ["image"])
    tfrecord_transformer.transform()
    os.remove(record_file)
Ejemplo n.º 2
0
def test_tfrecord_to_mindrecord_list_without_bytes_type():
    """test transform tfrecord to mindrecord."""
    if not tf or tf.__version__ < SupportedTensorFlowVersion:
        # skip the test
        logger.warning("Module tensorflow is not found or version wrong, \
            please use pip install it / reinstall version >= {}.".format(
            SupportedTensorFlowVersion))
        return

    generate_tfrecord()
    assert os.path.exists(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME))

    feature_dict = {
        "file_name": tf.io.FixedLenFeature([], tf.string),
        "image_bytes": tf.io.FixedLenFeature([], tf.string),
        "int64_scalar": tf.io.FixedLenFeature([1], tf.int64),
        "float_scalar": tf.io.FixedLenFeature([1], tf.float32),
        "int64_list": tf.io.FixedLenFeature([6], tf.int64),
        "float_list": tf.io.FixedLenFeature([7], tf.float32),
    }

    if os.path.exists(MINDRECORD_FILE_NAME):
        os.remove(MINDRECORD_FILE_NAME)
    if os.path.exists(MINDRECORD_FILE_NAME + ".db"):
        os.remove(MINDRECORD_FILE_NAME + ".db")

    tfrecord_transformer = TFRecordToMR(
        os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME),
        MINDRECORD_FILE_NAME, feature_dict)
    tfrecord_transformer.transform()

    assert os.path.exists(MINDRECORD_FILE_NAME)
    assert os.path.exists(MINDRECORD_FILE_NAME + ".db")

    fr_mindrecord = FileReader(MINDRECORD_FILE_NAME)
    verify_data(tfrecord_transformer, fr_mindrecord)

    os.remove(MINDRECORD_FILE_NAME)
    os.remove(MINDRECORD_FILE_NAME + ".db")

    os.remove(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME))
Ejemplo n.º 3
0
def test_tfrecord_to_mindrecord_scalar_with_1_list_small_len_exception():
    """test transform tfrecord to mindrecord."""
    if not tf or tf.__version__ < SupportedTensorFlowVersion:
        # skip the test
        logger.warning("Module tensorflow is not found or version wrong, \
            please use pip install it / reinstall version >= {}.".format(
            SupportedTensorFlowVersion))
        return

    generate_tfrecord()
    assert os.path.exists(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME))

    feature_dict = {
        "file_name": tf.io.FixedLenFeature([], tf.string),
        "image_bytes": tf.io.FixedLenFeature([], tf.string),
        "int64_scalar": tf.io.FixedLenFeature([1], tf.int64),
        "float_scalar": tf.io.FixedLenFeature([1], tf.float32),
        "int64_list": tf.io.FixedLenFeature([6], tf.int64),
        "float_list": tf.io.FixedLenFeature([2], tf.float32),
    }

    if os.path.exists(MINDRECORD_FILE_NAME):
        os.remove(MINDRECORD_FILE_NAME)
    if os.path.exists(MINDRECORD_FILE_NAME + ".db"):
        os.remove(MINDRECORD_FILE_NAME + ".db")

    with pytest.raises(ValueError):
        tfrecord_transformer = TFRecordToMR(
            os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME),
            MINDRECORD_FILE_NAME, feature_dict, ["image_bytes"])
        tfrecord_transformer.transform()

    if os.path.exists(MINDRECORD_FILE_NAME):
        os.remove(MINDRECORD_FILE_NAME)
    if os.path.exists(MINDRECORD_FILE_NAME + ".db"):
        os.remove(MINDRECORD_FILE_NAME + ".db")

    os.remove(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME))
Ejemplo n.º 4
0
def test_tfrecord_to_mindrecord_with_special_field_name():
    """test transform tfrecord to mindrecord."""
    if not tf or tf.__version__ < SupportedTensorFlowVersion:
        # skip the test
        logger.warning("Module tensorflow is not found or version wrong, \
            please use pip install it / reinstall version >= {}.".format(
            SupportedTensorFlowVersion))
        return

    generate_tfrecord_with_special_field_name()
    assert os.path.exists(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME))

    feature_dict = {
        "image/class/label": tf.io.FixedLenFeature([], tf.int64),
        "image/encoded": tf.io.FixedLenFeature([], tf.string),
    }

    if os.path.exists(MINDRECORD_FILE_NAME):
        os.remove(MINDRECORD_FILE_NAME)
    if os.path.exists(MINDRECORD_FILE_NAME + ".db"):
        os.remove(MINDRECORD_FILE_NAME + ".db")

    tfrecord_transformer = TFRecordToMR(
        os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME),
        MINDRECORD_FILE_NAME, feature_dict, ["image/encoded"])
    tfrecord_transformer.transform()

    assert os.path.exists(MINDRECORD_FILE_NAME)
    assert os.path.exists(MINDRECORD_FILE_NAME + ".db")

    fr_mindrecord = FileReader(MINDRECORD_FILE_NAME)
    verify_data(tfrecord_transformer, fr_mindrecord)

    os.remove(MINDRECORD_FILE_NAME)
    os.remove(MINDRECORD_FILE_NAME + ".db")

    os.remove(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME))