コード例 #1
0
def run(**kwargs):
    allow_plots = ALLOW_PLOTS
    continue_training = CONTINUE_TRAINING
    inference_cpu_only = INFERENCE_CPU_ONLY
    use_biases = USE_BIASES
    num_iter = NUM_ITER

    for k in kwargs.keys():
        if k == 'allow_plots':
            allow_plots = kwargs[k]
        elif k == 'continue_training':
            continue_training = kwargs[k]
        elif k == 'inference_cpu_only':
            inference_cpu_only = kwargs[k]
        elif k == 'use_biases':
            use_biases = kwargs[k]
        elif k == 'num_iter':
            num_iter = kwargs[k]
        else:
            logger.warn('Keyword \'%s\' is unknown.' % k)

    logger.info('### Loading dataset ...')

    data = CelebAData(config.dataset_path, shape=[64, 64])

    # Important! Let the network know, which dataset to use.
    shared.data = data

    logger.info('### Loading dataset ... Done')

    logger.info('### Build, train and test network ...')

    train_net = setup_network(allow_plots,
                              continue_training,
                              inference_cpu_only,
                              use_biases,
                              mode='train')
    train_net.train(num_iter=num_iter, val_bs=100)

    test_net = setup_network(allow_plots,
                             continue_training,
                             inference_cpu_only,
                             use_biases,
                             mode='inference')
    test_net.test()

    if allow_plots:
        # Generate some fake images.
        latent_inputs = test_net.sample_latent(8)
        _, fake_dis_outs, fake_imgs = test_net.run(np.empty((0, 0)),
                                                   latent_inputs=latent_inputs)
        dplt.plot_gan_images('Generator Samples',
                             np.empty((0, 0)),
                             fake_imgs,
                             fake_dis_outputs=fake_dis_outs,
                             shuffle=True,
                             interactive=True)

    logger.info('### Build, train and test network ... Done')
コード例 #2
0
def run(**kwargs):
    allow_plots = ALLOW_PLOTS
    use_biases = USE_BIASES

    for k in kwargs.keys():
        if k == 'allow_plots':
            allow_plots = kwargs[k]
        elif k == 'use_biases':
            use_biases = kwargs[k]
        else:
            logger.warn('Keyword \'%s\' is unknown.' % k)

    logger.info('### Loading dataset ...')

    data = CIFAR10Data(config.dataset_path)

    # Important! Let the network know, which dataset to use.
    shared.data = data

    logger.info('### Loading dataset ... Done')

    logger.info('### Build, train and test network ...')

    train_net = WassersteinGAN(mode='train')
    train_net.continue_training = False
    train_net.allow_plots = allow_plots
    train_net.use_biases = use_biases
    train_net.build()
    train_net.train(num_iter=100001)

    test_net = WassersteinGAN(mode='inference')
    test_net.allow_plots = allow_plots
    test_net.use_biases = use_biases
    test_net.build()
    test_net.test()

    if allow_plots:
        # Generate some fake images.
        latent_inputs = test_net.sample_latent(8)
        _, fake_dis_outs, fake_imgs = test_net.run(np.empty((0, 0)),
                                                   latent_inputs=latent_inputs)
        dplt.plot_gan_images('Generator Samples',
                             np.empty((0, 0)),
                             fake_imgs,
                             fake_dis_outputs=fake_dis_outs,
                             shuffle=True,
                             interactive=True)

    logger.info('### Build, train and test network ... Done')
コード例 #3
0
def run(**kwargs):
    allow_plots = ALLOW_PLOTS
    continue_training = CONTINUE_TRAINING

    for k in kwargs.keys():
        if k == 'allow_plots':
            allow_plots = kwargs[k]
        elif k == 'continue_training':
            continue_training = kwargs[k]
        else:
            logger.warn('Keyword \'%s\' is unknown.' % k)

    logger.info('### Loading dataset ...')

    data = MNISTData(config.dataset_path)

    # Important! Let the network know, which dataset to use.
    shared.data = data

    logger.info('### Loading dataset ... Done')

    logger.info('### Build, train and test network ...')

    train_net = setup_network(allow_plots, continue_training, mode='train')
    train_net.train(num_iter=10001)

    test_net = setup_network(allow_plots, continue_training, mode='inference')
    test_net.test()

    if allow_plots:
        # Generate some fake images.
        latent_inputs = test_net.sample_latent(8)
        _, fake_dis_outs, fake_imgs = test_net.run(np.empty((0, 0)),
                                                   latent_inputs=latent_inputs)
        dplt.plot_gan_images('Generator Samples',
                             np.empty((0, 0)),
                             fake_imgs,
                             fake_dis_outputs=fake_dis_outs,
                             shuffle=True,
                             interactive=True)

    logger.info('### Build, train and test network ... Done')
コード例 #4
0
    def _validate_training_process(self, sess, epoch):
        """Validate the current training process on the validation batch. Note,
        that the validation uses the same graph and session as the training,
        but the training mode tensor (received as an input tensor of the
        network) is different.

        Args:
            sess: The current training session.
            epoch: The current training iteration.
        """
        logger.info('Epoch %d: validating training process ...' % epoch)

        if self.val_cpu_only:
            logger.warn('The option \'val_cpu_only\' is enabled, but not ' + \
                        'supported by this class. Option will be ignored.')

        val_handle = sess.run(self._val_iter.string_handle())
        sess.run(self._val_iter.initializer,
                 feed_dict={
                     self._t_val_raw_in: self._val_batch[0],
                     self._t_val_raw_out: self._val_batch[1],
                     self._t_val_batch_size: self._val_batch[0].shape[0]
                 })

        # Note, that subclasses (such as a WassersteinGan), don't have a
        # meaningful accuracy.
        if self._t_accuracy is None:
            g_loss, d_loss, summary = sess.run( \
                    [self._g_loss, self._d_loss, self._t_summaries],
                    feed_dict={self._g_inputs: self._val_latent_input,
                               self._t_handle: val_handle})
        else:
            acc, g_loss, d_loss, summary = sess.run( \
                    [self._t_accuracy, self._g_loss, self._d_loss,
                     self._t_summaries],
                    feed_dict={self._g_inputs: self._val_latent_input,
                               self._t_handle: val_handle})

            logger.info('Validation Accuracy: %f' % acc)
        logger.info('Generator loss on validation batch: %f' % g_loss)
        logger.info('Discriminator loss on validation batch: %f' % d_loss)

        self._val_summary_writer.add_summary(summary, epoch)
        self._val_summary_writer.flush()

        if self.allow_plots:
            num_plots = min(4, self._val_latent_input.shape[0])

            # We have to reinitialize to change the batch size (seems to be
            # a cleaner solution than processing the whole validation set).
            sess.run(self._val_iter.initializer,
                     feed_dict={
                         self._t_val_raw_in: self._val_batch[0][:num_plots, :],
                         self._t_val_raw_out:
                         self._val_batch[1][:num_plots, :],
                         self._t_val_batch_size: num_plots
                     })
            fake_imgs, fake_dis_outs = sess.run( \
                [self._g_outputs, self._d_outputs_fake],
                feed_dict={self._g_inputs:
                                self._val_latent_input[:num_plots, :],
                           self._t_handle: val_handle})

            dplt.plot_gan_images('Validation Samples',
                                 np.empty((0, 0)),
                                 fake_imgs,
                                 fake_dis_outputs=fake_dis_outs,
                                 interactive=True)

        logger.info('Epoch %d: validating training process ... Done' % epoch)
コード例 #5
0
    def test(self):
        """Evaluate the trained network using the whole test set.
        
        Note, the we sample random latent input for the generator.
        """
        if not self._is_build:
            raise CustomException('Network has not been build yet.')

        logger.info('Testing DCGAN ...')

        sess = self._get_inference_session()
        if sess is None:
            logger.error('Could not create session. Testing aborted.')

        test_ins = shared.data.get_test_inputs()
        test_outs = shared.data.get_test_outputs()
        test_latent_inputs = self.sample_latent(shared.data.num_test_samples)

        test_handle = sess.run(self._test_iter.string_handle())
        sess.run(self._test_iter.initializer,
                 feed_dict={
                     self._t_test_raw_in: test_ins,
                     self._t_test_raw_out: test_outs,
                     self._t_test_batch_size: shared.data.num_test_samples
                 })

        ckpt_epoch = tf.train.global_step(sess, self._t_global_step)
        logger.info('The network has been trained for %d epochs.' % ckpt_epoch)

        # Note, that subclasses (such as a WassersteinGan), don't have a
        # meaningful accuracy.
        if self._t_accuracy is None:
            g_loss, d_loss = sess.run([self._g_loss, self._d_loss],
                                      feed_dict={
                                          self._g_inputs: test_latent_inputs,
                                          self._t_handle: test_handle
                                      })
        else:
            acc, g_loss, d_loss = sess.run( \
                [self._t_accuracy, self._g_loss, self._d_loss],
                feed_dict={self._g_inputs: test_latent_inputs,
                           self._t_handle: test_handle})

            logger.info('Test Accuracy: %f' % acc)
        logger.info('Generator loss on test set: %f' % g_loss)
        logger.info('Discriminator loss on test set: %f' % d_loss)

        if self.allow_plots:
            num_plots = min(8, test_latent_inputs.shape[0])

            Z_in = test_latent_inputs[:num_plots, :]

            # We have to reinitialize to change the batch size (seems to be
            # a cleaner solution than processing the whole validation set).
            sess.run(self._test_iter.initializer,
                     feed_dict={
                         self._t_test_raw_in: test_ins[:num_plots, :],
                         self._t_test_raw_out: test_outs[:num_plots, :],
                         self._t_test_batch_size: num_plots
                     })
            real_imgs, real_lbls, fake_imgs, fake_dis_outs, real_dis_outs = \
                sess.run([self._t_ds_inputs, self._t_ds_outputs,
                          self._g_outputs, self._d_outputs_fake,
                          self._d_outputs_real],
                feed_dict={self._g_inputs: Z_in,
                           self._t_handle: test_handle})

            dplt.plot_gan_images('Test Samples',
                                 real_imgs,
                                 fake_imgs,
                                 real_outputs=real_lbls,
                                 real_dis_outputs=real_dis_outs,
                                 fake_dis_outputs=fake_dis_outs,
                                 shuffle=True,
                                 interactive=True,
                                 figsize=(10, 12))

        logger.info('Testing DCGAN ... Done')