Exemplo n.º 1
0
def test_mnist_to_mindrecord_compare_data():
    """test transform mnist dataset to mindrecord and compare data."""
    mnist_transformer = MnistToMR(MNIST_DIR, FILE_NAME)
    mnist_transformer.transform()
    assert os.path.exists("mnist_train.mindrecord")
    assert os.path.exists("mnist_test.mindrecord")

    train_name, test_name = "mnist_train.mindrecord", "mnist_test.mindrecord"

    def _extract_images(filename, num_images):
        """Extract the images into a 4D tensor [image index, y, x, channels]."""
        with gzip.open(filename) as bytestream:
            bytestream.read(16)
            buf = bytestream.read(
                IMAGE_SIZE * IMAGE_SIZE * num_images * NUM_CHANNELS)
            data = np.frombuffer(buf, dtype=np.uint8)
            data = data.reshape(
                num_images, IMAGE_SIZE, IMAGE_SIZE, NUM_CHANNELS)
            return data

    def _extract_labels(filename, num_images):
        """Extract the labels into a vector of int64 label IDs."""
        with gzip.open(filename) as bytestream:
            bytestream.read(8)
            buf = bytestream.read(1 * num_images)
            labels = np.frombuffer(buf, dtype=np.uint8).astype(np.int64)
            return labels

    train_data_filename_ = os.path.join(MNIST_DIR,
                                        'train-images-idx3-ubyte.gz')
    train_labels_filename_ = os.path.join(MNIST_DIR,
                                          'train-labels-idx1-ubyte.gz')
    test_data_filename_ = os.path.join(MNIST_DIR,
                                       't10k-images-idx3-ubyte.gz')
    test_labels_filename_ = os.path.join(MNIST_DIR,
                                         't10k-labels-idx1-ubyte.gz')
    train_data = _extract_images(train_data_filename_, 60000)
    train_labels = _extract_labels(train_labels_filename_, 60000)
    test_data = _extract_images(test_data_filename_, 10000)
    test_labels = _extract_labels(test_labels_filename_, 10000)

    reader = FileReader(train_name)
    for x, data, label in zip(reader.get_next(), train_data, train_labels):
        _, img = cv2.imencode(".jpeg", data)
        assert np.array(x['data']) == img.tobytes()
        assert np.array(x['label']) == label
    reader.close()

    reader = FileReader(test_name)
    for x, data, label in zip(reader.get_next(), test_data, test_labels):
        _, img = cv2.imencode(".jpeg", data)
        assert np.array(x['data']) == img.tobytes()
        assert np.array(x['label']) == label
    reader.close()

    os.remove("{}".format("mnist_train.mindrecord"))
    os.remove("{}.db".format("mnist_train.mindrecord"))
    os.remove("{}".format("mnist_test.mindrecord"))
    os.remove("{}.db".format("mnist_test.mindrecord"))
Exemplo n.º 2
0
def test_mnist_to_mindrecord(fixture_file):
    """test transform mnist dataset to mindrecord."""
    mnist_transformer = MnistToMR(MNIST_DIR, FILE_NAME)
    mnist_transformer.transform()
    assert os.path.exists("mnist_train.mindrecord")
    assert os.path.exists("mnist_test.mindrecord")

    read("mnist_train.mindrecord", "mnist_test.mindrecord")
Exemplo n.º 3
0
def test_mnist_to_mindrecord_multi_partition():
    """test transform mnist dataset to multiple mindrecord files."""
    mnist_transformer = MnistToMR(MNIST_DIR, FILE_NAME, PARTITION_NUM)
    mnist_transformer.transform()

    read("mnist_train.mindrecord0", "mnist_test.mindrecord0")

    for i in range(PARTITION_NUM):
        os.remove("{}".format("mnist_train.mindrecord" + str(i)))
        os.remove("{}.db".format("mnist_train.mindrecord" + str(i)))
        os.remove("{}".format("mnist_test.mindrecord" + str(i)))
        os.remove("{}.db".format("mnist_test.mindrecord" + str(i)))
Exemplo n.º 4
0
def test_mnist_to_mindrecord():
    """test transform mnist dataset to mindrecord."""
    mnist_transformer = MnistToMR(MNIST_DIR, FILE_NAME)
    mnist_transformer.transform()
    assert os.path.exists("mnist_train.mindrecord")
    assert os.path.exists("mnist_test.mindrecord")

    read("mnist_train.mindrecord", "mnist_test.mindrecord")

    os.remove("{}".format("mnist_train.mindrecord"))
    os.remove("{}.db".format("mnist_train.mindrecord"))
    os.remove("{}".format("mnist_test.mindrecord"))
    os.remove("{}.db".format("mnist_test.mindrecord"))
Exemplo n.º 5
0
def test_mnist_to_mindrecord_multi_partition(fixture_file):
    """test transform mnist dataset to multiple mindrecord files."""
    mnist_transformer = MnistToMR(MNIST_DIR, FILE_NAME, PARTITION_NUM)
    mnist_transformer.transform()

    read("mnist_train.mindrecord0", "mnist_test.mindrecord0")