Exemplo n.º 1
0
def prepare_data_seq(task, batch_size=100, lang=None):
    data_path = 'data/dialog-bAbI-tasks/dialog-babi'
    file_train = '{}-task{}trn.txt'.format(data_path, task)
    file_dev = '{}-task{}dev.txt'.format(data_path, task)
    file_test = '{}-task{}tst.txt'.format(data_path, task)
    kb_path = data_path + '-kb-all.txt'
    file_test_OOV = '{}-task{}tst-OOV.txt'.format(data_path, task)
    type_dict = get_type_dict(kb_path, dstc2=False)
    global_ent = entityList('data/dialog-bAbI-tasks/dialog-babi-kb-all.txt',
                            int(task))

    pair_train, train_max_len = read_langs(file_train, global_ent, type_dict)
    pair_dev, dev_max_len = read_langs(file_dev, global_ent, type_dict)
    pair_test, test_max_len = read_langs(file_test, global_ent, type_dict)
    pair_testoov, testoov_max_len = read_langs(file_test_OOV, global_ent,
                                               type_dict)
    max_resp_len = max(train_max_len, dev_max_len, test_max_len,
                       testoov_max_len) + 1
    if lang is None:
        lang = Lang()

    train = get_seq(pair_train, lang, batch_size, True)
    dev = get_seq(pair_dev, lang, 100, False)
    test = get_seq(pair_test, lang, batch_size, False)
    testoov = get_seq(pair_testoov, lang, batch_size, False)

    print("Read %s sentence pairs train" % len(pair_train))
    print("Read %s sentence pairs dev" % len(pair_dev))
    print("Read %s sentence pairs test" % len(pair_test))
    print("Vocab_size: %s " % lang.n_words)
    print("Max. length of system response: %s " % max_resp_len)
    print("USE_CUDA={}".format(USE_CUDA))

    return train, dev, test, testoov, lang, max_resp_len
Exemplo n.º 2
0
def get_data_seq(file_name, lang, max_len, task=5, batch_size=1):
    data_path = 'data/dialog-bAbI-tasks/dialog-babi'
    kb_path = data_path + '-kb-all.txt'
    type_dict = get_type_dict(kb_path, dstc2=False)
    global_ent = entityList(kb_path, int(task))
    pair, _ = read_langs(file_name, global_ent, type_dict)
    # print("pair", pair)
    d = get_seq(pair, lang, batch_size, False)
    return d
Exemplo n.º 3
0
def prepare_data_seq(task, batch_size=100):
    data_path = 'data/dialog-bAbI-tasks/dialog-babi'
    file_train = '{}-task{}trn.txt'.format(data_path, task)
    file_dev = '{}-task{}dev.txt'.format(data_path, task)
    file_test = '{}-task{}tst.txt'.format(data_path, task)
    kb_path = data_path + '-kb-all.txt'
    file_test_OOV = '{}-task{}tst-OOV.txt'.format(
        data_path, task)  #OOV文件是out of vocabulary测试文件
    type_dict = get_type_dict(
        kb_path, dstc2=False)  #三元组词典,{饭店和关系作为键,其他作为值}饭店都是Subject,其他都是Object
    global_ent = entityList('data/dialog-bAbI-tasks/dialog-babi-kb-all.txt',
                            int(task))  #收集所有的Subject和Object

    pair_train, train_max_len = read_langs(file_train, global_ent, type_dict)
    pair_dev, dev_max_len = read_langs(file_dev, global_ent, type_dict)
    pair_test, test_max_len = read_langs(file_test, global_ent, type_dict)
    pair_testoov, testoov_max_len = read_langs(file_test_OOV, global_ent,
                                               type_dict)
    max_resp_len = max(train_max_len, dev_max_len, test_max_len,
                       testoov_max_len) + 1

    lang = Lang()

    train = get_seq(pair_train, lang, batch_size, True)
    #最后一个参数的解释:因为lang在train数据集上需要做进行编码,后面的dev和test使用的也是lang,所以不需要再次hash,改为first
    dev = get_seq(pair_dev, lang, 100, False)
    test = get_seq(pair_test, lang, batch_size, False)
    testoov = get_seq(pair_testoov, lang, batch_size, False)

    print("Read %s sentence pairs train" % len(pair_train))
    print("Read %s sentence pairs dev" % len(pair_dev))
    print("Read %s sentence pairs test" % len(pair_test))
    print("Vocab_size: %s " % lang.n_words)
    print("Max. length of system response: %s " % max_resp_len)
    print("USE_CUDA={}".format(USE_CUDA))

    return train, dev, test, testoov, lang, max_resp_len