Ejemplo n.º 1
0
def test_file_writer_fail_add_index():
    """test file writer, read when failed on adding index."""
    data_raw = get_data("../data/mindrecord/testImageNetData/")
    schema_json = {
        "file_name": {
            "type": "string"
        },
        "label": {
            "type": "number"
        }
    }
    header = ShardHeader()
    schema = header.build_schema(schema_json, ["data"], "img")  # create schema
    schema_id = header.add_schema(schema)  # add schema
    with pytest.raises(TypeError, match="missing 1 "):
        ret = header.add_index_fields()
        assert ret == FAILED

    with pytest.raises(MRMAddIndexError):
        index_fields = []
        ret = header.add_index_fields(index_fields)
        assert ret == FAILED

    file_name = os.path.join(os.getcwd(),
                             "test_001.mindrecord")  # set output filename
    writer = ShardWriter()  # test_file_writer
    ret = writer.open([file_name])
    assert ret == SUCCESS, 'failed on opening files.'
    ret = writer.set_shard_header(header)  # write header
    assert ret == SUCCESS, 'failed on setting header.'
    ret = writer.write_raw_cv_data({schema_id: data_raw})
    assert ret == SUCCESS, 'failed on writing raw data.'
    ret = writer.commit()  # commit data
    assert ret == SUCCESS, "commit failed"
    # ShardIndexGenerator
    generator = ShardIndexGenerator(os.path.realpath(file_name))
    generator.build()
    generator.write_to_db()

    reader = ShardReader()
    ret = reader.open(file_name)
    reader.launch()
    index = 0
    _, blob_fields = reader.get_blob_fields()
    iterator = reader.get_next()
    while iterator:
        for blob, raw in iterator:
            raw[blob_fields[0]] = bytes(blob)
            logger.info("#item{}: {}".format(index, raw))
            index += 1
            iterator = reader.get_next()
    reader.finish()
    reader.close()

    os.remove("{}".format(file_name))
    os.remove("{}.db".format(file_name))
def test_nlp_file_reader():
    """test nlp file reader using shard api"""
    dataset = ShardReader()
    dataset.open(NLP_FILE_NAME + "0")
    dataset.launch()
    index = 0
    iterator = dataset.get_next()
    while iterator:
        for _, raw in iterator:
            logger.info("#item{}: {}".format(index, raw))
            index += 1
            iterator = dataset.get_next()
    dataset.finish()
    dataset.close()
def test_cv_file_reader():
    """test cv file reader using shard api"""
    dataset = ShardReader()
    dataset.open(CV_FILE_NAME + "0")
    dataset.launch()
    index = 0
    _, blob_fields = dataset.get_blob_fields()
    iterator = dataset.get_next()
    while iterator:
        for blob, raw in iterator:
            raw[blob_fields[0]] = bytes(blob)
            logger.info("#item{}: {}".format(index, raw))
            index += 1
            iterator = dataset.get_next()
    dataset.finish()
    dataset.close()
def test_mkv_file_reader_with_exactly_schema():
    """test mkv file reader using shard api"""
    dataset = ShardReader()
    dataset.open(MKV_FILE_NAME + "0")
    dataset.launch()
    index = 0
    _, blob_fields = dataset.get_blob_fields()
    iterator = dataset.get_next()
    while iterator:
        for blob, raw in iterator:
            raw[blob_fields[0]] = bytes(blob)
            logger.info("#item{}: {}".format(index, raw))
            index += 1
            iterator = dataset.get_next()
    dataset.finish()
    dataset.close()

    paths = [
        "{}{}".format(MKV_FILE_NAME,
                      str(x).rjust(1, '0')) for x in range(1)
    ]
    for x in paths:
        os.remove("{}".format(x))
        os.remove("{}.db".format(x))