Ejemplo n.º 1
0
    def test(self, sess, ckpt, sample_size):
        assert ckpt is not None, 'no checkpoint provided.'

        gen_res = self.generator(self.z)

        num_batches = int(math.ceil(sample_size / self.num_chain))

        saver = tf.train.Saver()

        sess.run(tf.global_variables_initializer())
        saver.restore(sess, ckpt)
        print('Loading checkpoint {}.'.format(ckpt))

        for i in range(num_batches):
            z_vec = np.random.randn(min(sample_size, self.num_chain),
                                    self.z_size)
            g_res = sess.run(gen_res, feed_dict={self.z: z_vec})
            save_sample_results(g_res,
                                "%s/gen%03d.png" % (self.test_dir, i),
                                col_num=self.n_tile_col)

            # output interpolation results
            interp_z = linear_interpolator(z_vec,
                                           npairs=self.n_tile_row,
                                           ninterp=self.n_tile_col)
            interp = sess.run(gen_res, feed_dict={self.z: interp_z})
            save_sample_results(interp,
                                "%s/interp%03d.png" % (self.test_dir, i),
                                col_num=self.n_tile_col)
            sample_size = sample_size - self.num_chain
Ejemplo n.º 2
0
    def train(self, sess):

        self.build_model()

        # Prepare training data
        train_data = DataSet(self.data_path, image_size=self.image_size)
        num_batches = int(math.ceil(len(train_data) / self.batch_size))

        # initialize training
        sess.run(tf.global_variables_initializer())
        sess.run(tf.local_variables_initializer())

        sample_results = np.random.randn(self.num_chain * num_batches,
                                         self.image_size, self.image_size, 3)

        saver = tf.train.Saver(max_to_keep=50)

        writer = tf.summary.FileWriter(self.log_dir, sess.graph)

        # make graph immutable
        tf.get_default_graph().finalize()

        # store graph in protobuf
        with open(self.model_dir + '/graph.proto', 'w') as f:
            f.write(str(tf.get_default_graph().as_graph_def()))

        # train
        for epoch in range(self.num_epochs):
            start_time = time.time()
            for i in range(num_batches):

                obs_data = train_data[i * self.
                                      batch_size:min(len(train_data), (i + 1) *
                                                     self.batch_size)]

                # Step G0: generate X ~ N(0, 1)
                z_vec = np.random.randn(self.num_chain, self.z_size)
                g_res = sess.run(self.gen_res, feed_dict={self.z: z_vec})
                # Step D1: obtain synthesized images Y
                if self.t1 > 0:
                    syn = sess.run(self.langevin_descriptor,
                                   feed_dict={self.syn: g_res})
                # Step G1: update X using Y as training image
                if self.t2 > 0:
                    z_vec = sess.run(self.langevin_generator,
                                     feed_dict={
                                         self.z: z_vec,
                                         self.obs: syn
                                     })
                # Step D2: update D net
                d_loss = sess.run(
                    [self.des_loss, self.des_loss_update, self.apply_d_grads],
                    feed_dict={
                        self.obs: obs_data,
                        self.syn: syn
                    })[0]
                # Step G2: update G net
                g_loss = sess.run(
                    [self.gen_loss, self.gen_loss_update, self.apply_g_grads],
                    feed_dict={
                        self.obs: syn,
                        self.z: z_vec
                    })[0]

                # Metrics
                mse = sess.run([self.recon_err, self.recon_err_update],
                               feed_dict={
                                   self.obs: obs_data,
                                   self.syn: syn
                               })[0]
                sample_results[i * self.num_chain:(i + 1) *
                               self.num_chain] = syn
                print(
                    'Epoch #{:d}, [{:2d}]/[{:2d}], des loss: {:.4f}, gen loss: {:.4f}, '
                    'L2 distance: {:4.4f}'.format(epoch, i + 1, num_batches,
                                                  d_loss.mean(), g_loss.mean(),
                                                  mse))
                if i == 0 and epoch % self.log_step == 0:
                    save_sample_results(syn,
                                        "%s/des%03d.png" %
                                        (self.sample_dir, epoch),
                                        col_num=self.n_tile_col)
                    save_sample_results(g_res,
                                        "%s/gen%03d.png" %
                                        (self.sample_dir, epoch),
                                        col_num=self.n_tile_col)

            [des_loss_avg, gen_loss_avg, mse_avg, summary] = sess.run([
                self.des_loss_mean, self.gen_loss_mean, self.recon_err_mean,
                self.summary_op
            ])

            end_time = time.time()
            print(
                'Epoch #{:d}, avg.des loss: {:.4f}, avg.gen loss: {:.4f}, avg.L2 distance: {:4.4f}, '
                'lr.des: {:f} lr.gen: {:f} time: {:.2f}s'.format(
                    epoch, des_loss_avg, gen_loss_avg, mse_avg,
                    self.lr_des.eval(), self.lr_gen.eval(),
                    end_time - start_time))
            writer.add_summary(summary, epoch)
            writer.flush()

            if epoch % self.log_step == 0:
                make_dir(self.model_dir)
                saver.save(sess,
                           "%s/%s" % (self.model_dir, 'model.ckpt'),
                           global_step=epoch)