示例#1
0
    def infer(self, sess):
        # 加载模型
        ckpt = tf.train.get_checkpoint_state(self.checkpoint_dir)
        if ckpt and tf.train.checkpoint_exists(ckpt.model_checkpoint_path):
            self.saver.restore(sess, ckpt.model_checkpoint_path)

        image_helper = ImageHelper()

        fake_imgs = sess.run(
            self.fake_imgs, 
            feed_dict={self.rand_noises: np.random.normal(size=[self.batch_size, self.noise_dim])})
        img_name = "{}/infer-image".format(self.hparams.sample_dir)
        image_helper.save_imgs(fake_imgs, 
                               img_name=img_name)

        tf.logging.info("====== generate images in file: {} ======".format(img_name))
示例#2
0
    def train(self, sess):
        # loss summaries
        d_summary_op = tf.summary.merge([
            tf.summary.histogram("d_real_prob", tf.sigmoid(self.real_logits)),
            tf.summary.histogram("d_fake_prob", tf.sigmoid(self.fake_logits)),
            tf.summary.scalar("d_loss_fake", self.d_loss_fake),
            tf.summary.scalar("d_loss_real", self.d_loss_real),
            tf.summary.scalar("d_loss", self.d_loss)
        ],
                                        name="discriminator_summary")
        g_summary_op = tf.summary.merge([
            tf.summary.histogram("g_prob", tf.sigmoid(self.fake_logits)),
            tf.summary.scalar("g_loss", self.g_loss),
            tf.summary.image("gen_images", self.fake_imgs)
        ],
                                        name="generator_summary")

        self.summary_dir = os.path.abspath(
            os.path.join(self.hparams.checkpoint_dir, "summary"))
        summary_writer = tf.summary.FileWriter(self.summary_dir, sess.graph)

        image_helper = ImageHelper()

        sess.run(tf.global_variables_initializer())

        for num_epoch, num_batch, batch_images in image_helper.iter_images(
                dirname=self.hparams.data_dir,
                batch_size=self.batch_size,
                epoches=self.epoches):
            if (num_epoch == 0) and (num_batch < self.hparams.d_pretrain):
                # pre-train discriminator
                _, current_step, d_loss, d_accuarcy = sess.run(
                    [
                        self.d_optim, self.global_step, self.d_loss,
                        self.d_accuarcy
                    ],
                    feed_dict={
                        self.rand_noises:
                        np.random.normal(
                            size=[self.batch_size, self.noise_dim]),
                        self.real_imgs:
                        batch_images
                    })
                if current_step == self.hparams.d_pretrain:
                    tf.logging.info("==== pre-train ==== current_step:{}, d_loss:{}, d_accuarcy:{}"\
                                    .format(current_step, d_loss, d_accuarcy))
            else:
                # optimize discriminator
                _, current_step, d_loss, d_accuarcy = sess.run(
                    [
                        self.d_optim, self.global_step, self.d_loss,
                        self.d_accuarcy
                    ],
                    feed_dict={
                        self.rand_noises:
                        np.random.normal(
                            size=[self.batch_size, self.noise_dim]),
                        self.real_imgs:
                        batch_images
                    })

                # optimize generator
                if current_step % self.hparams.d_schedule == 0:
                    _, g_loss = sess.run(
                        [self.g_optim, self.g_loss],
                        feed_dict={
                            self.rand_noises:
                            np.random.normal(
                                size=[self.batch_size, self.noise_dim])
                        })

                # summary
                if current_step % self.hparams.log_interval == 0:
                    d_summary_str, g_summary_str = sess.run(
                        [d_summary_op, g_summary_op],
                        feed_dict={
                            self.rand_noises:
                            np.random.normal(
                                size=[self.batch_size, self.noise_dim]),
                            self.real_imgs:
                            batch_images
                        })
                    summary_writer.add_summary(d_summary_str, current_step)
                    summary_writer.add_summary(g_summary_str, current_step)

                    tf.logging.info("step:{}, d_loss:{}, d_accuarcy:{}, g_loss:{}"\
                                    .format(current_step, d_loss, d_accuarcy, g_loss))

            if (num_epoch > 0) and (num_batch == 0):
                # generate images per epoch
                tf.logging.info(
                    "epoch:{} === generate images and save checkpoint".format(
                        num_epoch))
                fake_imgs = sess.run(
                    self.fake_imgs,
                    feed_dict={
                        self.rand_noises:
                        np.random.normal(
                            size=[self.batch_size, self.noise_dim])
                    })
                image_helper.save_imgs(fake_imgs,
                                       img_name="{}/fake-{}".format(
                                           self.hparams.sample_dir, num_epoch))
                # save model per epoch
                self.saver.save(sess,
                                self.checkpoint_prefix,
                                global_step=num_epoch)