Exemplo n.º 1
0
    def _build_GAN(self):

        self.initializer = tf.contrib.layers.xavier_initializer

        with tf.variable_scope('gan'):
            # discriminator input from real data
            self._X = self.inputs(self._hps.batch_size, self.s_size)
            # tf.placeholder(dtype=tf.float32, name='X',
            #                       shape=[None, self._hps.dis_input_size])
            # noise vector (generator input)
            self._preZ = tf.random_uniform(
                [self._hps.batch_size * 3, self._hps.gen_input_size],
                minval=-1.0,
                maxval=1.0)
            self._Z = tf.random_uniform(
                [self._hps.batch_size, self._hps.gen_input_size],
                minval=-1.0,
                maxval=1.0)
            self._Z_sample = tf.random_uniform([20, self._hps.gen_input_size],
                                               minval=-1.0,
                                               maxval=1.0)

            self.discriminator_inner = Discriminator(
                self._hps, scope='discriminator_inner')
            self.discriminator = Discriminator(self._hps)
            self.generator = Generator(self._hps)

            # Generator
            self.G_presample = self.generator.generate(self._preZ)
            self.G_sample_test = self.generator.generate(self._Z_sample)

            # Inner Discriminator
            D_in_fake_presample, D_in_logit_fake_presample = self.discriminator_inner.discriminate(
                self.G_presample)
            D_in_real, D_in_logit_real = self.discriminator_inner.discriminate(
                self._X)

            values, indices = tf.nn.top_k(D_in_fake_presample[:, 0],
                                          self._hps.batch_size)
            tf.logging.info(indices)
            self.G_selected_samples = tf.gather(self.G_presample, indices)
            tf.logging.info(self.G_selected_samples)

            D_in_fake, D_in_logit_fake = self.discriminator_inner.discriminate(
                self.G_selected_samples)

            # Discriminator
            D_real, D_logit_real = self.discriminator.discriminate(self._X)
            D_fake, D_logit_fake = self.discriminator.discriminate(
                self.G_selected_samples)

        with tf.variable_scope('D_loss'):
            D_loss_real = tf.reduce_mean(
                tf.nn.sigmoid_cross_entropy_with_logits(
                    logits=D_logit_real, labels=tf.ones_like(D_logit_real)))
            D_loss_fake = tf.reduce_mean(
                tf.nn.sigmoid_cross_entropy_with_logits(
                    logits=D_logit_fake, labels=tf.zeros_like(D_logit_fake)))
            self._D_loss = D_loss_real + D_loss_fake
            tf.summary.scalar('D_loss_real', D_loss_real, collections=['Dis'])
            tf.summary.scalar('D_loss_fake', D_loss_fake, collections=['Dis'])
            tf.summary.scalar('D_loss', self._D_loss, collections=['Dis'])
            tf.summary.scalar('D_out',
                              tf.reduce_mean(D_logit_fake),
                              collections=['Dis'])

        with tf.variable_scope('D_in_loss'):
            D_in_loss_fake = tf.reduce_mean(
                tf.losses.mean_squared_error(predictions=D_in_logit_fake,
                                             labels=D_logit_fake))
            D_in_loss_real = tf.reduce_mean(
                tf.losses.mean_squared_error(predictions=D_in_logit_real,
                                             labels=D_logit_real))
            self._D_in_loss = D_in_loss_fake + D_in_loss_real
            tf.summary.scalar('D_in_loss',
                              self._D_in_loss,
                              collections=['Dis_in'])
            tf.summary.scalar('D_in_out',
                              tf.reduce_mean(D_in_logit_fake),
                              collections=['Dis_in'])

        with tf.variable_scope('G_loss'):
            self._G_loss = tf.reduce_mean(
                tf.nn.sigmoid_cross_entropy_with_logits(
                    logits=D_in_logit_fake,
                    labels=tf.ones_like(D_in_logit_fake)))
            tf.summary.scalar('G_loss', self._G_loss, collections=['Gen'])

        with tf.variable_scope('GAN_Eval'):
            tf.logging.info(self.G_sample_test.shape)
            eval_fake_images = tf.image.resize_images(self.G_sample_test,
                                                      [28, 28])
            eval_real_images = tf.image.resize_images(self._X[:20], [28, 28])
            self.eval_score = util.mnist_score(eval_fake_images,
                                               MNIST_CLASSIFIER_FROZEN_GRAPH)
            self.frechet_distance = util.mnist_frechet_distance(
                eval_real_images, eval_fake_images,
                MNIST_CLASSIFIER_FROZEN_GRAPH)

            tf.summary.scalar('MNIST_Score',
                              self.eval_score,
                              collections=['All'])
            tf.summary.scalar('frechet_distance',
                              self.frechet_distance,
                              collections=['All'])
Exemplo n.º 2
0
    def _build_GAN(self):

        self.initializer = tf.contrib.layers.xavier_initializer
        self.discriminator = Discriminator(self._hps)
        self.generator = Generator(self._hps)

        with tf.variable_scope('gan'):
            # discriminator input from real data
            image_input = self.inputs(self._hps.batch_size, self.s_size)
            tf.logging.info("image input")
            tf.logging.info(image_input)
            self._X = image_input  #tf.contrib.layers.flatten(image_input)

            # tf.placeholder(dtype=tf.float32, name='X',
            #                       shape=[None, self._hps.dis_input_size])
            # noise vector (generator input)
            self._Z = tf.placeholder(dtype="float32",
                                     name='Z',
                                     shape=[None, self._hps.gen_input_size])
            #tf.random_uniform([self._hps.batch_size, self._hps.gen_input_size], minval=-1.0, maxval=1.0)
            #self._Z_sample = tf.random_uniform([20, self._hps.gen_input_size], minval=-1.0, maxval=1.0)

            # Generator
            self.G_sample = self.generator.generate(self._Z)

            D_real, D_logit_real = self.discriminator.discriminate(self._X)
            D_fake, D_logit_fake = self.discriminator.discriminate(
                self.G_sample)

        with tf.variable_scope('D_loss'):
            D_loss_real = tf.reduce_mean(
                tf.nn.sigmoid_cross_entropy_with_logits(
                    logits=D_logit_real, labels=tf.ones_like(D_logit_real)))
            D_loss_fake = tf.reduce_mean(
                tf.nn.sigmoid_cross_entropy_with_logits(
                    logits=D_logit_fake, labels=tf.zeros_like(D_logit_fake)))
            self._D_loss = D_loss_real + D_loss_fake
            tf.summary.scalar('D_loss_real', D_loss_real, collections=['Dis'])
            tf.summary.scalar('D_loss_fake', D_loss_fake, collections=['Dis'])
            tf.summary.scalar('D_loss', self._D_loss, collections=['Dis'])

        with tf.variable_scope('G_loss'):
            self._G_loss = tf.reduce_mean(
                tf.nn.sigmoid_cross_entropy_with_logits(
                    logits=D_logit_fake, labels=tf.ones_like(D_logit_fake)))
            tf.summary.scalar('G_loss', self._G_loss, collections=['Gen'])

        with tf.variable_scope('GAN_Eval'):
            MNIST_CLASSIFIER_FROZEN_GRAPH = '../../models-master/research/gan/mnist/data/classify_mnist_graph_def.pb'
            tf.logging.info(self.G_sample.shape)
            eval_images = tf.reshape(self.G_sample, [-1, 28, 28, 1])
            tf.logging.info(eval_images.shape)

            self.eval_score = util.mnist_score(eval_images,
                                               MNIST_CLASSIFIER_FROZEN_GRAPH)
            self.frechet_distance = util.mnist_frechet_distance(
                tf.reshape(self._X[:20], [-1, 28, 28, 1]), eval_images,
                MNIST_CLASSIFIER_FROZEN_GRAPH)

            tf.summary.scalar('MNIST_Score',
                              self.eval_score,
                              collections=['All'])
            tf.summary.scalar('frechet_distance',
                              self.frechet_distance,
                              collections=['All'])
discriminator_optimizer = tf.train.AdamOptimizer(0.0001, beta1=0.5)
gan_train_ops = tfgan.gan_train_ops(gan_model, improved_wgan_loss,
                                    generator_optimizer,
                                    discriminator_optimizer)

num_images_to_eval = 500
MNIST_CLASSIFIER_FROZEN_GRAPH = './mnist/data/classify_mnist_graph_def.pb'

# For variables to load, use the same variable scope as in the train job.
with tf.variable_scope('Generator', reuse=True):
    eval_images = gan_model.generator_fn(tf.random_normal(
        [num_images_to_eval, noise_dims]),
                                         is_training=False)

# Calculate Inception score.
eval_score = util.mnist_score(eval_images, MNIST_CLASSIFIER_FROZEN_GRAPH)

# Calculate Frechet Inception distance.
with tf.device('/cpu:0'):
    real_images, _, _ = data_provider.provide_data('train', num_images_to_eval,
                                                   MNIST_DATA_DIR)
frechet_distance = util.mnist_frechet_distance(real_images, eval_images,
                                               MNIST_CLASSIFIER_FROZEN_GRAPH)

# Reshape eval images for viewing.
generated_data_to_visualize = tfgan.eval.image_reshaper(eval_images[:20, ...],
                                                        num_cols=10)

train_step_fn = tfgan.get_sequential_train_steps()

global_step = tf.train.get_or_create_global_step()
Exemplo n.º 4
0
def train(is_train):
    
    if not tf.gfile.Exists(MNIST_DATA_DIR):
        tf.gfile.MakeDirs(MNIST_DATA_DIR)
    
    #download_and_convert_mnist.run(MNIST_DATA_DIR)
    
    tf.reset_default_graph()
    
    # Define our input pipeline. Pin it to the CPU so that the GPU can be reserved
    # for forward and backwards propogation.
    batch_size = 32
    with tf.device('/cpu:0'):
        real_images, _, _ = data_provider.provide_data(
            'train', batch_size, MNIST_DATA_DIR)
    
    # Sanity check that we're getting images.
    #check_real_digits = tfgan.eval.image_reshaper(
    #    real_images[:20,...], num_cols=10)
    #print('visualize_digits')
    #visualize_digits(check_real_digits)
    #plt.show()    
    
    gan_model = tfgan.gan_model(
        generator_fn,
        discriminator_fn,
        real_data=real_images,
        generator_inputs=tf.random_normal([batch_size, noise_dims]))
    
    improved_wgan_loss = tfgan.gan_loss(
        gan_model,
        # We make the loss explicit for demonstration, even though the default is 
        # Wasserstein loss.
        generator_loss_fn=tfgan.losses.wasserstein_generator_loss,
        discriminator_loss_fn=tfgan.losses.wasserstein_discriminator_loss,
        gradient_penalty_weight=1.0)
    
    # Sanity check that we can evaluate our losses.
    print("Sanity check that we can evaluate our losses")
    for gan_loss, name in [(improved_wgan_loss, 'improved wgan loss')]:
        evaluate_tfgan_loss(gan_loss, name)
    
    
    #generator_optimizer = tf.train.AdamOptimizer(0.001, beta1=0.5)
    #discriminator_optimizer = tf.train.AdamOptimizer(0.0001, beta1=0.5)
    generator_optimizer = tf.train.RMSPropOptimizer(0.001)
    discriminator_optimizer = tf.train.RMSPropOptimizer(0.0001)
    gan_train_ops = tfgan.gan_train_ops(
        gan_model,
        improved_wgan_loss,
        generator_optimizer,
        discriminator_optimizer)
    
    # ### Evaluation
    
    num_images_to_eval = 500
    MNIST_CLASSIFIER_FROZEN_GRAPH = os.path.join(
            RESEARCH_FOLDER,
            'gan/mnist/data/classify_mnist_graph_def.pb')
    
    # For variables to load, use the same variable scope as in the train job.
    with tf.variable_scope('Generator', reuse=True):
        eval_images = gan_model.generator_fn(
            tf.random_normal([num_images_to_eval, noise_dims]),
            is_training=False)
    
    # Calculate Inception score.
    eval_score = util.mnist_score(eval_images, MNIST_CLASSIFIER_FROZEN_GRAPH)
    
    # Calculate Frechet Inception distance.
    with tf.device('/cpu:0'):
        real_images, _, _ = data_provider.provide_data(
            'train', num_images_to_eval, MNIST_DATA_DIR)
    frechet_distance = util.mnist_frechet_distance(
        real_images, eval_images, MNIST_CLASSIFIER_FROZEN_GRAPH)
    
    # Reshape eval images for viewing.
    generated_data_to_visualize = tfgan.eval.image_reshaper(
        eval_images[:20,...], num_cols=10)
    
     # This code block should take about **1 minute** to run on a GPU kernel, and about **8 minutes** on CPU.
    
    train_step_fn = tfgan.get_sequential_train_steps()
    
    global_step = tf.train.get_or_create_global_step()
    loss_values, mnist_scores, frechet_distances  = [], [], []
    tf.summary.scalar('dis_loss', gan_loss.discriminator_loss)
    tf.summary.scalar('gen_loss', gan_loss.generator_loss)
    merged = tf.summary.merge_all()
    saver = tf.train.Saver()
    saver_hook =  tf.train.CheckpointSaverHook(
      checkpoint_dir= "./models",
      save_steps=1000,
      saver=saver)
     
    print("Graph trainable nodes:")
    for v in tf.trainable_variables():
        print (v.name)
   
    with tf.train.SingularMonitoredSession(hooks=[saver_hook],
        checkpoint_dir="./models") as sess:
        start_time = time.time()
        train_writer = tf.summary.FileWriter("./summary", sess.graph)
        if is_train:
            for i in xrange(2000):
                cur_loss, _ = train_step_fn(
                    sess, gan_train_ops, global_step, train_step_kwargs={})
                loss_values.append((i, cur_loss))

                if i % 10 == 0:
                    merged_val = sess.run(merged)
                    train_writer.add_summary(merged_val, i)
                    print("Step:{}".format(i))

            mnist_score, f_distance, digits_np = sess.run(
                [eval_score, frechet_distance, generated_data_to_visualize])
            mnist_scores.append((i, mnist_score))
            frechet_distances.append((i, f_distance))
            print('Current loss: %f' % cur_loss)
            print('Current MNIST score: %f' % mnist_scores[-1][1])
            print('Current Frechet distance: %f' % frechet_distances[-1][1])
            visualize_training_generator(i, start_time, digits_np)
            
        else: #generate from trained model
            generated = sess.run(eval_images)
            print("generated[0] shape:{}".format(generated[0].shape))
            plt.imshow(np.squeeze(generated[0]), cmap='gray')
            plt.show()