Esempio n. 1
0
def test_issue_65():
    """test file reader when file name is illegal."""
    reader = ShardReader()
    file_name = os.path.join(os.getcwd(), "imagenet.mindrecord01qwert")
    with pytest.raises(Exception) as e:
        reader.open(file_name)
    assert str(e.value) == "[MRMOpenError]: error_code: 1347690596, " \
                           "error_msg: " \
                           "MindRecord File could not open successfully."
Esempio n. 2
0
def test_issue_84():
    """test file reader when db does not match."""
    writer = FileWriter(CV_FILE_NAME, FILES_NUM)
    data = get_data("../data/mindrecord/testImageNetData/")
    cv_schema_json = {"file_name": {"type": "string"},
                      "label": {"type": "number"}, "data": {"type": "bytes"}}
    writer.add_schema(cv_schema_json, "img_schema")
    writer.add_index(["file_name", "label"])
    writer.write_raw_data(data)
    writer.commit()

    writer = FileWriter(NLP_FILE_NAME, FILES_NUM)
    data = list(get_nlp_data("../data/mindrecord/testAclImdbData/pos",
                             "../data/mindrecord/testAclImdbData/vocab.txt",
                             10))
    nlp_schema_json = {"id": {"type": "string"}, "label": {"type": "number"},
                       "rating": {"type": "number"},
                       "input_ids": {"type": "array",
                                     "items": {"type": "number"}},
                       "input_mask": {"type": "array",
                                      "items": {"type": "number"}},
                       "segment_ids": {"type": "array",
                                       "items": {"type": "number"}}
                       }
    writer.set_header_size(1 << 14)
    writer.set_page_size(1 << 15)
    writer.add_schema(nlp_schema_json, "nlp_schema")
    writer.add_index(["id", "rating"])
    writer.write_raw_data(data)
    writer.commit()

    reader = ShardReader()
    os.rename("imagenet.mindrecord1.db", "imagenet.mindrecord1.db.bk")
    os.rename("aclImdb.mindrecord1.db", "imagenet.mindrecord1.db")
    file_name = os.path.join(os.getcwd(), "imagenet.mindrecord1")
    with pytest.raises(Exception) as e:
        reader.open(file_name)
    assert str(e.value) == "[MRMOpenError]: error_code: 1347690596, " \
                           "error_msg: " \
                           "MindRecord File could not open successfully."

    os.rename("imagenet.mindrecord1.db", "aclImdb.mindrecord1.db")
    paths = ["{}{}".format(NLP_FILE_NAME, str(x).rjust(1, '0'))
             for x in range(FILES_NUM)]
    for item in paths:
        os.remove("{}".format(item))
        os.remove("{}.db".format(item))

    os.rename("imagenet.mindrecord1.db.bk", "imagenet.mindrecord1.db")
    paths = ["{}{}".format(CV_FILE_NAME, str(x).rjust(1, '0'))
             for x in range(FILES_NUM)]
    for item in paths:
        os.remove("{}".format(item))
        os.remove("{}.db".format(item))
def test_nlp_file_reader():
    """test nlp file reader using shard api"""
    dataset = ShardReader()
    dataset.open(NLP_FILE_NAME + "0")
    dataset.launch()
    index = 0
    iterator = dataset.get_next()
    while iterator:
        for _, raw in iterator:
            logger.info("#item{}: {}".format(index, raw))
            index += 1
            iterator = dataset.get_next()
    dataset.finish()
    dataset.close()
def test_cv_file_reader():
    """test cv file reader using shard api"""
    dataset = ShardReader()
    dataset.open(CV_FILE_NAME + "0")
    dataset.launch()
    index = 0
    _, blob_fields = dataset.get_blob_fields()
    iterator = dataset.get_next()
    while iterator:
        for blob, raw in iterator:
            raw[blob_fields[0]] = bytes(blob)
            logger.info("#item{}: {}".format(index, raw))
            index += 1
            iterator = dataset.get_next()
    dataset.finish()
    dataset.close()
Esempio n. 5
0
def test_file_writer_fail_add_index():
    """test file writer, read when failed on adding index."""
    data_raw = get_data("../data/mindrecord/testImageNetData/")
    schema_json = {
        "file_name": {
            "type": "string"
        },
        "label": {
            "type": "number"
        }
    }
    header = ShardHeader()
    schema = header.build_schema(schema_json, ["data"], "img")  # create schema
    schema_id = header.add_schema(schema)  # add schema
    with pytest.raises(TypeError, match="missing 1 "):
        ret = header.add_index_fields()
        assert ret == FAILED

    with pytest.raises(MRMAddIndexError):
        index_fields = []
        ret = header.add_index_fields(index_fields)
        assert ret == FAILED

    file_name = os.path.join(os.getcwd(),
                             "test_001.mindrecord")  # set output filename
    writer = ShardWriter()  # test_file_writer
    ret = writer.open([file_name])
    assert ret == SUCCESS, 'failed on opening files.'
    ret = writer.set_shard_header(header)  # write header
    assert ret == SUCCESS, 'failed on setting header.'
    ret = writer.write_raw_cv_data({schema_id: data_raw})
    assert ret == SUCCESS, 'failed on writing raw data.'
    ret = writer.commit()  # commit data
    assert ret == SUCCESS, "commit failed"
    # ShardIndexGenerator
    generator = ShardIndexGenerator(os.path.realpath(file_name))
    generator.build()
    generator.write_to_db()

    reader = ShardReader()
    ret = reader.open(file_name)
    reader.launch()
    index = 0
    _, blob_fields = reader.get_blob_fields()
    iterator = reader.get_next()
    while iterator:
        for blob, raw in iterator:
            raw[blob_fields[0]] = bytes(blob)
            logger.info("#item{}: {}".format(index, raw))
            index += 1
            iterator = reader.get_next()
    reader.finish()
    reader.close()

    os.remove("{}".format(file_name))
    os.remove("{}.db".format(file_name))
def test_mkv_file_reader_with_exactly_schema():
    """test mkv file reader using shard api"""
    dataset = ShardReader()
    dataset.open(MKV_FILE_NAME + "0")
    dataset.launch()
    index = 0
    _, blob_fields = dataset.get_blob_fields()
    iterator = dataset.get_next()
    while iterator:
        for blob, raw in iterator:
            raw[blob_fields[0]] = bytes(blob)
            logger.info("#item{}: {}".format(index, raw))
            index += 1
            iterator = dataset.get_next()
    dataset.finish()
    dataset.close()

    paths = [
        "{}{}".format(MKV_FILE_NAME,
                      str(x).rjust(1, '0')) for x in range(1)
    ]
    for x in paths:
        os.remove("{}".format(x))
        os.remove("{}.db".format(x))