def test_cv_file_writer_no_raw():
    """test cv file writer without raw data."""
    writer = FileWriter(NLP_FILE_NAME)
    data = list(
        get_nlp_data("../data/mindrecord/testAclImdbData/pos",
                     "../data/mindrecord/testAclImdbData/vocab.txt", 10))
    nlp_schema_json = {
        "input_ids": {
            "type": "int64",
            "shape": [1, -1]
        },
        "input_mask": {
            "type": "int64",
            "shape": [1, -1]
        },
        "segment_ids": {
            "type": "int64",
            "shape": [1, -1]
        }
    }
    writer.add_schema(nlp_schema_json, "no_raw_schema")
    writer.write_raw_data(data)
    writer.commit()
    reader = FileReader(NLP_FILE_NAME)
    count = 0
    for index, x in enumerate(reader.get_next()):
        count += 1
        assert len(x) == 3
        logger.info("#item{}: {}".format(index, x))
    assert count == 10
    reader.close()
    os.remove(NLP_FILE_NAME)
    os.remove("{}.db".format(NLP_FILE_NAME))
def test_nlp_file_writer_tutorial():
    """tutorial for nlp file writer."""
    writer = FileWriter(NLP_FILE_NAME, FILES_NUM)
    data = list(
        get_nlp_data("../data/mindrecord/testAclImdbData/pos",
                     "../data/mindrecord/testAclImdbData/vocab.txt", 10))
    nlp_schema_json = {
        "id": {
            "type": "string"
        },
        "label": {
            "type": "int32"
        },
        "rating": {
            "type": "float32"
        },
        "input_ids": {
            "type": "int64",
            "shape": [1, -1]
        },
        "input_mask": {
            "type": "int64",
            "shape": [1, -1]
        },
        "segment_ids": {
            "type": "int64",
            "shape": [1, -1]
        }
    }
    writer.add_schema(nlp_schema_json, "nlp_schema")
    writer.add_index(["id", "rating"])
    writer.write_raw_data(data)
    writer.commit()
def test_nlp_file_writer():
    """test nlp file writer using shard api"""
    schema_json = {
        "id": {
            "type": "string"
        },
        "label": {
            "type": "number"
        },
        "rating": {
            "type": "number"
        },
        "input_ids": {
            "type": "array",
            "items": {
                "type": "number"
            }
        },
        "input_mask": {
            "type": "array",
            "items": {
                "type": "number"
            }
        },
        "segment_ids": {
            "type": "array",
            "items": {
                "type": "number"
            }
        }
    }
    data = list(
        get_nlp_data("../data/mindrecord/testAclImdbData/pos",
                     "../data/mindrecord/testAclImdbData/vocab.txt", 10))
    header = ShardHeader()
    schema = header.build_schema(schema_json, ["segment_ids"], "nlp_schema")
    schema_id = header.add_schema(schema)
    assert schema_id == 0, 'failed on adding schema'
    index_fields_list = ["id", "rating"]
    ret = header.add_index_fields(index_fields_list)
    assert ret == SUCCESS, 'failed on adding index fields.'
    writer = ShardWriter()
    paths = ["{}{}".format(NLP_FILE_NAME, x) for x in range(FILES_NUM)]
    ret = writer.open(paths)
    assert ret == SUCCESS, 'failed on opening files.'
    writer.set_header_size(1 << 14)
    writer.set_page_size(1 << 15)
    ret = writer.set_shard_header(header)
    assert ret == SUCCESS, 'failed on setting header.'
    ret = writer.write_raw_nlp_data({schema_id: data})
    assert ret == SUCCESS, 'failed on writing raw data.'
    ret = writer.commit()
    assert ret == SUCCESS, 'failed on committing.'
    generator = ShardIndexGenerator(os.path.realpath(paths[0]))
    generator.build()
    generator.write_to_db()
示例#4
0
def test_issue_84():
    """test file reader when db does not match."""
    writer = FileWriter(CV_FILE_NAME, FILES_NUM)
    data = get_data("../data/mindrecord/testImageNetData/")
    cv_schema_json = {"file_name": {"type": "string"},
                      "label": {"type": "number"}, "data": {"type": "bytes"}}
    writer.add_schema(cv_schema_json, "img_schema")
    writer.add_index(["file_name", "label"])
    writer.write_raw_data(data)
    writer.commit()

    writer = FileWriter(NLP_FILE_NAME, FILES_NUM)
    data = list(get_nlp_data("../data/mindrecord/testAclImdbData/pos",
                             "../data/mindrecord/testAclImdbData/vocab.txt",
                             10))
    nlp_schema_json = {"id": {"type": "string"}, "label": {"type": "number"},
                       "rating": {"type": "number"},
                       "input_ids": {"type": "array",
                                     "items": {"type": "number"}},
                       "input_mask": {"type": "array",
                                      "items": {"type": "number"}},
                       "segment_ids": {"type": "array",
                                       "items": {"type": "number"}}
                       }
    writer.set_header_size(1 << 14)
    writer.set_page_size(1 << 15)
    writer.add_schema(nlp_schema_json, "nlp_schema")
    writer.add_index(["id", "rating"])
    writer.write_raw_data(data)
    writer.commit()

    reader = ShardReader()
    os.rename("imagenet.mindrecord1.db", "imagenet.mindrecord1.db.bk")
    os.rename("aclImdb.mindrecord1.db", "imagenet.mindrecord1.db")
    file_name = os.path.join(os.getcwd(), "imagenet.mindrecord1")
    with pytest.raises(Exception) as e:
        reader.open(file_name)
    assert str(e.value) == "[MRMOpenError]: error_code: 1347690596, " \
                           "error_msg: " \
                           "MindRecord File could not open successfully."

    os.rename("imagenet.mindrecord1.db", "aclImdb.mindrecord1.db")
    paths = ["{}{}".format(NLP_FILE_NAME, str(x).rjust(1, '0'))
             for x in range(FILES_NUM)]
    for item in paths:
        os.remove("{}".format(item))
        os.remove("{}.db".format(item))

    os.rename("imagenet.mindrecord1.db.bk", "imagenet.mindrecord1.db")
    paths = ["{}{}".format(CV_FILE_NAME, str(x).rjust(1, '0'))
             for x in range(FILES_NUM)]
    for item in paths:
        os.remove("{}".format(item))
        os.remove("{}.db".format(item))