def test_issue_73(): """test file reader by column name.""" writer = FileWriter(MKV_FILE_NAME, FILES_NUM) data = get_mkv_data("../data/mindrecord/testVehPerData/") mkv_schema_json = { "file_name": { "type": "string" }, "id": { "type": "number" }, "prelabel": { "type": "string" }, "data": { "type": "bytes" } } writer.add_schema(mkv_schema_json, "mkv_schema") writer.add_index(["file_name", "prelabel"]) writer.write_raw_data(data) writer.commit() reader = FileReader(MKV_FILE_NAME + "1", 4, ["file_name"]) for index, x in enumerate(reader.get_next()): logger.info("#item{}: {}".format(index, x)) reader.close() paths = [ "{}{}".format(MKV_FILE_NAME, str(x).rjust(1, '0')) for x in range(FILES_NUM) ] for x in paths: os.remove("{}".format(x)) os.remove("{}.db".format(x))
def test_issue_39(): """test cv dataset writer when schema fields' datatype does not match raw data.""" writer = FileWriter(CV_FILE_NAME, 1) data = get_data("../data/mindrecord/testImageNetData/") cv_schema_json = { "file_name": { "type": "number" }, "label": { "type": "number" }, "data": { "type": "bytes" } } writer.add_schema(cv_schema_json, "img_schema") writer.add_index(["file_name", "label"]) writer.write_raw_data(data) writer.commit() reader = FileReader(CV_FILE_NAME) index = 0 for _ in reader.get_next(): index += 1 assert index == 0, "failed on reading data!" reader.close() os.remove("{}".format(CV_FILE_NAME)) os.remove("{}.db".format(CV_FILE_NAME))
def test_write_raw_data_with_empty_list(): """test write raw data with empty list.""" writer = FileWriter(CV_FILE_NAME, FILES_NUM) cv_schema_json = { "file_name": { "type": "string" }, "label": { "type": "number" }, "data": { "type": "bytes" } } writer.add_schema(cv_schema_json, "img_schema") writer.add_index(["file_name", "label"]) ret = writer.write_raw_data([]) assert ret == SUCCESS writer.commit() reader = FileReader(CV_FILE_NAME + "0") for index, x in enumerate(reader.get_next()): logger.info("#item{}: {}".format(index, x)) reader.close() paths = [ "{}{}".format(CV_FILE_NAME, str(x).rjust(1, '0')) for x in range(FILES_NUM) ] for x in paths: os.remove("{}".format(x)) os.remove("{}.db".format(x))
def test_cv_file_writer_default_shard_num(): """test cv dataset writer when shard_num is default value.""" writer = FileWriter(CV_FILE_NAME) data = get_data("../data/mindrecord/testImageNetData/") cv_schema_json = { "file_name": { "type": "string" }, "label": { "type": "number" }, "data": { "type": "bytes" } } writer.add_schema(cv_schema_json, "img_schema") writer.add_index(["file_name", "label"]) writer.write_raw_data(data) writer.commit() reader = FileReader(CV_FILE_NAME) for index, x in enumerate(reader.get_next()): logger.info("#item{}: {}".format(index, x)) reader.close() os.remove("{}".format(CV_FILE_NAME)) os.remove("{}.db".format(CV_FILE_NAME))
def test_cv_file_writer_shard_num_10(): """test cv dataset writer when shard_num equals 10.""" shard_num = 10 writer = FileWriter(CV_FILE_NAME, shard_num) data = get_data("../data/mindrecord/testImageNetData/") cv_schema_json = { "file_name": { "type": "string" }, "label": { "type": "number" }, "data": { "type": "bytes" } } writer.add_schema(cv_schema_json, "img_schema") writer.add_index(["file_name", "label"]) writer.write_raw_data(data) writer.commit() reader = FileReader(CV_FILE_NAME + "0") for index, x in enumerate(reader.get_next()): logger.info("#item{}: {}".format(index, x)) reader.close() paths = [ "{}{}".format(CV_FILE_NAME, str(x).rjust(1, '0')) for x in range(shard_num) ] for item in paths: os.remove("{}".format(item)) os.remove("{}.db".format(item))
def test_file_writer_raw_data_038(): """test write raw data without verify.""" shard_num = 11 writer = FileWriter("test_file_writer_raw_data_", shard_num) data_raw = get_data("../data/mindrecord/testImageNetData/") schema_json = {"file_name": {"type": "string"}, "label": {"type": "number"}, "data": {"type": "bytes"}} writer.add_schema(schema_json, "img_schema") writer.add_index(["file_name"]) for _ in range(shard_num): writer.write_raw_data(data_raw, False) writer.commit() file_name = "" if shard_num > 1: file_name = '99' if shard_num > 99 else str(shard_num - 1) reader = FileReader("test_file_writer_raw_data_" + file_name) i = 0 for _, _ in enumerate(reader.get_next()): i = i + 1 assert i == shard_num * 10 reader.close() if shard_num == 1: os.remove("test_file_writer_raw_data_") os.remove("test_file_writer_raw_data_.db") return for x in range(shard_num): n = str(x) if shard_num > 10: n = '0' + str(x) if x < 10 else str(x) if os.path.exists("test_file_writer_raw_data_{}".format(n)): os.remove("test_file_writer_raw_data_{}".format(n)) if os.path.exists("test_file_writer_raw_data_{}.db".format(n)): os.remove("test_file_writer_raw_data_{}.db".format(n))
def test_cv_file_append_writer_absolute_path(): """tutorial for cv dataset append writer.""" writer = FileWriter(CV4_FILE_NAME, 4) data = get_data("../data/mindrecord/testImageNetData/") cv_schema_json = {"file_name": {"type": "string"}, "label": {"type": "int64"}, "data": {"type": "bytes"}} writer.add_schema(cv_schema_json, "img_schema") writer.add_index(["file_name", "label"]) writer.write_raw_data(data[0:5]) writer.commit() write_append = FileWriter.open_for_append(CV4_FILE_NAME + "0") write_append.write_raw_data(data[5:10]) write_append.commit() reader = FileReader(CV4_FILE_NAME + "0") count = 0 for index, x in enumerate(reader.get_next()): assert len(x) == 3 count = count + 1 logger.info("#item{}: {}".format(index, x)) assert count == 10 reader.close() paths = ["{}{}".format(CV4_FILE_NAME, str(x).rjust(1, '0')) for x in range(4)] for x in paths: os.remove("{}".format(x)) os.remove("{}.db".format(x))
def test_cv_file_writer_loop_and_read(): """tutorial for cv dataset loop writer.""" writer = FileWriter(CV2_FILE_NAME, FILES_NUM) data = get_data("../data/mindrecord/testImageNetData/") cv_schema_json = {"file_name": {"type": "string"}, "label": {"type": "int64"}, "data": {"type": "bytes"}} writer.add_schema(cv_schema_json, "img_schema") writer.add_index(["file_name", "label"]) for row in data: writer.write_raw_data([row]) writer.commit() reader = FileReader(CV2_FILE_NAME + "0") count = 0 for index, x in enumerate(reader.get_next()): assert len(x) == 3 count = count + 1 logger.info("#item{}: {}".format(index, x)) assert count == 10 reader.close() paths = ["{}{}".format(CV2_FILE_NAME, str(x).rjust(1, '0')) for x in range(FILES_NUM)] for x in paths: os.remove("{}".format(x)) os.remove("{}.db".format(x))
def test_issue_34(): """test file writer""" writer = FileWriter(CV_FILE_NAME) data = get_data("../data/mindrecord/testImageNetData/") cv_schema_json = { "file_name": { "type": "string" }, "label": { "type": "number" }, "data": { "type": "bytes" } } writer.add_schema(cv_schema_json, "cv_schema") writer.add_index(["file_name", "label"]) writer.write_raw_data(data) writer.commit() reader = FileReader(CV_FILE_NAME) i = 0 for index, x in enumerate(reader.get_next()): logger.info("#item{}: {}".format(index, x)) i = i + 1 logger.info("count: {}".format(i)) reader.close() os.remove(CV_FILE_NAME) os.remove("{}.db".format(CV_FILE_NAME))
def test_cv_file_writer_no_blob(): """test cv file writer without blob data.""" writer = FileWriter(CV_FILE_NAME, 1) data = get_data("../data/mindrecord/testImageNetData/") cv_schema_json = { "file_name": { "type": "string" }, "label": { "type": "int64" } } writer.add_schema(cv_schema_json, "no_blob_schema") writer.add_index(["file_name", "label"]) writer.write_raw_data(data) writer.commit() reader = FileReader(CV_FILE_NAME) count = 0 for index, x in enumerate(reader.get_next()): count += 1 assert len(x) == 2 logger.info("#item{}: {}".format(index, x)) assert count == 10 reader.close() os.remove(CV_FILE_NAME) os.remove("{}.db".format(CV_FILE_NAME))
def test_cv_file_writer_no_raw(): """test cv file writer without raw data.""" writer = FileWriter(NLP_FILE_NAME) data = list( get_nlp_data("../data/mindrecord/testAclImdbData/pos", "../data/mindrecord/testAclImdbData/vocab.txt", 10)) nlp_schema_json = { "input_ids": { "type": "int64", "shape": [1, -1] }, "input_mask": { "type": "int64", "shape": [1, -1] }, "segment_ids": { "type": "int64", "shape": [1, -1] } } writer.add_schema(nlp_schema_json, "no_raw_schema") writer.write_raw_data(data) writer.commit() reader = FileReader(NLP_FILE_NAME) count = 0 for index, x in enumerate(reader.get_next()): count += 1 assert len(x) == 3 logger.info("#item{}: {}".format(index, x)) assert count == 10 reader.close() os.remove(NLP_FILE_NAME) os.remove("{}.db".format(NLP_FILE_NAME))
def test_shard_4_raw_data_1(): """test file writer when shard_num equals 4 and number of sample equals 1.""" writer = FileWriter(CV_FILE_NAME, FILES_NUM) schema_json = { "file_name": { "type": "string" }, "label": { "type": "number" } } writer.add_schema(schema_json, "img_schema") writer.add_index(["label"]) data = [{"file_name": "001.jpg", "label": 1}] writer.write_raw_data(data) writer.commit() reader = FileReader(CV_FILE_NAME + "0") count = 0 for index, x in enumerate(reader.get_next()): assert len(x) == 2 count = count + 1 logger.info("#item{}: {}".format(index, x)) assert count == 1 reader.close() paths = [ "{}{}".format(CV_FILE_NAME, str(x).rjust(1, '0')) for x in range(FILES_NUM) ] for x in paths: os.remove("{}".format(x)) os.remove("{}.db".format(x))
def write_mindrecord_tutorial(): writer = FileWriter(MINDRECORD_FILE_NAME) data = get_data("./ImageNetDataSimulation") schema_json = { "file_name": { "type": "string" }, "label": { "type": "int64" }, "data": { "type": "bytes" } } writer.add_schema(schema_json, "img_schema") writer.add_index(["file_name", "label"]) writer.write_raw_data(data) writer.commit() reader = FileReader(MINDRECORD_FILE_NAME) count = 0 for index, x in enumerate(reader.get_next()): assert len(x) == 3 count = count + 1 # print("#item {}: {}".format(index, x)) assert count == 20 reader.close()
def test_cv_file_writer_without_data(): """test cv file writer without data.""" writer = FileWriter(CV_FILE_NAME, 1) cv_schema_json = { "file_name": { "type": "string" }, "label": { "type": "int64" }, "data": { "type": "bytes" } } writer.add_schema(cv_schema_json, "img_schema") writer.add_index(["file_name", "label"]) writer.commit() reader = FileReader(CV_FILE_NAME) count = 0 for index, x in enumerate(reader.get_next()): count = count + 1 logger.info("#item{}: {}".format(index, x)) assert count == 0 reader.close() os.remove(CV_FILE_NAME) os.remove("{}.db".format(CV_FILE_NAME))
def test_mnist_to_mindrecord_compare_data(): """test transform mnist dataset to mindrecord and compare data.""" mnist_transformer = MnistToMR(MNIST_DIR, FILE_NAME) mnist_transformer.transform() assert os.path.exists("mnist_train.mindrecord") assert os.path.exists("mnist_test.mindrecord") train_name, test_name = "mnist_train.mindrecord", "mnist_test.mindrecord" def _extract_images(filename, num_images): """Extract the images into a 4D tensor [image index, y, x, channels].""" with gzip.open(filename) as bytestream: bytestream.read(16) buf = bytestream.read( IMAGE_SIZE * IMAGE_SIZE * num_images * NUM_CHANNELS) data = np.frombuffer(buf, dtype=np.uint8) data = data.reshape( num_images, IMAGE_SIZE, IMAGE_SIZE, NUM_CHANNELS) return data def _extract_labels(filename, num_images): """Extract the labels into a vector of int64 label IDs.""" with gzip.open(filename) as bytestream: bytestream.read(8) buf = bytestream.read(1 * num_images) labels = np.frombuffer(buf, dtype=np.uint8).astype(np.int64) return labels train_data_filename_ = os.path.join(MNIST_DIR, 'train-images-idx3-ubyte.gz') train_labels_filename_ = os.path.join(MNIST_DIR, 'train-labels-idx1-ubyte.gz') test_data_filename_ = os.path.join(MNIST_DIR, 't10k-images-idx3-ubyte.gz') test_labels_filename_ = os.path.join(MNIST_DIR, 't10k-labels-idx1-ubyte.gz') train_data = _extract_images(train_data_filename_, 60000) train_labels = _extract_labels(train_labels_filename_, 60000) test_data = _extract_images(test_data_filename_, 10000) test_labels = _extract_labels(test_labels_filename_, 10000) reader = FileReader(train_name) for x, data, label in zip(reader.get_next(), train_data, train_labels): _, img = cv2.imencode(".jpeg", data) assert np.array(x['data']) == img.tobytes() assert np.array(x['label']) == label reader.close() reader = FileReader(test_name) for x, data, label in zip(reader.get_next(), test_data, test_labels): _, img = cv2.imencode(".jpeg", data) assert np.array(x['data']) == img.tobytes() assert np.array(x['label']) == label reader.close() os.remove("{}".format("mnist_train.mindrecord")) os.remove("{}.db".format("mnist_train.mindrecord")) os.remove("{}".format("mnist_test.mindrecord")) os.remove("{}.db".format("mnist_test.mindrecord"))
def test_lack_partition_and_db(): """test file reader when mindrecord file does not exist.""" with pytest.raises(MRMOpenError) as err: reader = FileReader('dummy.mindrecord') reader.close() assert '[MRMOpenError]: error_code: 1347690596, ' \ 'error_msg: MindRecord File could not open successfully.' \ in str(err.value)
def read(filename, fields_num=5): count = 0 reader = FileReader(filename) for _, x in enumerate(reader.get_next()): assert len(x) == fields_num count = count + 1 logger.info("data: {}".format(x)) assert count == 5 reader.close()
def test_read_after_close(fixture_cv_file): """test file reader when close read.""" create_cv_mindrecord(1) reader = FileReader(CV_FILE_NAME) reader.close() count = 0 for index, x in enumerate(reader.get_next()): count = count + 1 logger.info("#item{}: {}".format(index, x)) assert count == 0
def test_lack_db(fixture_cv_file): """test file reader when db file does not exist.""" create_cv_mindrecord(1) os.remove("{}.db".format(CV_FILE_NAME)) with pytest.raises(MRMOpenError) as err: reader = FileReader(CV_FILE_NAME) reader.close() assert '[MRMOpenError]: error_code: 1347690596, ' \ 'error_msg: MindRecord File could not open successfully.' \ in str(err.value)
def test_nlp_file_reader_tutorial(): """tutorial for nlp file reader.""" reader = FileReader(NLP_FILE_NAME + "0") count = 0 for index, x in enumerate(reader.get_next()): assert len(x) == 6 count = count + 1 logger.info("#item{}: {}".format(index, x)) assert count == 10 reader.close()
def read(filename): """test file reade""" count = 0 reader = FileReader(filename) for _, x in enumerate(reader.get_next()): assert len(x) == 3 count = count + 1 if count == 1: logger.info("data: {}".format(x)) assert count == 20 reader.close()
def test_lack_some_db(fixture_cv_file): """test file reader when some db does not exist.""" create_cv_mindrecord(4) paths = [ "{}{}".format(CV_FILE_NAME, str(x).rjust(1, '0')) for x in range(FILES_NUM) ] os.remove("{}.db".format(paths[3])) with pytest.raises(MRMOpenError) as err: reader = FileReader(CV_FILE_NAME + "0") reader.close() assert '[MRMOpenError]: error_code: 1347690596, ' \ 'error_msg: MindRecord File could not open successfully.' \ in str(err.value)
def test_write_read_process_with_define_index_field(): mindrecord_file_name = "test.mindrecord" data = [{"file_name": "001.jpg", "label": 43, "score": 0.8, "mask": np.array([3, 6, 9], dtype=np.int64), "segments": np.array([[5.0, 1.6], [65.2, 8.3]], dtype=np.float32), "data": bytes("image bytes abc", encoding='UTF-8')}, {"file_name": "002.jpg", "label": 91, "score": 5.4, "mask": np.array([1, 4, 7], dtype=np.int64), "segments": np.array([[5.1, 9.1], [2.0, 65.4]], dtype=np.float32), "data": bytes("image bytes def", encoding='UTF-8')}, {"file_name": "003.jpg", "label": 61, "score": 6.4, "mask": np.array([7, 6, 3], dtype=np.int64), "segments": np.array([[0.0, 5.6], [3.0, 16.3]], dtype=np.float32), "data": bytes("image bytes ghi", encoding='UTF-8')}, {"file_name": "004.jpg", "label": 29, "score": 8.1, "mask": np.array([2, 8, 0], dtype=np.int64), "segments": np.array([[5.9, 7.2], [4.0, 89.0]], dtype=np.float32), "data": bytes("image bytes jkl", encoding='UTF-8')}, {"file_name": "005.jpg", "label": 78, "score": 7.7, "mask": np.array([3, 1, 2], dtype=np.int64), "segments": np.array([[0.6, 8.1], [5.3, 49.3]], dtype=np.float32), "data": bytes("image bytes mno", encoding='UTF-8')}, {"file_name": "006.jpg", "label": 37, "score": 9.4, "mask": np.array([7, 6, 7], dtype=np.int64), "segments": np.array([[4.2, 6.3], [8.9, 81.8]], dtype=np.float32), "data": bytes("image bytes pqr", encoding='UTF-8')} ] writer = FileWriter(mindrecord_file_name) schema = {"file_name": {"type": "string"}, "label": {"type": "int32"}, "score": {"type": "float64"}, "mask": {"type": "int64", "shape": [-1]}, "segments": {"type": "float32", "shape": [2, 2]}, "data": {"type": "bytes"}} writer.add_schema(schema, "data is so cool") writer.add_index(["label"]) writer.write_raw_data(data) writer.commit() reader = FileReader(mindrecord_file_name) count = 0 for index, x in enumerate(reader.get_next()): assert len(x) == 6 for field in x: if isinstance(x[field], np.ndarray): assert (x[field] == data[count][field]).all() else: assert x[field] == data[count][field] count = count + 1 logger.info("#item{}: {}".format(index, x)) assert count == 6 reader.close() os.remove("{}".format(mindrecord_file_name)) os.remove("{}.db".format(mindrecord_file_name))
def read(filename, columns, row_num): """test file reade""" if not pd: raise Exception( "Module pandas is not found, please use pip install it.") df = pd.read_csv(CSV_FILE) count = 0 reader = FileReader(filename) for _, x in enumerate(reader.get_next()): for col in columns: assert x[col] == df[col].iloc[count] assert len(x) == len(columns) count = count + 1 if count == 1: logger.info("data: {}".format(x)) assert count == row_num reader.close()
def test_file_read_after_read(): """test file reader when finish read.""" create_cv_mindrecord(1) reader = FileReader(CV_FILE_NAME) count = 0 for index, x in enumerate(reader.get_next()): assert len(x) == 3 count = count + 1 logger.info("#item{}: {}".format(index, x)) assert count == 10 reader.close() cnt = 0 for index, x in enumerate(reader.get_next()): cnt = cnt + 1 logger.info("#item{}: {}".format(index, x)) assert cnt == 0 os.remove(CV_FILE_NAME) os.remove("{}.db".format(CV_FILE_NAME))
def read(train_name, test_name): """test file reader""" count = 0 reader = FileReader(train_name) for _, x in enumerate(reader.get_next()): assert len(x) == 2 count = count + 1 if count == 1: logger.info("data: {}".format(x)) assert count == 60000 reader.close() count = 0 reader = FileReader(test_name) for _, x in enumerate(reader.get_next()): assert len(x) == 2 count = count + 1 if count == 1: logger.info("data: {}".format(x)) assert count == 10000 reader.close()
def read(): """test file reader""" count = 0 reader = FileReader(MINDRECORD_FILE) for _, x in enumerate(reader.get_next()): assert len(x) == 3 count = count + 1 if count == 1: logger.info("data: {}".format(x)) assert count == 16 reader.close() count = 0 reader = FileReader(MINDRECORD_FILE + "_test") for _, x in enumerate(reader.get_next()): assert len(x) == 3 count = count + 1 if count == 1: logger.info("data: {}".format(x)) assert count == 4 reader.close()
def test_issue_118(): """test file writer when raw data do not match schema.""" shard_num = 4 writer = FileWriter(CV_FILE_NAME, shard_num) data = get_data("../data/mindrecord/testImageNetData/") cv_schema_json = { "file_name": { "type": "string" }, "label": { "type": "number" }, "data": { "type": "bytes" } } data.append({ "file_name": "abcdefg", "label": 11, "data": str(data[0]["data"]) }) writer.add_schema(cv_schema_json, "img_schema") writer.add_index(["file_name", "label"]) writer.write_raw_data(data) writer.commit() reader = FileReader(CV_FILE_NAME + "0") for index, _ in enumerate(reader.get_next()): logger.info(index) reader.close() paths = [ "{}{}".format(CV_FILE_NAME, str(x).rjust(1, '0')) for x in range(shard_num) ] for item in paths: os.remove("{}".format(item)) os.remove("{}.db".format(item))
def test_issue_95(): """test file reader when failed on file write.""" writer = FileWriter(__file__, FILES_NUM) data_raw = get_data("../data/mindrecord/testImageNetData/") schema_json = {"file_name": {"type": "number"}, "label": {"type": "number"}, "data": {"type": "bytes"}, "data1": {"type": "string"}} writer.add_schema(schema_json, "img_schema") with pytest.raises(MRMAddIndexError): writer.add_index(["key"]) writer.write_raw_data(data_raw, True) writer.commit() reader = FileReader(__file__ + "1") for index, x in enumerate(reader.get_next()): logger.info("#item{}: {}".format(index, x)) reader.close() paths = ["{}{}".format(__file__, str(x).rjust(1, '0')) for x in range(FILES_NUM)] for x in paths: os.remove("{}".format(x)) os.remove("{}.db".format(x))
def test_write_read_process_with_multi_array(): mindrecord_file_name = "test.mindrecord" data = [{ "source_sos_ids": np.array([1, 2, 3, 4, 5], dtype=np.int64), "source_sos_mask": np.array([6, 7, 8, 9, 10, 11, 12], dtype=np.int64), "source_eos_ids": np.array([13, 14, 15, 16, 17, 18], dtype=np.int64), "source_eos_mask": np.array([19, 20, 21, 22, 23, 24, 25, 26, 27], dtype=np.int64), "target_sos_ids": np.array([28, 29, 30, 31, 32], dtype=np.int64), "target_sos_mask": np.array([33, 34, 35, 36, 37, 38], dtype=np.int64), "target_eos_ids": np.array([39, 40, 41, 42, 43, 44, 45, 46, 47], dtype=np.int64), "target_eos_mask": np.array([48, 49, 50, 51], dtype=np.int64) }, { "source_sos_ids": np.array([11, 2, 3, 4, 5], dtype=np.int64), "source_sos_mask": np.array([16, 7, 8, 9, 10, 11, 12], dtype=np.int64), "source_eos_ids": np.array([113, 14, 15, 16, 17, 18], dtype=np.int64), "source_eos_mask": np.array([119, 20, 21, 22, 23, 24, 25, 26, 27], dtype=np.int64), "target_sos_ids": np.array([128, 29, 30, 31, 32], dtype=np.int64), "target_sos_mask": np.array([133, 34, 35, 36, 37, 38], dtype=np.int64), "target_eos_ids": np.array([139, 40, 41, 42, 43, 44, 45, 46, 47], dtype=np.int64), "target_eos_mask": np.array([148, 49, 50, 51], dtype=np.int64) }, { "source_sos_ids": np.array([21, 2, 3, 4, 5], dtype=np.int64), "source_sos_mask": np.array([26, 7, 8, 9, 10, 11, 12], dtype=np.int64), "source_eos_ids": np.array([213, 14, 15, 16, 17, 18], dtype=np.int64), "source_eos_mask": np.array([219, 20, 21, 22, 23, 24, 25, 26, 27], dtype=np.int64), "target_sos_ids": np.array([228, 29, 30, 31, 32], dtype=np.int64), "target_sos_mask": np.array([233, 34, 35, 36, 37, 38], dtype=np.int64), "target_eos_ids": np.array([239, 40, 41, 42, 43, 44, 45, 46, 47], dtype=np.int64), "target_eos_mask": np.array([248, 49, 50, 51], dtype=np.int64) }, { "source_sos_ids": np.array([31, 2, 3, 4, 5], dtype=np.int64), "source_sos_mask": np.array([36, 7, 8, 9, 10, 11, 12], dtype=np.int64), "source_eos_ids": np.array([313, 14, 15, 16, 17, 18], dtype=np.int64), "source_eos_mask": np.array([319, 20, 21, 22, 23, 24, 25, 26, 27], dtype=np.int64), "target_sos_ids": np.array([328, 29, 30, 31, 32], dtype=np.int64), "target_sos_mask": np.array([333, 34, 35, 36, 37, 38], dtype=np.int64), "target_eos_ids": np.array([339, 40, 41, 42, 43, 44, 45, 46, 47], dtype=np.int64), "target_eos_mask": np.array([348, 49, 50, 51], dtype=np.int64) }, { "source_sos_ids": np.array([41, 2, 3, 4, 5], dtype=np.int64), "source_sos_mask": np.array([46, 7, 8, 9, 10, 11, 12], dtype=np.int64), "source_eos_ids": np.array([413, 14, 15, 16, 17, 18], dtype=np.int64), "source_eos_mask": np.array([419, 20, 21, 22, 23, 24, 25, 26, 27], dtype=np.int64), "target_sos_ids": np.array([428, 29, 30, 31, 32], dtype=np.int64), "target_sos_mask": np.array([433, 34, 35, 36, 37, 38], dtype=np.int64), "target_eos_ids": np.array([439, 40, 41, 42, 43, 44, 45, 46, 47], dtype=np.int64), "target_eos_mask": np.array([448, 49, 50, 51], dtype=np.int64) }, { "source_sos_ids": np.array([51, 2, 3, 4, 5], dtype=np.int64), "source_sos_mask": np.array([56, 7, 8, 9, 10, 11, 12], dtype=np.int64), "source_eos_ids": np.array([513, 14, 15, 16, 17, 18], dtype=np.int64), "source_eos_mask": np.array([519, 20, 21, 22, 23, 24, 25, 26, 27], dtype=np.int64), "target_sos_ids": np.array([528, 29, 30, 31, 32], dtype=np.int64), "target_sos_mask": np.array([533, 34, 35, 36, 37, 38], dtype=np.int64), "target_eos_ids": np.array([539, 40, 41, 42, 43, 44, 45, 46, 47], dtype=np.int64), "target_eos_mask": np.array([548, 49, 50, 51], dtype=np.int64) }] writer = FileWriter(mindrecord_file_name) schema = { "source_sos_ids": { "type": "int64", "shape": [-1] }, "source_sos_mask": { "type": "int64", "shape": [-1] }, "source_eos_ids": { "type": "int64", "shape": [-1] }, "source_eos_mask": { "type": "int64", "shape": [-1] }, "target_sos_ids": { "type": "int64", "shape": [-1] }, "target_sos_mask": { "type": "int64", "shape": [-1] }, "target_eos_ids": { "type": "int64", "shape": [-1] }, "target_eos_mask": { "type": "int64", "shape": [-1] } } writer.add_schema(schema, "data is so cool") writer.write_raw_data(data) writer.commit() reader = FileReader(mindrecord_file_name) count = 0 for index, x in enumerate(reader.get_next()): assert len(x) == 8 for field in x: if isinstance(x[field], np.ndarray): assert (x[field] == data[count][field]).all() else: assert x[field] == data[count][field] count = count + 1 logger.info("#item{}: {}".format(index, x)) assert count == 6 reader.close() reader = FileReader(file_name=mindrecord_file_name, columns=[ "source_eos_ids", "source_eos_mask", "target_sos_ids", "target_sos_mask", "target_eos_ids", "target_eos_mask" ]) count = 0 for index, x in enumerate(reader.get_next()): assert len(x) == 6 for field in x: if isinstance(x[field], np.ndarray): assert (x[field] == data[count][field]).all() else: assert x[field] == data[count][field] count = count + 1 logger.info("#item{}: {}".format(index, x)) assert count == 6 reader.close() reader = FileReader( file_name=mindrecord_file_name, columns=["source_sos_ids", "target_sos_ids", "target_eos_mask"]) count = 0 for index, x in enumerate(reader.get_next()): assert len(x) == 3 for field in x: if isinstance(x[field], np.ndarray): assert (x[field] == data[count][field]).all() else: assert x[field] == data[count][field] count = count + 1 logger.info("#item{}: {}".format(index, x)) assert count == 6 reader.close() reader = FileReader( file_name=mindrecord_file_name, columns=["target_eos_mask", "source_eos_mask", "source_sos_mask"]) count = 0 for index, x in enumerate(reader.get_next()): assert len(x) == 3 for field in x: if isinstance(x[field], np.ndarray): assert (x[field] == data[count][field]).all() else: assert x[field] == data[count][field] count = count + 1 logger.info("#item{}: {}".format(index, x)) assert count == 6 reader.close() reader = FileReader(file_name=mindrecord_file_name, columns=["target_eos_ids"]) count = 0 for index, x in enumerate(reader.get_next()): assert len(x) == 1 for field in x: if isinstance(x[field], np.ndarray): assert (x[field] == data[count][field]).all() else: assert x[field] == data[count][field] count = count + 1 logger.info("#item{}: {}".format(index, x)) assert count == 6 reader.close() os.remove("{}".format(mindrecord_file_name)) os.remove("{}.db".format(mindrecord_file_name))