Ejemplo n.º 1
0
def get_dataset(gen_config):
    """
    获取训练数据
    :return: vocab, rev_vocab, dev_set, train_set
    """
    train_path = os.path.join(gen_config.train_dir, "chitchat.train")
    voc_file_path = [train_path + ".answer", train_path + ".query"]
    vocab_path = os.path.join(gen_config.train_dir,
                              "vocab%d.all" % gen_config.vocab_size)
    data_utils.create_vocabulary(vocab_path, voc_file_path,
                                 gen_config.vocab_size)
    vocab, rev_vocab = data_utils.initialize_vocabulary(
        vocab_path)  # {dog: 0, cat: 1} [dog, cat]

    print(just("Preparing Chitchat gen_data in %s" % gen_config.train_dir))
    train_query, train_answer, dev_query, dev_answer = data_utils.prepare_chitchat_data(
        gen_config.train_dir, vocab, gen_config.vocab_size)

    # Read disc_data into buckets and compute their sizes.
    print(
        just("Reading development and training gen_data (limit: %d)." %
             gen_config.max_train_data_size))
    dev_set = read_data(gen_config, dev_query, dev_answer)
    train_set = read_data(gen_config, train_query, train_answer,
                          gen_config.max_train_data_size)

    return vocab, rev_vocab, dev_set, train_set
Ejemplo n.º 2
0
def prepare_data(gen_config):
    if os.path.exists('vocab') and os.path.exists(
            'rev_vocab') and os.path.exists('dev_set') and os.path.exists(
                'train_set'):
        fr_vocab = open('vocab', 'rb')
        fr_rev_vocab = open('rev_vocab', 'rb')
        fr_dev_set = open('dev_set', 'rb')
        fr_train_set = open('train_set', 'rb')
        vocab = pickle.load(fr_vocab)
        rev_vocab = pickle.load(fr_rev_vocab)
        dev_set = pickle.load(fr_dev_set)
        train_set = pickle.load(fr_train_set)
        fr_vocab.close()
        fr_rev_vocab.close()
        fr_dev_set.close()
        fr_train_set.close()
    else:
        train_path = os.path.join(gen_config.train_dir, "chitchat.train")
        voc_file_path = [train_path + ".answer", train_path + ".query"]
        vocab_path = os.path.join(gen_config.train_dir,
                                  "vocab%d.all" % gen_config.vocab_size)
        data_utils.create_vocabulary(vocab_path, voc_file_path,
                                     gen_config.vocab_size)
        vocab, rev_vocab = data_utils.initialize_vocabulary(vocab_path)

        print("Preparing Chitchat gen_data in %s" % gen_config.train_dir)
        train_query, train_answer, dev_query, dev_answer = data_utils.prepare_chitchat_data(
            gen_config.train_dir, vocab, gen_config.vocab_size)

        # Read disc_data into buckets and compute their sizes.
        print("Reading development and training gen_data (limit: %d)." %
              gen_config.max_train_data_size)
        dev_set = read_data(gen_config, dev_query, dev_answer)
        train_set = read_data(gen_config, train_query, train_answer,
                              gen_config.max_train_data_size)

        fw_vocab = open('vocab', 'wb')
        fw_rev_vocab = open('rev_vocab', 'wb')
        fw_dev_set = open('dev_set', 'wb')
        fw_train_set = open('train_set', 'wb')
        pickle.dump(vocab, fw_vocab)
        pickle.dump(rev_vocab, fw_rev_vocab)
        pickle.dump(dev_set, fw_dev_set)
        pickle.dump(train_set, fw_train_set)
        fw_vocab.close()
        fw_rev_vocab.close()
        fw_dev_set.close()
        fw_train_set.close()
    return vocab, rev_vocab, dev_set, train_set
def prepare_data(gen_config):
    train_path = os.path.join(gen_config.data_dir, "chitchat.train")
    voc_file_path = [train_path+".answer", train_path+".query"]
    vocab_path = os.path.join(gen_config.data_dir, "vocab%d.all" % gen_config.vocab_size)
    data_utils.create_vocabulary(vocab_path, voc_file_path, gen_config.vocab_size)
    vocab, rev_vocab = data_utils.initialize_vocabulary(vocab_path)

    print("Preparing Chitchat data in %s" % gen_config.data_dir)
    train_query, train_answer, dev_query, dev_answer = data_utils.prepare_chitchat_data(
        gen_config.data_dir, vocab, gen_config.vocab_size)

    # Read data into buckets and compute their sizes.
    print ("Reading development and training data (limit: %d)."
               % gen_config.max_train_data_size)
    dev_set = read_data(dev_query, dev_answer)
    train_set = read_data(train_query, train_answer, gen_config.max_train_data_size)

    return vocab, rev_vocab, dev_set, train_set
Ejemplo n.º 4
0
def prepare_data(gen_config):
    train_path = os.path.join(gen_config.data_dir, "train")
    test_path = os.path.join(gen_config.data_dir, "test")
    dev_path = os.path.join(gen_config.data_dir, "dev")
    voc_file_path = [
        train_path + ".answer", train_path + ".query", test_path + ".answer",
        test_path + ".query", dev_path + ".answer", dev_path + ".query"
    ]
    vocab_path = os.path.join(gen_config.data_dir,
                              "vocab%d.all" % gen_config.vocab_size)
    data_utils.create_vocabulary(vocab_path, voc_file_path,
                                 gen_config.vocab_size)
    # vocab_path = os.path.join(gen_config.data_dir, "vocab%d.all" % 30000)
    # TODO: change 30000 to 2500

    #其中vocab是word2id,rev_vocab是个list,保存所有的word
    vocab, rev_vocab = data_utils.initialize_vocabulary(vocab_path)
    gen_config.vocab_size = len(vocab)
    # print("Preparing Chitchat gen_data in %s" % gen_config.train_dir)

    #返回的是相应文件的名字
    train_query, train_answer, dev_query, dev_answer, test_query, test_answer = data_utils.prepare_chitchat_data(
        gen_config.data_dir, vocab, gen_config.vocab_size)
    # train_query, train_answer, dev_query, dev_answer = data_utils.prepare_chitchat_data_OpenSub(gen_config.data_dir)

    # Read disc_data into buckets and compute their sizes.
    print("Reading development and training gen_data (limit: %d)." %
          gen_config.max_train_data_size)

    unique_list = []
    train_set, unique_list = read_data(gen_config,
                                       train_query,
                                       train_answer,
                                       unique_list=unique_list)
    dev_set, unique_list = read_data(gen_config,
                                     dev_query,
                                     dev_answer,
                                     unique_list=unique_list)
    test_set, unique_list = read_data(gen_config,
                                      test_query,
                                      test_answer,
                                      unique_list=unique_list)

    return vocab, rev_vocab, test_set, dev_set, train_set
Ejemplo n.º 5
0
def prepare_data(gen_config):
    train_path = os.path.join(gen_config.train_dir, "train")
    voc_file_path = [train_path + ".answer", train_path + ".query"]
    vocab_path = os.path.join(gen_config.train_dir,
                              "vocab%d.all" % gen_config.vocab_size)
    data_utils.create_vocabulary(vocab_path, voc_file_path,
                                 gen_config.vocab_size)
    vocab, rev_vocab = data_utils.initialize_vocabulary(vocab_path)

    print("Preparing Chitchat gen_data in %s" % gen_config.train_dir)
    train_query, train_answer, dev_query, dev_answer = data_utils.prepare_chitchat_data(
        gen_config.train_dir, vocab, gen_config.vocab_size)

    # Read disc_data into buckets and compute their sizes.
    print("Reading development and training gen_data (limit: %d)." %
          gen_config.max_train_data_size)
    dev_set = read_data(gen_config, dev_query, dev_answer)
    #数据格式:train_set[[ [[source],[target]],[[source],[target]] ],....]  最外层的维度为bucket的个数
    train_set = read_data(gen_config, train_query, train_answer,
                          gen_config.max_train_data_size)

    return vocab, rev_vocab, dev_set, train_set
Ejemplo n.º 6
0
def prepare_data(gen_config):
    """
    1. data_utils.create_vocabulary : 創字典
    2. data_utils.initialize_vocabulary: 回傳 train data answer and query & dev data answer and query 的 id path
    3. train set  =  read_data : 製造 bucket (bucket就是每行一問一答句子的id對應),dev & train 都製造bucket
        buckets = [(5, 10), (10, 15), (20, 25), (40, 50)]
        ex:
            Q:How old are you? = [5,7,4,2,3] ; len = 5
            A:I'm six. = [44,6,8] ; len = 3
            bucket = [ [[5,7,4,2,3],[44,6,8]] ],[],[],[]]
            
        也就是說會放把QA放在固定長度範圍的bucket
    """
    
    
    train_path = os.path.join(gen_config.train_dir, "chitchat.train")
    voc_file_path = [train_path+".answer", train_path+".query"]
    vocab_path = os.path.join(gen_config.train_dir, "vocab%d.all" % gen_config.vocab_size)
    # 35000個字
    data_utils.create_vocabulary(vocab_path, voc_file_path, gen_config.vocab_size)
    vocab, rev_vocab = data_utils.initialize_vocabulary(vocab_path)
    # vocab & reverse vocab

    print("Preparing Chitchat gen_data in %s" % gen_config.train_dir)
    train_query, train_answer, dev_query, dev_answer = data_utils.prepare_chitchat_data(
        gen_config.train_dir, vocab, gen_config.vocab_size)

    # Read disc_data into buckets and compute their sizes.
    print ("Reading development and training gen_data (limit: %d)."
               % gen_config.max_train_data_size)
    dev_set = read_data(gen_config, dev_query, dev_answer)
    train_set = read_data(gen_config, train_query, train_answer, gen_config.max_train_data_size)
    print("see what bucket is:")
    print("\n")
    print(dev_set)
    
    return vocab, rev_vocab, dev_set, train_set