コード例 #1
0
def test_shard_4_raw_data_1():
    """test file writer when shard_num equals 4 and number of sample equals 1."""
    writer = FileWriter(CV_FILE_NAME, FILES_NUM)
    schema_json = {
        "file_name": {
            "type": "string"
        },
        "label": {
            "type": "number"
        }
    }
    writer.add_schema(schema_json, "img_schema")
    writer.add_index(["label"])
    data = [{"file_name": "001.jpg", "label": 1}]
    writer.write_raw_data(data)
    writer.commit()

    reader = FileReader(CV_FILE_NAME + "0")
    count = 0
    for index, x in enumerate(reader.get_next()):
        assert len(x) == 2
        count = count + 1
        logger.info("#item{}: {}".format(index, x))
    assert count == 1
    reader.close()
    paths = [
        "{}{}".format(CV_FILE_NAME,
                      str(x).rjust(1, '0')) for x in range(FILES_NUM)
    ]
    for x in paths:
        os.remove("{}".format(x))
        os.remove("{}.db".format(x))
コード例 #2
0
def skip_test_issue_155():
    """test file writer loop."""
    writer = FileWriter(CV_FILE_NAME, FILES_NUM)
    data = get_data("../data/mindrecord/testImageNetData/")
    cv_schema_json = {
        "file_name": {
            "type": "string"
        },
        "label": {
            "type": "number"
        },
        "data": {
            "type": "bytes"
        }
    }
    writer.add_schema(cv_schema_json, "img_schema")
    writer.add_index(["file_name"])
    for _ in range(1000):
        writer.write_raw_data(data)
    writer.commit()
    reader = FileReader(CV_FILE_NAME + "0")
    count = 0
    for _ in reader.get_next():
        count += 1
    assert count == 10000, "Failed to read multiple writed data."
コード例 #3
0
def test_issue_34():
    """test file writer"""
    writer = FileWriter(CV_FILE_NAME)
    data = get_data("../data/mindrecord/testImageNetData/")
    cv_schema_json = {
        "file_name": {
            "type": "string"
        },
        "label": {
            "type": "number"
        },
        "data": {
            "type": "bytes"
        }
    }
    writer.add_schema(cv_schema_json, "cv_schema")
    writer.add_index(["file_name", "label"])
    writer.write_raw_data(data)
    writer.commit()

    reader = FileReader(CV_FILE_NAME)
    i = 0
    for index, x in enumerate(reader.get_next()):
        logger.info("#item{}: {}".format(index, x))
        i = i + 1
    logger.info("count: {}".format(i))
    reader.close()
    os.remove(CV_FILE_NAME)
    os.remove("{}.db".format(CV_FILE_NAME))
コード例 #4
0
def test_cv_file_writer_shard_num_10():
    """test cv dataset writer when shard_num equals 10."""
    shard_num = 10
    writer = FileWriter(CV_FILE_NAME, shard_num)
    data = get_data("../data/mindrecord/testImageNetData/")
    cv_schema_json = {
        "file_name": {
            "type": "string"
        },
        "label": {
            "type": "number"
        },
        "data": {
            "type": "bytes"
        }
    }
    writer.add_schema(cv_schema_json, "img_schema")
    writer.add_index(["file_name", "label"])
    writer.write_raw_data(data)
    writer.commit()
    reader = FileReader(CV_FILE_NAME + "0")
    for index, x in enumerate(reader.get_next()):
        logger.info("#item{}: {}".format(index, x))
    reader.close()

    paths = [
        "{}{}".format(CV_FILE_NAME,
                      str(x).rjust(1, '0')) for x in range(shard_num)
    ]
    for item in paths:
        os.remove("{}".format(item))
        os.remove("{}.db".format(item))
コード例 #5
0
def test_issue_87():
    """test file writer when data(bytes) do not match field type(string)."""
    shard_num = 4
    writer = FileWriter(CV_FILE_NAME, shard_num)
    data = get_data("../data/mindrecord/testImageNetData/")
    cv_schema_json = {
        "file_name": {
            "type": "string"
        },
        "label": {
            "type": "number"
        },
        "data": {
            "type": "string"
        }
    }
    writer.add_schema(cv_schema_json, "img_schema")
    writer.add_index(["label"])
    with pytest.raises(Exception, match="data is wrong"):
        writer.write_raw_data(data, False)
        writer.commit()

    paths = [
        "{}{}".format(CV_FILE_NAME,
                      str(x).rjust(1, '0')) for x in range(shard_num)
    ]
    for item in paths:
        os.remove("{}".format(item))
コード例 #6
0
def test_issue_39():
    """test cv dataset writer when schema fields' datatype does not match raw data."""
    writer = FileWriter(CV_FILE_NAME, 1)
    data = get_data("../data/mindrecord/testImageNetData/")
    cv_schema_json = {
        "file_name": {
            "type": "number"
        },
        "label": {
            "type": "number"
        },
        "data": {
            "type": "bytes"
        }
    }
    writer.add_schema(cv_schema_json, "img_schema")
    writer.add_index(["file_name", "label"])
    writer.write_raw_data(data)
    writer.commit()

    reader = FileReader(CV_FILE_NAME)
    index = 0
    for _ in reader.get_next():
        index += 1
    assert index == 0, "failed on reading data!"
    reader.close()
    os.remove("{}".format(CV_FILE_NAME))
    os.remove("{}.db".format(CV_FILE_NAME))
コード例 #7
0
def test_issue_73():
    """test file reader by column name."""
    writer = FileWriter(MKV_FILE_NAME, FILES_NUM)
    data = get_mkv_data("../data/mindrecord/testVehPerData/")
    mkv_schema_json = {
        "file_name": {
            "type": "string"
        },
        "id": {
            "type": "number"
        },
        "prelabel": {
            "type": "string"
        },
        "data": {
            "type": "bytes"
        }
    }
    writer.add_schema(mkv_schema_json, "mkv_schema")
    writer.add_index(["file_name", "prelabel"])
    writer.write_raw_data(data)
    writer.commit()

    reader = FileReader(MKV_FILE_NAME + "1", 4, ["file_name"])
    for index, x in enumerate(reader.get_next()):
        logger.info("#item{}: {}".format(index, x))
    reader.close()

    paths = [
        "{}{}".format(MKV_FILE_NAME,
                      str(x).rjust(1, '0')) for x in range(FILES_NUM)
    ]
    for x in paths:
        os.remove("{}".format(x))
        os.remove("{}.db".format(x))
コード例 #8
0
def test_write_raw_data_with_empty_list():
    """test write raw data with empty list."""
    writer = FileWriter(CV_FILE_NAME, FILES_NUM)
    cv_schema_json = {
        "file_name": {
            "type": "string"
        },
        "label": {
            "type": "number"
        },
        "data": {
            "type": "bytes"
        }
    }
    writer.add_schema(cv_schema_json, "img_schema")
    writer.add_index(["file_name", "label"])
    ret = writer.write_raw_data([])
    assert ret == SUCCESS
    writer.commit()

    reader = FileReader(CV_FILE_NAME + "0")
    for index, x in enumerate(reader.get_next()):
        logger.info("#item{}: {}".format(index, x))
    reader.close()

    paths = [
        "{}{}".format(CV_FILE_NAME,
                      str(x).rjust(1, '0')) for x in range(FILES_NUM)
    ]
    for x in paths:
        os.remove("{}".format(x))
        os.remove("{}.db".format(x))
コード例 #9
0
def test_mkv_file_reader_with_negative_num_consumer():
    """test mkv file reader when the number of consumer is negative."""
    writer = FileWriter(MKV_FILE_NAME, FILES_NUM)
    data = get_mkv_data("../data/mindrecord/testVehPerData/")
    mkv_schema_json = {
        "file_name": {
            "type": "string"
        },
        "id": {
            "type": "number"
        },
        "prelabel": {
            "type": "string"
        },
        "data": {
            "type": "bytes"
        }
    }
    writer.add_schema(mkv_schema_json, "mkv_schema")
    writer.add_index(["file_name", "prelabel"])
    writer.write_raw_data(data)
    writer.commit()

    with pytest.raises(Exception) as e:
        FileReader(MKV_FILE_NAME + "1", -1)
    assert "Consumer number should between 1 and" in str(e.value)

    paths = [
        "{}{}".format(MKV_FILE_NAME,
                      str(x).rjust(1, '0')) for x in range(FILES_NUM)
    ]
    for x in paths:
        os.remove("{}".format(x))
        os.remove("{}.db".format(x))
コード例 #10
0
def test_cv_file_writer_default_shard_num():
    """test cv dataset writer when shard_num is default value."""
    writer = FileWriter(CV_FILE_NAME)
    data = get_data("../data/mindrecord/testImageNetData/")
    cv_schema_json = {
        "file_name": {
            "type": "string"
        },
        "label": {
            "type": "number"
        },
        "data": {
            "type": "bytes"
        }
    }
    writer.add_schema(cv_schema_json, "img_schema")
    writer.add_index(["file_name", "label"])
    writer.write_raw_data(data)
    writer.commit()
    reader = FileReader(CV_FILE_NAME)
    for index, x in enumerate(reader.get_next()):
        logger.info("#item{}: {}".format(index, x))
    reader.close()

    os.remove("{}".format(CV_FILE_NAME))
    os.remove("{}.db".format(CV_FILE_NAME))
コード例 #11
0
def add_and_remove_nlp_file():
    """add/remove nlp file"""
    paths = ["{}{}".format(NLP_FILE_NAME, str(x).rjust(1, '0'))
             for x in range(FILES_NUM)]
    for x in paths:
        if os.path.exists("{}".format(x)):
            os.remove("{}".format(x))
        if os.path.exists("{}.db".format(x)):
            os.remove("{}.db".format(x))
    writer = FileWriter(NLP_FILE_NAME, FILES_NUM)
    data = [x for x in get_nlp_data(NLP_FILE_POS, NLP_FILE_VOCAB, 10)]
    nlp_schema_json = {"id": {"type": "string"}, "label": {"type": "int32"},
                       "rating": {"type": "float32"},
                       "input_ids": {"type": "int64",
                                     "shape": [-1]},
                       "input_mask": {"type": "int64",
                                      "shape": [1, -1]},
                       "segment_ids": {"type": "int64",
                                       "shape": [2, -1]}
                       }
    writer.set_header_size(1 << 14)
    writer.set_page_size(1 << 15)
    writer.add_schema(nlp_schema_json, "nlp_schema")
    writer.add_index(["id", "rating"])
    writer.write_raw_data(data)
    writer.commit()
    yield "yield_nlp_data"
    for x in paths:
        os.remove("{}".format(x))
        os.remove("{}.db".format(x))
コード例 #12
0
def create_diff_page_size_cv_mindrecord(files_num):
    """tutorial for cv dataset writer."""
    if os.path.exists(CV1_FILE_NAME):
        os.remove(CV1_FILE_NAME)
    if os.path.exists("{}.db".format(CV1_FILE_NAME)):
        os.remove("{}.db".format(CV1_FILE_NAME))
    writer = FileWriter(CV1_FILE_NAME, files_num)
    writer.set_page_size(1 << 26)  # 64MB
    cv_schema_json = {
        "file_name": {
            "type": "string"
        },
        "label": {
            "type": "int32"
        },
        "data": {
            "type": "bytes"
        }
    }
    data = [{
        "file_name": "001.jpg",
        "label": 43,
        "data": bytes('0xffsafdafda', encoding='utf-8')
    }]
    writer.add_schema(cv_schema_json, "img_schema")
    writer.add_index(["file_name", "label"])
    writer.write_raw_data(data)
    writer.commit()
コード例 #13
0
def test_cv_file_append_writer_absolute_path():
    """tutorial for cv dataset append writer."""
    writer = FileWriter(CV4_FILE_NAME, 4)
    data = get_data("../data/mindrecord/testImageNetData/")
    cv_schema_json = {"file_name": {"type": "string"},
                      "label": {"type": "int64"}, "data": {"type": "bytes"}}
    writer.add_schema(cv_schema_json, "img_schema")
    writer.add_index(["file_name", "label"])
    writer.write_raw_data(data[0:5])
    writer.commit()
    write_append = FileWriter.open_for_append(CV4_FILE_NAME + "0")
    write_append.write_raw_data(data[5:10])
    write_append.commit()
    reader = FileReader(CV4_FILE_NAME + "0")
    count = 0
    for index, x in enumerate(reader.get_next()):
        assert len(x) == 3
        count = count + 1
        logger.info("#item{}: {}".format(index, x))
    assert count == 10
    reader.close()

    paths = ["{}{}".format(CV4_FILE_NAME, str(x).rjust(1, '0'))
             for x in range(4)]
    for x in paths:
        os.remove("{}".format(x))
        os.remove("{}.db".format(x))
コード例 #14
0
def test_cv_file_writer_no_blob():
    """test cv file writer without blob data."""
    writer = FileWriter(CV_FILE_NAME, 1)
    data = get_data("../data/mindrecord/testImageNetData/")
    cv_schema_json = {
        "file_name": {
            "type": "string"
        },
        "label": {
            "type": "int64"
        }
    }
    writer.add_schema(cv_schema_json, "no_blob_schema")
    writer.add_index(["file_name", "label"])
    writer.write_raw_data(data)
    writer.commit()
    reader = FileReader(CV_FILE_NAME)
    count = 0
    for index, x in enumerate(reader.get_next()):
        count += 1
        assert len(x) == 2
        logger.info("#item{}: {}".format(index, x))
    assert count == 10
    reader.close()
    os.remove(CV_FILE_NAME)
    os.remove("{}.db".format(CV_FILE_NAME))
コード例 #15
0
def test_cv_file_writer_absolute_path():
    """test cv file writer when file name is absolute path."""
    file_name = "/tmp/" + str(uuid.uuid4())
    writer = FileWriter(file_name, FILES_NUM)
    data = get_data("../data/mindrecord/testImageNetData/")
    cv_schema_json = {
        "file_name": {
            "type": "string"
        },
        "label": {
            "type": "int64"
        },
        "data": {
            "type": "bytes"
        }
    }
    writer.add_schema(cv_schema_json, "img_schema")
    writer.add_index(["file_name", "label"])
    writer.write_raw_data(data)
    writer.commit()

    paths = [
        "{}{}".format(file_name,
                      str(x).rjust(1, '0')) for x in range(FILES_NUM)
    ]
    for x in paths:
        os.remove("{}".format(x))
        os.remove("{}.db".format(x))
コード例 #16
0
def test_cv_file_writer_without_data():
    """test cv file writer without data."""
    writer = FileWriter(CV_FILE_NAME, 1)
    cv_schema_json = {
        "file_name": {
            "type": "string"
        },
        "label": {
            "type": "int64"
        },
        "data": {
            "type": "bytes"
        }
    }
    writer.add_schema(cv_schema_json, "img_schema")
    writer.add_index(["file_name", "label"])
    writer.commit()
    reader = FileReader(CV_FILE_NAME)
    count = 0
    for index, x in enumerate(reader.get_next()):
        count = count + 1
        logger.info("#item{}: {}".format(index, x))
    assert count == 0
    reader.close()
    os.remove(CV_FILE_NAME)
    os.remove("{}.db".format(CV_FILE_NAME))
コード例 #17
0
def test_nlp_file_writer_tutorial():
    """tutorial for nlp file writer."""
    writer = FileWriter(NLP_FILE_NAME, FILES_NUM)
    data = list(
        get_nlp_data("../data/mindrecord/testAclImdbData/pos",
                     "../data/mindrecord/testAclImdbData/vocab.txt", 10))
    nlp_schema_json = {
        "id": {
            "type": "string"
        },
        "label": {
            "type": "int32"
        },
        "rating": {
            "type": "float32"
        },
        "input_ids": {
            "type": "int64",
            "shape": [1, -1]
        },
        "input_mask": {
            "type": "int64",
            "shape": [1, -1]
        },
        "segment_ids": {
            "type": "int64",
            "shape": [1, -1]
        }
    }
    writer.add_schema(nlp_schema_json, "nlp_schema")
    writer.add_index(["id", "rating"])
    writer.write_raw_data(data)
    writer.commit()
コード例 #18
0
def test_cv_file_writer_shard_num_10():
    """test file writer when shard num equals 10."""
    writer = FileWriter(CV_FILE_NAME, 10)
    data = get_data("../data/mindrecord/testImageNetData/")
    cv_schema_json = {
        "file_name": {
            "type": "string"
        },
        "label": {
            "type": "int64"
        },
        "data": {
            "type": "bytes"
        }
    }
    writer.add_schema(cv_schema_json, "img_schema")
    writer.add_index(["file_name", "label"])
    writer.write_raw_data(data)
    writer.commit()

    paths = [
        "{}{}".format(CV_FILE_NAME,
                      str(x).rjust(1, '0')) for x in range(10)
    ]
    for x in paths:
        os.remove("{}".format(x))
        os.remove("{}.db".format(x))
コード例 #19
0
def add_and_remove_cv_file():
    """add/remove cv file"""
    paths = ["{}{}".format(CV_FILE_NAME, str(x).rjust(1, '0'))
             for x in range(FILES_NUM)]
    try:
        for x in paths:
            if os.path.exists("{}".format(x)):
                os.remove("{}".format(x))
            if os.path.exists("{}.db".format(x)):
                os.remove("{}.db".format(x))
        writer = FileWriter(CV_FILE_NAME, FILES_NUM)
        data = get_data(CV_DIR_NAME)
        cv_schema_json = {"id": {"type": "int32"},
                          "file_name": {"type": "string"},
                          "label": {"type": "int32"},
                          "data": {"type": "bytes"}}
        writer.add_schema(cv_schema_json, "img_schema")
        writer.add_index(["file_name", "label"])
        writer.write_raw_data(data)
        writer.commit()
        yield "yield_cv_data"
    except Exception as error:
        for x in paths:
            os.remove("{}".format(x))
            os.remove("{}.db".format(x))
        raise error
    else:
        for x in paths:
            os.remove("{}".format(x))
            os.remove("{}.db".format(x))
コード例 #20
0
def test_cv_file_writer_loop_and_read():
    """tutorial for cv dataset loop writer."""
    writer = FileWriter(CV2_FILE_NAME, FILES_NUM)
    data = get_data("../data/mindrecord/testImageNetData/")
    cv_schema_json = {"file_name": {"type": "string"},
                      "label": {"type": "int64"}, "data": {"type": "bytes"}}
    writer.add_schema(cv_schema_json, "img_schema")
    writer.add_index(["file_name", "label"])
    for row in data:
        writer.write_raw_data([row])
    writer.commit()

    reader = FileReader(CV2_FILE_NAME + "0")
    count = 0
    for index, x in enumerate(reader.get_next()):
        assert len(x) == 3
        count = count + 1
        logger.info("#item{}: {}".format(index, x))
    assert count == 10
    reader.close()

    paths = ["{}{}".format(CV2_FILE_NAME, str(x).rjust(1, '0'))
             for x in range(FILES_NUM)]
    for x in paths:
        os.remove("{}".format(x))
        os.remove("{}.db".format(x))
コード例 #21
0
def test_file_writer_raw_data_038():
    """test write raw data without verify."""
    shard_num = 11
    writer = FileWriter("test_file_writer_raw_data_", shard_num)
    data_raw = get_data("../data/mindrecord/testImageNetData/")
    schema_json = {"file_name": {"type": "string"}, "label": {"type": "number"},
                   "data": {"type": "bytes"}}
    writer.add_schema(schema_json, "img_schema")
    writer.add_index(["file_name"])
    for _ in range(shard_num):
        writer.write_raw_data(data_raw, False)
    writer.commit()

    file_name = ""
    if shard_num > 1:
        file_name = '99' if shard_num > 99 else str(shard_num - 1)
    reader = FileReader("test_file_writer_raw_data_" + file_name)
    i = 0
    for _, _ in enumerate(reader.get_next()):
        i = i + 1
    assert i == shard_num * 10
    reader.close()
    if shard_num == 1:
        os.remove("test_file_writer_raw_data_")
        os.remove("test_file_writer_raw_data_.db")
        return
    for x in range(shard_num):
        n = str(x)
        if shard_num > 10:
            n = '0' + str(x) if x < 10 else str(x)
        if os.path.exists("test_file_writer_raw_data_{}".format(n)):
            os.remove("test_file_writer_raw_data_{}".format(n))
        if os.path.exists("test_file_writer_raw_data_{}.db".format(n)):
            os.remove("test_file_writer_raw_data_{}.db".format(n))
コード例 #22
0
def convert_to_mindrecord(features, labels, mindrecord_path):
    schema_json = {"id": {"type": "int32"},
                   "label": {"type": "int32"},
                   "feature": {"type": "int32", "shape": [-1]}}
    if not os.path.exists(mindrecord_path):
        os.makedirs(mindrecord_path)
    else:
        print(mindrecord_path, 'exists. Please make sure it is empty!')
    file_name = os.path.join(mindrecord_path, 'style.mindrecord')
    print('writing mindrecord into', file_name)
    def get_imdb_data(features, labels):
        data_list = []
        for i, (label, feature) in enumerate(zip(labels, features)):
            data_json = {"id": i,
                         "label": int(label),
                         "feature": feature.reshape(-1)}
            data_list.append(data_json)
        return data_list
    writer = FileWriter(file_name, shard_num=4)
    data = get_imdb_data(features, labels)
    writer.add_schema(schema_json, "style_schema")
    writer.add_index(["id", "label"])
    writer.write_raw_data(data)
    writer.commit()
    print('done')
コード例 #23
0
ファイル: main.py プロジェクト: shirley18411/course
def _convert_to_mindrecord(data_home, features, labels, weight_np=None, training=True):
    """
    convert imdb dataset to mindrecoed dataset
    """
    if weight_np is not None:
        np.savetxt(os.path.join(data_home, 'weight.txt'), weight_np)

    # write mindrecord
    schema_json = {"id": {"type": "int32"},
                   "label": {"type": "int32"},
                   "feature": {"type": "int32", "shape": [-1]}}

    data_dir = os.path.join(data_home, "aclImdb_train.mindrecord")
    if not training:
        data_dir = os.path.join(data_home, "aclImdb_test.mindrecord")

    def get_imdb_data(features, labels):
        data_list = []
        for i, (label, feature) in enumerate(zip(labels, features)):
            data_json = {"id": i,
                         "label": int(label),
                         "feature": feature.reshape(-1)}
            data_list.append(data_json)
        return data_list

    writer = FileWriter(data_dir, shard_num=4)
    data = get_imdb_data(features, labels)
    writer.add_schema(schema_json, "nlp_schema")
    writer.add_index(["id", "label"])
    writer.write_raw_data(data)
    writer.commit()
コード例 #24
0
ファイル: test_minddataset.py プロジェクト: zimaxeg/mindspore
def test_cv_minddataset_writer_tutorial():
    """tutorial for cv dataset writer."""
    paths = [
        "{}{}".format(CV_FILE_NAME,
                      str(x).rjust(1, '0')) for x in range(FILES_NUM)
    ]
    for x in paths:
        os.remove("{}".format(x)) if os.path.exists("{}".format(x)) else None
        os.remove("{}.db".format(x)) if os.path.exists(
            "{}.db".format(x)) else None
    writer = FileWriter(CV_FILE_NAME, FILES_NUM)
    data = get_data(CV_DIR_NAME)
    cv_schema_json = {
        "file_name": {
            "type": "string"
        },
        "label": {
            "type": "int32"
        },
        "data": {
            "type": "bytes"
        }
    }
    writer.add_schema(cv_schema_json, "img_schema")
    writer.add_index(["file_name", "label"])
    writer.write_raw_data(data)
    writer.commit()
    for x in paths:
        os.remove("{}".format(x))
        os.remove("{}.db".format(x))
コード例 #25
0
def convert_to_mindrecord(info_name, file_path, store_path, mr_name,
                          num_classes):
    """ convert dataset to mindrecord """
    num_shard = 4
    data, label = generator_md(info_name, file_path, num_classes)
    schema_json = {
        "id": {
            "type": "int32"
        },
        "feature": {
            "type": "float32",
            "shape": [1, 96, 1366]
        },
        "label": {
            "type": "int32",
            "shape": [num_classes]
        }
    }

    writer = FileWriter(
        os.path.join(store_path, '{}.mindrecord'.format(mr_name)), num_shard)
    datax = get_data(data, label)
    writer.add_schema(schema_json, "music_tagger_schema")
    writer.add_index(["id"])
    writer.write_raw_data(datax)
    writer.commit()
コード例 #26
0
def write_mindrecord_tutorial():
    writer = FileWriter(MINDRECORD_FILE_NAME)
    data = get_data("./ImageNetDataSimulation")
    schema_json = {
        "file_name": {
            "type": "string"
        },
        "label": {
            "type": "int64"
        },
        "data": {
            "type": "bytes"
        }
    }
    writer.add_schema(schema_json, "img_schema")
    writer.add_index(["file_name", "label"])
    writer.write_raw_data(data)
    writer.commit()

    reader = FileReader(MINDRECORD_FILE_NAME)
    count = 0
    for index, x in enumerate(reader.get_next()):
        assert len(x) == 3
        count = count + 1
        # print("#item {}: {}".format(index, x))
    assert count == 20
    reader.close()
コード例 #27
0
def json到minecord(路径,numpy路径,最终文件名):

    #路径 = "../data/预训练数据.json"
    输入表单 = 读取训练数据_A(路径)
    词_数表路径 = "../data/词_数50000.json"
    数_词表路径 = "../data/数_词50000.json"

    if os.path.isfile(词_数表路径) and os.path.isfile(数_词表路径):
        词_数表, 数_词表 = 读出引索(词_数表路径, 数_词表路径)



        numpy数组路径 = numpy路径
        if os.path.isfile(numpy数组路径):
            npz文件 = np.load(numpy数组路径, allow_pickle=True)
            输出np, 输入np = npz文件["输出np"], npz文件["输入np"]
        else:

            生成训练用numpy数组_B(输入表单, 词_数表, numpy数组路径)
            npz文件 = np.load(numpy数组路径, allow_pickle=True)
            输出np, 输入np = npz文件["输出np"], npz文件["输入np"]

        if os.path.isfile(numpy数组路径):
            npz文件 = np.load(numpy数组路径)
            输出np, 输入np = npz文件["输出np"], npz文件["输入np"]
        else:
            print("训练用numpy数组不存在")
        数据_表 = []
        print("正在打包numpy数组为mindspore所需json格式......")
        for i in range(输入np.shape[0]):

            输入_分 = 输入np[i:i+1, :]
            输入_分 = 输入_分.reshape(-1)
            输出_分 = 输出np[i:i+1, :]
            输出_分 = 输出_分.reshape(-1)
            数据_json = {"id": i, "input": 输入_分.astype(np.int32), "output": 输出_分.astype(np.int32)}
            数据_表.append(数据_json)

        纲要_json = {"id": {"type": "int32"},
                      "input": {"type": "int32", "shape": [-1]},
                      "output": {"type": "int32", "shape": [-1]}}
        if os.path.isfile("../data/mindrecord/"+最终文件名+".minecord.db"):
            os.remove("../data/mindrecord/"+最终文件名+".minecord.db")
        if os.path.isfile("../data/mindrecord/"+最终文件名+".minecord"):
            os.remove("../data/mindrecord/"+最终文件名+".minecord")
        print("正在写入mindspore格式......")
        writer = FileWriter("../data/mindrecord/"+最终文件名+".minecord", shard_num=1)
        writer.add_schema(纲要_json, "nlp_1")
        writer.add_index(["id"])
        writer.write_raw_data(数据_表)
        writer.commit()
        print("写入mindspore格式完成。")

    else:

        print('词_数表路径或数_词表路径不存在')
コード例 #28
0
def test_cv_file_writer_tutorial():
    """tutorial for cv dataset writer."""
    writer = FileWriter(CV_FILE_NAME, FILES_NUM)
    data = get_data("../data/mindrecord/testImageNetData/")
    cv_schema_json = {"file_name": {"type": "string"},
                      "label": {"type": "int64"}, "data": {"type": "bytes"}}
    writer.add_schema(cv_schema_json, "img_schema")
    writer.add_index(["file_name", "label"])
    writer.write_raw_data(data)
    writer.commit()
コード例 #29
0
ファイル: gen_mindrecord.py プロジェクト: xyg320/mindspore
def gen_mindrecord(data_type):
    """gen mindreocrd according exactly schema"""
    if data_type == "train":
        fw = FileWriter(MINDRECORD_FILE_NAME_TRAIN)
    else:
        fw = FileWriter(MINDRECORD_FILE_NAME_TEST)

    schema = {
        "id": {
            "type": "int32"
        },
        "label": {
            "type": "int32"
        },
        "score": {
            "type": "int32"
        },
        "input_ids": {
            "type": "int32",
            "shape": [-1]
        },
        "input_mask": {
            "type": "int32",
            "shape": [-1]
        },
        "segment_ids": {
            "type": "int32",
            "shape": [-1]
        }
    }
    fw.add_schema(schema, "aclImdb preprocessed dataset")
    fw.add_index(["id", "label", "score"])

    vocab_dict = load_vocab(os.path.join(ACLIMDB_DIR, "imdb.vocab"))

    get_data_iter = get_nlp_data(os.path.join(ACLIMDB_DIR, data_type),
                                 vocab_dict)

    batch_size = 256
    transform_count = 0
    while True:
        data_list = []
        try:
            for _ in range(batch_size):
                data_list.append(get_data_iter.__next__())
                transform_count += 1
            fw.write_raw_data(data_list)
            print(">> transformed {} record...".format(transform_count))
        except StopIteration:
            if data_list:
                fw.write_raw_data(data_list)
                print(">> transformed {} record...".format(transform_count))
            break

    fw.commit()
コード例 #30
0
def test_issue_84():
    """test file reader when db does not match."""
    writer = FileWriter(CV_FILE_NAME, FILES_NUM)
    data = get_data("../data/mindrecord/testImageNetData/")
    cv_schema_json = {"file_name": {"type": "string"},
                      "label": {"type": "number"}, "data": {"type": "bytes"}}
    writer.add_schema(cv_schema_json, "img_schema")
    writer.add_index(["file_name", "label"])
    writer.write_raw_data(data)
    writer.commit()

    writer = FileWriter(NLP_FILE_NAME, FILES_NUM)
    data = list(get_nlp_data("../data/mindrecord/testAclImdbData/pos",
                             "../data/mindrecord/testAclImdbData/vocab.txt",
                             10))
    nlp_schema_json = {"id": {"type": "string"}, "label": {"type": "number"},
                       "rating": {"type": "number"},
                       "input_ids": {"type": "array",
                                     "items": {"type": "number"}},
                       "input_mask": {"type": "array",
                                      "items": {"type": "number"}},
                       "segment_ids": {"type": "array",
                                       "items": {"type": "number"}}
                       }
    writer.set_header_size(1 << 14)
    writer.set_page_size(1 << 15)
    writer.add_schema(nlp_schema_json, "nlp_schema")
    writer.add_index(["id", "rating"])
    writer.write_raw_data(data)
    writer.commit()

    reader = ShardReader()
    os.rename("imagenet.mindrecord1.db", "imagenet.mindrecord1.db.bk")
    os.rename("aclImdb.mindrecord1.db", "imagenet.mindrecord1.db")
    file_name = os.path.join(os.getcwd(), "imagenet.mindrecord1")
    with pytest.raises(Exception) as e:
        reader.open(file_name)
    assert str(e.value) == "[MRMOpenError]: error_code: 1347690596, " \
                           "error_msg: " \
                           "MindRecord File could not open successfully."

    os.rename("imagenet.mindrecord1.db", "aclImdb.mindrecord1.db")
    paths = ["{}{}".format(NLP_FILE_NAME, str(x).rjust(1, '0'))
             for x in range(FILES_NUM)]
    for item in paths:
        os.remove("{}".format(item))
        os.remove("{}.db".format(item))

    os.rename("imagenet.mindrecord1.db.bk", "imagenet.mindrecord1.db")
    paths = ["{}{}".format(CV_FILE_NAME, str(x).rjust(1, '0'))
             for x in range(FILES_NUM)]
    for item in paths:
        os.remove("{}".format(item))
        os.remove("{}.db".format(item))