def test_cv_file_append_writer(): """tutorial for cv dataset append writer.""" writer = FileWriter(CV3_FILE_NAME, 4) data = get_data("../data/mindrecord/testImageNetData/") cv_schema_json = { "file_name": { "type": "string" }, "label": { "type": "int64" }, "data": { "type": "bytes" } } writer.add_schema(cv_schema_json, "img_schema") writer.add_index(["file_name", "label"]) writer.write_raw_data(data[0:5]) writer.commit() write_append = FileWriter.open_for_append(CV3_FILE_NAME + "0") write_append.write_raw_data(data[5:10]) write_append.commit() reader = FileReader(CV3_FILE_NAME + "0") count = 0 for index, x in enumerate(reader.get_next()): assert len(x) == 3 count = count + 1 logger.info("#item{}: {}".format(index, x)) assert count == 10 reader.close() paths = [ "{}{}".format(CV3_FILE_NAME, str(x).rjust(1, '0')) for x in range(4) ] for x in paths: os.remove("{}".format(x)) os.remove("{}.db".format(x))
def test_add_index_with_dict(): """test index add when index fields' datatype is dict(64).""" writer = FileWriter(MKV_FILE_NAME, FILES_NUM) mkv_schema_json = { "file_name": { "type": "string" }, "id": { "type": "number" }, "prelabel": { "type": "string" }, "data": { "type": "bytes" } } writer.add_schema(mkv_schema_json, "mkv_schema") with pytest.raises(Exception) as e: writer.add_index({"file_name": {"type": "string"}}) assert str(e.value) == "[ParamTypeError]: error_code: 1347686401, " \ "error_msg: Invalid parameter type. " \ "'index_fields' expect list type."
def add_and_remove_cv_file(): """add/remove cv file""" paths = ["{}{}".format(CV_FILE_NAME, str(x).rjust(1, '0')) for x in range(FILES_NUM)] for x in paths: if os.path.exists("{}".format(x)): os.remove("{}".format(x)) if os.path.exists("{}.db".format(x)): os.remove("{}.db".format(x)) writer = FileWriter(CV_FILE_NAME, FILES_NUM) data = get_data(CV_DIR_NAME, True) cv_schema_json = {"id": {"type": "int32"}, "file_name": {"type": "string"}, "label": {"type": "int32"}, "data": {"type": "bytes"}} writer.add_schema(cv_schema_json, "img_schema") writer.add_index(["file_name", "label"]) writer.write_raw_data(data) writer.commit() yield "yield_cv_data" for x in paths: os.remove("{}".format(x)) os.remove("{}.db".format(x))
def data_to_mindrecord_byte_image(dataset="coco", is_training=True, prefix="fasterrcnn.mindrecord", file_num=8): """Create MindRecord file.""" mindrecord_dir = config.mindrecord_dir mindrecord_path = os.path.join(mindrecord_dir, prefix) writer = FileWriter(mindrecord_path, file_num) if dataset == "coco": image_files, image_anno_dict = create_coco_label(is_training) else: image_files, image_anno_dict = filter_valid_data(config.IMAGE_DIR, config.ANNO_PATH) fasterrcnn_json = { "image": {"type": "bytes"}, "annotation": {"type": "int32", "shape": [-1, 6]}, } writer.add_schema(fasterrcnn_json, "fasterrcnn_json") for image_name in image_files: with open(image_name, 'rb') as f: img = f.read() annos = np.array(image_anno_dict[image_name], dtype=np.int32) row = {"image": img, "annotation": annos} writer.write_raw_data([row]) writer.commit()
def write_to_mindrecord(self, path, train_mode, shard_num=1, desc="gnmt"): """ Write mindrecord file. Args: path (str): File path. shard_num (int): Shard num. desc (str): Description. """ if not os.path.isabs(path): path = os.path.abspath(path) writer = FileWriter(file_name=path, shard_num=shard_num) if train_mode: writer.add_schema(self._SCHEMA, desc) else: writer.add_schema(self._TEST_SCHEMA, desc) if not self._examples: self._load() writer.write_raw_data(self._examples) writer.commit() print(f"| Wrote to {path}.")
def test_issue_124(): """test file writer when data(string) do not match field type(bytes).""" shard_num = 4 writer = FileWriter(CV_FILE_NAME, shard_num) data = get_data("../data/mindrecord/testImageNetData/") cv_schema_json = {"file_name": {"type": "string"}, "label": {"type": "number"}, "data": {"type": "bytes"}} data.append({"file_name": "abcdefg", "label": 11, "data": str(data[0]["data"])}) writer.add_schema(cv_schema_json, "img_schema") writer.write_raw_data(data) writer.commit() reader = FileReader(CV_FILE_NAME + "0") for index, _ in enumerate(reader.get_next()): logger.info(index) reader.close() paths = ["{}{}".format(CV_FILE_NAME, str(x).rjust(1, '0')) for x in range(shard_num)] for item in paths: os.remove("{}".format(item)) os.remove("{}.db".format(item))
def test_issue_95(): """test file reader when failed on file write.""" writer = FileWriter(__file__, FILES_NUM) data_raw = get_data("../data/mindrecord/testImageNetData/") schema_json = {"file_name": {"type": "number"}, "label": {"type": "number"}, "data": {"type": "bytes"}, "data1": {"type": "string"}} writer.add_schema(schema_json, "img_schema") with pytest.raises(MRMAddIndexError): writer.add_index(["key"]) writer.write_raw_data(data_raw, True) writer.commit() reader = FileReader(__file__ + "1") for index, x in enumerate(reader.get_next()): logger.info("#item{}: {}".format(index, x)) reader.close() paths = ["{}{}".format(__file__, str(x).rjust(1, '0')) for x in range(FILES_NUM)] for x in paths: os.remove("{}".format(x)) os.remove("{}.db".format(x))
def test_issue_73(): """test file reader by column name.""" writer = FileWriter(MKV_FILE_NAME, FILES_NUM) data = get_mkv_data("../data/mindrecord/testVehPerData/") mkv_schema_json = {"file_name": {"type": "string"}, "id": {"type": "number"}, "prelabel": {"type": "string"}, "data": {"type": "bytes"}} writer.add_schema(mkv_schema_json, "mkv_schema") writer.add_index(["file_name", "prelabel"]) writer.write_raw_data(data) writer.commit() reader = FileReader(MKV_FILE_NAME + "1", 4, ["file_name"]) for index, x in enumerate(reader.get_next()): logger.info("#item{}: {}".format(index, x)) reader.close() paths = ["{}{}".format(MKV_FILE_NAME, str(x).rjust(1, '0')) for x in range(FILES_NUM)] for x in paths: os.remove("{}".format(x)) os.remove("{}.db".format(x))
def test_issue_36(): """test file writer when shard num is illegal.""" with pytest.raises(ParamValueError, match="Shard number should between "): writer = FileWriter(CV_FILE_NAME, -1) data = get_data("../data/mindrecord/testImageNetData/") cv_schema_json = { "file_name": { "type": "string" }, "label": { "type": "number" }, "data": { "type": "bytes" } } with pytest.raises(UnboundLocalError, match="local variable " "'writer' referenced before assignment"): writer.add_schema(cv_schema_json, "cv_schema") writer.add_index(["file_name", "label"]) writer.write_raw_data(data) writer.commit()
def convert_data_to_mindrecord(): '''Covert data to mindrecord.''' writer = FileWriter(mindrecord_file_name, mindrecord_num) attri_json = { "image": { "type": "bytes" }, "label": { "type": "int32", "shape": [-1] } } print('Loading train data...') total_data = [] with open(dataset_txt_file, 'r') as ft: lines = ft.readlines() for line in lines: sline = line.strip().split(" ") image_file = sline[0] labels = [] for item in sline[1:]: labels.append(int(item)) with open(image_file, 'rb') as f: img = f.read() data = {"image": img, "label": np.array(labels, dtype='int32')} total_data.append(data) print('Writing train data to mindrecord...') writer.add_schema(attri_json, "attri_json") if total_data is None: raise ValueError("None needs writing to mindrecord.") writer.write_raw_data(total_data) writer.commit()
def convert_yolo_data_to_mindrecord(): '''convert_yolo_data_to_mindrecord''' writer = FileWriter(mindrecord_file_name, mindrecord_num) yolo_json = { "image": { "type": "bytes" }, "annotation": { "type": "float64", "shape": [-1, 6] }, "image_name": { "type": "string" }, "image_size": { "type": "int32", "shape": [-1, 2] } } print('Loading eval data...') image_files, anno_files, image_names = prepare_file_paths() dataset_size = len(anno_files) assert dataset_size == len(image_files) assert dataset_size == len(image_names) logger.info("#size of dataset: {}".format(dataset_size)) data = [] for i in range(dataset_size): data.append(get_data(image_files[i], anno_files[i], image_names[i])) print('Writing eval data to mindrecord...') writer.add_schema(yolo_json, "yolo_json") if data is None: raise ValueError("None needs writing to mindrecord.") writer.write_raw_data(data) writer.commit()
def test_issue_117(): """test add schema when field type is incorrect.""" writer = FileWriter(__file__, FILES_NUM) schema = { "id": { "type": "string" }, "label": { "type": "number" }, "rating": { "type": "number" }, "input_ids": { "type": "list", "items": { "type": "number" } }, "input_mask": { "type": "array", "items": { "type": "number" } }, "segment_ids": { "type": "array", "items": { "type": "number" } } } with pytest.raises(Exception, match="Field '{'type': 'list', " "'items': {'type': 'number'}}' " "contains illegal attributes"): writer.add_schema(schema, "img_schema")
def test_shard_4_raw_data_1(): """test file writer when shard_num equals 4 and number of sample equals 1.""" writer = FileWriter(CV_FILE_NAME, FILES_NUM) schema_json = {"file_name": {"type": "string"}, "label": {"type": "number"}} writer.add_schema(schema_json, "img_schema") writer.add_index(["label"]) data = [{"file_name": "001.jpg", "label": 1}] writer.write_raw_data(data) writer.commit() reader = FileReader(CV_FILE_NAME + "0") count = 0 for index, x in enumerate(reader.get_next()): assert len(x) == 2 count = count + 1 logger.info("#item{}: {}".format(index, x)) assert count == 1 reader.close() paths = ["{}{}".format(CV_FILE_NAME, str(x).rjust(1, '0')) for x in range(FILES_NUM)] for x in paths: os.remove("{}".format(x)) os.remove("{}.db".format(x))
def test_write_two_images_mindrecord_whole_field(): """test two images to mindrecord""" if os.path.exists("{}".format(CV_FILE_NAME + ".db")): os.remove(CV_FILE_NAME + ".db") if os.path.exists("{}".format(CV_FILE_NAME)): os.remove(CV_FILE_NAME) writer = FileWriter(CV_FILE_NAME, FILES_NUM) data = get_two_bytes_data(MAP_FILE_NAME) cv_schema_json = { "id": { "type": "int32" }, "file_name": { "type": "string" }, "label_name": { "type": "string" }, "img_data": { "type": "bytes" }, "label_data": { "type": "bytes" } } writer.add_schema(cv_schema_json, "two_images_schema") writer.write_raw_data(data) writer.commit() assert os.path.exists(CV_FILE_NAME) assert os.path.exists(CV_FILE_NAME + ".db") read(CV_FILE_NAME, 5) if os.path.exists("{}".format(CV_FILE_NAME + ".db")): os.remove(CV_FILE_NAME + ".db") if os.path.exists("{}".format(CV_FILE_NAME)): os.remove(CV_FILE_NAME)
def test_issue_40(): """test cv dataset when write raw data twice.""" writer = FileWriter(CV_FILE_NAME, 1) data = get_data("../data/mindrecord/testImageNetData/") cv_schema_json = { "file_name": { "type": "string" }, "label": { "type": "number" }, "data": { "type": "bytes" } } writer.add_schema(cv_schema_json, "img_schema") writer.add_index(["file_name", "label"]) writer.write_raw_data(data) writer.write_raw_data(data) ret = writer.commit() assert ret == SUCCESS, 'failed on writing data!' os.remove("{}".format(CV_FILE_NAME)) os.remove("{}.db".format(CV_FILE_NAME))
def init_writer(mr_schema): """ init writer """ print("Init writer ...") mr_writer = FileWriter(args.mindrecord_file, args.mindrecord_partitions) # set the header size if args.mindrecord_header_size_by_bit != 24: header_size = 1 << args.mindrecord_header_size_by_bit mr_writer.set_header_size(header_size) # set the page size if args.mindrecord_page_size_by_bit != 25: page_size = 1 << args.mindrecord_page_size_by_bit mr_writer.set_page_size(page_size) # create the schema mr_writer.add_schema(mr_schema, "mindrecord_graph_schema") # open file and set header mr_writer.open_and_set_header() return mr_writer
def test_add_index_without_add_schema(): with pytest.raises(MRMGetMetaError) as err: fw = FileWriter(CV_FILE_NAME) fw.add_index(["label"]) assert 'Failed to get meta info' in str(err.value)
def test_cv_file_writer_shard_num_greater_than_1000(): """test cv file writer shard number greater than 1000.""" with pytest.raises(ParamValueError) as err: FileWriter(CV_FILE_NAME, 1001) assert 'Shard number should between' in str(err.value)
def random_split_trans2mindrecord(input_file_path, output_file_path, criteo_stats_dict, part_rows=2000000, line_per_sample=1000, test_size=0.1, seed=2020): """Random split data and save mindrecord""" test_size = int(TRAIN_LINE_COUNT * test_size) all_indices = [i for i in range(TRAIN_LINE_COUNT)] np.random.seed(seed) np.random.shuffle(all_indices) print("all_indices.size:{}".format(len(all_indices))) test_indices_set = set(all_indices[:test_size]) print("test_indices_set.size:{}".format(len(test_indices_set))) print("-----------------------" * 10 + "\n" * 2) train_data_list = [] test_data_list = [] ids_list = [] wts_list = [] label_list = [] writer_train = FileWriter( os.path.join(output_file_path, "train_input_part.mindrecord"), 21) writer_test = FileWriter( os.path.join(output_file_path, "test_input_part.mindrecord"), 3) schema = { "label": { "type": "float32", "shape": [-1] }, "feat_vals": { "type": "float32", "shape": [-1] }, "feat_ids": { "type": "int32", "shape": [-1] } } writer_train.add_schema(schema, "CRITEO_TRAIN") writer_test.add_schema(schema, "CRITEO_TEST") with open(input_file_path, encoding="utf-8") as file_in: items_error_size_lineCount = [] count = 0 train_part_number = 0 test_part_number = 0 for i, line in enumerate(file_in): count += 1 if count % 1000000 == 0: print("Have handle {}w lines.".format(count // 10000)) line = line.strip("\n") items = line.split("\t") if len(items) != 40: items_error_size_lineCount.append(i) continue label = float(items[0]) values = items[1:14] cats = items[14:] assert len(values) == 13, "values.size: {}".format(len(values)) assert len(cats) == 26, "cats.size: {}".format(len(cats)) ids, wts = criteo_stats_dict.map_cat2id(values, cats) ids_list.extend(ids) wts_list.extend(wts) label_list.append(label) if count % line_per_sample == 0: if i not in test_indices_set: train_data_list.append({ "feat_ids": np.array(ids_list, dtype=np.int32), "feat_vals": np.array(wts_list, dtype=np.float32), "label": np.array(label_list, dtype=np.float32) }) else: test_data_list.append({ "feat_ids": np.array(ids_list, dtype=np.int32), "feat_vals": np.array(wts_list, dtype=np.float32), "label": np.array(label_list, dtype=np.float32) }) if train_data_list and len(train_data_list) % part_rows == 0: writer_train.write_raw_data(train_data_list) train_data_list.clear() train_part_number += 1 if test_data_list and len(test_data_list) % part_rows == 0: writer_test.write_raw_data(test_data_list) test_data_list.clear() test_part_number += 1 ids_list.clear() wts_list.clear() label_list.clear() if train_data_list: writer_train.write_raw_data(train_data_list) if test_data_list: writer_test.write_raw_data(test_data_list) writer_train.commit() writer_test.commit() print("------" * 5) print("items_error_size_lineCount.size(): {}.".format( len(items_error_size_lineCount))) print("------" * 5) np.save(os.path.join(output_file_path, "items_error_size_lineCount.npy"), items_error_size_lineCount)
def test_write_read_process_with_multi_array(): mindrecord_file_name = "test.mindrecord" data = [{ "source_sos_ids": np.array([1, 2, 3, 4, 5], dtype=np.int64), "source_sos_mask": np.array([6, 7, 8, 9, 10, 11, 12], dtype=np.int64), "source_eos_ids": np.array([13, 14, 15, 16, 17, 18], dtype=np.int64), "source_eos_mask": np.array([19, 20, 21, 22, 23, 24, 25, 26, 27], dtype=np.int64), "target_sos_ids": np.array([28, 29, 30, 31, 32], dtype=np.int64), "target_sos_mask": np.array([33, 34, 35, 36, 37, 38], dtype=np.int64), "target_eos_ids": np.array([39, 40, 41, 42, 43, 44, 45, 46, 47], dtype=np.int64), "target_eos_mask": np.array([48, 49, 50, 51], dtype=np.int64) }, { "source_sos_ids": np.array([11, 2, 3, 4, 5], dtype=np.int64), "source_sos_mask": np.array([16, 7, 8, 9, 10, 11, 12], dtype=np.int64), "source_eos_ids": np.array([113, 14, 15, 16, 17, 18], dtype=np.int64), "source_eos_mask": np.array([119, 20, 21, 22, 23, 24, 25, 26, 27], dtype=np.int64), "target_sos_ids": np.array([128, 29, 30, 31, 32], dtype=np.int64), "target_sos_mask": np.array([133, 34, 35, 36, 37, 38], dtype=np.int64), "target_eos_ids": np.array([139, 40, 41, 42, 43, 44, 45, 46, 47], dtype=np.int64), "target_eos_mask": np.array([148, 49, 50, 51], dtype=np.int64) }, { "source_sos_ids": np.array([21, 2, 3, 4, 5], dtype=np.int64), "source_sos_mask": np.array([26, 7, 8, 9, 10, 11, 12], dtype=np.int64), "source_eos_ids": np.array([213, 14, 15, 16, 17, 18], dtype=np.int64), "source_eos_mask": np.array([219, 20, 21, 22, 23, 24, 25, 26, 27], dtype=np.int64), "target_sos_ids": np.array([228, 29, 30, 31, 32], dtype=np.int64), "target_sos_mask": np.array([233, 34, 35, 36, 37, 38], dtype=np.int64), "target_eos_ids": np.array([239, 40, 41, 42, 43, 44, 45, 46, 47], dtype=np.int64), "target_eos_mask": np.array([248, 49, 50, 51], dtype=np.int64) }, { "source_sos_ids": np.array([31, 2, 3, 4, 5], dtype=np.int64), "source_sos_mask": np.array([36, 7, 8, 9, 10, 11, 12], dtype=np.int64), "source_eos_ids": np.array([313, 14, 15, 16, 17, 18], dtype=np.int64), "source_eos_mask": np.array([319, 20, 21, 22, 23, 24, 25, 26, 27], dtype=np.int64), "target_sos_ids": np.array([328, 29, 30, 31, 32], dtype=np.int64), "target_sos_mask": np.array([333, 34, 35, 36, 37, 38], dtype=np.int64), "target_eos_ids": np.array([339, 40, 41, 42, 43, 44, 45, 46, 47], dtype=np.int64), "target_eos_mask": np.array([348, 49, 50, 51], dtype=np.int64) }, { "source_sos_ids": np.array([41, 2, 3, 4, 5], dtype=np.int64), "source_sos_mask": np.array([46, 7, 8, 9, 10, 11, 12], dtype=np.int64), "source_eos_ids": np.array([413, 14, 15, 16, 17, 18], dtype=np.int64), "source_eos_mask": np.array([419, 20, 21, 22, 23, 24, 25, 26, 27], dtype=np.int64), "target_sos_ids": np.array([428, 29, 30, 31, 32], dtype=np.int64), "target_sos_mask": np.array([433, 34, 35, 36, 37, 38], dtype=np.int64), "target_eos_ids": np.array([439, 40, 41, 42, 43, 44, 45, 46, 47], dtype=np.int64), "target_eos_mask": np.array([448, 49, 50, 51], dtype=np.int64) }, { "source_sos_ids": np.array([51, 2, 3, 4, 5], dtype=np.int64), "source_sos_mask": np.array([56, 7, 8, 9, 10, 11, 12], dtype=np.int64), "source_eos_ids": np.array([513, 14, 15, 16, 17, 18], dtype=np.int64), "source_eos_mask": np.array([519, 20, 21, 22, 23, 24, 25, 26, 27], dtype=np.int64), "target_sos_ids": np.array([528, 29, 30, 31, 32], dtype=np.int64), "target_sos_mask": np.array([533, 34, 35, 36, 37, 38], dtype=np.int64), "target_eos_ids": np.array([539, 40, 41, 42, 43, 44, 45, 46, 47], dtype=np.int64), "target_eos_mask": np.array([548, 49, 50, 51], dtype=np.int64) }] writer = FileWriter(mindrecord_file_name) schema = { "source_sos_ids": { "type": "int64", "shape": [-1] }, "source_sos_mask": { "type": "int64", "shape": [-1] }, "source_eos_ids": { "type": "int64", "shape": [-1] }, "source_eos_mask": { "type": "int64", "shape": [-1] }, "target_sos_ids": { "type": "int64", "shape": [-1] }, "target_sos_mask": { "type": "int64", "shape": [-1] }, "target_eos_ids": { "type": "int64", "shape": [-1] }, "target_eos_mask": { "type": "int64", "shape": [-1] } } writer.add_schema(schema, "data is so cool") writer.write_raw_data(data) writer.commit() reader = FileReader(mindrecord_file_name) count = 0 for index, x in enumerate(reader.get_next()): assert len(x) == 8 for field in x: if isinstance(x[field], np.ndarray): assert (x[field] == data[count][field]).all() else: assert x[field] == data[count][field] count = count + 1 logger.info("#item{}: {}".format(index, x)) assert count == 6 reader.close() reader = FileReader(file_name=mindrecord_file_name, columns=[ "source_eos_ids", "source_eos_mask", "target_sos_ids", "target_sos_mask", "target_eos_ids", "target_eos_mask" ]) count = 0 for index, x in enumerate(reader.get_next()): assert len(x) == 6 for field in x: if isinstance(x[field], np.ndarray): assert (x[field] == data[count][field]).all() else: assert x[field] == data[count][field] count = count + 1 logger.info("#item{}: {}".format(index, x)) assert count == 6 reader.close() reader = FileReader( file_name=mindrecord_file_name, columns=["source_sos_ids", "target_sos_ids", "target_eos_mask"]) count = 0 for index, x in enumerate(reader.get_next()): assert len(x) == 3 for field in x: if isinstance(x[field], np.ndarray): assert (x[field] == data[count][field]).all() else: assert x[field] == data[count][field] count = count + 1 logger.info("#item{}: {}".format(index, x)) assert count == 6 reader.close() reader = FileReader( file_name=mindrecord_file_name, columns=["target_eos_mask", "source_eos_mask", "source_sos_mask"]) count = 0 for index, x in enumerate(reader.get_next()): assert len(x) == 3 for field in x: if isinstance(x[field], np.ndarray): assert (x[field] == data[count][field]).all() else: assert x[field] == data[count][field] count = count + 1 logger.info("#item{}: {}".format(index, x)) assert count == 6 reader.close() reader = FileReader(file_name=mindrecord_file_name, columns=["target_eos_ids"]) count = 0 for index, x in enumerate(reader.get_next()): assert len(x) == 1 for field in x: if isinstance(x[field], np.ndarray): assert (x[field] == data[count][field]).all() else: assert x[field] == data[count][field] count = count + 1 logger.info("#item{}: {}".format(index, x)) assert count == 6 reader.close() os.remove("{}".format(mindrecord_file_name)) os.remove("{}.db".format(mindrecord_file_name))
def test_write_read_process_with_multi_bytes(): mindrecord_file_name = "test.mindrecord" data = [{ "file_name": "001.jpg", "label": 43, "image1": bytes("image1 bytes abc", encoding='UTF-8'), "image2": bytes("image1 bytes def", encoding='UTF-8'), "image3": bytes("image1 bytes ghi", encoding='UTF-8'), "image4": bytes("image1 bytes jkl", encoding='UTF-8'), "image5": bytes("image1 bytes mno", encoding='UTF-8') }, { "file_name": "002.jpg", "label": 91, "image1": bytes("image2 bytes abc", encoding='UTF-8'), "image2": bytes("image2 bytes def", encoding='UTF-8'), "image3": bytes("image2 bytes ghi", encoding='UTF-8'), "image4": bytes("image2 bytes jkl", encoding='UTF-8'), "image5": bytes("image2 bytes mno", encoding='UTF-8') }, { "file_name": "003.jpg", "label": 61, "image1": bytes("image3 bytes abc", encoding='UTF-8'), "image2": bytes("image3 bytes def", encoding='UTF-8'), "image3": bytes("image3 bytes ghi", encoding='UTF-8'), "image4": bytes("image3 bytes jkl", encoding='UTF-8'), "image5": bytes("image3 bytes mno", encoding='UTF-8') }, { "file_name": "004.jpg", "label": 29, "image1": bytes("image4 bytes abc", encoding='UTF-8'), "image2": bytes("image4 bytes def", encoding='UTF-8'), "image3": bytes("image4 bytes ghi", encoding='UTF-8'), "image4": bytes("image4 bytes jkl", encoding='UTF-8'), "image5": bytes("image4 bytes mno", encoding='UTF-8') }, { "file_name": "005.jpg", "label": 78, "image1": bytes("image5 bytes abc", encoding='UTF-8'), "image2": bytes("image5 bytes def", encoding='UTF-8'), "image3": bytes("image5 bytes ghi", encoding='UTF-8'), "image4": bytes("image5 bytes jkl", encoding='UTF-8'), "image5": bytes("image5 bytes mno", encoding='UTF-8') }, { "file_name": "006.jpg", "label": 37, "image1": bytes("image6 bytes abc", encoding='UTF-8'), "image2": bytes("image6 bytes def", encoding='UTF-8'), "image3": bytes("image6 bytes ghi", encoding='UTF-8'), "image4": bytes("image6 bytes jkl", encoding='UTF-8'), "image5": bytes("image6 bytes mno", encoding='UTF-8') }] writer = FileWriter(mindrecord_file_name) schema = { "file_name": { "type": "string" }, "image1": { "type": "bytes" }, "image2": { "type": "bytes" }, "image3": { "type": "bytes" }, "label": { "type": "int32" }, "image4": { "type": "bytes" }, "image5": { "type": "bytes" } } writer.add_schema(schema, "data is so cool") writer.write_raw_data(data) writer.commit() reader = FileReader(mindrecord_file_name) count = 0 for index, x in enumerate(reader.get_next()): assert len(x) == 7 for field in x: if isinstance(x[field], np.ndarray): assert (x[field] == data[count][field]).all() else: assert x[field] == data[count][field] count = count + 1 logger.info("#item{}: {}".format(index, x)) assert count == 6 reader.close() reader2 = FileReader(file_name=mindrecord_file_name, columns=["image1", "image2", "image5"]) count = 0 for index, x in enumerate(reader2.get_next()): assert len(x) == 3 for field in x: if isinstance(x[field], np.ndarray): assert (x[field] == data[count][field]).all() else: assert x[field] == data[count][field] count = count + 1 logger.info("#item{}: {}".format(index, x)) assert count == 6 reader2.close() reader3 = FileReader(file_name=mindrecord_file_name, columns=["image2", "image4"]) count = 0 for index, x in enumerate(reader3.get_next()): assert len(x) == 2 for field in x: if isinstance(x[field], np.ndarray): assert (x[field] == data[count][field]).all() else: assert x[field] == data[count][field] count = count + 1 logger.info("#item{}: {}".format(index, x)) assert count == 6 reader3.close() reader4 = FileReader(file_name=mindrecord_file_name, columns=["image5", "image2"]) count = 0 for index, x in enumerate(reader4.get_next()): assert len(x) == 2 for field in x: if isinstance(x[field], np.ndarray): assert (x[field] == data[count][field]).all() else: assert x[field] == data[count][field] count = count + 1 logger.info("#item{}: {}".format(index, x)) assert count == 6 reader4.close() reader5 = FileReader(file_name=mindrecord_file_name, columns=["image5", "image2", "label"]) count = 0 for index, x in enumerate(reader5.get_next()): assert len(x) == 3 for field in x: if isinstance(x[field], np.ndarray): assert (x[field] == data[count][field]).all() else: assert x[field] == data[count][field] count = count + 1 logger.info("#item{}: {}".format(index, x)) assert count == 6 reader5.close() os.remove("{}".format(mindrecord_file_name)) os.remove("{}.db".format(mindrecord_file_name))
def test_write_read_process(): mindrecord_file_name = "test.mindrecord" data = [{ "file_name": "001.jpg", "label": 43, "score": 0.8, "mask": np.array([3, 6, 9], dtype=np.int64), "segments": np.array([[5.0, 1.6], [65.2, 8.3]], dtype=np.float32), "data": bytes("image bytes abc", encoding='UTF-8') }, { "file_name": "002.jpg", "label": 91, "score": 5.4, "mask": np.array([1, 4, 7], dtype=np.int64), "segments": np.array([[5.1, 9.1], [2.0, 65.4]], dtype=np.float32), "data": bytes("image bytes def", encoding='UTF-8') }, { "file_name": "003.jpg", "label": 61, "score": 6.4, "mask": np.array([7, 6, 3], dtype=np.int64), "segments": np.array([[0.0, 5.6], [3.0, 16.3]], dtype=np.float32), "data": bytes("image bytes ghi", encoding='UTF-8') }, { "file_name": "004.jpg", "label": 29, "score": 8.1, "mask": np.array([2, 8, 0], dtype=np.int64), "segments": np.array([[5.9, 7.2], [4.0, 89.0]], dtype=np.float32), "data": bytes("image bytes jkl", encoding='UTF-8') }, { "file_name": "005.jpg", "label": 78, "score": 7.7, "mask": np.array([3, 1, 2], dtype=np.int64), "segments": np.array([[0.6, 8.1], [5.3, 49.3]], dtype=np.float32), "data": bytes("image bytes mno", encoding='UTF-8') }, { "file_name": "006.jpg", "label": 37, "score": 9.4, "mask": np.array([7, 6, 7], dtype=np.int64), "segments": np.array([[4.2, 6.3], [8.9, 81.8]], dtype=np.float32), "data": bytes("image bytes pqr", encoding='UTF-8') }] writer = FileWriter(mindrecord_file_name) schema = { "file_name": { "type": "string" }, "label": { "type": "int32" }, "score": { "type": "float64" }, "mask": { "type": "int64", "shape": [-1] }, "segments": { "type": "float32", "shape": [2, 2] }, "data": { "type": "bytes" } } writer.add_schema(schema, "data is so cool") writer.write_raw_data(data) writer.commit() reader = FileReader(mindrecord_file_name) count = 0 for index, x in enumerate(reader.get_next()): assert len(x) == 6 for field in x: if isinstance(x[field], np.ndarray): assert (x[field] == data[count][field]).all() else: assert x[field] == data[count][field] count = count + 1 logger.info("#item{}: {}".format(index, x)) assert count == 6 reader.close() os.remove("{}".format(mindrecord_file_name)) os.remove("{}.db".format(mindrecord_file_name))
def test_write_read_process_with_multi_bytes_and_array(): mindrecord_file_name = "test.mindrecord" data = [{ "file_name": "001.jpg", "label": 4, "image1": bytes("image1 bytes abc", encoding='UTF-8'), "image2": bytes("image1 bytes def", encoding='UTF-8'), "source_sos_ids": np.array([1, 2, 3, 4, 5], dtype=np.int64), "source_sos_mask": np.array([6, 7, 8, 9, 10, 11, 12], dtype=np.int64), "image3": bytes("image1 bytes ghi", encoding='UTF-8'), "image4": bytes("image1 bytes jkl", encoding='UTF-8'), "image5": bytes("image1 bytes mno", encoding='UTF-8'), "target_sos_ids": np.array([28, 29, 30, 31, 32], dtype=np.int64), "target_sos_mask": np.array([33, 34, 35, 36, 37, 38], dtype=np.int64), "target_eos_ids": np.array([39, 40, 41, 42, 43, 44, 45, 46, 47], dtype=np.int64), "target_eos_mask": np.array([48, 49, 50, 51], dtype=np.int64) }, { "file_name": "002.jpg", "label": 5, "image1": bytes("image2 bytes abc", encoding='UTF-8'), "image2": bytes("image2 bytes def", encoding='UTF-8'), "image3": bytes("image2 bytes ghi", encoding='UTF-8'), "image4": bytes("image2 bytes jkl", encoding='UTF-8'), "image5": bytes("image2 bytes mno", encoding='UTF-8'), "source_sos_ids": np.array([11, 2, 3, 4, 5], dtype=np.int64), "source_sos_mask": np.array([16, 7, 8, 9, 10, 11, 12], dtype=np.int64), "target_sos_ids": np.array([128, 29, 30, 31, 32], dtype=np.int64), "target_sos_mask": np.array([133, 34, 35, 36, 37, 38], dtype=np.int64), "target_eos_ids": np.array([139, 40, 41, 42, 43, 44, 45, 46, 47], dtype=np.int64), "target_eos_mask": np.array([148, 49, 50, 51], dtype=np.int64) }, { "file_name": "003.jpg", "label": 6, "source_sos_ids": np.array([21, 2, 3, 4, 5], dtype=np.int64), "source_sos_mask": np.array([26, 7, 8, 9, 10, 11, 12], dtype=np.int64), "target_sos_ids": np.array([228, 29, 30, 31, 32], dtype=np.int64), "target_sos_mask": np.array([233, 34, 35, 36, 37, 38], dtype=np.int64), "target_eos_ids": np.array([239, 40, 41, 42, 43, 44, 45, 46, 47], dtype=np.int64), "image1": bytes("image3 bytes abc", encoding='UTF-8'), "image2": bytes("image3 bytes def", encoding='UTF-8'), "image3": bytes("image3 bytes ghi", encoding='UTF-8'), "image4": bytes("image3 bytes jkl", encoding='UTF-8'), "image5": bytes("image3 bytes mno", encoding='UTF-8'), "target_eos_mask": np.array([248, 49, 50, 51], dtype=np.int64) }, { "file_name": "004.jpg", "label": 7, "source_sos_ids": np.array([31, 2, 3, 4, 5], dtype=np.int64), "source_sos_mask": np.array([36, 7, 8, 9, 10, 11, 12], dtype=np.int64), "image1": bytes("image4 bytes abc", encoding='UTF-8'), "image2": bytes("image4 bytes def", encoding='UTF-8'), "image3": bytes("image4 bytes ghi", encoding='UTF-8'), "image4": bytes("image4 bytes jkl", encoding='UTF-8'), "image5": bytes("image4 bytes mno", encoding='UTF-8'), "target_sos_ids": np.array([328, 29, 30, 31, 32], dtype=np.int64), "target_sos_mask": np.array([333, 34, 35, 36, 37, 38], dtype=np.int64), "target_eos_ids": np.array([339, 40, 41, 42, 43, 44, 45, 46, 47], dtype=np.int64), "target_eos_mask": np.array([348, 49, 50, 51], dtype=np.int64) }, { "file_name": "005.jpg", "label": 8, "source_sos_ids": np.array([41, 2, 3, 4, 5], dtype=np.int64), "source_sos_mask": np.array([46, 7, 8, 9, 10, 11, 12], dtype=np.int64), "target_sos_ids": np.array([428, 29, 30, 31, 32], dtype=np.int64), "target_sos_mask": np.array([433, 34, 35, 36, 37, 38], dtype=np.int64), "image1": bytes("image5 bytes abc", encoding='UTF-8'), "image2": bytes("image5 bytes def", encoding='UTF-8'), "image3": bytes("image5 bytes ghi", encoding='UTF-8'), "image4": bytes("image5 bytes jkl", encoding='UTF-8'), "image5": bytes("image5 bytes mno", encoding='UTF-8'), "target_eos_ids": np.array([439, 40, 41, 42, 43, 44, 45, 46, 47], dtype=np.int64), "target_eos_mask": np.array([448, 49, 50, 51], dtype=np.int64) }, { "file_name": "006.jpg", "label": 9, "source_sos_ids": np.array([51, 2, 3, 4, 5], dtype=np.int64), "source_sos_mask": np.array([56, 7, 8, 9, 10, 11, 12], dtype=np.int64), "target_sos_ids": np.array([528, 29, 30, 31, 32], dtype=np.int64), "image1": bytes("image6 bytes abc", encoding='UTF-8'), "image2": bytes("image6 bytes def", encoding='UTF-8'), "image3": bytes("image6 bytes ghi", encoding='UTF-8'), "image4": bytes("image6 bytes jkl", encoding='UTF-8'), "image5": bytes("image6 bytes mno", encoding='UTF-8'), "target_sos_mask": np.array([533, 34, 35, 36, 37, 38], dtype=np.int64), "target_eos_ids": np.array([539, 40, 41, 42, 43, 44, 45, 46, 47], dtype=np.int64), "target_eos_mask": np.array([548, 49, 50, 51], dtype=np.int64) }] writer = FileWriter(mindrecord_file_name) schema = { "file_name": { "type": "string" }, "image1": { "type": "bytes" }, "image2": { "type": "bytes" }, "source_sos_ids": { "type": "int64", "shape": [-1] }, "source_sos_mask": { "type": "int64", "shape": [-1] }, "image3": { "type": "bytes" }, "image4": { "type": "bytes" }, "image5": { "type": "bytes" }, "target_sos_ids": { "type": "int64", "shape": [-1] }, "target_sos_mask": { "type": "int64", "shape": [-1] }, "target_eos_ids": { "type": "int64", "shape": [-1] }, "target_eos_mask": { "type": "int64", "shape": [-1] }, "label": { "type": "int32" } } writer.add_schema(schema, "data is so cool") writer.write_raw_data(data) writer.commit() reader = FileReader(mindrecord_file_name) count = 0 for index, x in enumerate(reader.get_next()): assert len(x) == 13 for field in x: if isinstance(x[field], np.ndarray): assert (x[field] == data[count][field]).all() else: assert x[field] == data[count][field] count = count + 1 logger.info("#item{}: {}".format(index, x)) assert count == 6 reader.close() reader = FileReader( file_name=mindrecord_file_name, columns=["source_sos_ids", "source_sos_mask", "target_sos_ids"]) count = 0 for index, x in enumerate(reader.get_next()): assert len(x) == 3 for field in x: if isinstance(x[field], np.ndarray): assert (x[field] == data[count][field]).all() else: assert x[field] == data[count][field] count = count + 1 logger.info("#item{}: {}".format(index, x)) assert count == 6 reader.close() reader = FileReader( file_name=mindrecord_file_name, columns=["image2", "source_sos_mask", "image3", "target_sos_ids"]) count = 0 for index, x in enumerate(reader.get_next()): assert len(x) == 4 for field in x: if isinstance(x[field], np.ndarray): assert (x[field] == data[count][field]).all() else: assert x[field] == data[count][field] count = count + 1 logger.info("#item{}: {}".format(index, x)) assert count == 6 reader.close() reader = FileReader(file_name=mindrecord_file_name, columns=["target_sos_ids", "image4", "source_sos_ids"]) count = 0 for index, x in enumerate(reader.get_next()): assert len(x) == 3 for field in x: if isinstance(x[field], np.ndarray): assert (x[field] == data[count][field]).all() else: assert x[field] == data[count][field] count = count + 1 logger.info("#item{}: {}".format(index, x)) assert count == 6 reader.close() reader = FileReader(file_name=mindrecord_file_name, columns=[ "target_sos_ids", "image5", "image4", "image3", "source_sos_ids" ]) count = 0 for index, x in enumerate(reader.get_next()): assert len(x) == 5 for field in x: if isinstance(x[field], np.ndarray): assert (x[field] == data[count][field]).all() else: assert x[field] == data[count][field] count = count + 1 logger.info("#item{}: {}".format(index, x)) assert count == 6 reader.close() reader = FileReader(file_name=mindrecord_file_name, columns=[ "target_eos_mask", "image5", "image2", "source_sos_mask", "label" ]) count = 0 for index, x in enumerate(reader.get_next()): assert len(x) == 5 for field in x: if isinstance(x[field], np.ndarray): assert (x[field] == data[count][field]).all() else: assert x[field] == data[count][field] count = count + 1 logger.info("#item{}: {}".format(index, x)) assert count == 6 reader.close() os.remove("{}".format(mindrecord_file_name)) os.remove("{}.db".format(mindrecord_file_name))
train_2 = X[train_idx] # (110000, 2, 128), float32 test_2 = X[test_idx] # (110000, 2, 128), float32 train_label = np.array(list(map(lambda x: mods.index(lbl[x][0]), train_idx))).astype( np.int32) # (110000,), int32 test_label = np.array(list(map(lambda x: mods.index(lbl[x][0]), test_idx))).astype(np.int32) # (110000,), int32 print(train_2.shape, train_2.dtype, train_label.shape, train_label.dtype, test_2.shape, test_2.dtype, test_label.shape, test_label.dtype) # get train MindRecord MINDRECORD_FILE = "/home/huawei/data/IQ_signal/RML2016.10b_train.mindrecord" if os.path.exists(MINDRECORD_FILE): os.remove(MINDRECORD_FILE) os.remove(MINDRECORD_FILE + ".db") writer = FileWriter(file_name=MINDRECORD_FILE, shard_num=1) npy_schema = { "data": { "type": "float32", "shape": [1, train_2.shape[1], train_2.shape[2]] }, "label": { "type": "int32" } } writer.add_schema(npy_schema, "it is a RML2016.10b IQ signal train dataset") data = [] for i in tqdm(range(train_2.shape[0])): sample = {"data": train_2[i:i + 1, :, :], "label": train_label[i]} data.append(sample) i += 1
def test_cv_file_writer_shard_num_str(): """test cv file writer when shard num is string.""" with pytest.raises(Exception, match="Shard num is illegal."): FileWriter("/tmp/123454321", "20")
def test_case_00(add_and_remove_cv_file): # only bin data data = [{ "image1": bytes("image1 bytes abc", encoding='UTF-8'), "image2": bytes("image1 bytes def", encoding='UTF-8'), "image3": bytes("image1 bytes ghi", encoding='UTF-8'), "image4": bytes("image1 bytes jkl", encoding='UTF-8'), "image5": bytes("image1 bytes mno", encoding='UTF-8') }, { "image1": bytes("image2 bytes abc", encoding='UTF-8'), "image2": bytes("image2 bytes def", encoding='UTF-8'), "image3": bytes("image2 bytes ghi", encoding='UTF-8'), "image4": bytes("image2 bytes jkl", encoding='UTF-8'), "image5": bytes("image2 bytes mno", encoding='UTF-8') }, { "image1": bytes("image3 bytes abc", encoding='UTF-8'), "image2": bytes("image3 bytes def", encoding='UTF-8'), "image3": bytes("image3 bytes ghi", encoding='UTF-8'), "image4": bytes("image3 bytes jkl", encoding='UTF-8'), "image5": bytes("image3 bytes mno", encoding='UTF-8') }, { "image1": bytes("image5 bytes abc", encoding='UTF-8'), "image2": bytes("image5 bytes def", encoding='UTF-8'), "image3": bytes("image5 bytes ghi", encoding='UTF-8'), "image4": bytes("image5 bytes jkl", encoding='UTF-8'), "image5": bytes("image5 bytes mno", encoding='UTF-8') }, { "image1": bytes("image6 bytes abc", encoding='UTF-8'), "image2": bytes("image6 bytes def", encoding='UTF-8'), "image3": bytes("image6 bytes ghi", encoding='UTF-8'), "image4": bytes("image6 bytes jkl", encoding='UTF-8'), "image5": bytes("image6 bytes mno", encoding='UTF-8') }] schema = { "image1": { "type": "bytes" }, "image2": { "type": "bytes" }, "image3": { "type": "bytes" }, "image4": { "type": "bytes" }, "image5": { "type": "bytes" } } writer = FileWriter(CV_FILE_NAME1, FILES_NUM) writer.add_schema(schema, "schema") writer.write_raw_data(data) writer.commit() d1 = ds.MindDataset(CV_FILE_NAME1, None, num_readers, shuffle=False) d1.save(CV_FILE_NAME2, FILES_NUM) data_value_to_list = [] for item in data: new_data = {} new_data['image1'] = np.asarray(list(item["image1"]), dtype=np.uint8) new_data['image2'] = np.asarray(list(item["image2"]), dtype=np.uint8) new_data['image3'] = np.asarray(list(item["image3"]), dtype=np.uint8) new_data['image4'] = np.asarray(list(item["image4"]), dtype=np.uint8) new_data['image5'] = np.asarray(list(item["image5"]), dtype=np.uint8) data_value_to_list.append(new_data) d2 = ds.MindDataset(dataset_file=CV_FILE_NAME2, num_parallel_workers=num_readers, shuffle=False) assert d2.get_dataset_size() == 5 num_iter = 0 for item in d2.create_dict_iterator(num_epochs=1, output_numpy=True): assert len(item) == 5 for field in item: if isinstance(item[field], np.ndarray): assert ( item[field] == data_value_to_list[num_iter][field]).all() else: assert item[field] == data_value_to_list[num_iter][field] num_iter += 1 assert num_iter == 5
def test_case_02(add_and_remove_cv_file): # muti-bytes data = [{ "file_name": "001.jpg", "label": 43, "float32_array": np.array([1.2, 2.78, 3.1234, 4.9871, 5.12341], dtype=np.float32), "float64_array": np.array([ 48.1234556789, 49.3251241431, 50.13514312414, 51.8971298471, 123414314.2141243, 87.1212122 ], dtype=np.float64), "float32": 3456.12345, "float64": 1987654321.123456785, "source_sos_ids": np.array([1, 2, 3, 4, 5], dtype=np.int32), "source_sos_mask": np.array([6, 7, 8, 9, 10, 11, 12], dtype=np.int64), "image1": bytes("image1 bytes abc", encoding='UTF-8'), "image2": bytes("image1 bytes def", encoding='UTF-8'), "image3": bytes("image1 bytes ghi", encoding='UTF-8'), "image4": bytes("image1 bytes jkl", encoding='UTF-8'), "image5": bytes("image1 bytes mno", encoding='UTF-8') }, { "file_name": "002.jpg", "label": 91, "float32_array": np.array([1.2, 2.78, 4.1234, 4.9871, 5.12341], dtype=np.float32), "float64_array": np.array([ 48.1234556789, 49.3251241431, 60.13514312414, 51.8971298471, 123414314.2141243, 87.1212122 ], dtype=np.float64), "float32": 3456.12445, "float64": 1987654321.123456786, "source_sos_ids": np.array([11, 2, 3, 4, 5], dtype=np.int32), "source_sos_mask": np.array([16, 7, 8, 9, 10, 11, 12], dtype=np.int64), "image1": bytes("image2 bytes abc", encoding='UTF-8'), "image2": bytes("image2 bytes def", encoding='UTF-8'), "image3": bytes("image2 bytes ghi", encoding='UTF-8'), "image4": bytes("image2 bytes jkl", encoding='UTF-8'), "image5": bytes("image2 bytes mno", encoding='UTF-8') }, { "file_name": "003.jpg", "label": 61, "float32_array": np.array([1.2, 2.78, 5.1234, 4.9871, 5.12341], dtype=np.float32), "float64_array": np.array([ 48.1234556789, 49.3251241431, 70.13514312414, 51.8971298471, 123414314.2141243, 87.1212122 ], dtype=np.float64), "float32": 3456.12545, "float64": 1987654321.123456787, "source_sos_ids": np.array([21, 2, 3, 4, 5], dtype=np.int32), "source_sos_mask": np.array([26, 7, 8, 9, 10, 11, 12], dtype=np.int64), "image1": bytes("image3 bytes abc", encoding='UTF-8'), "image2": bytes("image3 bytes def", encoding='UTF-8'), "image3": bytes("image3 bytes ghi", encoding='UTF-8'), "image4": bytes("image3 bytes jkl", encoding='UTF-8'), "image5": bytes("image3 bytes mno", encoding='UTF-8') }, { "file_name": "004.jpg", "label": 29, "float32_array": np.array([1.2, 2.78, 6.1234, 4.9871, 5.12341], dtype=np.float32), "float64_array": np.array([ 48.1234556789, 49.3251241431, 80.13514312414, 51.8971298471, 123414314.2141243, 87.1212122 ], dtype=np.float64), "float32": 3456.12645, "float64": 1987654321.123456788, "source_sos_ids": np.array([31, 2, 3, 4, 5], dtype=np.int32), "source_sos_mask": np.array([36, 7, 8, 9, 10, 11, 12], dtype=np.int64), "image1": bytes("image4 bytes abc", encoding='UTF-8'), "image2": bytes("image4 bytes def", encoding='UTF-8'), "image3": bytes("image4 bytes ghi", encoding='UTF-8'), "image4": bytes("image4 bytes jkl", encoding='UTF-8'), "image5": bytes("image4 bytes mno", encoding='UTF-8') }, { "file_name": "005.jpg", "label": 78, "float32_array": np.array([1.2, 2.78, 7.1234, 4.9871, 5.12341], dtype=np.float32), "float64_array": np.array([ 48.1234556789, 49.3251241431, 90.13514312414, 51.8971298471, 123414314.2141243, 87.1212122 ], dtype=np.float64), "float32": 3456.12745, "float64": 1987654321.123456789, "source_sos_ids": np.array([41, 2, 3, 4, 5], dtype=np.int32), "source_sos_mask": np.array([46, 7, 8, 9, 10, 11, 12], dtype=np.int64), "image1": bytes("image5 bytes abc", encoding='UTF-8'), "image2": bytes("image5 bytes def", encoding='UTF-8'), "image3": bytes("image5 bytes ghi", encoding='UTF-8'), "image4": bytes("image5 bytes jkl", encoding='UTF-8'), "image5": bytes("image5 bytes mno", encoding='UTF-8') }, { "file_name": "006.jpg", "label": 37, "float32_array": np.array([1.2, 2.78, 7.1234, 4.9871, 5.12341], dtype=np.float32), "float64_array": np.array([ 48.1234556789, 49.3251241431, 90.13514312414, 51.8971298471, 123414314.2141243, 87.1212122 ], dtype=np.float64), "float32": 3456.12745, "float64": 1987654321.123456789, "source_sos_ids": np.array([51, 2, 3, 4, 5], dtype=np.int32), "source_sos_mask": np.array([56, 7, 8, 9, 10, 11, 12], dtype=np.int64), "image1": bytes("image6 bytes abc", encoding='UTF-8'), "image2": bytes("image6 bytes def", encoding='UTF-8'), "image3": bytes("image6 bytes ghi", encoding='UTF-8'), "image4": bytes("image6 bytes jkl", encoding='UTF-8'), "image5": bytes("image6 bytes mno", encoding='UTF-8') }] schema = { "file_name": { "type": "string" }, "float32_array": { "type": "float32", "shape": [-1] }, "float64_array": { "type": "float64", "shape": [-1] }, "float32": { "type": "float32" }, "float64": { "type": "float64" }, "source_sos_ids": { "type": "int32", "shape": [-1] }, "source_sos_mask": { "type": "int64", "shape": [-1] }, "image1": { "type": "bytes" }, "image2": { "type": "bytes" }, "image3": { "type": "bytes" }, "label": { "type": "int32" }, "image4": { "type": "bytes" }, "image5": { "type": "bytes" } } writer = FileWriter(CV_FILE_NAME1, FILES_NUM) writer.add_schema(schema, "schema") writer.write_raw_data(data) writer.commit() d1 = ds.MindDataset(CV_FILE_NAME1, None, num_readers, shuffle=False) d1.save(CV_FILE_NAME2, FILES_NUM) data_value_to_list = [] for item in data: new_data = {} new_data['file_name'] = np.asarray(item["file_name"], dtype='S') new_data['float32_array'] = item["float32_array"] new_data['float64_array'] = item["float64_array"] new_data['float32'] = item["float32"] new_data['float64'] = item["float64"] new_data['source_sos_ids'] = item["source_sos_ids"] new_data['source_sos_mask'] = item["source_sos_mask"] new_data['label'] = np.asarray(list([item["label"]]), dtype=np.int32) new_data['image1'] = np.asarray(list(item["image1"]), dtype=np.uint8) new_data['image2'] = np.asarray(list(item["image2"]), dtype=np.uint8) new_data['image3'] = np.asarray(list(item["image3"]), dtype=np.uint8) new_data['image4'] = np.asarray(list(item["image4"]), dtype=np.uint8) new_data['image5'] = np.asarray(list(item["image5"]), dtype=np.uint8) data_value_to_list.append(new_data) d2 = ds.MindDataset(dataset_file=CV_FILE_NAME2, num_parallel_workers=num_readers, shuffle=False) assert d2.get_dataset_size() == 6 num_iter = 0 for item in d2.create_dict_iterator(num_epochs=1, output_numpy=True): assert len(item) == 13 for field in item: if isinstance(item[field], np.ndarray): if item[field].dtype == np.float32: assert (item[field] == np.array( data_value_to_list[num_iter][field], np.float32)).all() else: assert (item[field] == data_value_to_list[num_iter][field] ).all() else: assert item[field] == data_value_to_list[num_iter][field] num_iter += 1 assert num_iter == 6
try: mr_api = import_module(args.mindrecord_script + '.mr_api') except ModuleNotFoundError: raise RuntimeError( "Unknown module path: {}".format(args.mindrecord_script + '.mr_api')) num_tasks = mr_api.mindrecord_task_number() print("Write mindrecord ...") mindrecord_dict_data = mr_api.mindrecord_dict_data # get number of files writer = FileWriter(args.mindrecord_file, args.mindrecord_partitions) start_time = time.time() # set the header size try: header_size = mr_api.mindrecord_header_size writer.set_header_size(header_size) except AttributeError: print("Default header size: {}".format(1 << 24)) # set the page size try: page_size = mr_api.mindrecord_page_size writer.set_page_size(page_size) except AttributeError:
def fsns_train_data_to_mindrecord(mindrecord_dir, prefix="data_ocr.mindrecord", file_num=8): anno_file_dirs = [config.annotation_file] images, image_path_dict, image_anno_dict = create_fsns_label(image_dir=config.data_root, anno_file_dirs=anno_file_dirs) vocab, _ = initialize_vocabulary(config.vocab_path) data_schema = {"image": {"type": "bytes"}, "label": {"type": "int32", "shape": [-1]}, "decoder_input": {"type": "int32", "shape": [-1]}, "decoder_mask": {"type": "int32", "shape": [-1]}, "decoder_target": {"type": "int32", "shape": [-1]}, "annotation": {"type": "string"}} mindrecord_path = os.path.join(mindrecord_dir, prefix) writer = FileWriter(mindrecord_path, file_num) writer.add_schema(data_schema, "ocr") for img_id in images: image_path = image_path_dict[img_id] annotation = image_anno_dict[img_id] label_max_len = config.max_length text_max_len = config.max_length - 2 if len(annotation) > text_max_len: continue label = serialize_annotation(image_path, annotation, vocab) if label is None: continue label_len = len(label) decoder_input_len = label_max_len if label_len <= decoder_input_len: label = np.concatenate((label, np.zeros(decoder_input_len - label_len, dtype=np.int32))) one_mask_len = label_len - config.go_shift target_weight = np.concatenate((np.ones(one_mask_len, dtype=np.float32), np.zeros(decoder_input_len - one_mask_len, dtype=np.float32))) else: continue decoder_input = (np.array(label).T).astype(np.int32) target_weight = (np.array(target_weight).T).astype(np.int32) if not len(decoder_input) == len(target_weight): continue target = [decoder_input[i + 1] for i in range(len(decoder_input) - 1)] target = (np.array(target)).astype(np.int32) with open(image_path, 'rb') as f: img = f.read() row = {"image": img, "label": label, "decoder_input": decoder_input, "decoder_mask": target_weight, "decoder_target": target, "annotation": str(annotation)} writer.write_raw_data([row]) writer.commit()
def prepare_mindrecord_data(tokenizer, max_seq_len, label2id, path, out_path): """ 生成训练数据, *.mindrecord, 单标签分类模型, 随机打乱数据 """ writer = FileWriter(out_path) data_schema = {"input_ids": {"type": "int64", "shape": [-1]}, "input_mask": {"type": "int64", "shape": [-1]}, "segment_ids": {"type": "int64", "shape": [-1]}, "label_ids": {"type": "int64", "shape": [-1]}} writer.add_schema(data_schema, "CLUENER2020 schema") example_count = 0 for line in open(path): if not line.strip(): continue _ = json.loads(line.strip()) len_ = len(_["text"]) labels = ["O"] * len_ for k, v in _["label"].items(): for kk, vv in v.items(): for vvv in vv: span = vvv s = span[0] e = span[1] + 1 # print(s, e) if e - s == 1: labels[s] = "S_" + k else: labels[s] = "B_" + k for i in range(s + 1, e - 1): labels[i] = "M_" + k labels[e - 1] = "E_" + k # print() # feature = process_one_example(tokenizer, label2id, row[column_name_x1], row[column_name_y], # max_seq_len=max_seq_len) feature = process_one_example(tokenizer, label2id, list(_["text"]), labels, max_seq_len=max_seq_len) features = collections.OrderedDict() # 序列标注任务 features["input_ids"] = np.asarray(feature[0]) features["input_mask"] = np.asarray(feature[1]) features["segment_ids"] = np.asarray(feature[2]) features["label_ids"] = np.asarray(feature[3]) if example_count < 5: print("*** Example ***") print(_["text"]) print(_["label"]) print("input_ids: %s" % " ".join([str(x) for x in feature[0]])) print("input_mask: %s" % " ".join([str(x) for x in feature[1]])) print("segment_ids: %s" % " ".join([str(x) for x in feature[2]])) print("label: %s " % " ".join([str(x) for x in feature[3]])) writer.write_raw_data([features]) example_count += 1 # if example_count == 20: # break if example_count % 3000 == 0: print(example_count) print("total example:", example_count) writer.commit()