def run(**kwargs): allow_plots = ALLOW_PLOTS continue_training = CONTINUE_TRAINING inference_cpu_only = INFERENCE_CPU_ONLY use_biases = USE_BIASES num_iter = NUM_ITER for k in kwargs.keys(): if k == 'allow_plots': allow_plots = kwargs[k] elif k == 'continue_training': continue_training = kwargs[k] elif k == 'inference_cpu_only': inference_cpu_only = kwargs[k] elif k == 'use_biases': use_biases = kwargs[k] elif k == 'num_iter': num_iter = kwargs[k] else: logger.warn('Keyword \'%s\' is unknown.' % k) logger.info('### Loading dataset ...') data = CelebAData(config.dataset_path, shape=[64, 64]) # Important! Let the network know, which dataset to use. shared.data = data logger.info('### Loading dataset ... Done') logger.info('### Build, train and test network ...') train_net = setup_network(allow_plots, continue_training, inference_cpu_only, use_biases, mode='train') train_net.train(num_iter=num_iter, val_bs=100) test_net = setup_network(allow_plots, continue_training, inference_cpu_only, use_biases, mode='inference') test_net.test() if allow_plots: # Generate some fake images. latent_inputs = test_net.sample_latent(8) _, fake_dis_outs, fake_imgs = test_net.run(np.empty((0, 0)), latent_inputs=latent_inputs) dplt.plot_gan_images('Generator Samples', np.empty((0, 0)), fake_imgs, fake_dis_outputs=fake_dis_outs, shuffle=True, interactive=True) logger.info('### Build, train and test network ... Done')
def run(**kwargs): allow_plots = ALLOW_PLOTS use_biases = USE_BIASES for k in kwargs.keys(): if k == 'allow_plots': allow_plots = kwargs[k] elif k == 'use_biases': use_biases = kwargs[k] else: logger.warn('Keyword \'%s\' is unknown.' % k) logger.info('### Loading dataset ...') data = CIFAR10Data(config.dataset_path) # Important! Let the network know, which dataset to use. shared.data = data logger.info('### Loading dataset ... Done') logger.info('### Build, train and test network ...') train_net = WassersteinGAN(mode='train') train_net.continue_training = False train_net.allow_plots = allow_plots train_net.use_biases = use_biases train_net.build() train_net.train(num_iter=100001) test_net = WassersteinGAN(mode='inference') test_net.allow_plots = allow_plots test_net.use_biases = use_biases test_net.build() test_net.test() if allow_plots: # Generate some fake images. latent_inputs = test_net.sample_latent(8) _, fake_dis_outs, fake_imgs = test_net.run(np.empty((0, 0)), latent_inputs=latent_inputs) dplt.plot_gan_images('Generator Samples', np.empty((0, 0)), fake_imgs, fake_dis_outputs=fake_dis_outs, shuffle=True, interactive=True) logger.info('### Build, train and test network ... Done')
def run(**kwargs): allow_plots = ALLOW_PLOTS continue_training = CONTINUE_TRAINING for k in kwargs.keys(): if k == 'allow_plots': allow_plots = kwargs[k] elif k == 'continue_training': continue_training = kwargs[k] else: logger.warn('Keyword \'%s\' is unknown.' % k) logger.info('### Loading dataset ...') data = MNISTData(config.dataset_path) # Important! Let the network know, which dataset to use. shared.data = data logger.info('### Loading dataset ... Done') logger.info('### Build, train and test network ...') train_net = setup_network(allow_plots, continue_training, mode='train') train_net.train(num_iter=10001) test_net = setup_network(allow_plots, continue_training, mode='inference') test_net.test() if allow_plots: # Generate some fake images. latent_inputs = test_net.sample_latent(8) _, fake_dis_outs, fake_imgs = test_net.run(np.empty((0, 0)), latent_inputs=latent_inputs) dplt.plot_gan_images('Generator Samples', np.empty((0, 0)), fake_imgs, fake_dis_outputs=fake_dis_outs, shuffle=True, interactive=True) logger.info('### Build, train and test network ... Done')
def _validate_training_process(self, sess, epoch): """Validate the current training process on the validation batch. Note, that the validation uses the same graph and session as the training, but the training mode tensor (received as an input tensor of the network) is different. Args: sess: The current training session. epoch: The current training iteration. """ logger.info('Epoch %d: validating training process ...' % epoch) if self.val_cpu_only: logger.warn('The option \'val_cpu_only\' is enabled, but not ' + \ 'supported by this class. Option will be ignored.') val_handle = sess.run(self._val_iter.string_handle()) sess.run(self._val_iter.initializer, feed_dict={ self._t_val_raw_in: self._val_batch[0], self._t_val_raw_out: self._val_batch[1], self._t_val_batch_size: self._val_batch[0].shape[0] }) # Note, that subclasses (such as a WassersteinGan), don't have a # meaningful accuracy. if self._t_accuracy is None: g_loss, d_loss, summary = sess.run( \ [self._g_loss, self._d_loss, self._t_summaries], feed_dict={self._g_inputs: self._val_latent_input, self._t_handle: val_handle}) else: acc, g_loss, d_loss, summary = sess.run( \ [self._t_accuracy, self._g_loss, self._d_loss, self._t_summaries], feed_dict={self._g_inputs: self._val_latent_input, self._t_handle: val_handle}) logger.info('Validation Accuracy: %f' % acc) logger.info('Generator loss on validation batch: %f' % g_loss) logger.info('Discriminator loss on validation batch: %f' % d_loss) self._val_summary_writer.add_summary(summary, epoch) self._val_summary_writer.flush() if self.allow_plots: num_plots = min(4, self._val_latent_input.shape[0]) # We have to reinitialize to change the batch size (seems to be # a cleaner solution than processing the whole validation set). sess.run(self._val_iter.initializer, feed_dict={ self._t_val_raw_in: self._val_batch[0][:num_plots, :], self._t_val_raw_out: self._val_batch[1][:num_plots, :], self._t_val_batch_size: num_plots }) fake_imgs, fake_dis_outs = sess.run( \ [self._g_outputs, self._d_outputs_fake], feed_dict={self._g_inputs: self._val_latent_input[:num_plots, :], self._t_handle: val_handle}) dplt.plot_gan_images('Validation Samples', np.empty((0, 0)), fake_imgs, fake_dis_outputs=fake_dis_outs, interactive=True) logger.info('Epoch %d: validating training process ... Done' % epoch)
def test(self): """Evaluate the trained network using the whole test set. Note, the we sample random latent input for the generator. """ if not self._is_build: raise CustomException('Network has not been build yet.') logger.info('Testing DCGAN ...') sess = self._get_inference_session() if sess is None: logger.error('Could not create session. Testing aborted.') test_ins = shared.data.get_test_inputs() test_outs = shared.data.get_test_outputs() test_latent_inputs = self.sample_latent(shared.data.num_test_samples) test_handle = sess.run(self._test_iter.string_handle()) sess.run(self._test_iter.initializer, feed_dict={ self._t_test_raw_in: test_ins, self._t_test_raw_out: test_outs, self._t_test_batch_size: shared.data.num_test_samples }) ckpt_epoch = tf.train.global_step(sess, self._t_global_step) logger.info('The network has been trained for %d epochs.' % ckpt_epoch) # Note, that subclasses (such as a WassersteinGan), don't have a # meaningful accuracy. if self._t_accuracy is None: g_loss, d_loss = sess.run([self._g_loss, self._d_loss], feed_dict={ self._g_inputs: test_latent_inputs, self._t_handle: test_handle }) else: acc, g_loss, d_loss = sess.run( \ [self._t_accuracy, self._g_loss, self._d_loss], feed_dict={self._g_inputs: test_latent_inputs, self._t_handle: test_handle}) logger.info('Test Accuracy: %f' % acc) logger.info('Generator loss on test set: %f' % g_loss) logger.info('Discriminator loss on test set: %f' % d_loss) if self.allow_plots: num_plots = min(8, test_latent_inputs.shape[0]) Z_in = test_latent_inputs[:num_plots, :] # We have to reinitialize to change the batch size (seems to be # a cleaner solution than processing the whole validation set). sess.run(self._test_iter.initializer, feed_dict={ self._t_test_raw_in: test_ins[:num_plots, :], self._t_test_raw_out: test_outs[:num_plots, :], self._t_test_batch_size: num_plots }) real_imgs, real_lbls, fake_imgs, fake_dis_outs, real_dis_outs = \ sess.run([self._t_ds_inputs, self._t_ds_outputs, self._g_outputs, self._d_outputs_fake, self._d_outputs_real], feed_dict={self._g_inputs: Z_in, self._t_handle: test_handle}) dplt.plot_gan_images('Test Samples', real_imgs, fake_imgs, real_outputs=real_lbls, real_dis_outputs=real_dis_outs, fake_dis_outputs=fake_dis_outs, shuffle=True, interactive=True, figsize=(10, 12)) logger.info('Testing DCGAN ... Done')