Example #1
0
def convert_to_mindrecord(embed_size, data_path, proprocess_path, glove_path):
    """ convert imdb dataset to mindrecord """

    num_shard = 4
    train_features, train_labels, test_features, test_labels, weight_np, _ = \
        preprocess(data_path, glove_path, embed_size)
    np.savetxt(os.path.join(proprocess_path, 'weight.txt'), weight_np)

    print("train_features.shape:", train_features.shape, "train_labels.shape:", train_labels.shape, "weight_np.shape:",\
    weight_np.shape, "type:", train_labels.dtype)
    # write mindrecord
    schema_json = {
        "id": {
            "type": "int32"
        },
        "label": {
            "type": "int32"
        },
        "feature": {
            "type": "int32",
            "shape": [-1]
        }
    }

    writer = FileWriter(
        os.path.join(proprocess_path, 'aclImdb_train.mindrecord'), num_shard)
    data = get_imdb_data(train_labels, train_features)
    writer.add_schema(schema_json, "nlp_schema")
    writer.add_index(["id", "label"])
    writer.write_raw_data(data)
    writer.commit()

    writer = FileWriter(
        os.path.join(proprocess_path, 'aclImdb_test.mindrecord'), num_shard)
    data = get_imdb_data(test_labels, test_features)
    writer.add_schema(schema_json, "nlp_schema")
    writer.add_index(["id", "label"])
    writer.write_raw_data(data)
    writer.commit()
def test_cv_file_append_writer():
    """tutorial for cv dataset append writer."""
    writer = FileWriter(CV3_FILE_NAME, 4)
    data = get_data("../data/mindrecord/testImageNetData/")
    cv_schema_json = {
        "file_name": {
            "type": "string"
        },
        "label": {
            "type": "int64"
        },
        "data": {
            "type": "bytes"
        }
    }
    writer.add_schema(cv_schema_json, "img_schema")
    writer.add_index(["file_name", "label"])
    writer.write_raw_data(data[0:5])
    writer.commit()
    write_append = FileWriter.open_for_append(CV3_FILE_NAME + "0")
    write_append.write_raw_data(data[5:10])
    write_append.commit()
    reader = FileReader(CV3_FILE_NAME + "0")
    count = 0
    for index, x in enumerate(reader.get_next()):
        assert len(x) == 3
        count = count + 1
        logger.info("#item{}: {}".format(index, x))
    assert count == 10
    reader.close()

    paths = [
        "{}{}".format(CV3_FILE_NAME,
                      str(x).rjust(1, '0')) for x in range(4)
    ]
    for x in paths:
        os.remove("{}".format(x))
        os.remove("{}.db".format(x))
def test_write_read_process():
    mindrecord_file_name = "test.mindrecord"
    data = [{
        "file_name": "001.jpg",
        "label": 43,
        "score": 0.8,
        "mask": np.array([3, 6, 9], dtype=np.int64),
        "segments": np.array([[5.0, 1.6], [65.2, 8.3]], dtype=np.float32),
        "data": bytes("image bytes abc", encoding='UTF-8')
    }, {
        "file_name": "002.jpg",
        "label": 91,
        "score": 5.4,
        "mask": np.array([1, 4, 7], dtype=np.int64),
        "segments": np.array([[5.1, 9.1], [2.0, 65.4]], dtype=np.float32),
        "data": bytes("image bytes def", encoding='UTF-8')
    }, {
        "file_name": "003.jpg",
        "label": 61,
        "score": 6.4,
        "mask": np.array([7, 6, 3], dtype=np.int64),
        "segments": np.array([[0.0, 5.6], [3.0, 16.3]], dtype=np.float32),
        "data": bytes("image bytes ghi", encoding='UTF-8')
    }, {
        "file_name": "004.jpg",
        "label": 29,
        "score": 8.1,
        "mask": np.array([2, 8, 0], dtype=np.int64),
        "segments": np.array([[5.9, 7.2], [4.0, 89.0]], dtype=np.float32),
        "data": bytes("image bytes jkl", encoding='UTF-8')
    }, {
        "file_name": "005.jpg",
        "label": 78,
        "score": 7.7,
        "mask": np.array([3, 1, 2], dtype=np.int64),
        "segments": np.array([[0.6, 8.1], [5.3, 49.3]], dtype=np.float32),
        "data": bytes("image bytes mno", encoding='UTF-8')
    }, {
        "file_name": "006.jpg",
        "label": 37,
        "score": 9.4,
        "mask": np.array([7, 6, 7], dtype=np.int64),
        "segments": np.array([[4.2, 6.3], [8.9, 81.8]], dtype=np.float32),
        "data": bytes("image bytes pqr", encoding='UTF-8')
    }]
    writer = FileWriter(mindrecord_file_name)
    schema = {
        "file_name": {
            "type": "string"
        },
        "label": {
            "type": "int32"
        },
        "score": {
            "type": "float64"
        },
        "mask": {
            "type": "int64",
            "shape": [-1]
        },
        "segments": {
            "type": "float32",
            "shape": [2, 2]
        },
        "data": {
            "type": "bytes"
        }
    }
    writer.add_schema(schema, "data is so cool")
    writer.write_raw_data(data)
    writer.commit()

    reader = FileReader(mindrecord_file_name)
    count = 0
    for index, x in enumerate(reader.get_next()):
        assert len(x) == 6
        for field in x:
            if isinstance(x[field], np.ndarray):
                assert (x[field] == data[count][field]).all()
            else:
                assert x[field] == data[count][field]
        count = count + 1
        logger.info("#item{}: {}".format(index, x))
    assert count == 6
    reader.close()

    os.remove("{}".format(mindrecord_file_name))
    os.remove("{}.db".format(mindrecord_file_name))
def add_and_remove_nlp_file():
    """add/remove nlp file"""
    paths = [
        "{}{}".format(NLP_FILE_NAME,
                      str(x).rjust(1, '0')) for x in range(FILES_NUM)
    ]
    try:
        for x in paths:
            if os.path.exists("{}".format(x)):
                os.remove("{}".format(x))
            if os.path.exists("{}.db".format(x)):
                os.remove("{}.db".format(x))
        writer = FileWriter(NLP_FILE_NAME, FILES_NUM)
        data = [x for x in get_nlp_data(NLP_FILE_POS, NLP_FILE_VOCAB, 10)]
        nlp_schema_json = {
            "id": {
                "type": "string"
            },
            "label": {
                "type": "int32"
            },
            "rating": {
                "type": "float32"
            },
            "input_ids": {
                "type": "int64",
                "shape": [-1]
            },
            "input_mask": {
                "type": "int64",
                "shape": [1, -1]
            },
            "segment_ids": {
                "type": "int64",
                "shape": [2, -1]
            }
        }
        writer.set_header_size(1 << 14)
        writer.set_page_size(1 << 15)
        writer.add_schema(nlp_schema_json, "nlp_schema")
        writer.add_index(["id", "rating"])
        writer.write_raw_data(data)
        writer.commit()
        yield "yield_nlp_data"
    except Exception as error:
        for x in paths:
            os.remove("{}".format(x))
            os.remove("{}.db".format(x))
        raise error
    else:
        for x in paths:
            os.remove("{}".format(x))
            os.remove("{}.db".format(x))
Example #5
0
if __name__ == '__main__':
    args = parse_args()

    data_list = []
    with open(args.data_lst) as f:
        lines = f.readlines()
    if args.shuffle:
        np.random.shuffle(lines)

    dst_dir = '/'.join(args.dst_path.split('/')[:-1])
    if not os.path.exists(dst_dir):
        os.makedirs(dst_dir)

    print('number of samples:', len(lines))
    writer = FileWriter(file_name=args.dst_path, shard_num=args.num_shards)
    writer.add_schema(seg_schema, "seg_schema")
    cnt = 0

    for l in lines:
        img_name = l.strip('\n')

        img_path = 'img/' + str(img_name) + '.jpg'
        label_path = 'cls_png/' + str(img_name) + '.png'

        sample_ = {"file_name": img_path.split('/')[-1]}

        with open(os.path.join(args.data_root, img_path), 'rb') as f:
            sample_['data'] = f.read()
        with open(os.path.join(args.data_root, label_path), 'rb') as f:
            sample_['label'] = f.read()
Example #6
0
def test_write_read_process_with_multi_array():
    mindrecord_file_name = "test.mindrecord"
    data = [{"source_sos_ids": np.array([1, 2, 3, 4, 5], dtype=np.int64),
             "source_sos_mask": np.array([6, 7, 8, 9, 10, 11, 12], dtype=np.int64),
             "source_eos_ids": np.array([13, 14, 15, 16, 17, 18], dtype=np.int64),
             "source_eos_mask": np.array([19, 20, 21, 22, 23, 24, 25, 26, 27], dtype=np.int64),
             "target_sos_ids": np.array([28, 29, 30, 31, 32], dtype=np.int64),
             "target_sos_mask": np.array([33, 34, 35, 36, 37, 38], dtype=np.int64),
             "target_eos_ids": np.array([39, 40, 41, 42, 43, 44, 45, 46, 47], dtype=np.int64),
             "target_eos_mask": np.array([48, 49, 50, 51], dtype=np.int64)},
            {"source_sos_ids": np.array([11, 2, 3, 4, 5], dtype=np.int64),
             "source_sos_mask": np.array([16, 7, 8, 9, 10, 11, 12], dtype=np.int64),
             "source_eos_ids": np.array([113, 14, 15, 16, 17, 18], dtype=np.int64),
             "source_eos_mask": np.array([119, 20, 21, 22, 23, 24, 25, 26, 27], dtype=np.int64),
             "target_sos_ids": np.array([128, 29, 30, 31, 32], dtype=np.int64),
             "target_sos_mask": np.array([133, 34, 35, 36, 37, 38], dtype=np.int64),
             "target_eos_ids": np.array([139, 40, 41, 42, 43, 44, 45, 46, 47], dtype=np.int64),
             "target_eos_mask": np.array([148, 49, 50, 51], dtype=np.int64)},
            {"source_sos_ids": np.array([21, 2, 3, 4, 5], dtype=np.int64),
             "source_sos_mask": np.array([26, 7, 8, 9, 10, 11, 12], dtype=np.int64),
             "source_eos_ids": np.array([213, 14, 15, 16, 17, 18], dtype=np.int64),
             "source_eos_mask": np.array([219, 20, 21, 22, 23, 24, 25, 26, 27], dtype=np.int64),
             "target_sos_ids": np.array([228, 29, 30, 31, 32], dtype=np.int64),
             "target_sos_mask": np.array([233, 34, 35, 36, 37, 38], dtype=np.int64),
             "target_eos_ids": np.array([239, 40, 41, 42, 43, 44, 45, 46, 47], dtype=np.int64),
             "target_eos_mask": np.array([248, 49, 50, 51], dtype=np.int64)},
            {"source_sos_ids": np.array([31, 2, 3, 4, 5], dtype=np.int64),
             "source_sos_mask": np.array([36, 7, 8, 9, 10, 11, 12], dtype=np.int64),
             "source_eos_ids": np.array([313, 14, 15, 16, 17, 18], dtype=np.int64),
             "source_eos_mask": np.array([319, 20, 21, 22, 23, 24, 25, 26, 27], dtype=np.int64),
             "target_sos_ids": np.array([328, 29, 30, 31, 32], dtype=np.int64),
             "target_sos_mask": np.array([333, 34, 35, 36, 37, 38], dtype=np.int64),
             "target_eos_ids": np.array([339, 40, 41, 42, 43, 44, 45, 46, 47], dtype=np.int64),
             "target_eos_mask": np.array([348, 49, 50, 51], dtype=np.int64)},
            {"source_sos_ids": np.array([41, 2, 3, 4, 5], dtype=np.int64),
             "source_sos_mask": np.array([46, 7, 8, 9, 10, 11, 12], dtype=np.int64),
             "source_eos_ids": np.array([413, 14, 15, 16, 17, 18], dtype=np.int64),
             "source_eos_mask": np.array([419, 20, 21, 22, 23, 24, 25, 26, 27], dtype=np.int64),
             "target_sos_ids": np.array([428, 29, 30, 31, 32], dtype=np.int64),
             "target_sos_mask": np.array([433, 34, 35, 36, 37, 38], dtype=np.int64),
             "target_eos_ids": np.array([439, 40, 41, 42, 43, 44, 45, 46, 47], dtype=np.int64),
             "target_eos_mask": np.array([448, 49, 50, 51], dtype=np.int64)},
            {"source_sos_ids": np.array([51, 2, 3, 4, 5], dtype=np.int64),
             "source_sos_mask": np.array([56, 7, 8, 9, 10, 11, 12], dtype=np.int64),
             "source_eos_ids": np.array([513, 14, 15, 16, 17, 18], dtype=np.int64),
             "source_eos_mask": np.array([519, 20, 21, 22, 23, 24, 25, 26, 27], dtype=np.int64),
             "target_sos_ids": np.array([528, 29, 30, 31, 32], dtype=np.int64),
             "target_sos_mask": np.array([533, 34, 35, 36, 37, 38], dtype=np.int64),
             "target_eos_ids": np.array([539, 40, 41, 42, 43, 44, 45, 46, 47], dtype=np.int64),
             "target_eos_mask": np.array([548, 49, 50, 51], dtype=np.int64)}
            ]
    writer = FileWriter(mindrecord_file_name)
    schema = {"source_sos_ids": {"type": "int64", "shape": [-1]},
              "source_sos_mask": {"type": "int64", "shape": [-1]},
              "source_eos_ids": {"type": "int64", "shape": [-1]},
              "source_eos_mask": {"type": "int64", "shape": [-1]},
              "target_sos_ids": {"type": "int64", "shape": [-1]},
              "target_sos_mask": {"type": "int64", "shape": [-1]},
              "target_eos_ids": {"type": "int64", "shape": [-1]},
              "target_eos_mask": {"type": "int64", "shape": [-1]}}
    writer.add_schema(schema, "data is so cool")
    writer.write_raw_data(data)
    writer.commit()

    reader = FileReader(mindrecord_file_name)
    count = 0
    for index, x in enumerate(reader.get_next()):
        assert len(x) == 8
        for field in x:
            if isinstance(x[field], np.ndarray):
                assert (x[field] == data[count][field]).all()
            else:
                assert x[field] == data[count][field]
        count = count + 1
        logger.info("#item{}: {}".format(index, x))
    assert count == 6
    reader.close()

    reader = FileReader(file_name=mindrecord_file_name, columns=["source_eos_ids", "source_eos_mask",
                                                                 "target_sos_ids", "target_sos_mask",
                                                                 "target_eos_ids", "target_eos_mask"])
    count = 0
    for index, x in enumerate(reader.get_next()):
        assert len(x) == 6
        for field in x:
            if isinstance(x[field], np.ndarray):
                assert (x[field] == data[count][field]).all()
            else:
                assert x[field] == data[count][field]
        count = count + 1
        logger.info("#item{}: {}".format(index, x))
    assert count == 6
    reader.close()

    reader = FileReader(file_name=mindrecord_file_name, columns=["source_sos_ids",
                                                                 "target_sos_ids",
                                                                 "target_eos_mask"])
    count = 0
    for index, x in enumerate(reader.get_next()):
        assert len(x) == 3
        for field in x:
            if isinstance(x[field], np.ndarray):
                assert (x[field] == data[count][field]).all()
            else:
                assert x[field] == data[count][field]
        count = count + 1
        logger.info("#item{}: {}".format(index, x))
    assert count == 6
    reader.close()

    reader = FileReader(file_name=mindrecord_file_name, columns=["target_eos_mask",
                                                                 "source_eos_mask",
                                                                 "source_sos_mask"])
    count = 0
    for index, x in enumerate(reader.get_next()):
        assert len(x) == 3
        for field in x:
            if isinstance(x[field], np.ndarray):
                assert (x[field] == data[count][field]).all()
            else:
                assert x[field] == data[count][field]
        count = count + 1
        logger.info("#item{}: {}".format(index, x))
    assert count == 6
    reader.close()

    reader = FileReader(file_name=mindrecord_file_name, columns=["target_eos_ids"])
    count = 0
    for index, x in enumerate(reader.get_next()):
        assert len(x) == 1
        for field in x:
            if isinstance(x[field], np.ndarray):
                assert (x[field] == data[count][field]).all()
            else:
                assert x[field] == data[count][field]
        count = count + 1
        logger.info("#item{}: {}".format(index, x))
    assert count == 6
    reader.close()

    os.remove("{}".format(mindrecord_file_name))
    os.remove("{}.db".format(mindrecord_file_name))
Example #7
0
def gen_mindrecord(data_type):
    """gen mindreocrd according exactly schema"""
    if data_type == "train":
        fw = FileWriter(MINDRECORD_FILE_NAME_TRAIN)
    else:
        fw = FileWriter(MINDRECORD_FILE_NAME_TEST)

    schema = {
        "id": {
            "type": "int32"
        },
        "label": {
            "type": "int32"
        },
        "score": {
            "type": "int32"
        },
        "review": {
            "type": "string"
        }
    }
    fw.add_schema(schema, "aclImdb dataset")
    fw.add_index(["id", "label", "score"])

    get_data_iter = get_data_as_dict(os.path.join(ACLIMDB_DIR, data_type))

    batch_size = 256
    transform_count = 0
    while True:
        data_list = []
        try:
            for _ in range(batch_size):
                data_list.append(get_data_iter.__next__())
                transform_count += 1
            fw.write_raw_data(data_list)
            print(">> transformed {} record...".format(transform_count))
        except StopIteration:
            if data_list:
                fw.write_raw_data(data_list)
                print(">> transformed {} record...".format(transform_count))
            break

    fw.commit()
Example #8
0
    def transfer_coco_to_mindrecord(self,
                                    mindrecord_dir,
                                    file_name="coco_det.train.mind",
                                    shard_num=1):
        """Create MindRecord file by image_dir and anno_path."""
        if not os.path.isdir(mindrecord_dir):
            os.makedirs(mindrecord_dir)
        if os.path.isdir(self.image_path) and os.path.exists(self.annot_path):
            logger.info("Create MindRecord based on COCO_HP dataset")
        else:
            raise ValueError(
                'data_dir {} or anno_path {} does not exist'.format(
                    self.image_path, self.annot_path))

        mindrecord_path = os.path.join(mindrecord_dir, file_name)
        writer = FileWriter(mindrecord_path, shard_num)

        centernet_json = {
            "img_id": {
                "type": "int32",
                "shape": [1]
            },
            "image": {
                "type": "bytes"
            },
            "num_objects": {
                "type": "int32"
            },
            "bboxes": {
                "type": "float32",
                "shape": [-1, 4]
            },
            "category_id": {
                "type": "int32",
                "shape": [-1]
            },
        }

        writer.add_schema(centernet_json, "centernet_json")

        for img_id in self.images:
            image_info = self.coco.loadImgs([img_id])
            annos = self.coco.loadAnns(self.anns[img_id])
            # get image
            img_name = image_info[0]['file_name']
            img_name = os.path.join(self.image_path, img_name)
            with open(img_name, 'rb') as f:
                image = f.read()

            bboxes = []
            category_id = []
            num_objects = len(annos)
            for anno in annos:
                bbox = self._coco_box_to_bbox(anno['bbox'])
                class_name = self.classs_dict[anno["category_id"]]
                if class_name in self.train_cls:
                    x_min, x_max = bbox[0], bbox[2]
                    y_min, y_max = bbox[1], bbox[3]
                    bboxes.append([x_min, y_min, x_max, y_max])
                    category_id.append(self.train_cls_dict[class_name])

            row = {
                "img_id": np.array([img_id], dtype=np.int32),
                "image": image,
                "num_objects": num_objects,
                "bboxes": np.array(bboxes, np.float32),
                "category_id": np.array(category_id, np.int32)
            }
            writer.write_raw_data([row])

        writer.commit()
        logger.info("Create Mindrecord Done, at {}".format(mindrecord_dir))
def create_diff_page_size_cv_mindrecord(files_num):
    """tutorial for cv dataset writer."""
    if os.path.exists(CV1_FILE_NAME):
        os.remove(CV1_FILE_NAME)
    if os.path.exists("{}.db".format(CV1_FILE_NAME)):
        os.remove("{}.db".format(CV1_FILE_NAME))
    writer = FileWriter(CV1_FILE_NAME, files_num)
    writer.set_page_size(1 << 26)  # 64MB
    cv_schema_json = {
        "file_name": {
            "type": "string"
        },
        "label": {
            "type": "int32"
        },
        "data": {
            "type": "bytes"
        }
    }
    data = [{
        "file_name": "001.jpg",
        "label": 43,
        "data": bytes('0xffsafdafda', encoding='utf-8')
    }]
    writer.add_schema(cv_schema_json, "img_schema")
    writer.add_index(["file_name", "label"])
    writer.write_raw_data(data)
    writer.commit()
def test_cv_file_writer_shard_num_str():
    """test cv file writer when shard num is string."""
    with pytest.raises(Exception, match="Shard num is illegal."):
        FileWriter("/tmp/123454321", "20")
Example #11
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--input_file", type=str, required=True,
                        help='Input raw text file (or comma-separated list of files).')
    parser.add_argument("--output_file", type=str, required=True, help='Output MindRecord file.')
    parser.add_argument("--num_splits", type=int, default=16,
                        help='The MindRecord file will be split into the number of partition.')
    parser.add_argument("--vocab_file", type=str, required=True,
                        help='The vocabulary file that the Transformer model was trained on.')
    parser.add_argument("--clip_to_max_len", type=bool, default=False,
                        help='clip sequences to maximum sequence length.')
    parser.add_argument("--max_seq_length", type=int, default=128, help='Maximum sequence length.')
    parser.add_argument("--bucket", type=ast.literal_eval, default=[16, 32, 48, 64, 128],
                        help='bucket sequence length')

    args = parser.parse_args()

    tokenizer = tokenization.WhiteSpaceTokenizer(vocab_file=args.vocab_file)

    input_files = []
    for input_pattern in args.input_file.split(","):
        input_files.append(input_pattern)

    logging.info("*** Read from input files ***")
    for input_file in input_files:
        logging.info("  %s", input_file)

    output_file = args.output_file
    logging.info("*** Write to output files ***")
    logging.info("  %s", output_file)

    total_written = 0
    total_read = 0

    feature_dict = {}
    for i in args.bucket:
        feature_dict[i] = []

    for input_file in input_files:
        logging.info("*** Reading from   %s ***", input_file)
        with open(input_file, "r") as reader:
            while True:
                line = tokenization.convert_to_unicode(reader.readline())
                if not line:
                    break

                total_read += 1
                if total_read % 100000 == 0:
                    logging.info("Read %d ...", total_read)

                source_line, target_line = line.strip().split("\t")
                source_tokens = tokenizer.tokenize(source_line)
                target_tokens = tokenizer.tokenize(target_line)

                if len(source_tokens) >= args.max_seq_length or len(target_tokens) >= args.max_seq_length:
                    logging.info("ignore long sentence!")
                    continue

                instance = create_training_instance(source_tokens, target_tokens, args.max_seq_length,
                                                    clip_to_max_len=args.clip_to_max_len)
                if instance is None:
                    continue

                features, seq_max_bucket_length = get_instance_features(instance, tokenizer, args.max_seq_length,
                                                                        args.bucket)
                for key in feature_dict:
                    if key == seq_max_bucket_length:
                        feature_dict[key].append(features)

                if total_read <= 10:
                    logging.info("*** Example ***")
                    logging.info("source tokens: %s", " ".join(
                        [tokenization.convert_to_printable(x) for x in instance.source_eos_tokens]))
                    logging.info("target tokens: %s", " ".join(
                        [tokenization.convert_to_printable(x) for x in instance.target_sos_tokens]))

                    for feature_name in features.keys():
                        feature = features[feature_name]
                        logging.info("%s: %s", feature_name, feature)

    for i in args.bucket:
        if args.num_splits == 1:
            output_file_name = output_file
        else:
            output_file_name = output_file + '_' + str(i) + '_'
        writer = FileWriter(output_file_name, args.num_splits)
        data_schema = {"source_sos_ids": {"type": "int64", "shape": [-1]},
                       "source_sos_mask": {"type": "int64", "shape": [-1]},
                       "source_eos_ids": {"type": "int64", "shape": [-1]},
                       "source_eos_mask": {"type": "int64", "shape": [-1]},
                       "target_sos_ids": {"type": "int64", "shape": [-1]},
                       "target_sos_mask": {"type": "int64", "shape": [-1]},
                       "target_eos_ids": {"type": "int64", "shape": [-1]},
                       "target_eos_mask": {"type": "int64", "shape": [-1]}
                       }
        writer.add_schema(data_schema, "tranformer")
        features_ = feature_dict[i]
        logging.info("Bucket length %d has %d samples, start writing...", i, len(features_))

        for item in features_:
            writer.write_raw_data([item])
            total_written += 1
        writer.commit()

    logging.info("Wrote %d total instances", total_written)
def test_add_index_without_add_schema():
    with pytest.raises(MRMGetMetaError) as err:
        fw = FileWriter(CV_FILE_NAME)
        fw.add_index(["label"])
    assert 'Failed to get meta info' in str(err.value)
def test_cv_file_writer_shard_num_greater_than_1000():
    """test cv file writer shard number greater than 1000."""
    with pytest.raises(ParamValueError) as err:
        FileWriter(CV_FILE_NAME, 1001)
    assert 'Shard number should between' in str(err.value)
Example #14
0
def test_write_read_process_with_multi_bytes_and_array():
    mindrecord_file_name = "test.mindrecord"
    data = [{"file_name": "001.jpg", "label": 4,
             "image1": bytes("image1 bytes abc", encoding='UTF-8'),
             "image2": bytes("image1 bytes def", encoding='UTF-8'),
             "source_sos_ids": np.array([1, 2, 3, 4, 5], dtype=np.int64),
             "source_sos_mask": np.array([6, 7, 8, 9, 10, 11, 12], dtype=np.int64),
             "image3": bytes("image1 bytes ghi", encoding='UTF-8'),
             "image4": bytes("image1 bytes jkl", encoding='UTF-8'),
             "image5": bytes("image1 bytes mno", encoding='UTF-8'),
             "target_sos_ids": np.array([28, 29, 30, 31, 32], dtype=np.int64),
             "target_sos_mask": np.array([33, 34, 35, 36, 37, 38], dtype=np.int64),
             "target_eos_ids": np.array([39, 40, 41, 42, 43, 44, 45, 46, 47], dtype=np.int64),
             "target_eos_mask": np.array([48, 49, 50, 51], dtype=np.int64)},
            {"file_name": "002.jpg", "label": 5,
             "image1": bytes("image2 bytes abc", encoding='UTF-8'),
             "image2": bytes("image2 bytes def", encoding='UTF-8'),
             "image3": bytes("image2 bytes ghi", encoding='UTF-8'),
             "image4": bytes("image2 bytes jkl", encoding='UTF-8'),
             "image5": bytes("image2 bytes mno", encoding='UTF-8'),
             "source_sos_ids": np.array([11, 2, 3, 4, 5], dtype=np.int64),
             "source_sos_mask": np.array([16, 7, 8, 9, 10, 11, 12], dtype=np.int64),
             "target_sos_ids": np.array([128, 29, 30, 31, 32], dtype=np.int64),
             "target_sos_mask": np.array([133, 34, 35, 36, 37, 38], dtype=np.int64),
             "target_eos_ids": np.array([139, 40, 41, 42, 43, 44, 45, 46, 47], dtype=np.int64),
             "target_eos_mask": np.array([148, 49, 50, 51], dtype=np.int64)},
            {"file_name": "003.jpg", "label": 6,
             "source_sos_ids": np.array([21, 2, 3, 4, 5], dtype=np.int64),
             "source_sos_mask": np.array([26, 7, 8, 9, 10, 11, 12], dtype=np.int64),
             "target_sos_ids": np.array([228, 29, 30, 31, 32], dtype=np.int64),
             "target_sos_mask": np.array([233, 34, 35, 36, 37, 38], dtype=np.int64),
             "target_eos_ids": np.array([239, 40, 41, 42, 43, 44, 45, 46, 47], dtype=np.int64),
             "image1": bytes("image3 bytes abc", encoding='UTF-8'),
             "image2": bytes("image3 bytes def", encoding='UTF-8'),
             "image3": bytes("image3 bytes ghi", encoding='UTF-8'),
             "image4": bytes("image3 bytes jkl", encoding='UTF-8'),
             "image5": bytes("image3 bytes mno", encoding='UTF-8'),
             "target_eos_mask": np.array([248, 49, 50, 51], dtype=np.int64)},
            {"file_name": "004.jpg", "label": 7,
             "source_sos_ids": np.array([31, 2, 3, 4, 5], dtype=np.int64),
             "source_sos_mask": np.array([36, 7, 8, 9, 10, 11, 12], dtype=np.int64),
             "image1": bytes("image4 bytes abc", encoding='UTF-8'),
             "image2": bytes("image4 bytes def", encoding='UTF-8'),
             "image3": bytes("image4 bytes ghi", encoding='UTF-8'),
             "image4": bytes("image4 bytes jkl", encoding='UTF-8'),
             "image5": bytes("image4 bytes mno", encoding='UTF-8'),
             "target_sos_ids": np.array([328, 29, 30, 31, 32], dtype=np.int64),
             "target_sos_mask": np.array([333, 34, 35, 36, 37, 38], dtype=np.int64),
             "target_eos_ids": np.array([339, 40, 41, 42, 43, 44, 45, 46, 47], dtype=np.int64),
             "target_eos_mask": np.array([348, 49, 50, 51], dtype=np.int64)},
            {"file_name": "005.jpg", "label": 8,
             "source_sos_ids": np.array([41, 2, 3, 4, 5], dtype=np.int64),
             "source_sos_mask": np.array([46, 7, 8, 9, 10, 11, 12], dtype=np.int64),
             "target_sos_ids": np.array([428, 29, 30, 31, 32], dtype=np.int64),
             "target_sos_mask": np.array([433, 34, 35, 36, 37, 38], dtype=np.int64),
             "image1": bytes("image5 bytes abc", encoding='UTF-8'),
             "image2": bytes("image5 bytes def", encoding='UTF-8'),
             "image3": bytes("image5 bytes ghi", encoding='UTF-8'),
             "image4": bytes("image5 bytes jkl", encoding='UTF-8'),
             "image5": bytes("image5 bytes mno", encoding='UTF-8'),
             "target_eos_ids": np.array([439, 40, 41, 42, 43, 44, 45, 46, 47], dtype=np.int64),
             "target_eos_mask": np.array([448, 49, 50, 51], dtype=np.int64)},
            {"file_name": "006.jpg", "label": 9,
             "source_sos_ids": np.array([51, 2, 3, 4, 5], dtype=np.int64),
             "source_sos_mask": np.array([56, 7, 8, 9, 10, 11, 12], dtype=np.int64),
             "target_sos_ids": np.array([528, 29, 30, 31, 32], dtype=np.int64),
             "image1": bytes("image6 bytes abc", encoding='UTF-8'),
             "image2": bytes("image6 bytes def", encoding='UTF-8'),
             "image3": bytes("image6 bytes ghi", encoding='UTF-8'),
             "image4": bytes("image6 bytes jkl", encoding='UTF-8'),
             "image5": bytes("image6 bytes mno", encoding='UTF-8'),
             "target_sos_mask": np.array([533, 34, 35, 36, 37, 38], dtype=np.int64),
             "target_eos_ids": np.array([539, 40, 41, 42, 43, 44, 45, 46, 47], dtype=np.int64),
             "target_eos_mask": np.array([548, 49, 50, 51], dtype=np.int64)}
            ]

    writer = FileWriter(mindrecord_file_name)
    schema = {"file_name": {"type": "string"},
              "image1": {"type": "bytes"},
              "image2": {"type": "bytes"},
              "source_sos_ids": {"type": "int64", "shape": [-1]},
              "source_sos_mask": {"type": "int64", "shape": [-1]},
              "image3": {"type": "bytes"},
              "image4": {"type": "bytes"},
              "image5": {"type": "bytes"},
              "target_sos_ids": {"type": "int64", "shape": [-1]},
              "target_sos_mask": {"type": "int64", "shape": [-1]},
              "target_eos_ids": {"type": "int64", "shape": [-1]},
              "target_eos_mask": {"type": "int64", "shape": [-1]},
              "label": {"type": "int32"}}
    writer.add_schema(schema, "data is so cool")
    writer.write_raw_data(data)
    writer.commit()

    reader = FileReader(mindrecord_file_name)
    count = 0
    for index, x in enumerate(reader.get_next()):
        assert len(x) == 13
        for field in x:
            if isinstance(x[field], np.ndarray):
                assert (x[field] == data[count][field]).all()
            else:
                assert x[field] == data[count][field]
        count = count + 1
        logger.info("#item{}: {}".format(index, x))
    assert count == 6
    reader.close()

    reader = FileReader(file_name=mindrecord_file_name, columns=["source_sos_ids", "source_sos_mask",
                                                                 "target_sos_ids"])
    count = 0
    for index, x in enumerate(reader.get_next()):
        assert len(x) == 3
        for field in x:
            if isinstance(x[field], np.ndarray):
                assert (x[field] == data[count][field]).all()
            else:
                assert x[field] == data[count][field]
        count = count + 1
        logger.info("#item{}: {}".format(index, x))
    assert count == 6
    reader.close()

    reader = FileReader(file_name=mindrecord_file_name, columns=["image2", "source_sos_mask",
                                                                 "image3", "target_sos_ids"])
    count = 0
    for index, x in enumerate(reader.get_next()):
        assert len(x) == 4
        for field in x:
            if isinstance(x[field], np.ndarray):
                assert (x[field] == data[count][field]).all()
            else:
                assert x[field] == data[count][field]
        count = count + 1
        logger.info("#item{}: {}".format(index, x))
    assert count == 6
    reader.close()

    reader = FileReader(file_name=mindrecord_file_name, columns=["target_sos_ids", "image4",
                                                                 "source_sos_ids"])
    count = 0
    for index, x in enumerate(reader.get_next()):
        assert len(x) == 3
        for field in x:
            if isinstance(x[field], np.ndarray):
                assert (x[field] == data[count][field]).all()
            else:
                assert x[field] == data[count][field]
        count = count + 1
        logger.info("#item{}: {}".format(index, x))
    assert count == 6
    reader.close()

    reader = FileReader(file_name=mindrecord_file_name, columns=["target_sos_ids", "image5",
                                                                 "image4", "image3", "source_sos_ids"])
    count = 0
    for index, x in enumerate(reader.get_next()):
        assert len(x) == 5
        for field in x:
            if isinstance(x[field], np.ndarray):
                assert (x[field] == data[count][field]).all()
            else:
                assert x[field] == data[count][field]
        count = count + 1
        logger.info("#item{}: {}".format(index, x))
    assert count == 6
    reader.close()

    reader = FileReader(file_name=mindrecord_file_name, columns=["target_eos_mask", "image5", "image2",
                                                                 "source_sos_mask", "label"])
    count = 0
    for index, x in enumerate(reader.get_next()):
        assert len(x) == 5
        for field in x:
            if isinstance(x[field], np.ndarray):
                assert (x[field] == data[count][field]).all()
            else:
                assert x[field] == data[count][field]
        count = count + 1
        logger.info("#item{}: {}".format(index, x))
    assert count == 6
    reader.close()

    os.remove("{}".format(mindrecord_file_name))
    os.remove("{}.db".format(mindrecord_file_name))
Example #15
0
def test_case_02(add_and_remove_cv_file):  # muti-bytes
    data = [{
        "file_name":
        "001.jpg",
        "label":
        43,
        "float32_array":
        np.array([1.2, 2.78, 3.1234, 4.9871, 5.12341], dtype=np.float32),
        "float64_array":
        np.array([
            48.1234556789, 49.3251241431, 50.13514312414, 51.8971298471,
            123414314.2141243, 87.1212122
        ],
                 dtype=np.float64),
        "float32":
        3456.12345,
        "float64":
        1987654321.123456785,
        "source_sos_ids":
        np.array([1, 2, 3, 4, 5], dtype=np.int32),
        "source_sos_mask":
        np.array([6, 7, 8, 9, 10, 11, 12], dtype=np.int64),
        "image1":
        bytes("image1 bytes abc", encoding='UTF-8'),
        "image2":
        bytes("image1 bytes def", encoding='UTF-8'),
        "image3":
        bytes("image1 bytes ghi", encoding='UTF-8'),
        "image4":
        bytes("image1 bytes jkl", encoding='UTF-8'),
        "image5":
        bytes("image1 bytes mno", encoding='UTF-8')
    }, {
        "file_name":
        "002.jpg",
        "label":
        91,
        "float32_array":
        np.array([1.2, 2.78, 4.1234, 4.9871, 5.12341], dtype=np.float32),
        "float64_array":
        np.array([
            48.1234556789, 49.3251241431, 60.13514312414, 51.8971298471,
            123414314.2141243, 87.1212122
        ],
                 dtype=np.float64),
        "float32":
        3456.12445,
        "float64":
        1987654321.123456786,
        "source_sos_ids":
        np.array([11, 2, 3, 4, 5], dtype=np.int32),
        "source_sos_mask":
        np.array([16, 7, 8, 9, 10, 11, 12], dtype=np.int64),
        "image1":
        bytes("image2 bytes abc", encoding='UTF-8'),
        "image2":
        bytes("image2 bytes def", encoding='UTF-8'),
        "image3":
        bytes("image2 bytes ghi", encoding='UTF-8'),
        "image4":
        bytes("image2 bytes jkl", encoding='UTF-8'),
        "image5":
        bytes("image2 bytes mno", encoding='UTF-8')
    }, {
        "file_name":
        "003.jpg",
        "label":
        61,
        "float32_array":
        np.array([1.2, 2.78, 5.1234, 4.9871, 5.12341], dtype=np.float32),
        "float64_array":
        np.array([
            48.1234556789, 49.3251241431, 70.13514312414, 51.8971298471,
            123414314.2141243, 87.1212122
        ],
                 dtype=np.float64),
        "float32":
        3456.12545,
        "float64":
        1987654321.123456787,
        "source_sos_ids":
        np.array([21, 2, 3, 4, 5], dtype=np.int32),
        "source_sos_mask":
        np.array([26, 7, 8, 9, 10, 11, 12], dtype=np.int64),
        "image1":
        bytes("image3 bytes abc", encoding='UTF-8'),
        "image2":
        bytes("image3 bytes def", encoding='UTF-8'),
        "image3":
        bytes("image3 bytes ghi", encoding='UTF-8'),
        "image4":
        bytes("image3 bytes jkl", encoding='UTF-8'),
        "image5":
        bytes("image3 bytes mno", encoding='UTF-8')
    }, {
        "file_name":
        "004.jpg",
        "label":
        29,
        "float32_array":
        np.array([1.2, 2.78, 6.1234, 4.9871, 5.12341], dtype=np.float32),
        "float64_array":
        np.array([
            48.1234556789, 49.3251241431, 80.13514312414, 51.8971298471,
            123414314.2141243, 87.1212122
        ],
                 dtype=np.float64),
        "float32":
        3456.12645,
        "float64":
        1987654321.123456788,
        "source_sos_ids":
        np.array([31, 2, 3, 4, 5], dtype=np.int32),
        "source_sos_mask":
        np.array([36, 7, 8, 9, 10, 11, 12], dtype=np.int64),
        "image1":
        bytes("image4 bytes abc", encoding='UTF-8'),
        "image2":
        bytes("image4 bytes def", encoding='UTF-8'),
        "image3":
        bytes("image4 bytes ghi", encoding='UTF-8'),
        "image4":
        bytes("image4 bytes jkl", encoding='UTF-8'),
        "image5":
        bytes("image4 bytes mno", encoding='UTF-8')
    }, {
        "file_name":
        "005.jpg",
        "label":
        78,
        "float32_array":
        np.array([1.2, 2.78, 7.1234, 4.9871, 5.12341], dtype=np.float32),
        "float64_array":
        np.array([
            48.1234556789, 49.3251241431, 90.13514312414, 51.8971298471,
            123414314.2141243, 87.1212122
        ],
                 dtype=np.float64),
        "float32":
        3456.12745,
        "float64":
        1987654321.123456789,
        "source_sos_ids":
        np.array([41, 2, 3, 4, 5], dtype=np.int32),
        "source_sos_mask":
        np.array([46, 7, 8, 9, 10, 11, 12], dtype=np.int64),
        "image1":
        bytes("image5 bytes abc", encoding='UTF-8'),
        "image2":
        bytes("image5 bytes def", encoding='UTF-8'),
        "image3":
        bytes("image5 bytes ghi", encoding='UTF-8'),
        "image4":
        bytes("image5 bytes jkl", encoding='UTF-8'),
        "image5":
        bytes("image5 bytes mno", encoding='UTF-8')
    }, {
        "file_name":
        "006.jpg",
        "label":
        37,
        "float32_array":
        np.array([1.2, 2.78, 7.1234, 4.9871, 5.12341], dtype=np.float32),
        "float64_array":
        np.array([
            48.1234556789, 49.3251241431, 90.13514312414, 51.8971298471,
            123414314.2141243, 87.1212122
        ],
                 dtype=np.float64),
        "float32":
        3456.12745,
        "float64":
        1987654321.123456789,
        "source_sos_ids":
        np.array([51, 2, 3, 4, 5], dtype=np.int32),
        "source_sos_mask":
        np.array([56, 7, 8, 9, 10, 11, 12], dtype=np.int64),
        "image1":
        bytes("image6 bytes abc", encoding='UTF-8'),
        "image2":
        bytes("image6 bytes def", encoding='UTF-8'),
        "image3":
        bytes("image6 bytes ghi", encoding='UTF-8'),
        "image4":
        bytes("image6 bytes jkl", encoding='UTF-8'),
        "image5":
        bytes("image6 bytes mno", encoding='UTF-8')
    }]
    schema = {
        "file_name": {
            "type": "string"
        },
        "float32_array": {
            "type": "float32",
            "shape": [-1]
        },
        "float64_array": {
            "type": "float64",
            "shape": [-1]
        },
        "float32": {
            "type": "float32"
        },
        "float64": {
            "type": "float64"
        },
        "source_sos_ids": {
            "type": "int32",
            "shape": [-1]
        },
        "source_sos_mask": {
            "type": "int64",
            "shape": [-1]
        },
        "image1": {
            "type": "bytes"
        },
        "image2": {
            "type": "bytes"
        },
        "image3": {
            "type": "bytes"
        },
        "label": {
            "type": "int32"
        },
        "image4": {
            "type": "bytes"
        },
        "image5": {
            "type": "bytes"
        }
    }
    writer = FileWriter(CV_FILE_NAME1, FILES_NUM)
    writer.add_schema(schema, "schema")
    writer.write_raw_data(data)
    writer.commit()

    d1 = ds.MindDataset(CV_FILE_NAME1, None, num_readers, shuffle=False)
    d1.save(CV_FILE_NAME2, FILES_NUM)
    data_value_to_list = []

    for item in data:
        new_data = {}
        new_data['file_name'] = np.asarray(item["file_name"], dtype='S')
        new_data['float32_array'] = item["float32_array"]
        new_data['float64_array'] = item["float64_array"]
        new_data['float32'] = item["float32"]
        new_data['float64'] = item["float64"]
        new_data['source_sos_ids'] = item["source_sos_ids"]
        new_data['source_sos_mask'] = item["source_sos_mask"]
        new_data['label'] = np.asarray(list([item["label"]]), dtype=np.int32)
        new_data['image1'] = np.asarray(list(item["image1"]), dtype=np.uint8)
        new_data['image2'] = np.asarray(list(item["image2"]), dtype=np.uint8)
        new_data['image3'] = np.asarray(list(item["image3"]), dtype=np.uint8)
        new_data['image4'] = np.asarray(list(item["image4"]), dtype=np.uint8)
        new_data['image5'] = np.asarray(list(item["image5"]), dtype=np.uint8)
        data_value_to_list.append(new_data)

    d2 = ds.MindDataset(dataset_file=CV_FILE_NAME2,
                        num_parallel_workers=num_readers,
                        shuffle=False)
    assert d2.get_dataset_size() == 6
    num_iter = 0
    for item in d2.create_dict_iterator(num_epochs=1, output_numpy=True):
        assert len(item) == 13
        for field in item:
            if isinstance(item[field], np.ndarray):
                if item[field].dtype == np.float32:
                    assert (item[field] == np.array(
                        data_value_to_list[num_iter][field],
                        np.float32)).all()
                else:
                    assert (item[field] == data_value_to_list[num_iter][field]
                            ).all()
            else:
                assert item[field] == data_value_to_list[num_iter][field]
        num_iter += 1
    assert num_iter == 6
Example #16
0
    parser.add_argument('--file_partition', type=int, default=1)
    parser.add_argument('--file_batch_size', type=int, default=1024)
    parser.add_argument('--num_process', type=int, default=64)

    args = parser.parse_args()
    ###
    out_dir, out_file = os.path.split(os.path.abspath(args.output_file))
    if not os.path.exists(out_dir):
        os.mkdir(out_dir)
    schema = {
        "input_ids": {
            "type": "int32",
            "shape": [-1]
        },
    }
    writer = FileWriter(file_name=args.output_file,
                        shard_num=args.file_partition)
    writer.add_schema(schema, args.dataset_type)
    writer.open_and_set_header()
    ###
    transforms_count = 0
    if args.dataset_type == 'wiki':
        for x in tokenize_wiki(args.input_glob):
            transforms_count += 1
            writer.write_raw_data([x])
        print("Transformed {} records.".format(transforms_count))
    elif args.dataset_type == 'lambada':
        for x in tokenize_lambada(args.input_glob):
            transforms_count += 1
            writer.write_raw_data([x])
        print("Transformed {} records.".format(transforms_count))
    elif args.dataset_type == 'openwebtext':
Example #17
0
def test_case_00(add_and_remove_cv_file):  # only bin data
    data = [{
        "image1": bytes("image1 bytes abc", encoding='UTF-8'),
        "image2": bytes("image1 bytes def", encoding='UTF-8'),
        "image3": bytes("image1 bytes ghi", encoding='UTF-8'),
        "image4": bytes("image1 bytes jkl", encoding='UTF-8'),
        "image5": bytes("image1 bytes mno", encoding='UTF-8')
    }, {
        "image1": bytes("image2 bytes abc", encoding='UTF-8'),
        "image2": bytes("image2 bytes def", encoding='UTF-8'),
        "image3": bytes("image2 bytes ghi", encoding='UTF-8'),
        "image4": bytes("image2 bytes jkl", encoding='UTF-8'),
        "image5": bytes("image2 bytes mno", encoding='UTF-8')
    }, {
        "image1": bytes("image3 bytes abc", encoding='UTF-8'),
        "image2": bytes("image3 bytes def", encoding='UTF-8'),
        "image3": bytes("image3 bytes ghi", encoding='UTF-8'),
        "image4": bytes("image3 bytes jkl", encoding='UTF-8'),
        "image5": bytes("image3 bytes mno", encoding='UTF-8')
    }, {
        "image1": bytes("image5 bytes abc", encoding='UTF-8'),
        "image2": bytes("image5 bytes def", encoding='UTF-8'),
        "image3": bytes("image5 bytes ghi", encoding='UTF-8'),
        "image4": bytes("image5 bytes jkl", encoding='UTF-8'),
        "image5": bytes("image5 bytes mno", encoding='UTF-8')
    }, {
        "image1": bytes("image6 bytes abc", encoding='UTF-8'),
        "image2": bytes("image6 bytes def", encoding='UTF-8'),
        "image3": bytes("image6 bytes ghi", encoding='UTF-8'),
        "image4": bytes("image6 bytes jkl", encoding='UTF-8'),
        "image5": bytes("image6 bytes mno", encoding='UTF-8')
    }]
    schema = {
        "image1": {
            "type": "bytes"
        },
        "image2": {
            "type": "bytes"
        },
        "image3": {
            "type": "bytes"
        },
        "image4": {
            "type": "bytes"
        },
        "image5": {
            "type": "bytes"
        }
    }
    writer = FileWriter(CV_FILE_NAME1, FILES_NUM)
    writer.add_schema(schema, "schema")
    writer.write_raw_data(data)
    writer.commit()

    d1 = ds.MindDataset(CV_FILE_NAME1, None, num_readers, shuffle=False)
    d1.save(CV_FILE_NAME2, FILES_NUM)
    data_value_to_list = []

    for item in data:
        new_data = {}
        new_data['image1'] = np.asarray(list(item["image1"]), dtype=np.uint8)
        new_data['image2'] = np.asarray(list(item["image2"]), dtype=np.uint8)
        new_data['image3'] = np.asarray(list(item["image3"]), dtype=np.uint8)
        new_data['image4'] = np.asarray(list(item["image4"]), dtype=np.uint8)
        new_data['image5'] = np.asarray(list(item["image5"]), dtype=np.uint8)
        data_value_to_list.append(new_data)

    d2 = ds.MindDataset(dataset_file=CV_FILE_NAME2,
                        num_parallel_workers=num_readers,
                        shuffle=False)
    assert d2.get_dataset_size() == 5
    num_iter = 0
    for item in d2.create_dict_iterator(num_epochs=1, output_numpy=True):
        assert len(item) == 5
        for field in item:
            if isinstance(item[field], np.ndarray):
                assert (
                    item[field] == data_value_to_list[num_iter][field]).all()
            else:
                assert item[field] == data_value_to_list[num_iter][field]
        num_iter += 1
    assert num_iter == 5
Example #18
0
def test_write_read_process_with_multi_bytes():
    mindrecord_file_name = "test.mindrecord"
    data = [{"file_name": "001.jpg", "label": 43,
             "image1": bytes("image1 bytes abc", encoding='UTF-8'),
             "image2": bytes("image1 bytes def", encoding='UTF-8'),
             "image3": bytes("image1 bytes ghi", encoding='UTF-8'),
             "image4": bytes("image1 bytes jkl", encoding='UTF-8'),
             "image5": bytes("image1 bytes mno", encoding='UTF-8')},
            {"file_name": "002.jpg", "label": 91,
             "image1": bytes("image2 bytes abc", encoding='UTF-8'),
             "image2": bytes("image2 bytes def", encoding='UTF-8'),
             "image3": bytes("image2 bytes ghi", encoding='UTF-8'),
             "image4": bytes("image2 bytes jkl", encoding='UTF-8'),
             "image5": bytes("image2 bytes mno", encoding='UTF-8')},
            {"file_name": "003.jpg", "label": 61,
             "image1": bytes("image3 bytes abc", encoding='UTF-8'),
             "image2": bytes("image3 bytes def", encoding='UTF-8'),
             "image3": bytes("image3 bytes ghi", encoding='UTF-8'),
             "image4": bytes("image3 bytes jkl", encoding='UTF-8'),
             "image5": bytes("image3 bytes mno", encoding='UTF-8')},
            {"file_name": "004.jpg", "label": 29,
             "image1": bytes("image4 bytes abc", encoding='UTF-8'),
             "image2": bytes("image4 bytes def", encoding='UTF-8'),
             "image3": bytes("image4 bytes ghi", encoding='UTF-8'),
             "image4": bytes("image4 bytes jkl", encoding='UTF-8'),
             "image5": bytes("image4 bytes mno", encoding='UTF-8')},
            {"file_name": "005.jpg", "label": 78,
             "image1": bytes("image5 bytes abc", encoding='UTF-8'),
             "image2": bytes("image5 bytes def", encoding='UTF-8'),
             "image3": bytes("image5 bytes ghi", encoding='UTF-8'),
             "image4": bytes("image5 bytes jkl", encoding='UTF-8'),
             "image5": bytes("image5 bytes mno", encoding='UTF-8')},
            {"file_name": "006.jpg", "label": 37,
             "image1": bytes("image6 bytes abc", encoding='UTF-8'),
             "image2": bytes("image6 bytes def", encoding='UTF-8'),
             "image3": bytes("image6 bytes ghi", encoding='UTF-8'),
             "image4": bytes("image6 bytes jkl", encoding='UTF-8'),
             "image5": bytes("image6 bytes mno", encoding='UTF-8')}
            ]
    writer = FileWriter(mindrecord_file_name)
    schema = {"file_name": {"type": "string"},
              "image1": {"type": "bytes"},
              "image2": {"type": "bytes"},
              "image3": {"type": "bytes"},
              "label": {"type": "int32"},
              "image4": {"type": "bytes"},
              "image5": {"type": "bytes"}}
    writer.add_schema(schema, "data is so cool")
    writer.write_raw_data(data)
    writer.commit()

    reader = FileReader(mindrecord_file_name)
    count = 0
    for index, x in enumerate(reader.get_next()):
        assert len(x) == 7
        for field in x:
            if isinstance(x[field], np.ndarray):
                assert (x[field] == data[count][field]).all()
            else:
                assert x[field] == data[count][field]
        count = count + 1
        logger.info("#item{}: {}".format(index, x))
    assert count == 6
    reader.close()

    reader2 = FileReader(file_name=mindrecord_file_name, columns=["image1", "image2", "image5"])
    count = 0
    for index, x in enumerate(reader2.get_next()):
        assert len(x) == 3
        for field in x:
            if isinstance(x[field], np.ndarray):
                assert (x[field] == data[count][field]).all()
            else:
                assert x[field] == data[count][field]
        count = count + 1
        logger.info("#item{}: {}".format(index, x))
    assert count == 6
    reader2.close()

    reader3 = FileReader(file_name=mindrecord_file_name, columns=["image2", "image4"])
    count = 0
    for index, x in enumerate(reader3.get_next()):
        assert len(x) == 2
        for field in x:
            if isinstance(x[field], np.ndarray):
                assert (x[field] == data[count][field]).all()
            else:
                assert x[field] == data[count][field]
        count = count + 1
        logger.info("#item{}: {}".format(index, x))
    assert count == 6
    reader3.close()

    reader4 = FileReader(file_name=mindrecord_file_name, columns=["image5", "image2"])
    count = 0
    for index, x in enumerate(reader4.get_next()):
        assert len(x) == 2
        for field in x:
            if isinstance(x[field], np.ndarray):
                assert (x[field] == data[count][field]).all()
            else:
                assert x[field] == data[count][field]
        count = count + 1
        logger.info("#item{}: {}".format(index, x))
    assert count == 6
    reader4.close()

    reader5 = FileReader(file_name=mindrecord_file_name, columns=["image5", "image2", "label"])
    count = 0
    for index, x in enumerate(reader5.get_next()):
        assert len(x) == 3
        for field in x:
            if isinstance(x[field], np.ndarray):
                assert (x[field] == data[count][field]).all()
            else:
                assert x[field] == data[count][field]
        count = count + 1
        logger.info("#item{}: {}".format(index, x))
    assert count == 6
    reader5.close()

    os.remove("{}".format(mindrecord_file_name))
    os.remove("{}.db".format(mindrecord_file_name))