def test(self, sess, ckpt, sample_size): assert ckpt is not None, 'no checkpoint provided.' gen_res = self.generator(self.z, reuse=False) num_batches = int(math.ceil(sample_size / self.num_chain)) saver = tf.train.Saver() sess.run(tf.global_variables_initializer()) saver.restore(sess, ckpt) print('Loading checkpoint {}.'.format(ckpt)) for i in xrange(num_batches): z_vec = np.random.randn(min(sample_size, self.num_chain), self.z_size) g_res = sess.run(gen_res, feed_dict={self.z: z_vec}) saveSampleResults(g_res, "%s/gen%03d.png" % (self.test_dir, i), col_num=self.nTileCol) # output interpolation results interp_z = linear_interpolator(z_vec, npairs=self.nTileRow, ninterp=self.nTileCol) interp = sess.run(gen_res, feed_dict={self.z: interp_z}) saveSampleResults(interp, "%s/interp%03d.png" % (self.test_dir, i), col_num=self.nTileCol) sample_size = sample_size - self.num_chain
def test(self, sess, ckpt, sample_size): assert ckpt is not None, 'no checkpoint provided.' # gen_res = self.generator(self.z, reuse=False) self.gen_res = self.generator(self.z, reuse=False) # num_batches = int(math.ceil(sample_size / self.num_chain)) num_batches = 100 saver = tf.train.Saver() sess.run(tf.global_variables_initializer()) saver.restore(sess, ckpt) print('Loading checkpoint {}.'.format(ckpt)) test_data = DataSet(self.data_path3, image_size=self.image_size) for i in xrange(num_batches): test_data_batch = test_data[ i * self.batch_size:min(len(test_data), (i + 1) * self.batch_size)] z_vec = test_data_batch g_res = sess.run(self.gen_res, feed_dict={self.z: z_vec}) saveSampleResults(g_res, "%s/gen%03d.png" % (self.test_dir, i), col_num=self.nTileCol)
def sampling(self, sess, ckpt, sample_size, sample_step, calculate_inception=False): assert ckpt is not None, 'no checkpoint provided.' self.t1 = sample_step gen_res = self.generator(self.z, reuse=False) obs_res = self.descriptor(self.obs, reuse=False) self.langevin_descriptor = self.langevin_dynamics_descriptor(gen_res) num_batches = int(math.ceil(sample_size / self.num_chain)) saver = tf.train.Saver() sess.run(tf.global_variables_initializer()) saver.restore(sess, ckpt) print('Loading checkpoint {}.'.format(ckpt)) sample_results_des = np.random.randn(self.num_chain * num_batches, self.image_size, self.image_size, self.num_channel) for i in xrange(num_batches): z_vec = np.random.randn(self.num_chain, self.z_size) # synthesis by generator g_res = sess.run(gen_res, feed_dict={self.z: z_vec}) saveSampleResults(g_res, "%s/gen%03d_test.png" % (self.test_dir, i), col_num=self.nTileCol) # synthesis by descriptor and generator syn = sess.run(self.langevin_descriptor, feed_dict={self.z: z_vec}) saveSampleResults(syn, "%s/des%03d_test.png" % (self.test_dir, i), col_num=self.nTileCol) sample_results_des[i * self.num_chain:(i + 1) * self.num_chain] = syn if i % 10 == 0: print("Sampling batches: {}, from {} to {}".format( i, i * self.num_chain, min((i + 1) * self.num_chain, sample_size))) sample_results_des = sample_results_des[:sample_size] sample_results_des = np.minimum(1, np.maximum(-1, sample_results_des)) sample_results_des = (sample_results_des + 1) / 2 * 255 if calculate_inception: m, s = get_inception_score(sample_results_des) print("Inception score: mean {}, sd {}".format(m, s)) sampling_output_file = os.path.join(self.output_dir, 'samples_des.npy') np.save(sampling_output_file, sample_results_des) print("The results are saved in folder: {}".format(self.output_dir))
def test(self): assert self.opts.ckpt_gen is not None, 'Please specify the path to the checkpoint of generator.' assert self.opts.ckpt_des is not None, 'Please specify the path to the checkpoint of generator.' print('===Test on ' + self.opts.ckpt_gen + ' and ' + self.opts.ckpt_des + ' ===') generator = torch.load(self.opts.ckpt_gen).eval() descriptor = torch.load(self.opts.ckpt_des).eval() if not os.path.exists(self.opts.output_dir): os.makedirs(self.opts.output_dir) test_batch = int( np.ceil(self.opts.test_size / self.opts.nRow / self.opts.nCol)) print('===Generated images saved to %s ===' % (self.opts.output_dir)) for i in range(test_batch): z = torch.randn(self.num_chain, self.opts.z_size, 1, 1) z = Variable(z.cuda()) gen_res = generator(z) for s in range(self.opts.langevin_step_num_des): # clone it and turn x into a leaf variable so the grad won't be thrown away gen_res = Variable(gen_res.data, requires_grad=True) gen_res_feature = descriptor(gen_res) gen_res_feature.backward( torch.ones(self.num_chain, self.opts.z_size).cuda()) grad = gen_res.grad gen_res = gen_res - 0.5 * self.opts.langevin_step_size_des * self.opts.langevin_step_size_des * \ (gen_res / self.opts.sigma_des / self.opts.sigma_des - grad) if self.opts.score: gen_res = gen_res.detach().cpu() for img_no, img in enumerate(gen_res): if i * self.num_chain + img_no + 1 > self.opts.test_size: break print('Generating {:05d}/{:05d}'.format( i * self.num_chain + img_no + 1, self.opts.test_size)) saveSampleResults(img[None, :, :, :], "%s/testres_%03d.png" % (self.opts.output_dir, i * self.num_chain + img_no + 1), col_num=1, margin_syn=0) else: gen_res = gen_res.detach().cpu() print('Generating {:05d}/{:05d}'.format(i + 1, test_batch)) saveSampleResults(gen_res, "%s/testres_%03d.png" % (self.opts.output_dir, i + 1), col_num=self.opts.nCol, margin_syn=0) print('===Image generation done.===')
def test(self, sess, ckpt, sample_size): assert ckpt is not None, 'no checkpoint provided.' # gen_res = self.generator(self.z, reuse=False) self.gen_res = self.generator(self.z, reuse=False) # num_batches = int(math.ceil(sample_size / self.num_chain)) num_batches = 100 saver = tf.train.Saver() sess.run(tf.global_variables_initializer()) saver.restore(sess, ckpt) print('Loading checkpoint {}.'.format(ckpt)) test_data = DataSet(self.data_path3, image_size=self.image_size) for i in xrange(num_batches): # z_vec = np.random.randn(min(sample_size, self.num_chain), self.z_size) test_data_batch = test_data[ i * self.batch_size:min(len(test_data), (i + 1) * self.batch_size)] # print('(i + 1) = {}'.format((i + 1))) # print('self.batch_size = {}'.format(self.batch_size)) # self.batch_size = 100 # print('len(origin_data) = {}'.format(len(origin_data))) # len(origin_data) = 1504 # print('ori_data.shape = {}'.format(ori_data.shape)) # origin_data.shape = (100, 64, 64, 3) # ori_data = np.reshape(ori_data, (100, 64*64*3)) # print('ori_data.shape = {}'.format(ori_data.shape)) # origin_data.shape = (100, 12288) # Step G0: generate X ~ N(0, 1) ''' z_vec = np.random.randn(self.num_chain, self.z_size) # z_vec.shape = (144, 100) print('G0 : self.num_chain = {}'.format(self.num_chain)) # self.num_chain = 144 print('G0 : self.z_size = {}'.format(self.z_size)) # self.z_size = 100 print('G0 : z_vec.shape = {}'.format(z_vec.shape)) # z_vec.shape = (144, 100) ''' z_vec = test_data_batch g_res = sess.run(self.gen_res, feed_dict={self.z: z_vec}) # g_res = sess.run(self.gen_res, feed_dict={self.z: z_vec}) saveSampleResults(g_res, "%s/gen%03d.png" % (self.test_dir, i), col_num=self.nTileCol)
def train(self): if self.opts.ckpt_des != None and self.opts.ckpt_des != 'None': self.descriptor = torch.load(self.opts.ckpt_des) print('Loading Descriptor from ' + self.opts.ckpt_des + '...') else: if self.opts.set == 'scene' or self.opts.set == 'lsun': self.descriptor = Descriptor(self.opts).cuda() print('Loading Descriptor without initialization...') elif self.opts.set == 'cifar': self.descriptor = Descriptor_cifar(self.opts).cuda() print('Loading Descriptor_cifar without initialization...') else: raise NotImplementedError( 'The set should be either scene, lsun, or cifar') if self.opts.ckpt_gen != None and self.opts.ckpt_gen != 'None': self.generator = torch.load(self.opts.ckpt_gen) print('Loading Generator from ' + self.opts.ckpt_gen + '...') else: if self.opts.set == 'scene' or self.opts.set == 'lsun': self.generator = Generator(self.opts).cuda() print('Loading Generator without initialization...') elif self.opts.set == 'cifar': self.generator = Generator_cifar(self.opts).cuda() print('Loading Generator_cifar without initialization...') else: raise NotImplementedError( 'The set should be either scene, lsun or cifar') # TODO -add tensorboard & plot batch_size = self.opts.batch_size if self.opts.set == 'scene' or self.opts.set == 'cifar': train_data = DataSet(os.path.join(self.opts.data_path, self.opts.category), image_size=self.opts.img_size) else: train_data = torchvision.datasets.LSUN( root=self.opts.data_path, classes=['bedroom_train'], transform=transforms.Compose([ transforms.Resize(self.img_size), transforms.ToTensor(), ])) num_batches = int(math.ceil(len(train_data) / batch_size)) # sample_results = np.random.randn(self.num_chain * num_batches, self.opts.img_size, self.opts.img_size, 3) des_optimizer = torch.optim.Adam(self.descriptor.parameters(), lr=self.opts.lr_des, betas=[self.opts.beta1_des, 0.999]) gen_optimizer = torch.optim.Adam(self.generator.parameters(), lr=self.opts.lr_gen, betas=[self.opts.beta1_gen, 0.999]) if not os.path.exists(self.opts.ckpt_dir): os.makedirs(self.opts.ckpt_dir) if not os.path.exists(self.opts.output_dir): os.makedirs(self.opts.output_dir) logfile = open(self.opts.ckpt_dir + '/log', 'w+') mse_loss = torch.nn.MSELoss(size_average=False, reduce=True) for epoch in range(self.opts.num_epoch): start_time = time.time() gen_loss_epoch, des_loss_epoch, recon_loss_epoch = [], [], [] for i in range(num_batches): if (i + 1) * batch_size > len(train_data): continue obs_data = train_data[i * batch_size:min( (i + 1) * batch_size, len(train_data))] obs_data = Variable( torch.Tensor(obs_data).cuda()) # ,requires_grad=True # G0 z = torch.randn(self.num_chain, self.opts.z_size, 1, 1) z = Variable(z.cuda(), requires_grad=True) # NCHW gen_res = self.generator(z) # D1 if self.opts.langevin_step_num_des > 0: revised = self.langevin_dynamics_descriptor(gen_res) # G1 if self.opts.langevin_step_num_gen > 0: z = self.langevin_dynamics_generator(z, revised) # D2 obs_feature = self.descriptor(obs_data) revised_feature = self.descriptor(revised) des_loss = (revised_feature.mean(0) - obs_feature.mean(0)).sum() des_optimizer.zero_grad() des_loss.backward() des_optimizer.step() # G2 ini_gen_res = gen_res.detach() if self.opts.langevin_step_num_gen > 0: gen_res = self.generator(z) # gen_res=gen_res.detach() gen_loss = 0.5 * self.opts.sigma_gen * self.opts.sigma_gen * mse_loss( gen_res, revised.detach()) gen_optimizer.zero_grad() gen_loss.backward() gen_optimizer.step() # Compute reconstruction loss recon_loss = mse_loss(revised, ini_gen_res) gen_loss_epoch.append(gen_loss.cpu().data) des_loss_epoch.append(des_loss.cpu().data) recon_loss_epoch.append(recon_loss.cpu().data) # TO-FIX (confliction between pytorch and tf) # if opts.incep_interval>0, compute inception score each [incep_interval] epochs. # if self.opts.incep_interval > 0: # import inception_model # if epoch % self.opts.incep_interval == 0: # inception_log_file = os.path.join(self.opts.output_dir, 'inception.txt') # inception_output_file = os.path.join(self.opts.output_dir, 'inception.mat') # sample_results_partial = revised[:len(train_data)] # sample_results_partial = np.minimum(1, np.maximum(-1, sample_results_partial)) # sample_results_partial = (sample_results_partial + 1) / 2 * 255 # # sample_results_list = sample_results.copy().swapaxes(1, 3) # # sample_results_list = np.split(sample_results, len(sample_results), axis=0) # m, s = get_inception_score(sample_results_partial) # fo = open(inception_log_file, 'a') # fo.write("Epoch {}: mean {}, sd {}".format(epoch, m, s)) # fo.close() # inception_mean.append(m) # inception_sd.append(s) # sio.savemat(inception_output_file, # {'mean': np.asarray(inception_mean), 'sd': np.asarray(inception_sd)}) try: col_num = self.opts.nCol saveSampleResults(obs_data.cpu().data[:col_num * col_num], "%s/observed.png" % (self.opts.output_dir), col_num=col_num) except: print('Error when saving obs_data. Skip.') continue saveSampleResults(revised.cpu().data, "%s/des_%03d.png" % (self.opts.output_dir, epoch + 1), col_num=self.opts.nCol) saveSampleResults(gen_res.cpu().data, "%s/gen_%03d.png" % (self.opts.output_dir, epoch + 1), col_num=self.opts.nCol) end_time = time.time() print( 'Epoch #{:d}/{:d}, des_loss: {:.4f}, gen_loss: {:.4f}, recon_loss: {:.4f}, ' 'time: {:.2f}s'.format(epoch + 1, self.opts.num_epoch, np.mean(des_loss_epoch), np.mean(gen_loss_epoch), np.mean(recon_loss_epoch), end_time - start_time)) # python 3 print( 'Epoch #{:d}/{:d}, des_loss: {:.4f}, gen_loss: {:.4f}, recon_loss: {:.4f}, ' 'time: {:.2f}s'.format(epoch, self.opts.num_epoch, np.mean(des_loss_epoch), np.mean(gen_loss_epoch), np.mean(recon_loss_epoch), end_time - start_time), file=logfile) # python 2.7 # print >> logfile, ('Epoch #{:d}/{:d}, des_loss: {:.4f}, gen_loss: {:.4f}, recon_loss: {:.4f}, ' # 'time: {:.2f}s'.format(epoch,self.opts.num_epoch, np.mean(des_loss_epoch), np.mean(gen_loss_epoch), np.mean(recon_loss_epoch), # end_time - start_time)) if epoch % self.opts.log_epoch == 0: torch.save( self.descriptor, self.opts.ckpt_dir + '/des_ckpt_{}.pth'.format(epoch)) torch.save( self.generator, self.opts.ckpt_dir + '/gen_ckpt_{}.pth'.format(epoch)) logfile.close()
def train(self, sess): self.build_model() # Prepare training data train_data = DataSet(self.data_path, image_size=self.image_size) num_batches = int(math.ceil(len(train_data) / self.batch_size)) # initialize training sess.run(tf.global_variables_initializer()) sample_results = np.random.randn(self.num_chain * num_batches, self.image_size, self.image_size, 3) saver = tf.train.Saver(max_to_keep=50) # make graph immutable tf.get_default_graph().finalize() # store graph in protobuf with open(self.model_dir + '/graph.proto', 'w') as f: f.write(str(tf.get_default_graph().as_graph_def())) des_loss_vis = Visualizer(title='descriptor', ylabel='normalized negative log-likelihood', ylim=(-200, 200), save_figpath=self.log_dir + '/des_loss.png', avg_period = self.batch_size) gen_loss_vis = Visualizer(title='generator', ylabel='reconstruction error', save_figpath=self.log_dir + '/gen_loss.png', avg_period = self.batch_size) # train for epoch in xrange(self.num_epochs): start_time = time.time() des_loss_avg, gen_loss_avg, mse_avg = [], [], [] for i in xrange(num_batches): obs_data = train_data[i * self.batch_size:min(len(train_data), (i + 1) * self.batch_size)] # Step G0: generate X ~ N(0, 1) z_vec = np.random.randn(self.num_chain, self.z_size) print('z_vec = {}'.format(z_vec)) g_res = sess.run(self.gen_res, feed_dict={self.z: z_vec}) print('g_res = {}'.format(g_res)) # Step D1: obtain synthesized images Y if self.t1 > 0: syn = sess.run(self.langevin_descriptor, feed_dict={self.syn: g_res}) # Step G1: update X using Y as training image if self.t2 > 0: z_vec = sess.run(self.langevin_generator, feed_dict={self.z: z_vec, self.obs: syn}) # Step D2: update D net d_loss = sess.run([self.des_loss, self.apply_d_grads], feed_dict={self.obs: obs_data, self.syn: syn})[0] # Step G2: update G net g_loss = sess.run([self.gen_loss, self.apply_g_grads], feed_dict={self.obs: syn, self.z: z_vec})[0] # Compute MSE for generator mse = sess.run(self.recon_err, feed_dict={self.obs: syn, self.syn: g_res}) sample_results[i * self.num_chain:(i + 1) * self.num_chain] = syn des_loss_avg.append(d_loss) gen_loss_avg.append(g_loss) mse_avg.append(mse) des_loss_vis.add_loss_val(epoch*num_batches + i, d_loss / float(self.image_size * self.image_size * 3)) gen_loss_vis.add_loss_val(epoch*num_batches + i, mse) if self.debug: print('Epoch #{:d}, [{:2d}]/[{:2d}], descriptor loss: {:.4f}, generator loss: {:.4f}, ' 'L2 distance: {:4.4f}'.format(epoch, i + 1, num_batches, d_loss.mean(), g_loss.mean(), mse)) if i == 0 and epoch % self.log_step == 0: if not os.path.exists(self.sample_dir): os.makedirs(self.sample_dir) saveSampleResults(syn, "%s/des%03d.png" % (self.sample_dir, epoch), col_num=self.nTileCol) saveSampleResults(g_res, "%s/gen%03d.png" % (self.sample_dir, epoch), col_num=self.nTileCol) end_time = time.time() print('Epoch #{:d}, avg.descriptor loss: {:.4f}, avg.generator loss: {:.4f}, avg.L2 distance: {:4.4f}, ' 'time: {:.2f}s'.format(epoch, np.mean(des_loss_avg), np.mean(gen_loss_avg), np.mean(mse_avg), end_time - start_time)) if epoch % self.log_step == 0: if not os.path.exists(self.model_dir): os.makedirs(self.model_dir) saver.save(sess, "%s/%s" % (self.model_dir, 'model.ckpt'), global_step=epoch) if not os.path.exists(self.log_dir): os.makedirs(self.log_dir) des_loss_vis.draw_figure() gen_loss_vis.draw_figure()
def train(self): if self.opts.ckpt_des != None and self.opts.ckpt_des != 'None': self.descriptor = torch.load(self.opts.ckpt_des) print('Loading Descriptor from ' + self.opts.ckpt_des + '...') else: if self.opts.set == 'scene' or self.opts.set == 'lsun': self.descriptor = Descriptor(self.opts).cuda() print('Loading Descriptor without initialization...') elif self.opts.set == 'cifar': self.descriptor = Descriptor_cifar(self.opts).cuda() print('Loading Descriptor_cifar without initialization...') else: raise NotImplementedError('The set should be either scene, lsun, or cifar') if self.opts.ckpt_gen != None and self.opts.ckpt_gen != 'None': self.generator = torch.load(self.opts.ckpt_gen) print('Loading Generator from ' + self.opts.ckpt_gen + '...') else: if self.opts.set == 'scene' or self.opts.set == 'lsun': self.generator = Generator(self.opts).cuda() print('Loading Generator without initialization...') elif self.opts.set == 'cifar': self.generator = Generator_cifar(self.opts).cuda() print('Loading Generator_cifar without initialization...') else: raise NotImplementedError('The set should be either scene, lsun or cifar') batch_size = self.opts.batch_size if self.opts.set == 'scene' or self.opts.set == 'cifar': train_data = DataSet(os.path.join(self.opts.data_path, self.opts.category), image_size=self.opts.img_size) else: train_data = torchvision.datasets.LSUN(root=self.opts.data_path, classes=['bedroom_train'], transform=transforms.Compose([transforms.Resize(self.img_size), transforms.ToTensor(), ])) num_batches = int(math.ceil(len(train_data) / batch_size)) # sample_results = np.random.randn(self.num_chain * num_batches, self.opts.img_size, self.opts.img_size, 3) des_optimizer = torch.optim.Adam(self.descriptor.parameters(), lr=self.opts.lr_des, betas=[self.opts.beta1_des, 0.999]) gen_optimizer = torch.optim.Adam(self.generator.parameters(), lr=self.opts.lr_gen, betas=[self.opts.beta1_gen, 0.999]) if not os.path.exists(self.opts.ckpt_dir): os.makedirs(self.opts.ckpt_dir) if not os.path.exists(self.opts.output_dir): os.makedirs(self.opts.output_dir) logfile = open(self.opts.ckpt_dir + '/log', 'w+') mse_loss = torch.nn.MSELoss(size_average=False, reduce=True) for epoch in range(self.opts.num_epoch): start_time = time.time() gen_loss_epoch, des_loss_epoch, recon_loss_epoch = [], [], [] for i in range(num_batches): if (i + 1) * batch_size > len(train_data): continue obs_data = train_data[i * batch_size:min((i + 1) * batch_size, len(train_data))] obs_data = Variable(torch.Tensor(obs_data).cuda()) # ,requires_grad=True # G0 z = torch.randn(self.num_chain, self.opts.z_size, 1, 1) z = Variable(z.cuda(), requires_grad=True) # NCHW gen_res = self.generator(z) # D1 if self.opts.langevin_step_num_des > 0: revised = self.langevin_dynamics_descriptor(gen_res) # G1 if self.opts.langevin_step_num_gen > 0: z = self.langevin_dynamics_generator(z, revised) # D2 obs_feature = self.descriptor(obs_data) revised_feature = self.descriptor(revised) des_loss = (revised_feature.mean(0) - obs_feature.mean(0)).sum() des_optimizer.zero_grad() des_loss.backward() des_optimizer.step() # G2 ini_gen_res = gen_res.detach() if self.opts.langevin_step_num_gen > 0: gen_res = self.generator(z) gen_loss = 1.0 / (2.0 * self.opts.sigma_gen * self.opts.sigma_gen) * mse_loss(gen_res, revised.detach()) gen_optimizer.zero_grad() gen_loss.backward() gen_optimizer.step() # Compute reconstruction loss recon_loss = mse_loss(revised, ini_gen_res) gen_loss_epoch.append(gen_loss.cpu().data) des_loss_epoch.append(des_loss.cpu().data) recon_loss_epoch.append(recon_loss.cpu().data) try: col_num = self.opts.nCol saveSampleResults(obs_data.cpu().data[:col_num * col_num], "%s/observed.png" % (self.opts.output_dir), col_num=col_num) except: print('Error when saving obs_data. Skip.') continue saveSampleResults(revised.cpu().data, "%s/des_%03d.png" % (self.opts.output_dir, epoch + 1), col_num=self.opts.nCol) saveSampleResults(gen_res.cpu().data, "%s/gen_%03d.png" % (self.opts.output_dir, epoch + 1), col_num=self.opts.nCol) end_time = time.time() print('Epoch #{:d}/{:d}, des_loss: {:.4f}, gen_loss: {:.4f}, recon_loss: {:.4f}, ' 'time: {:.2f}s'.format(epoch + 1, self.opts.num_epoch, np.mean(des_loss_epoch), np.mean(gen_loss_epoch), np.mean(recon_loss_epoch), end_time - start_time)) # python 3 print('Epoch #{:d}/{:d}, des_loss: {:.4f}, gen_loss: {:.4f}, recon_loss: {:.4f}, ' 'time: {:.2f}s'.format(epoch, self.opts.num_epoch, np.mean(des_loss_epoch), np.mean(gen_loss_epoch), np.mean(recon_loss_epoch), end_time - start_time), file=logfile) # python 2.7 # print >> logfile, ('Epoch #{:d}/{:d}, des_loss: {:.4f}, gen_loss: {:.4f}, recon_loss: {:.4f}, ' # 'time: {:.2f}s'.format(epoch,self.opts.num_epoch, np.mean(des_loss_epoch), np.mean(gen_loss_epoch), np.mean(recon_loss_epoch), # end_time - start_time)) if epoch % self.opts.log_epoch == 0: torch.save(self.descriptor, self.opts.ckpt_dir + '/des_ckpt_{}.pth'.format(epoch)) torch.save(self.generator, self.opts.ckpt_dir + '/gen_ckpt_{}.pth'.format(epoch)) logfile.close()
def train(self, sess): self.build_model() # Prepare training data train_data = DataSet(self.data_path, image_size=self.image_size) num_batches = int(math.ceil(len(train_data) / self.batch_size)) # initialize training sess.run(tf.global_variables_initializer()) sess.run(tf.local_variables_initializer()) sample_results = np.random.randn(self.num_chain * num_batches, self.image_size, self.image_size, 3) saver = tf.train.Saver(max_to_keep=50) writer = tf.summary.FileWriter(self.log_dir, sess.graph) # symbolic langevins langevin_descriptor = self.langevin_dynamics_descriptor(self.syn) langevin_generator = self.langevin_dynamics_generator(self.z) # make graph immutable tf.get_default_graph().finalize() # store graph in protobuf if not os.path.exists(self.model_dir): os.makedirs(self.model_dir) with open(self.model_dir + '/graph.proto', 'w') as f: f.write(str(tf.get_default_graph().as_graph_def())) # train for epoch in xrange(self.num_epochs): start_time = time.time() for i in xrange(num_batches): obs_data = train_data[i * self. batch_size:min(len(train_data), (i + 1) * self.batch_size)] # Step G0: generate X ~ N(0, 1) z_vec = np.random.randn(self.num_chain, self.z_size) g_res = sess.run(self.gen_res, feed_dict={self.z: z_vec}) # Step D1: obtain synthesized images Y if self.t1 > 0: syn = sess.run(langevin_descriptor, feed_dict={self.syn: g_res}) # Step G1: update X using Y as training image if self.t2 > 0: z_vec = sess.run(langevin_generator, feed_dict={ self.z: z_vec, self.obs: syn }) # Step D2: update D net d_loss = sess.run( [self.des_loss, self.des_loss_update, self.apply_d_grads], feed_dict={ self.obs: obs_data, self.syn: syn })[0] # Step G2: update G net g_loss = sess.run( [self.gen_loss, self.gen_loss_update, self.apply_g_grads], feed_dict={ self.obs: syn, self.z: z_vec })[0] # Compute MSE mse = sess.run([self.recon_err, self.recon_err_update], feed_dict={ self.obs: obs_data, self.syn: syn })[0] sample_results[i * self.num_chain:(i + 1) * self.num_chain] = syn print( 'Epoch #{:d}, [{:2d}]/[{:2d}], descriptor loss: {:.4f}, generator loss: {:.4f}, ' 'L2 distance: {:4.4f}'.format(epoch, i + 1, num_batches, d_loss, g_loss, mse)) if i == 0 and epoch % self.log_step == 0: if not os.path.exists(self.sample_dir): os.makedirs(self.sample_dir) saveSampleResults(syn, "%s/des%03d.png" % (self.sample_dir, epoch), col_num=self.nTileCol) saveSampleResults(g_res, "%s/gen%03d.png" % (self.sample_dir, epoch), col_num=self.nTileCol) [des_loss_avg, gen_loss_avg, mse_avg, summary] = sess.run([ self.des_loss_mean, self.gen_loss_mean, self.recon_err_mean, self.summary_op ]) end_time = time.time() print( 'Epoch #{:d}, avg. descriptor loss: {:.4f}, avg. generator loss: {:.4f}, avg. L2 distance: {:4.4f}, ' 'time: {:.2f}s'.format(epoch, des_loss_avg, gen_loss_avg, mse_avg, end_time - start_time)) writer.add_summary(summary, epoch) writer.flush() if epoch % self.log_step == 0: if not os.path.exists(self.model_dir): os.makedirs(self.model_dir) saver.save(sess, "%s/%s" % (self.model_dir, 'model.ckpt'), global_step=epoch)
def train(self, sess): self.build_model() # Prepare training data is_mnist = True if self.type == "mnist" else False dataset = DataSet(self.data_path, image_size=self.image_size, batch_sz=self.batch_size, prefetch=self.prefetch, read_len=self.read_len, is_mnist=is_mnist) # initialize training sess.run(tf.global_variables_initializer()) sess.run(tf.local_variables_initializer()) sample_results_des = np.random.randn( self.num_chain * dataset.num_batch, self.image_size, self.image_size, self.num_channel) sample_results_gen = np.random.randn( self.num_chain * dataset.num_batch, self.image_size, self.image_size, self.num_channel) saver = tf.train.Saver(max_to_keep=50) writer = tf.summary.FileWriter(self.log_dir, sess.graph) # measure 1: Parzon-window based likelihood if self.calculate_parzen: kde = ParsenDensityEsimator(sess) parzon_des_mean_list, parzon_des_se_list = [], [] parzon_gen_mean_list, parzon_gen_se_list = [], [] parzon_log_file = os.path.join(self.output_dir, 'parzon.txt') parzon_write_file = os.path.join(self.output_dir, 'parzon.mat') parzon_syn_data_file = os.path.join(self.output_dir, 'parzon_syn_dat.mat') parzon_max = -10000 # measure 2: inception score if self.calculate_inception: inception_log_file = os.path.join(self.output_dir, 'inception.txt') inception_write_file = os.path.join(self.output_dir, 'inception.mat') # make graph immutable tf.get_default_graph().finalize() # store graph in protobuf with open(self.model_dir + '/graph.proto', 'w') as f: f.write(str(tf.get_default_graph().as_graph_def())) inception_mean, inception_sd = [], [] # train minibatch = -1 for epoch in xrange(self.num_epochs): start_time = time.time() for i in xrange(dataset.num_batch): minibatch = minibatch + 1 obs_data = dataset.get_batch() # Step G0: generate X ~ N(0, 1) z_vec = np.random.randn(self.num_chain, self.z_size) g_res = sess.run(self.gen_res, feed_dict={self.z: z_vec}) # Step D1: obtain synthesized images Y if self.t1 > 0: syn = sess.run(self.langevin_descriptor, feed_dict={self.syn: g_res}) # Step G1: update X using Y as training image if self.t2 > 0: z_vec = sess.run(self.langevin_generator, feed_dict={ self.z: z_vec, self.obs: syn }) # Step D2: update D net d_loss = sess.run( [self.des_loss, self.des_loss_update, self.apply_d_grads], feed_dict={ self.obs: obs_data, self.syn: syn })[0] # Step G2: update G net g_loss = sess.run( [self.gen_loss, self.gen_loss_update, self.apply_g_grads], feed_dict={ self.obs: syn, self.z: z_vec })[0] # Compute MSE mse = sess.run([self.recon_err, self.recon_err_update], feed_dict={ self.obs: obs_data, self.syn: syn })[0] sample_results_gen[i * self.num_chain:(i + 1) * self.num_chain] = g_res sample_results_des[i * self.num_chain:(i + 1) * self.num_chain] = syn if minibatch % self.log_step == 0: end_time = time.time() [des_loss_avg, gen_loss_avg, mse_avg, summary] = sess.run([ self.des_loss_mean, self.gen_loss_mean, self.recon_err_mean, self.summary_op ]) writer.add_summary(summary, minibatch) writer.flush() print( 'Epoch #{:d}, minibatch #{:d}, avg.des loss: {:.4f}, avg.gen loss: {:.4f}, ' 'avg.L2 dist: {:4.4f}, time: {:.2f}s'.format( epoch, minibatch, des_loss_avg, gen_loss_avg, mse_avg, end_time - start_time)) start_time = time.time() # save synthesis images if not os.path.exists(self.sample_dir): os.makedirs(self.sample_dir) saveSampleResults(syn, "%s/des_%06d_%06d.png" % (self.sample_dir, epoch, minibatch), col_num=self.nTileCol) saveSampleResults(g_res, "%s/gen_%06d_%06d.png" % (self.sample_dir, epoch, minibatch), col_num=self.nTileCol) if minibatch % (self.log_step * 20) == 0: # save check points if not os.path.exists(self.model_dir): os.makedirs(self.model_dir) saver.save(sess, "%s/%s" % (self.model_dir, 'model.ckpt'), global_step=minibatch) if self.calculate_inception and epoch % 20 == 0: sample_results_partial = sample_results_des[:len(dataset)] sample_results_partial = np.minimum( 1, np.maximum(-1, sample_results_partial)) sample_results_partial = (sample_results_partial + 1) / 2 * 255 m, s = get_inception_score(sample_results_partial) print("Inception score: mean {}, sd {}".format(m, s)) fo = open(inception_log_file, 'a') fo.write("Epoch {}: mean {}, sd {} \n".format(epoch, m, s)) fo.close() inception_mean.append(m) inception_sd.append(s) sio.savemat( inception_write_file, { 'mean': np.asarray(inception_mean), 'sd': np.asarray(inception_sd) }) if self.calculate_parzen: samples_des = sample_results_des[:10000] samples_gen = sample_results_gen[:10000] parzon_des_mean, parzon_des_se, parzon_gen_mean, parzon_gen_se = kde.eval_parzen( samples_des, samples_gen) parzon_des_mean_list.append(parzon_des_mean) parzon_des_se_list.append(parzon_des_se) parzon_gen_mean_list.append(parzon_gen_mean) parzon_gen_se_list.append(parzon_gen_se) if parzon_des_mean > parzon_max: parzon_max = parzon_des_mean sio.savemat(parzon_syn_data_file, { 'samples_des': samples_des, 'samples_gen': samples_gen }) fo = open(parzon_log_file, 'a') fo.write( "Epoch {}: des mean {}, sd {}; gen mean {}, sd {}, max score {}. \n" .format(epoch, parzon_des_mean, parzon_des_se, parzon_gen_mean, parzon_gen_se, parzon_max)) fo.close() sio.savemat( parzon_write_file, { 'parzon_des_mean': np.asarray(parzon_des_mean_list), 'parzon_des_se': np.asarray(parzon_des_se_list), 'parzon_gen_mean': np.asarray(parzon_gen_mean_list), 'parzon_gen_se': np.asarray(parzon_gen_se_list) })