#-*- coding:utf-8 -*- import generate_chat import seq2seq_model import tensorflow as tf import numpy as np import logging import logging.handlers if __name__ == '__main__': _, _, source_vocab_size = generate_chat.get_vocabs( generate_chat.vocab_encode_file) _, _, target_vocab_size = generate_chat.get_vocabs( generate_chat.vocab_decode_file) train_set = generate_chat.read_data(generate_chat.train_encode_vec_file, generate_chat.train_decode_vec_file) test_set = generate_chat.read_data(generate_chat.test_encode_vec_file, generate_chat.test_decode_vec_file) train_bucket_sizes = [ len(train_set[i]) for i in range(len(generate_chat._buckets)) ] train_total_size = float(sum(train_bucket_sizes)) train_buckets_scale = [ sum(train_bucket_sizes[:i + 1]) / train_total_size for i in range(len(train_bucket_sizes)) ] with tf.Session() as sess: model = seq2seq_model.Seq2SeqModel( source_vocab_size, target_vocab_size, generate_chat._buckets,
#-*- coding:utf-8 -*- import generate_chat import seq2seq_model import tensorflow as tf import numpy as np import sys if __name__ == '__main__': source_id_to_word, source_word_to_id, source_vocab_size = generate_chat.get_vocabs( generate_chat.vocab_encode_file) target_id_to_word, target_word_to_id, target_vocab_size = generate_chat.get_vocabs( generate_chat.vocab_decode_file) to_id = lambda word: source_word_to_id.get(word, generate_chat.UNK_ID) with tf.Session() as sess: model = seq2seq_model.Seq2SeqModel( source_vocab_size, target_vocab_size, generate_chat._buckets, generate_chat.units_num, generate_chat.num_layers, generate_chat.max_gradient_norm, 1, generate_chat.learning_rate, generate_chat.learning_rate_decay_factor, forward_only=True, use_lstm=True) model.saver.restore(sess, "chatbot.ckpt-317000") while True: sys.stdout.write("ask > ") sys.stdout.flush() sentence = sys.stdin.readline().strip('\n')