示例#1
0
def sampling():
    batchloader = BatchLoader(with_label=True)

    # gpu memory
    sess_conf = tf.ConfigProto(gpu_options=tf.GPUOptions(
        # allow_growth = True
    ))

    with tf.Graph().as_default():
        with tf.Session(config=sess_conf) as sess:
            with tf.variable_scope("VAE"):
                vae_restored = VAE[FLAGS.VAE_NAME](batchloader,
                                                   is_training=False,
                                                   ru=False)

            saver = tf.train.Saver()
            saver.restore(sess, MODEL_DIR + "/model50.ckpt")

            itr = SAMPLE_NUM // FLAGS.BATCH_SIZE
            res = SAMPLE_NUM - itr * FLAGS.BATCH_SIZE

            # random output
            generated_texts = []
            for i in range(itr + 1):
                z = np.random.normal(
                    loc=0.0,
                    scale=1.0,
                    size=[FLAGS.BATCH_SIZE, FLAGS.LATENT_VARIABLE_SIZE])
                sample_logits = sess.run(
                    vae_restored.logits,
                    feed_dict={vae_restored.latent_variables: z})

                if i == itr:
                    sample_num = res
                else:
                    sample_num = FLAGS.BATCH_SIZE

                sample_texts = batchloader.logits2str(logits=sample_logits,
                                                      sample_num=sample_num)
                generated_texts.extend(sample_texts)

            for i in range(SAMPLE_NUM):
                log_and_print(SAVE_FILE, generated_texts[i])
示例#2
0
def main():
    os.mkdir(FLAGS.LOG_DIR)
    os.mkdir(FLAGS.LOG_DIR + "/model")
    log_file = FLAGS.LOG_DIR + "/log.txt"
    shutil.copyfile("config.py", FLAGS.LOG_DIR + "/config.py")
    shutil.copyfile("README.md", FLAGS.LOG_DIR + "/README.md")

    # gpu memory
    sess_conf = tf.ConfigProto(
        gpu_options = tf.GPUOptions(
            # allow_growth = True
        )
    )

    with tf.Graph().as_default():
        with tf.Session(config=sess_conf) as sess:
            batchloader = BatchLoader(with_label=False)

            with tf.variable_scope("VAE"):
                vae = VAE[FLAGS.VAE_NAME](batchloader, is_training=True, ru=False)

            with tf.variable_scope("VAE", reuse=True):
                vae_test = VAE[FLAGS.VAE_NAME](batchloader, is_training=False, ru=True)

            saver = tf.train.Saver()
            summary_writer = tf.summary.FileWriter(FLAGS.LOG_DIR, sess.graph)

            sess.run(tf.global_variables_initializer())

            log_and_print(log_file, "start training")

            loss_sum = []
            reconst_loss_sum = []
            kld_sum = []

            lr = FLAGS.LEARNING_RATE
            step = 0
            for epoch in range(FLAGS.EPOCH):
                log_and_print(log_file, "epoch %d" % (epoch+1))
                if epoch >= FLAGS.LR_DECAY_START:
                    lr *= 0.95
                for batch in range(FLAGS.BATCHES_PER_EPOCH):

                    step += 1

                    kld_weight = (math.tanh((step - 3500)/1000) + 1) / 2

                    encoder_input, decoder_input, target = \
                                        batchloader.next_batch(FLAGS.BATCH_SIZE, "train")
                    feed_dict = {vae.encoder_input: encoder_input,
                                 vae.decoder_input: decoder_input,
                                 vae.target: target,
                                 vae.kld_weight: kld_weight,
                                 vae.step: step,
                                 vae.lr: lr}

                    logits, loss, reconst_loss, kld, merged_summary, _ \
                        = sess.run([vae.logits, vae.loss, vae.reconst_loss,
                                    vae.kld, vae.merged_summary, vae.train_op],
                                   feed_dict = feed_dict)

                    reconst_loss_sum.append(reconst_loss)
                    kld_sum.append(kld)
                    loss_sum.append(loss)
                    summary_writer.add_summary(merged_summary, step)

                    if(batch % 100 == 99):
                        log_and_print(log_file, "epoch %d batch %d" % \
                                                ((epoch+1), (batch+1)), br=False)

                        ave_loss = np.average(loss_sum)
                        log_and_print(log_file, "\tloss: %f" % ave_loss, br=False)
                        ave_rnnloss = np.average(reconst_loss_sum)
                        log_and_print(log_file, "\treconst_loss: %f" % ave_rnnloss, br=False)
                        ave_kld = np.average(kld_sum)
                        log_and_print(log_file, "\tkld %f" % ave_kld, br=False)

                        loss_sum = []
                        reconst_loss_sum = []
                        kld_sum = []

                        # train input, output
                        # output input and logits
                        sample_train_input, sample_train_input_list \
                            = sess.run([vae.encoder_input, vae.encoder_input_list],
                                       feed_dict = feed_dict)
                        encoder_input_texts = batchloader.logits2str(sample_train_input_list,
                                                                     1,
                                                                     onehot=False,
                                                                     numpy=True)

                        log_and_print(log_file, "\ttrain input: %s" % encoder_input_texts[0])
                        sample_train_outputs = batchloader.logits2str(logits, 1)
                        log_and_print(log_file, "\ttrain output: %s" % sample_train_outputs[0])


                        # validation output
                        sample_input, _, sample_target = batchloader.next_batch(FLAGS.BATCH_SIZE, "test")
                        sample_input_list, sample_latent_variables = \
                            sess.run([vae_test.encoder_input_list, vae_test.encoder.latent_variables],
                                     feed_dict = {vae_test.encoder_input: sample_input})
                        sample_logits, valid_loss, merged_summary = \
                                sess.run([vae_test.logits, vae_test.reconst_loss, vae_test.merged_summary],
                                          feed_dict = {vae_test.target: sample_target,
                                                       vae_test.latent_variables: sample_latent_variables,
                                                       vae_test.kld_weight: kld_weight})

                        log_and_print(log_file, "\tvalid loss: %f" % valid_loss)
                        sample_input_texts = batchloader.logits2str(sample_input_list,
                                                                    1, onehot=False, numpy=True)
                        sample_output_texts = batchloader.logits2str(sample_logits, 1)
                        log_and_print(log_file, "\tsample input: %s" % sample_input_texts[0])
                        log_and_print(log_file, "\tsample output: %s" % sample_output_texts[0])

                        summary_writer.add_summary(merged_summary, step)

                # save model
                save_path = saver.save(sess, FLAGS.LOG_DIR + ("/model/model%d.ckpt" % (epoch+1)))
                log_and_print(log_file, "Model saved in file %s" % save_path)