def test_mnist_to_mindrecord_compare_data(): """test transform mnist dataset to mindrecord and compare data.""" mnist_transformer = MnistToMR(MNIST_DIR, FILE_NAME) mnist_transformer.transform() assert os.path.exists("mnist_train.mindrecord") assert os.path.exists("mnist_test.mindrecord") train_name, test_name = "mnist_train.mindrecord", "mnist_test.mindrecord" def _extract_images(filename, num_images): """Extract the images into a 4D tensor [image index, y, x, channels].""" with gzip.open(filename) as bytestream: bytestream.read(16) buf = bytestream.read( IMAGE_SIZE * IMAGE_SIZE * num_images * NUM_CHANNELS) data = np.frombuffer(buf, dtype=np.uint8) data = data.reshape( num_images, IMAGE_SIZE, IMAGE_SIZE, NUM_CHANNELS) return data def _extract_labels(filename, num_images): """Extract the labels into a vector of int64 label IDs.""" with gzip.open(filename) as bytestream: bytestream.read(8) buf = bytestream.read(1 * num_images) labels = np.frombuffer(buf, dtype=np.uint8).astype(np.int64) return labels train_data_filename_ = os.path.join(MNIST_DIR, 'train-images-idx3-ubyte.gz') train_labels_filename_ = os.path.join(MNIST_DIR, 'train-labels-idx1-ubyte.gz') test_data_filename_ = os.path.join(MNIST_DIR, 't10k-images-idx3-ubyte.gz') test_labels_filename_ = os.path.join(MNIST_DIR, 't10k-labels-idx1-ubyte.gz') train_data = _extract_images(train_data_filename_, 60000) train_labels = _extract_labels(train_labels_filename_, 60000) test_data = _extract_images(test_data_filename_, 10000) test_labels = _extract_labels(test_labels_filename_, 10000) reader = FileReader(train_name) for x, data, label in zip(reader.get_next(), train_data, train_labels): _, img = cv2.imencode(".jpeg", data) assert np.array(x['data']) == img.tobytes() assert np.array(x['label']) == label reader.close() reader = FileReader(test_name) for x, data, label in zip(reader.get_next(), test_data, test_labels): _, img = cv2.imencode(".jpeg", data) assert np.array(x['data']) == img.tobytes() assert np.array(x['label']) == label reader.close() os.remove("{}".format("mnist_train.mindrecord")) os.remove("{}.db".format("mnist_train.mindrecord")) os.remove("{}".format("mnist_test.mindrecord")) os.remove("{}.db".format("mnist_test.mindrecord"))
def test_mnist_to_mindrecord(fixture_file): """test transform mnist dataset to mindrecord.""" mnist_transformer = MnistToMR(MNIST_DIR, FILE_NAME) mnist_transformer.transform() assert os.path.exists("mnist_train.mindrecord") assert os.path.exists("mnist_test.mindrecord") read("mnist_train.mindrecord", "mnist_test.mindrecord")
def test_mnist_to_mindrecord_multi_partition(): """test transform mnist dataset to multiple mindrecord files.""" mnist_transformer = MnistToMR(MNIST_DIR, FILE_NAME, PARTITION_NUM) mnist_transformer.transform() read("mnist_train.mindrecord0", "mnist_test.mindrecord0") for i in range(PARTITION_NUM): os.remove("{}".format("mnist_train.mindrecord" + str(i))) os.remove("{}.db".format("mnist_train.mindrecord" + str(i))) os.remove("{}".format("mnist_test.mindrecord" + str(i))) os.remove("{}.db".format("mnist_test.mindrecord" + str(i)))
def test_mnist_to_mindrecord(): """test transform mnist dataset to mindrecord.""" mnist_transformer = MnistToMR(MNIST_DIR, FILE_NAME) mnist_transformer.transform() assert os.path.exists("mnist_train.mindrecord") assert os.path.exists("mnist_test.mindrecord") read("mnist_train.mindrecord", "mnist_test.mindrecord") os.remove("{}".format("mnist_train.mindrecord")) os.remove("{}.db".format("mnist_train.mindrecord")) os.remove("{}".format("mnist_test.mindrecord")) os.remove("{}.db".format("mnist_test.mindrecord"))
def test_mnist_to_mindrecord_multi_partition(fixture_file): """test transform mnist dataset to multiple mindrecord files.""" mnist_transformer = MnistToMR(MNIST_DIR, FILE_NAME, PARTITION_NUM) mnist_transformer.transform() read("mnist_train.mindrecord0", "mnist_test.mindrecord0")