def train(words, poetry_vector, x_batches, y_batches): input_data = tf.placeholder(tf.int32, [batch_size, None]) output_targets = tf.placeholder(tf.int32, [batch_size, None]) end_points = rnn_model(len(words), input_data=input_data, output_data=output_targets, batch_size=batch_size) saver = tf.train.Saver(tf.global_variables()) init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer()) merge = tf.summary.merge_all() # todo ??? with tf.Session(config=config) as sess: writer = tf.summary.FileWriter('./logs', sess.graph) sess.run(init_op) start_epoch = 0 model_dir = './model' epochs = 50 checkpoint = tf.train.latest_checkpoint(model_dir) if checkpoint: # 导出模型 saver.restore(sess, checkpoint) print('## restore from the checkpoint {}'.format(checkpoint)) # 从上次技术的地方开始继续训练 start_epoch += int(checkpoint.split('-')[-1]) print('## start training...') try: for epoch in range(start_epoch, epochs): n_chunk = len(poetry_vector) // batch_size for n in range(n_chunk): loss, _, _ = sess.run([ end_points['total_loss'], end_points['last_state'], end_points['train_op'] ], feed_dict={ input_data: x_batches[n], output_targets: y_batches[n] }) print('Epoch: {}, batch: {}, training loss: {}'.format( epoch, n, loss)) if epoch % 5 == 0: saver.save(sess, os.path.join(model_dir, 'poetry'), global_step=epoch) result = sess.run(merge, feed_dict={ input_data: x_batches[n], output_targets: y_batches[n] }) writer.add_summary(result, epoch * n_chunk + n) except KeyboardInterrupt: print('## Interrupt manually, try saving checkpoint for now ...') saver.save(sess, os.path.join(model_dir, 'pooetry'), global_step=epoch) print( '## Last epoch were saved, next time will start form epoch {}'. format(epoch))
def generate(words, to_num, style_word=None): batch_size = 1 input_data = tf.placeholder(tf.int32, [batch_size, None]) cell_model = rnn_model(len(words), input_data, batch_size=batch_size) saver = tf.train.Saver() init_op = tf.global_variables_initializer() with tf.Session(config=config) as sess: sess.run(init_op) check_point = tf.train.latest_checkpoint('./model') saver.restore(check_point) x = np.array(to_num('B')).reshape(1, 1) _, last_state = sess.run([cell_model.prediction, cell_model.loss], feed_dict={input_data: x}) if style_word: for i in style_word: x = np.array(to_num(i)).reshape(1, 1) predict, last_state = sess.run( [cell_model.prediction, cell_model.last_state], feed_dict={ input_data: x, last_state: last_state }) start_words = list('少小离家老大回') start_len = len(start_words) result = start_words.copy() max_len = 200 for i in range(max_len): if i < start_len: w = start_words[i] x = np.array(to_num(w)).reshape(1, 1) predict, last_state = sess.run( [cell_model.prediction, cell_model.last_state], feed_dict={ input_data: x, last_state: last_state }) else: predict, last_state = sess.run( [cell_model.prediction, cell_model.last_state], feed_dict={ input_data: x, last_state: last_state }) w = to_word(predict, words) x = np.array(to_num(w)).reshape(1, 1) if w == 'E': break result.append(w) print(''.join(result)) return ''.join(result)
def gen_poetry(words, to_num): batch_size = 1 print('模型保存目录为:{}'.format('./model')) input_data = tf.placeholder(tf.int32, [batch_size, None]) end_points = rnn_model(len(words), input_data=input_data, batch_size=batch_size) saver = tf.train.Saver(tf.global_variables()) # todo 全局变量? init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer()) # todo 为什么这里用到了本地变量? with tf.Session(config=config) as sess: sess.run(init_op) checkpoint = tf.train.latest_checkpoint('./model') saver.restore(sess, checkpoint) x = np.array(to_num('B')).reshape(1, 1) _, last_state = sess.run( [end_points['prediction'], end_points['last_state']], feed_dict={input_data: x}) word = input('请输入起始字符: ') poem_ = '' while word != 'E': poem_ += word x = np.array(to_num(word)).reshape(1, 1) predict, last_state = sess.run( [end_points['prediction'], end_points['last_state']], feed_dict={ input_data: x, end_points['initial_state']: last_state }) word = to_num(predict, words) print(poem_) return poem_
def train(words,poetry_vector,word_num_map,batch_size=64,epoches=50): input_data = tf.placeholder(tf.int32,[batch_size,None]) output_data = tf.placeholder(tf.int32,[batch_size,None]) cell_model = rnn_model(len(words),input_data=input_data,output_data=output_data,batch_size=batch_size) saver = tf.train.Saver() init_op = tf.global_variables_initializer() merge = tf.summary.merge_all() with tf.Session(config=config) as sess: writer = tf.summary.FileWriter('./logs',sess.graph) sess.run(init_op) start_epoch = 0 model_dir = './model' checkpoint = tf.train.latest_checkpoint(model_dir) if checkpoint: saver.restore(sess,checkpoint) print('## restore model from {}'.format(checkpoint)) start_epoch = int(checkpoint.split('-')[-1]) print('## start training ...') try: for epoch in range(start_epoch, epoches): n_bucket = len(poetry_vector) // batch_size for n, (x_batch, y_batch) in enumerate(get_batches(poetry_vector,word_num_map,batch_size=batch_size)): loss, _, _ = sess.run([cell_model.total_loss,cell_model.last_state,cell_model.train_op],feed_dict={input_data:x_batch,output_data:y_batch}) print('Epoch: {} batch: {} training loss: {}'.format(epoch,n,loss)) if n % 5 == 0: saver.save(sess,os.path.join(model_dir,'poetry'),global_step=epoch) result = sess.run(merge,feed_dict={input_data:x_batch,output_data:y_batch}) writer.add_summary(result,epoch * n_bucket + n) except KeyboardInterrupt: print('## ERROR:Interrupt !!! try save model now...') saver.save(sess,os.path.join(model_dir,'poetry'),global_step=epoch) print('## Last epoch model were saved, next time will train form epoch {}'.format(epoch))
def generate(words, to_num, style_words="狂沙将军战燕然,大漠孤烟黄河骑。"): batch_size = 1 input_data = tf.placeholder(tf.int32, [batch_size, None]) end_points = rnn_model(len(words), input_data=input_data, batch_size=batch_size) saver = tf.train.Saver(tf.global_variables()) init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer()) with tf.Session(config=config) as sess: sess.run(init_op) checkpoint = tf.train.latest_checkpoint('./model') saver.restore(sess, checkpoint) x = np.array(to_num('B')).reshape(1, 1) _, last_state = sess.run( [end_points['prediction'], end_points['last_state']], feed_dict={input_data: x}) if style_words: for word in style_words: x = np.array(to_num(word)).reshape(1, 1) last_state = sess.run(end_points['last_state'], feed_dict={ input_data: x, end_points['initial_state']: last_state }) start_words = list("少小离家老大回") # list(input("请输入起始语句:")) start_word_len = len(start_words) result = start_words.copy() max_len = 200 for i in range(max_len): if i < start_word_len: w = start_words[i] x = np.array(to_num(w)).reshape(1, 1) predict, last_state = sess.run( [end_points['prediction'], end_points['last_state']], feed_dict={ input_data: x, end_points['initial_state']: last_state }) else: predict, last_state = sess.run( [end_points['prediction'], end_points['last_state']], feed_dict={ input_data: x, end_points['initial_state']: last_state }) w = to_word(predict, words) # w = words[np.argmax(predict)] x = np.array(to_num(w)).reshape(1, 1) if w == 'E': break result.append(w) print(''.join(result))
def generate(words, to_num, style_words='狂沙将军战燕然,大漠孤烟黄河骑。'): batch_size = 1 input_data = tf.placeholder(tf.int32, [batch_size, None]) end_point = rnn_model(len(words), input_data=input_data, batch_size=batch_size) saver = tf.train.Saver(tf.global_variables()) init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer()) with tf.Session(config=config) as sess: sess.run(init_op) # checkpoint = tf.train.latest_checkpoint('./model') # saver.restore(sess,checkpoint) saver.restore(sess, './model/poetry-0') x = np.array(to_num('B')).reshape(1, 1) # 同时获取last_state用于下一个的预测 _, last_state = sess.run( [end_point['prediction'], end_point['last_state']], feed_dict={input_data: x}) if style_words: # 计算style的state,相当于给出start_words和这句然后生成下面的诗词 for word in style_words: x = np.array(to_num(word)).reshape(1, 1) last_state = sess.run(end_point['last_state'], feed_dict={ input_data: x, end_point['initial_state']: last_state }) start_words = list('少小离家老大回') # start_words = list(input('请输入起始语句:')) start_words_len = len(start_words) result = start_words.copy() max_len = 200 for i in range(max_len): # 这里不需要输出,但是我们需要计算出state用于这局之后的生成 if i < start_words_len: w = start_words[i] x = np.array(to_num(w)).reshape(1, 1) predict, last = sess.run( [end_point['prediction'], end_point['last_state']], feed_dict={ input_data: x, end_point['initial_state']: last_state }) else: predict, last_state = sess.run( [end_point['prediction'], end_point['initial_state']], feed_dict={ input_data: x, end_point['initial_state']: last_state }) w = to_word(predict, words) # w = words[np.argmax(predict)] x = np.array(to_num(w)).reshape(1, 1) if w == 'E': break result.append(w) print(''.join(result))