Пример #1
0
def test_decoder(gen_config):
    with tf.Session() as sess:
        model = create_model(sess,
                             gen_config,
                             forward_only=True,
                             name_scope=gen_config.name_model)
        model.batch_size = 1

        train_path = os.path.join(gen_config.train_dir, "chitchat.train")
        voc_file_path = [train_path + ".answer", train_path + ".query"]
        vocab_path = os.path.join(gen_config.train_dir,
                                  "vocab%d.all" % gen_config.vocab_size)
        data_utils.create_vocabulary(vocab_path, voc_file_path,
                                     gen_config.vocab_size)
        vocab, rev_vocab = data_utils.initialize_vocabulary(vocab_path)

        sys.stdout.write("> ")
        sys.stdout.flush()
        sentence = sys.stdin.readline()
        while sentence:
            token_ids = data_utils.sentence_to_token_ids(
                tf.compat.as_bytes(sentence), vocab)
            print("token_id: ", token_ids)
            bucket_id = len(gen_config.buckets) - 1
            for i, bucket in enumerate(gen_config.buckets):
                if bucket[0] >= len(token_ids):
                    bucket_id = i
                    break
            else:
                print("Sentence truncated: %s", sentence)

            encoder_inputs, decoder_inputs, target_weights, _, _ = model.get_batch(
                {bucket_id: [(token_ids, [1])]},
                bucket_id,
                model.batch_size,
                type=0)

            print("bucket_id: ", bucket_id)
            print("encoder_inputs:", encoder_inputs)
            print("decoder_inputs:", decoder_inputs)
            print("target_weights:", target_weights)

            _, _, output_logits = model.step(sess, encoder_inputs,
                                             decoder_inputs, target_weights,
                                             bucket_id, True)

            print("output_logits", np.shape(output_logits))

            outputs = [
                int(np.argmax(logit, axis=1)) for logit in output_logits
            ]

            if data_utils.EOS_ID in outputs:
                outputs = outputs[:outputs.index(data_utils.EOS_ID)]

            print(" ".join(
                [tf.compat.as_str(rev_vocab[output]) for output in outputs]))
            print("> ", end="")
            sys.stdout.flush()
            sentence = sys.stdin.readline()
Пример #2
0
def get_dataset(gen_config):
    """
    获取训练数据
    :return: vocab, rev_vocab, dev_set, train_set
    """
    train_path = os.path.join(gen_config.train_dir, "chitchat.train")
    voc_file_path = [train_path + ".answer", train_path + ".query"]
    vocab_path = os.path.join(gen_config.train_dir,
                              "vocab%d.all" % gen_config.vocab_size)
    data_utils.create_vocabulary(vocab_path, voc_file_path,
                                 gen_config.vocab_size)
    vocab, rev_vocab = data_utils.initialize_vocabulary(
        vocab_path)  # {dog: 0, cat: 1} [dog, cat]

    print(just("Preparing Chitchat gen_data in %s" % gen_config.train_dir))
    train_query, train_answer, dev_query, dev_answer = data_utils.prepare_chitchat_data(
        gen_config.train_dir, vocab, gen_config.vocab_size)

    # Read disc_data into buckets and compute their sizes.
    print(
        just("Reading development and training gen_data (limit: %d)." %
             gen_config.max_train_data_size))
    dev_set = read_data(gen_config, dev_query, dev_answer)
    train_set = read_data(gen_config, train_query, train_answer,
                          gen_config.max_train_data_size)

    return vocab, rev_vocab, dev_set, train_set
def prepare_data(gen_config):
    train_path = gen_config.train_dir
    voc_file_path = [train_path + _incorpus, train_path + _outcorpus]
    vocab_path = os.path.join(gen_config.train_dir,
                              "vocab%d.all" % gen_config.vocab_size)
    data_utils.create_vocabulary(vocab_path, voc_file_path,
                                 gen_config.vocab_size)
    vocab, rev_vocab = data_utils.initialize_vocabulary(vocab_path)
    #
    #    print("Preparing Chitchat disc_data in %s" % gen_config.data_dir)
    #    train_query, train_answer, dev_query, dev_answer = data_utils.prepare_chitchat_data(
    #        gen_config.data_dir, vocab, gen_config.vocab_size)
    #
    #    # Read disc_data into buckets and compute their sizes.
    #    print ("Reading development and training disc_data (limit: %d)."
    #               % gen_config.max_train_data_size)
    query_path = os.path.join(train_path + _incorpus)
    answer_path = os.path.join(train_path + _outcorpus)
    null_path = os.path.join(train_path + _nullcorpus)
    gen_path = os.path.join(train_path + _gencorpus)
    dev_set = read_data(gen_config, query_path, answer_path)
    train_set = read_data(gen_config, query_path, answer_path,
                          gen_config.max_train_data_size)
    negative_train_set = read_data(gen_config, null_path, gen_path,
                                   gen_config.max_train_data_size)
    null_train_set = read_data(gen_config, null_path, answer_path,
                               gen_config.max_train_data_size)

    return vocab, rev_vocab, dev_set, train_set, negative_train_set, null_train_set
Пример #4
0
def prepare_data(config):
    train_path = os.path.join(config.train_dir, "train")
    voc_file_path = [train_path + ".query", train_path + ".answer", train_path + ".gen"]
    vocab_path = os.path.join(config.train_dir, "vocab%d.all" % config.vocab_size)
    data_utils.create_vocabulary(vocab_path, voc_file_path, config.vocab_size)
    vocab, rev_vocab = data_utils.initialize_vocabulary(vocab_path)

    print("Preparing train disc_data in %s" % config.train_dir)
    train_query_path, train_answer_path, train_gen_path =data_utils.hier_prepare_disc_data(config.train_dir, vocab, config.vocab_size)
    query_set, answer_set, gen_set = hier_read_data(config, train_query_path, train_answer_path, train_gen_path)
    return query_set, answer_set, gen_set
def test_file_decoder(gen_config, input_file, output_file):
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    with tf.Session(config=config) as sess:
        model = create_model(sess,
                             gen_config,
                             forward_only=True,
                             name_scope=gen_config.name_model)
        model.batch_size = 1
        train_path = os.path.join(gen_config.train_dir, "chitchat.train")
        voc_file_path = [train_path + ".answer", train_path + ".query"]
        vocab_path = os.path.join(gen_config.train_dir,
                                  "vocab%d.all" % gen_config.vocab_size)
        data_utils.create_vocabulary(vocab_path, voc_file_path,
                                     gen_config.vocab_size)
        vocab, rev_vocab = data_utils.initialize_vocabulary(vocab_path)
        with open(output_file, 'w') as fout:
            with open(input_file, 'r') as fin:
                for sent in fin:
                    print(sent)
                    token_ids = data_utils.sentence_to_token_ids(
                        tf.compat.as_str(sent), vocab)
                    print("token_id: ", token_ids)
                    bucket_id = len(gen_config.buckets) - 1
                    for i, bucket in enumerate(gen_config.buckets):
                        if bucket[0] >= len(token_ids):
                            bucket_id = i
                            break
                    else:
                        print("Sentence truncated: %s", sentence)
                    encoder_inputs, decoder_inputs, target_weights, _, _ = model.get_batch(
                        {bucket_id: [(token_ids, [1])]},
                        bucket_id,
                        model.batch_size,
                        type=0)
                    _, _, output_logits = model.step(sess, encoder_inputs,
                                                     decoder_inputs,
                                                     target_weights, bucket_id,
                                                     True)

                    outputs = [
                        int(np.argmax(logit, axis=1))
                        for logit in output_logits
                    ]
                    if data_utils.EOS_ID in outputs:
                        outputs = outputs[:outputs.index(data_utils.EOS_ID)]
                    out_sent = " ".join([
                        tf.compat.as_str(rev_vocab[output])
                        for output in outputs
                    ])
                    fout.write(out_sent + '\n')
                    print(out_sent)
Пример #6
0
def prepare_data(gen_config):
    if os.path.exists('vocab') and os.path.exists(
            'rev_vocab') and os.path.exists('dev_set') and os.path.exists(
                'train_set'):
        fr_vocab = open('vocab', 'rb')
        fr_rev_vocab = open('rev_vocab', 'rb')
        fr_dev_set = open('dev_set', 'rb')
        fr_train_set = open('train_set', 'rb')
        vocab = pickle.load(fr_vocab)
        rev_vocab = pickle.load(fr_rev_vocab)
        dev_set = pickle.load(fr_dev_set)
        train_set = pickle.load(fr_train_set)
        fr_vocab.close()
        fr_rev_vocab.close()
        fr_dev_set.close()
        fr_train_set.close()
    else:
        train_path = os.path.join(gen_config.train_dir, "chitchat.train")
        voc_file_path = [train_path + ".answer", train_path + ".query"]
        vocab_path = os.path.join(gen_config.train_dir,
                                  "vocab%d.all" % gen_config.vocab_size)
        data_utils.create_vocabulary(vocab_path, voc_file_path,
                                     gen_config.vocab_size)
        vocab, rev_vocab = data_utils.initialize_vocabulary(vocab_path)

        print("Preparing Chitchat gen_data in %s" % gen_config.train_dir)
        train_query, train_answer, dev_query, dev_answer = data_utils.prepare_chitchat_data(
            gen_config.train_dir, vocab, gen_config.vocab_size)

        # Read disc_data into buckets and compute their sizes.
        print("Reading development and training gen_data (limit: %d)." %
              gen_config.max_train_data_size)
        dev_set = read_data(gen_config, dev_query, dev_answer)
        train_set = read_data(gen_config, train_query, train_answer,
                              gen_config.max_train_data_size)

        fw_vocab = open('vocab', 'wb')
        fw_rev_vocab = open('rev_vocab', 'wb')
        fw_dev_set = open('dev_set', 'wb')
        fw_train_set = open('train_set', 'wb')
        pickle.dump(vocab, fw_vocab)
        pickle.dump(rev_vocab, fw_rev_vocab)
        pickle.dump(dev_set, fw_dev_set)
        pickle.dump(train_set, fw_train_set)
        fw_vocab.close()
        fw_rev_vocab.close()
        fw_dev_set.close()
        fw_train_set.close()
    return vocab, rev_vocab, dev_set, train_set
def prepare_data(gen_config):
    train_path = os.path.join(gen_config.data_dir, "chitchat.train")
    voc_file_path = [train_path+".answer", train_path+".query"]
    vocab_path = os.path.join(gen_config.data_dir, "vocab%d.all" % gen_config.vocab_size)
    data_utils.create_vocabulary(vocab_path, voc_file_path, gen_config.vocab_size)
    vocab, rev_vocab = data_utils.initialize_vocabulary(vocab_path)

    print("Preparing Chitchat data in %s" % gen_config.data_dir)
    train_query, train_answer, dev_query, dev_answer = data_utils.prepare_chitchat_data(
        gen_config.data_dir, vocab, gen_config.vocab_size)

    # Read data into buckets and compute their sizes.
    print ("Reading development and training data (limit: %d)."
               % gen_config.max_train_data_size)
    dev_set = read_data(dev_query, dev_answer)
    train_set = read_data(train_query, train_answer, gen_config.max_train_data_size)

    return vocab, rev_vocab, dev_set, train_set
Пример #8
0
def prepare_data(gen_config):
    train_path = os.path.join(gen_config.data_dir, "train")
    test_path = os.path.join(gen_config.data_dir, "test")
    dev_path = os.path.join(gen_config.data_dir, "dev")
    voc_file_path = [
        train_path + ".answer", train_path + ".query", test_path + ".answer",
        test_path + ".query", dev_path + ".answer", dev_path + ".query"
    ]
    vocab_path = os.path.join(gen_config.data_dir,
                              "vocab%d.all" % gen_config.vocab_size)
    data_utils.create_vocabulary(vocab_path, voc_file_path,
                                 gen_config.vocab_size)
    # vocab_path = os.path.join(gen_config.data_dir, "vocab%d.all" % 30000)
    # TODO: change 30000 to 2500

    #其中vocab是word2id,rev_vocab是个list,保存所有的word
    vocab, rev_vocab = data_utils.initialize_vocabulary(vocab_path)
    gen_config.vocab_size = len(vocab)
    # print("Preparing Chitchat gen_data in %s" % gen_config.train_dir)

    #返回的是相应文件的名字
    train_query, train_answer, dev_query, dev_answer, test_query, test_answer = data_utils.prepare_chitchat_data(
        gen_config.data_dir, vocab, gen_config.vocab_size)
    # train_query, train_answer, dev_query, dev_answer = data_utils.prepare_chitchat_data_OpenSub(gen_config.data_dir)

    # Read disc_data into buckets and compute their sizes.
    print("Reading development and training gen_data (limit: %d)." %
          gen_config.max_train_data_size)

    unique_list = []
    train_set, unique_list = read_data(gen_config,
                                       train_query,
                                       train_answer,
                                       unique_list=unique_list)
    dev_set, unique_list = read_data(gen_config,
                                     dev_query,
                                     dev_answer,
                                     unique_list=unique_list)
    test_set, unique_list = read_data(gen_config,
                                      test_query,
                                      test_answer,
                                      unique_list=unique_list)

    return vocab, rev_vocab, test_set, dev_set, train_set
Пример #9
0
def prepare_data(gen_config):
    train_path = os.path.join(gen_config.train_dir, "train")
    voc_file_path = [train_path + ".answer", train_path + ".query"]
    vocab_path = os.path.join(gen_config.train_dir,
                              "vocab%d.all" % gen_config.vocab_size)
    data_utils.create_vocabulary(vocab_path, voc_file_path,
                                 gen_config.vocab_size)
    vocab, rev_vocab = data_utils.initialize_vocabulary(vocab_path)

    print("Preparing Chitchat gen_data in %s" % gen_config.train_dir)
    train_query, train_answer, dev_query, dev_answer = data_utils.prepare_chitchat_data(
        gen_config.train_dir, vocab, gen_config.vocab_size)

    # Read disc_data into buckets and compute their sizes.
    print("Reading development and training gen_data (limit: %d)." %
          gen_config.max_train_data_size)
    dev_set = read_data(gen_config, dev_query, dev_answer)
    #数据格式:train_set[[ [[source],[target]],[[source],[target]] ],....]  最外层的维度为bucket的个数
    train_set = read_data(gen_config, train_query, train_answer,
                          gen_config.max_train_data_size)

    return vocab, rev_vocab, dev_set, train_set
Пример #10
0
def prepare_data(gen_config):
    """
    1. data_utils.create_vocabulary : 創字典
    2. data_utils.initialize_vocabulary: 回傳 train data answer and query & dev data answer and query 的 id path
    3. train set  =  read_data : 製造 bucket (bucket就是每行一問一答句子的id對應),dev & train 都製造bucket
        buckets = [(5, 10), (10, 15), (20, 25), (40, 50)]
        ex:
            Q:How old are you? = [5,7,4,2,3] ; len = 5
            A:I'm six. = [44,6,8] ; len = 3
            bucket = [ [[5,7,4,2,3],[44,6,8]] ],[],[],[]]
            
        也就是說會放把QA放在固定長度範圍的bucket
    """
    
    
    train_path = os.path.join(gen_config.train_dir, "chitchat.train")
    voc_file_path = [train_path+".answer", train_path+".query"]
    vocab_path = os.path.join(gen_config.train_dir, "vocab%d.all" % gen_config.vocab_size)
    # 35000個字
    data_utils.create_vocabulary(vocab_path, voc_file_path, gen_config.vocab_size)
    vocab, rev_vocab = data_utils.initialize_vocabulary(vocab_path)
    # vocab & reverse vocab

    print("Preparing Chitchat gen_data in %s" % gen_config.train_dir)
    train_query, train_answer, dev_query, dev_answer = data_utils.prepare_chitchat_data(
        gen_config.train_dir, vocab, gen_config.vocab_size)

    # Read disc_data into buckets and compute their sizes.
    print ("Reading development and training gen_data (limit: %d)."
               % gen_config.max_train_data_size)
    dev_set = read_data(gen_config, dev_query, dev_answer)
    train_set = read_data(gen_config, train_query, train_answer, gen_config.max_train_data_size)
    print("see what bucket is:")
    print("\n")
    print(dev_set)
    
    return vocab, rev_vocab, dev_set, train_set
Пример #11
0
import os, time
import tensorflow as tf

from model_utils import online_lstm_model
from utils import data_utils, data_process
from lib import config
#载入词向量,生成字典
embedding_matrix, word_list = data_utils.load_pretained_vector(
    config.WORD_VECTOR_PATH)
data_utils.create_vocabulary(config.VOCABULARY_PATH, word_list)
#数据预处理
for file_name in os.listdir(config.ORGINAL_PATH):
    data_utils.data_to_token_ids(config.ORGINAL_PATH + file_name,
                                 config.TOCKEN_PATN + file_name,
                                 config.VOCABULARY_PATH)
vocabulary_size = len(word_list) + 2
#获取训练数据
x_train, y_train, x_test, y_test = data_process.data_split(
    choose_path=config.choose_car_path,
    buy_path=config.buy_car_path,
    no_path=config.no_car_path)

with tf.Graph().as_default():
    #build graph
    model = online_lstm_model.RNN_Model(vocabulary_size, config.BATCH_SIZE,
                                        embedding_matrix)
    logits = model.logits
    loss = model.loss
    cost = model.cost
    acu = model.accuracy
    prediction = model.prediction