예제 #1
0
def train(conv, batch_size,epoch):

    model = sequence2sequence(conv.voc_size)

    with tf.Session() as sess:

        if checkpoint and tf.train.checkpoint_exists(checkpoint.model_checkpoint_path):
            model.saver.restore(sess,checkpoint.model_checkpoint_path)
        else:
            print("Building a model")
            sess.run(tf.global_variables_initializer())

        writer = tf.summary.FileWriter(config.log_dir,sess.graph)
        total_batch = int(math.ceil(len(conv.conversation)/float(batch_size)))

        for step in range(total_batch * epoch):
            enc_input, dec_input,dec_target=conv.next_batch(batch_size)

            _,loss = model.train(sess,enc_input,dec_input,dec_target)

            if(step+1)% 50 == 0:
                model.logs(sess, writer, enc_input, dec_input, dec_target)
                model.saver.save(sess, checkpoint_path, global_step = model.global_step)
                print('Step:', '%06d' % model.global_step.eval(),'cost =', '{:.6f}'.format(loss))

        model.saver.save(sess, checkpoint_path, global_step = model.global_step)
        print("Finished")
예제 #2
0
    def __init__(self, VOC_PATH, train_dir):
        self.conv = Conversation()
        self.conv.Load_voc(VOC_PATH)
        self.model = sequence2sequence(self.conv.voc_size)

        self.sess = tf.Session()
        ckpt = tf.train.get_checkpoint_state(train_dir)
        self.model.saver.restore(self.sess, ckpt.model_checkpoint_path)