def train(self, mode='train'): print('Beginning Training: ') # with self.sess as sess: sess = self.sess tf.global_variables_initializer().run(session=self.sess) if mode == 'test' or mode == 'validation': print("loading model from checkpoint") checkpoint = tf.train.latest_checkpoint(self.model_dir) # print(checkpoint) self.saver.restore(sess, checkpoint) else: # checkpoint = tf.train.latest_checkpoint(self.model_dir) # if checkpoint: # self.saver.restore(sess, checkpoint) # print("Restored from checkpoint") counter = 0 ep = 0 could_load, checkpoint_counter, checkpoint_epoch = self.load(self.model_dir) if could_load: ep = checkpoint_epoch counter = checkpoint_counter print("Successfully loaded checkpoint") else: print("Failed to load checkpoint") start_time = time.time() labels = open("samples/labels.txt", "a") for epoch in tqdm(range(ep, self.epochs)): for step in tqdm(range(counter, self.max_steps)): if step - counter in range(2101, 2116): x, _ = sess.run([self.x, self.real_labels]) print(x.shape) continue for _ in range(5): # self.x, self.real_labels = self.iter.get_next() _, disc_loss, x = self.sess.run([self.disc_step, self.d_loss, self.x]) if step - counter > 2110: print(x.shape) # _ = self.sess.run([self.disc_gp_step]) # self.x, self.real_labels = self.iter.get_next() _, gen_loss = self.sess.run([self.gen_step, self.g_loss]) if step % 100 == 0: print("Time: {:.4f}, Epoch: {}, Step: {}, Generator Loss: {:.4f}, Discriminator Loss: {:.4f}" .format(time.time() - start_time, epoch, step, gen_loss, disc_loss)) fake_im, real_im, fake_l, real_l = sess.run([self.fake_image, self.x, self.fake_labels, self.real_labels]) save_images(fake_im, image_manifold_size(fake_im.shape[0]), './samples/train_{:02d}_{:06d}.png'.format(epoch, step)) save_images(real_im, image_manifold_size(real_im.shape[0]), './samples/train_{:02d}_{:06d}_real.png'.format(epoch, step)) labels.write("{:02d}_{:06d}:\nReal Labels -\n{}\nFake Labels -\n{}\n".format(epoch, step, str(real_l), str(fake_l))) print('Translated images and saved..!') if step % 200 == 0: self.save(self.model_dir, step, epoch) print("Checkpoint saved") counter = 0
def visualize(self): # TODO: Solve bug with the generator which generates unmatched images. sample_z = np.random.uniform(-1, 1, size=(self.model.sample_num, self.model.z_dim)) _, sample_embed, _, captions = self.dataset.train.next_batch_test( self.model.sample_num, randint(0, self.dataset.test.num_examples), 1) sample_embed = np.squeeze(sample_embed, axis=0) samples = self.sess.run(self.model.sampler, feed_dict={ self.model.z_sample: sample_z, self.model.phi_sample: sample_embed, }) fake_img = samples[0] closest_img = closest_image(fake_img, self.dataset) closest_pair = np.array([fake_img, closest_img]) save_images( closest_pair, image_manifold_size(closest_pair.shape[0]), './{}/{}/{}/test5.png'.format(self.config.test_dir, self.model.name, self.dataset.name))
def train(self): self.define_losses() self.define_summaries() sample_z = np.random.normal(0, 1, (self.model.sample_num, self.model.z_dim)) _, sample_embed, _, captions = self.dataset.test.next_batch_test( self.model.sample_num, randint(0, self.dataset.test.num_examples), 1) sample_embed = np.squeeze(sample_embed, axis=0) print(sample_embed.shape) # Display the captions of the sampled images print('\nCaptions of the sampled images:') for caption_idx, caption_batch in enumerate(captions): print('{}: {}'.format(caption_idx + 1, caption_batch[0])) print() counter = 1 start_time = time.time() # Try to load the parameters of the stage II networks tf.global_variables_initializer().run() could_load, checkpoint_counter = load(self.stageii_saver, self.sess, self.cfg.CHECKPOINT_DIR) if could_load: counter = checkpoint_counter print(" [*] Load SUCCESS: Stage II networks are loaded.") else: print(" [!] Load failed for stage II networks...") could_load, checkpoint_counter = load(self.stagei_g_saver, self.sess, self.cfg_stage_i.CHECKPOINT_DIR) if could_load: counter = checkpoint_counter print(" [*] Load SUCCESS: Stage I generator is loaded") else: print( " [!] WARNING!!! Failed to load the parameters for stage I generator..." ) for epoch in range(self.cfg.TRAIN.EPOCH): # Updates per epoch are given by the training data size / batch size updates_per_epoch = self.dataset.train.num_examples // self.model.batch_size for idx in range(0, updates_per_epoch): images, wrong_images, embed, _, _ = self.dataset.train.next_batch( self.model.batch_size, 4) batch_z = np.random.normal( 0, 1, (self.model.batch_size, self.model.z_dim)) # Update D network _, err_d_real_match, err_d_real_mismatch, err_d_fake, err_d, summary_str = self.sess.run( [ self.D_optim, self.D_real_match_loss, self.D_real_mismatch_loss, self.D_synthetic_loss, self.D_loss, self.D_merged_summ ], feed_dict={ self.model.inputs: images, self.model.wrong_inputs: wrong_images, self.model.embed_inputs: embed, self.model.z: batch_z }) self.writer.add_summary(summary_str, counter) # Update G network _, err_g, summary_str = self.sess.run( [self.G_optim, self.G_loss, self.G_merged_summ], feed_dict={ self.model.z: batch_z, self.model.embed_inputs: embed }) self.writer.add_summary(summary_str, counter) counter += 1 print( "Epoch: [%2d] [%4d/%4d] time: %4.4f, d_loss: %.8f, g_loss: %.8f" % (epoch, idx, updates_per_epoch, time.time() - start_time, err_d, err_g)) if np.mod(counter, 100) == 0: try: samples = self.sess.run(self.model.sampler, feed_dict={ self.model.z_sample: sample_z, self.model.embed_sample: sample_embed, }) save_images( samples, image_manifold_size(samples.shape[0]), '{}train_{:02d}_{:04d}.png'.format( self.cfg.SAMPLE_DIR, epoch, idx)) print("[Sample] d_loss: %.8f, g_loss: %.8f" % (err_d, err_g)) # Display the captions of the sampled images print('\nCaptions of the sampled images:') for caption_idx, caption_batch in enumerate(captions): print('{}: {}'.format(caption_idx + 1, caption_batch[0])) print() except Exception as e: print("Failed to generate sample image") print(type(e)) print(e.args) print(e) if np.mod(counter, 500) == 2: save(self.stageii_saver, self.sess, self.cfg.CHECKPOINT_DIR, counter)