def test(self, sess, ckpt, sample_size): assert ckpt is not None, 'no checkpoint provided.' gen_res = self.generator(self.z) num_batches = int(math.ceil(sample_size / self.num_chain)) saver = tf.train.Saver() sess.run(tf.global_variables_initializer()) saver.restore(sess, ckpt) print('Loading checkpoint {}.'.format(ckpt)) for i in range(num_batches): z_vec = np.random.randn(min(sample_size, self.num_chain), self.z_size) g_res = sess.run(gen_res, feed_dict={self.z: z_vec}) save_sample_results(g_res, "%s/gen%03d.png" % (self.test_dir, i), col_num=self.n_tile_col) # output interpolation results interp_z = linear_interpolator(z_vec, npairs=self.n_tile_row, ninterp=self.n_tile_col) interp = sess.run(gen_res, feed_dict={self.z: interp_z}) save_sample_results(interp, "%s/interp%03d.png" % (self.test_dir, i), col_num=self.n_tile_col) sample_size = sample_size - self.num_chain
def train(self, sess): self.build_model() # Prepare training data train_data = DataSet(self.data_path, image_size=self.image_size) num_batches = int(math.ceil(len(train_data) / self.batch_size)) # initialize training sess.run(tf.global_variables_initializer()) sess.run(tf.local_variables_initializer()) sample_results = np.random.randn(self.num_chain * num_batches, self.image_size, self.image_size, 3) saver = tf.train.Saver(max_to_keep=50) writer = tf.summary.FileWriter(self.log_dir, sess.graph) # make graph immutable tf.get_default_graph().finalize() # store graph in protobuf with open(self.model_dir + '/graph.proto', 'w') as f: f.write(str(tf.get_default_graph().as_graph_def())) # train for epoch in range(self.num_epochs): start_time = time.time() for i in range(num_batches): obs_data = train_data[i * self. batch_size:min(len(train_data), (i + 1) * self.batch_size)] # Step G0: generate X ~ N(0, 1) z_vec = np.random.randn(self.num_chain, self.z_size) g_res = sess.run(self.gen_res, feed_dict={self.z: z_vec}) # Step D1: obtain synthesized images Y if self.t1 > 0: syn = sess.run(self.langevin_descriptor, feed_dict={self.syn: g_res}) # Step G1: update X using Y as training image if self.t2 > 0: z_vec = sess.run(self.langevin_generator, feed_dict={ self.z: z_vec, self.obs: syn }) # Step D2: update D net d_loss = sess.run( [self.des_loss, self.des_loss_update, self.apply_d_grads], feed_dict={ self.obs: obs_data, self.syn: syn })[0] # Step G2: update G net g_loss = sess.run( [self.gen_loss, self.gen_loss_update, self.apply_g_grads], feed_dict={ self.obs: syn, self.z: z_vec })[0] # Metrics mse = sess.run([self.recon_err, self.recon_err_update], feed_dict={ self.obs: obs_data, self.syn: syn })[0] sample_results[i * self.num_chain:(i + 1) * self.num_chain] = syn print( 'Epoch #{:d}, [{:2d}]/[{:2d}], des loss: {:.4f}, gen loss: {:.4f}, ' 'L2 distance: {:4.4f}'.format(epoch, i + 1, num_batches, d_loss.mean(), g_loss.mean(), mse)) if i == 0 and epoch % self.log_step == 0: save_sample_results(syn, "%s/des%03d.png" % (self.sample_dir, epoch), col_num=self.n_tile_col) save_sample_results(g_res, "%s/gen%03d.png" % (self.sample_dir, epoch), col_num=self.n_tile_col) [des_loss_avg, gen_loss_avg, mse_avg, summary] = sess.run([ self.des_loss_mean, self.gen_loss_mean, self.recon_err_mean, self.summary_op ]) end_time = time.time() print( 'Epoch #{:d}, avg.des loss: {:.4f}, avg.gen loss: {:.4f}, avg.L2 distance: {:4.4f}, ' 'lr.des: {:f} lr.gen: {:f} time: {:.2f}s'.format( epoch, des_loss_avg, gen_loss_avg, mse_avg, self.lr_des.eval(), self.lr_gen.eval(), end_time - start_time)) writer.add_summary(summary, epoch) writer.flush() if epoch % self.log_step == 0: make_dir(self.model_dir) saver.save(sess, "%s/%s" % (self.model_dir, 'model.ckpt'), global_step=epoch)