def train(conv, batch_size,epoch): model = sequence2sequence(conv.voc_size) with tf.Session() as sess: if checkpoint and tf.train.checkpoint_exists(checkpoint.model_checkpoint_path): model.saver.restore(sess,checkpoint.model_checkpoint_path) else: print("Building a model") sess.run(tf.global_variables_initializer()) writer = tf.summary.FileWriter(config.log_dir,sess.graph) total_batch = int(math.ceil(len(conv.conversation)/float(batch_size))) for step in range(total_batch * epoch): enc_input, dec_input,dec_target=conv.next_batch(batch_size) _,loss = model.train(sess,enc_input,dec_input,dec_target) if(step+1)% 50 == 0: model.logs(sess, writer, enc_input, dec_input, dec_target) model.saver.save(sess, checkpoint_path, global_step = model.global_step) print('Step:', '%06d' % model.global_step.eval(),'cost =', '{:.6f}'.format(loss)) model.saver.save(sess, checkpoint_path, global_step = model.global_step) print("Finished")
def __init__(self, VOC_PATH, train_dir): self.conv = Conversation() self.conv.Load_voc(VOC_PATH) self.model = sequence2sequence(self.conv.voc_size) self.sess = tf.Session() ckpt = tf.train.get_checkpoint_state(train_dir) self.model.saver.restore(self.sess, ckpt.model_checkpoint_path)