Beispiel #1
0
def test_discriminator():

    # parameters
    file_name = "animals.txt"
    genr_hidden_size = 10
    disr_hidden_size = 11
    num_epochs = 20
    lr = 1
    alpha = 0.9
    batch_size = 100

    # load data
    char_list = dataloader.get_char_list(file_name)
    X_actual = dataloader.load_data(file_name)
    num_examples = X_actual.shape[0]
    seq_len = X_actual.shape[1]

    # generate
    genr = Generator(genr_hidden_size, char_list)
    X_generated = genr.generate_tensor(seq_len, num_examples)

    # train discriminator
    disr = Discriminator(len(char_list), disr_hidden_size)
    disr.train_RMS(X_actual,
                   X_generated,
                   num_epochs,
                   lr,
                   alpha,
                   batch_size,
                   print_progress=True)

    # print discriminator output
    outp = disr.discriminate(np.concatenate((X_actual, X_generated), axis=0))
    print(outp)

    # evaluate discriminator
    accuracy = disr.accuracy(X_actual, X_generated)
    print("accuracy: ", accuracy)
Beispiel #2
0
class GAN_model(object):
    """"""
    def __init__(self, hps, s_size=4):
        self._hps = hps
        self.s_size = s_size

    def inputs(self, batch_size, s_size):
        files = [
            os.path.join(self._hps.data_path, f)
            for f in os.listdir(self._hps.data_path)
            if f.endswith('.tfrecords')
        ]
        print("tfrecord files: ", files)
        fqueue = tf.train.string_input_producer(files)
        reader = tf.TFRecordReader()
        _, value = reader.read(fqueue)
        features = tf.parse_single_example(value,
                                           features={
                                               'image_raw':
                                               tf.FixedLenFeature([],
                                                                  tf.string),
                                               'height':
                                               tf.FixedLenFeature([],
                                                                  tf.int64),
                                               'width':
                                               tf.FixedLenFeature([],
                                                                  tf.int64),
                                               'depth':
                                               tf.FixedLenFeature([], tf.int64)
                                           })
        image = tf.decode_raw(features['image_raw'], tf.uint8)
        image.set_shape((mnist.IMAGE_PIXELS))
        image = tf.cast(image, tf.float32) * (1 / 255)
        #image = tf.image.resize_image_with_crop_or_pad(image, CROP_IMAGE_SIZE, CROP_IMAGE_SIZE)
        #image = tf.image.random_flip_left_right(image)

        min_queue_examples = self._hps.batch_size * 2
        images = tf.train.shuffle_batch([image],
                                        batch_size=batch_size,
                                        capacity=min_queue_examples +
                                        3 * batch_size,
                                        min_after_dequeue=min_queue_examples)
        tf.summary.image('images', images)

        return images  #.resize_images(images, [s_size * 2 ** 4, s_size * 2 ** 4])

    def _build_GAN(self):

        self.initializer = tf.contrib.layers.xavier_initializer

        with tf.variable_scope('gan'):
            # discriminator input from real data
            self._X = self.inputs(self._hps.batch_size, self.s_size)
            # tf.placeholder(dtype=tf.float32, name='X',
            #                       shape=[None, self._hps.dis_input_size])
            # noise vector (generator input)
            self._preZ = tf.random_uniform(
                [self._hps.batch_size * 3, self._hps.gen_input_size],
                minval=-1.0,
                maxval=1.0)
            self._Z = tf.random_uniform(
                [self._hps.batch_size, self._hps.gen_input_size],
                minval=-1.0,
                maxval=1.0)
            self._Z_sample = tf.random_uniform([20, self._hps.gen_input_size],
                                               minval=-1.0,
                                               maxval=1.0)

            self.discriminator_inner = Discriminator(
                self._hps, scope='discriminator_inner')
            self.discriminator = Discriminator(self._hps)
            self.generator = Generator(self._hps)

            # Generator
            self.G_presample = self.generator.generate(self._preZ)
            self.G_sample_test = self.generator.generate(self._Z_sample)

            # Inner Discriminator
            D_in_fake_presample, D_in_logit_fake_presample = self.discriminator_inner.discriminate(
                self.G_presample)
            D_in_real, D_in_logit_real = self.discriminator_inner.discriminate(
                self._X)

            values, indices = tf.nn.top_k(D_in_fake_presample[:, 0],
                                          self._hps.batch_size)
            tf.logging.info(indices)
            self.G_selected_samples = tf.gather(self.G_presample, indices)
            tf.logging.info(self.G_selected_samples)

            D_in_fake, D_in_logit_fake = self.discriminator_inner.discriminate(
                self.G_selected_samples)

            # Discriminator
            D_real, D_logit_real = self.discriminator.discriminate(self._X)
            D_fake, D_logit_fake = self.discriminator.discriminate(
                self.G_selected_samples)

        with tf.variable_scope('D_loss'):
            D_loss_real = tf.reduce_mean(
                tf.nn.sigmoid_cross_entropy_with_logits(
                    logits=D_logit_real, labels=tf.ones_like(D_logit_real)))
            D_loss_fake = tf.reduce_mean(
                tf.nn.sigmoid_cross_entropy_with_logits(
                    logits=D_logit_fake, labels=tf.zeros_like(D_logit_fake)))
            self._D_loss = D_loss_real + D_loss_fake
            tf.summary.scalar('D_loss_real', D_loss_real, collections=['Dis'])
            tf.summary.scalar('D_loss_fake', D_loss_fake, collections=['Dis'])
            tf.summary.scalar('D_loss', self._D_loss, collections=['Dis'])
            tf.summary.scalar('D_out',
                              tf.reduce_mean(D_logit_fake),
                              collections=['Dis'])

        with tf.variable_scope('D_in_loss'):
            D_in_loss_fake = tf.reduce_mean(
                tf.losses.mean_squared_error(predictions=D_in_logit_fake,
                                             labels=D_logit_fake))
            D_in_loss_real = tf.reduce_mean(
                tf.losses.mean_squared_error(predictions=D_in_logit_real,
                                             labels=D_logit_real))
            self._D_in_loss = D_in_loss_fake + D_in_loss_real
            tf.summary.scalar('D_in_loss',
                              self._D_in_loss,
                              collections=['Dis_in'])
            tf.summary.scalar('D_in_out',
                              tf.reduce_mean(D_in_logit_fake),
                              collections=['Dis_in'])

        with tf.variable_scope('G_loss'):
            self._G_loss = tf.reduce_mean(
                tf.nn.sigmoid_cross_entropy_with_logits(
                    logits=D_in_logit_fake,
                    labels=tf.ones_like(D_in_logit_fake)))
            tf.summary.scalar('G_loss', self._G_loss, collections=['Gen'])

        with tf.variable_scope('GAN_Eval'):
            tf.logging.info(self.G_sample_test.shape)
            eval_fake_images = tf.image.resize_images(self.G_sample_test,
                                                      [28, 28])
            eval_real_images = tf.image.resize_images(self._X[:20], [28, 28])
            self.eval_score = util.mnist_score(eval_fake_images,
                                               MNIST_CLASSIFIER_FROZEN_GRAPH)
            self.frechet_distance = util.mnist_frechet_distance(
                eval_real_images, eval_fake_images,
                MNIST_CLASSIFIER_FROZEN_GRAPH)

            tf.summary.scalar('MNIST_Score',
                              self.eval_score,
                              collections=['All'])
            tf.summary.scalar('frechet_distance',
                              self.frechet_distance,
                              collections=['All'])

    def _add_train_op(self):
        """Sets self._train_op, the op to run for training.
    """
        with tf.device("/gpu:0"):
            learning_rate_D = 0.0004  # tf.train.exponential_decay(0.001, self.global_step_D,
            #                                           100000, 0.96, staircase=True)
            learning_rate_G = 0.0004  # tf.train.exponential_decay(0.001, self.global_step_G,
            #                                           100000, 0.96, staircase=True)
            learning_rate_D_in = 0.0004  # tf.train.exponential_decay(0.001, self.global_step_D,
            #                                           100000, 0.96, staircase=True)
            self._train_op_D = tf.train.AdamOptimizer(
                learning_rate_D,
                beta1=0.5).minimize(self._D_loss,
                                    global_step=self.global_step_D,
                                    var_list=self.discriminator._theta)
            self._train_op_D_in = tf.train.AdamOptimizer(
                learning_rate_D_in,
                beta1=0.5).minimize(self._D_in_loss,
                                    global_step=self.global_step_D_in,
                                    var_list=self.discriminator_inner._theta)

            self._train_op_G = tf.train.AdamOptimizer(
                learning_rate_G,
                beta1=0.5).minimize(self._G_loss,
                                    global_step=self.global_step_G,
                                    var_list=self.generator._theta)

    def build_graph(self):
        """Add the model, global step, train_op and summaries to the graph"""
        tf.logging.info('Building graph...')
        t0 = time.time()
        # with tf.device("/gpu:0"):
        self._build_GAN()

        self.global_step_D = tf.Variable(0,
                                         name='global_step_D',
                                         trainable=False)
        self.global_step_D_in = tf.Variable(0,
                                            name='global_step_D_in',
                                            trainable=False)
        self.global_step_G = tf.Variable(0,
                                         name='global_step_G',
                                         trainable=False)
        self.global_step = tf.add(tf.add(self.global_step_G,
                                         self.global_step_D),
                                  self.global_step_D_in,
                                  name='global_step')

        tf.summary.scalar('global_step_D',
                          self.global_step_D,
                          collections=['All'])
        tf.summary.scalar('global_step_D_in',
                          self.global_step_D_in,
                          collections=['All'])
        tf.summary.scalar('global_step_G',
                          self.global_step_G,
                          collections=['All'])
        self._add_train_op()
        self._summaries_D = tf.summary.merge_all(key='Dis')
        self._summaries_D_in = tf.summary.merge_all(key='Dis_in')
        self._summaries_G = tf.summary.merge_all(key='Gen')
        self._summaries_All = tf.summary.merge_all(key='All')
        t1 = time.time()
        tf.logging.info('Time to build graph: %i seconds', t1 - t0)

    def run_train_step(self, sess, summary_writer, logging=False):
        """Runs one training iteration. Returns a dictionary containing train op,
    summaries, loss, global_step"""

        ######
        to_return_D = {
            'train_op': self._train_op_D,
            'summaries': self._summaries_D,
            'summaries_all': self._summaries_All,
            'loss': self._D_loss,
            'global_step_D': self.global_step_D,
            'global_step': self.global_step,
        }
        results_D = sess.run(to_return_D)

        ######

        to_return_D_in = {
            'train_op': self._train_op_D_in,
            'summaries': self._summaries_D_in,
            'summaries_all': self._summaries_All,
            'loss': self._D_in_loss,
            'global_step_D_in': self.global_step_D_in,
            'global_step': self.global_step,
        }
        results_D_in = sess.run(to_return_D_in)

        ######

        to_return_G = {
            'train_op': self._train_op_G,
            'summaries': self._summaries_G,
            'summaries_all': self._summaries_All,
            'loss': self._G_loss,
            'global_step_G': self.global_step_G,
            'global_step': self.global_step,
        }

        results_G = sess.run(to_return_G)

        # we will write these summaries to tensorboard using summary_writer
        summaries_G = results_G['summaries']
        summaries_D = results_D['summaries']
        summaries_D_in = results_D_in['summaries']
        summaries_All = results_G['summaries_all']

        global_step_G = results_G['global_step_G']
        global_step_D = results_D['global_step_D']
        global_step_D_in = results_D_in['global_step_D_in']
        global_step = results_G['global_step']

        summary_writer.add_summary(summaries_G,
                                   global_step_G)  # write the summaries
        summary_writer.add_summary(summaries_D,
                                   global_step_D)  # write the summaries
        summary_writer.add_summary(summaries_D_in,
                                   global_step_D_in)  # write the summaries
        summary_writer.add_summary(summaries_All,
                                   global_step)  # write the summaries

        if logging:

            loss_D = results_D['loss']
            tf.logging.info('loss_D: %f', loss_D)  # print the loss to screen

            loss_D_in = results_D_in['loss']
            tf.logging.info('loss_D_in: %f',
                            loss_D_in)  # print the loss to screen

            loss_G = results_G['loss']
            tf.logging.info('loss_G: %f', loss_G)  # print the loss to screen

            if not np.isfinite(loss_G):
                raise Exception("Loss_G is not finite. Stopping.")
            if not np.isfinite(loss_D):
                raise Exception("Loss_D is not finite. Stopping.")
            if not np.isfinite(loss_D_in):
                raise Exception("Loss_D_in is not finite. Stopping.")

            # flush the summary writer every so often
            summary_writer.flush()

    def run_eval_step(self, sess):

        return sess.run([self.eval_score, self.frechet_distance])

    def sample_generator(self, sess):
        """Runs generator to generate samples"""

        to_return = {
            'g_sample': self.G_sample_test,
        }
        return sess.run(to_return)
Beispiel #3
0
                    vocab_size=vocab_size,
                    paramSavePath=paramSavePath,
                    logPath=logPath,
                    input_dim=input_dim,
                    keep_prob=keep_prob,
                    reuse=reuse,
                    generator=gen,
                    timestr=timestr,
                    debug=debug)
print('Discriminator building finished!')
pred_w, gee = gen.generate(
    z)  # gee stands for generatee (I made this word. hhh)
# pred_w: [timestep * batch_size, vocab_size]
# gee   : [batch_size, timestep, input_dim, 1]
timestep = gen.timestep  # The addition of PAD makes timestep bigger. So adjustion needs here.
gee_cnn_out, gee_dised = dis.discriminate(
    gee)  # gee_cnn_out is the cnn output before the FC layer.
# gee_dised stands for generatee that has been discriminated.
# gee_cnn_out: [batch_size, input_dim * len(window)]
# gee_dised  : [batch_size, 1]
result3 = tf.reshape(tf.argmax(pred_w, axis=1),
                     [batch_size, timestep])  # [batch_size, timestep]
print('result3 prepared')
result5 = gen.max_print[0]  # [1, timestep]
print('result5 prepared')
# This is also not mentioned in the paper.
# but seems to be the reverse part for exp in generator.
# Not mentioned starts here ---------------------------|
# gee_recon = (gee_dised + 1)/2
# gee_recon = tf.log(gee_recon)
# z_code = tensor.cast(z[:, 0], dtype='int32')
# z_index = tensor.arange(n_batch)
Beispiel #4
0
class GAN_model(object):
    """"""
    def __init__(self, hps, s_size=4):
        self._hps = hps
        self.s_size = s_size
        #tf.set_random_seed(np.random.get_state())

    def inputs(self, batch_size, s_size):
        files = [
            os.path.join(self._hps.data_path, f)
            for f in os.listdir(self._hps.data_path)
            if f.endswith('.tfrecords')
        ]
        print("tfrecord files: ", files)
        fqueue = tf.train.string_input_producer(files)
        reader = tf.TFRecordReader()
        _, value = reader.read(fqueue)
        features = tf.parse_single_example(value,
                                           features={
                                               'height':
                                               tf.FixedLenFeature([],
                                                                  tf.int64),
                                               'width':
                                               tf.FixedLenFeature([],
                                                                  tf.int64),
                                               'depth':
                                               tf.FixedLenFeature([],
                                                                  tf.int64),
                                               'image_raw':
                                               tf.FixedLenFeature([],
                                                                  tf.string)
                                           })
        image = tf.decode_raw(features['image_raw'], tf.uint8)
        image.set_shape((mnist.IMAGE_PIXELS))
        image = tf.cast(image, tf.float32) * (1 / 255)

        min_queue_examples = self._hps.batch_size * 2
        images = tf.train.shuffle_batch([image],
                                        batch_size=batch_size,
                                        capacity=min_queue_examples +
                                        3 * batch_size,
                                        min_after_dequeue=min_queue_examples)
        tf.summary.image('images', images)

        return images
        #return tf.image.resize_images(images, [s_size * 2 ** 4, s_size * 2 ** 4])

    def _build_GAN(self):

        self.initializer = tf.contrib.layers.xavier_initializer
        self.discriminator = Discriminator(self._hps)
        self.generator = Generator(self._hps)

        with tf.variable_scope('gan'):
            # discriminator input from real data
            image_input = self.inputs(self._hps.batch_size, self.s_size)
            tf.logging.info("image input")
            tf.logging.info(image_input)
            self._X = image_input  #tf.contrib.layers.flatten(image_input)

            # tf.placeholder(dtype=tf.float32, name='X',
            #                       shape=[None, self._hps.dis_input_size])
            # noise vector (generator input)
            self._Z = tf.placeholder(dtype="float32",
                                     name='Z',
                                     shape=[None, self._hps.gen_input_size])
            #tf.random_uniform([self._hps.batch_size, self._hps.gen_input_size], minval=-1.0, maxval=1.0)
            #self._Z_sample = tf.random_uniform([20, self._hps.gen_input_size], minval=-1.0, maxval=1.0)

            # Generator
            self.G_sample = self.generator.generate(self._Z)

            D_real, D_logit_real = self.discriminator.discriminate(self._X)
            D_fake, D_logit_fake = self.discriminator.discriminate(
                self.G_sample)

        with tf.variable_scope('D_loss'):
            D_loss_real = tf.reduce_mean(
                tf.nn.sigmoid_cross_entropy_with_logits(
                    logits=D_logit_real, labels=tf.ones_like(D_logit_real)))
            D_loss_fake = tf.reduce_mean(
                tf.nn.sigmoid_cross_entropy_with_logits(
                    logits=D_logit_fake, labels=tf.zeros_like(D_logit_fake)))
            self._D_loss = D_loss_real + D_loss_fake
            tf.summary.scalar('D_loss_real', D_loss_real, collections=['Dis'])
            tf.summary.scalar('D_loss_fake', D_loss_fake, collections=['Dis'])
            tf.summary.scalar('D_loss', self._D_loss, collections=['Dis'])

        with tf.variable_scope('G_loss'):
            self._G_loss = tf.reduce_mean(
                tf.nn.sigmoid_cross_entropy_with_logits(
                    logits=D_logit_fake, labels=tf.ones_like(D_logit_fake)))
            tf.summary.scalar('G_loss', self._G_loss, collections=['Gen'])

        with tf.variable_scope('GAN_Eval'):
            MNIST_CLASSIFIER_FROZEN_GRAPH = '../../models-master/research/gan/mnist/data/classify_mnist_graph_def.pb'
            tf.logging.info(self.G_sample.shape)
            eval_images = tf.reshape(self.G_sample, [-1, 28, 28, 1])
            tf.logging.info(eval_images.shape)

            self.eval_score = util.mnist_score(eval_images,
                                               MNIST_CLASSIFIER_FROZEN_GRAPH)
            self.frechet_distance = util.mnist_frechet_distance(
                tf.reshape(self._X[:20], [-1, 28, 28, 1]), eval_images,
                MNIST_CLASSIFIER_FROZEN_GRAPH)

            tf.summary.scalar('MNIST_Score',
                              self.eval_score,
                              collections=['All'])
            tf.summary.scalar('frechet_distance',
                              self.frechet_distance,
                              collections=['All'])

    def _add_train_op(self):
        """Sets self._train_op, the op to run for training.
    """

        with tf.device("/gpu:0"):
            tf.logging.info(self.discriminator._theta)
            learning_rate_D = 0.0004  #tf.train.exponential_decay(0.001, self.global_step_D,
            #                                           100000, 0.96, staircase=True)
            learning_rate_G = 0.0004  # tf.train.exponential_decay(0.001, self.global_step_G,
            #                                           100000, 0.96, staircase=True)
            #learning_rate_D,beta1=0.5
            self._train_op_D = tf.train.AdamOptimizer(
                learning_rate_D,
                beta1=0.5).minimize(self._D_loss,
                                    global_step=self.global_step_D,
                                    var_list=self.discriminator._theta)
            tf.logging.info(self.generator._theta)
            #learning_rate_G,beta1=0.5
            self._train_op_G = tf.train.AdamOptimizer(
                learning_rate_G,
                beta1=0.5).minimize(self._G_loss,
                                    global_step=self.global_step_G,
                                    var_list=self.generator._theta)

        # Alternative: More control over optimization hyperparameters
        # # Take gradients of the trainable variables w.r.t. the loss function to minimize
        # gradients_D = tf.gradients(self._D_loss, self._theta_G,
        #                          aggregation_method=tf.AggregationMethod.EXPERIMENTAL_TREE)
        # gradients_G = tf.gradients(self._D_loss, self._theta_D,
        #                          aggregation_method=tf.AggregationMethod.EXPERIMENTAL_TREE)
        #
        # # Clip the gradients
        # with tf.device("/gpu:0"):
        #   grads_G, global_norm_G = tf.clip_by_global_norm(gradients_G,
        #                                                   self._hps.max_grad_norm)
        #   grads_D, global_norm_D = tf.clip_by_global_norm(gradients_D,
        #                                                   self._hps.max_grad_norm)
        #
        # # Add a summary
        # tf.summary.scalar('global_norm_D', global_norm_D)
        # tf.summary.scalar('global_norm_G', global_norm_G)
        #
        # # Apply adagrad optimizer
        # optimizer = tf.train.AdagradOptimizer(self._hps.lr,
        #                                       initial_accumulator_value=self._hps.adagrad_init_acc)
        # with tf.device("/gpu:0"):
        #   self._train_op = optimizer.apply_gradients(zip(grads_D, self._theta_D),
        #                                              global_step=self.global_step,
        #                                              name='train_step')
        #   self._train_op = optimizer.apply_gradients(zip(grads_G, self._theta_G),
        #                                              global_step=self.global_step,
        #                                              name='train_step')

    def build_graph(self):
        """Add the model, global step, train_op and summaries to the graph"""
        tf.logging.info('Building graph...')
        t0 = time.time()
        with tf.device("/gpu:0"):
            self._build_GAN()
        self.global_step_D = tf.Variable(0,
                                         name='global_step_D',
                                         trainable=False)
        self.global_step_G = tf.Variable(0,
                                         name='global_step_G',
                                         trainable=False)
        self.global_step = tf.add(self.global_step_G,
                                  self.global_step_D,
                                  name='global_step')

        tf.summary.scalar('global_step_D',
                          self.global_step_D,
                          collections=['All'])
        tf.summary.scalar('global_step_G',
                          self.global_step_G,
                          collections=['All'])
        self._add_train_op()
        self._summaries_D = tf.summary.merge_all(key='Dis')
        self._summaries_G = tf.summary.merge_all(key='Gen')
        self._summaries_All = tf.summary.merge_all(key='All')
        t1 = time.time()
        tf.logging.info('Time to build graph: %i seconds', t1 - t0)

    def run_train_step(self, sess, summary_writer, logging=False):
        """Runs one training iteration. Returns a dictionary containing train op,
    summaries, loss, global_step"""
        feed_dict = {
            self._Z: np.random.uniform(-1,
                                       1,
                                       size=[20, self._hps.gen_input_size])
        }

        to_return_D = {
            'train_op': self._train_op_D,
            'summaries': self._summaries_D,
            'summaries_all': self._summaries_All,
            'loss': self._D_loss,
            'global_step_D': self.global_step_D,
            'global_step': self.global_step,
        }
        results_D = sess.run(to_return_D, feed_dict=feed_dict)

        to_return_G = {
            'train_op': self._train_op_G,
            'summaries': self._summaries_G,
            'summaries_all': self._summaries_All,
            'loss': self._G_loss,
            'global_step_G': self.global_step_G,
            'global_step': self.global_step,
        }
        results_G = sess.run(to_return_G, feed_dict=feed_dict)

        # we will write these summaries to tensorboard using summary_writer
        summaries_G = results_G['summaries']
        summaries_D = results_D['summaries']
        summaries_All = results_G['summaries_all']

        global_step_G = results_G['global_step_G']
        global_step_D = results_D['global_step_D']
        global_step = results_G['global_step']

        summary_writer.add_summary(summaries_G,
                                   global_step_G)  # write the summaries
        summary_writer.add_summary(summaries_D,
                                   global_step_D)  # write the summaries
        summary_writer.add_summary(summaries_All,
                                   global_step)  # write the summaries

        if logging:

            loss_D = results_D['loss']
            tf.logging.info('loss_D: %f', loss_D)  # print the loss to screen

            loss_G = results_G['loss']
            tf.logging.info('loss_G: %f', loss_G)  # print the loss to screen

            if not np.isfinite(loss_G):
                raise Exception("Loss_G is not finite. Stopping.")
            if not np.isfinite(loss_D):
                raise Exception("Loss_D is not finite. Stopping.")

            # flush the summary writer every so often
            summary_writer.flush()

    def run_eval_step(self, sess):

        feed_dic = {
            self._Z: np.random.uniform(-1,
                                       1,
                                       size=[20, self._hps.gen_input_size])
        }
        return sess.run([self.eval_score, self.frechet_distance],
                        feed_dict=feed_dic)

    def sample_generator(self, sess):
        """Runs generator to generate samples"""

        to_return = {
            'g_sample': self.G_sample,
        }
        feed_dic = {
            self._Z: np.random.uniform(-1,
                                       1,
                                       size=[20, self._hps.gen_input_size])
        }
        return sess.run(to_return, feed_dic)
Beispiel #5
0
class BasicGAN:
    def __init__(self, hparams):
        self.hparams = hparams
        self._build()

    def _build(self):
        self.noise_dim = self.hparams.noise_dim
        self.learning_rate = self.hparams.learning_rate
        self.epoches = self.hparams.epoches
        self.batch_size = self.hparams.batch_size

        self.generator = Generator(self.hparams)
        self.discriminator = Discriminator(self.hparams)

        self._add_placeholder()
        self._add_loss()
        self._add_optim()
        self._add_saver()

    def _add_placeholder(self):
        self.rand_noises = tf.placeholder(
            dtype=tf.float32,
            shape=[self.batch_size, self.noise_dim],
            name="rand_noises")
        self.real_imgs = tf.placeholder(dtype=tf.float32,
                                        shape=[self.batch_size, 64, 64, 3],
                                        name="real_imgs")

    def _add_loss(self):
        self.fake_imgs = self.generator.generate(self.rand_noises)
        self.fake_logits = self.discriminator.discriminate(self.fake_imgs)
        self.real_logits = self.discriminator.discriminate(self.real_imgs,
                                                           reuse=True)

        self.d_accuarcy = (
            tf.reduce_mean(tf.cast(self.real_logits > 0, tf.float32)) +
            tf.reduce_mean(tf.cast(self.fake_logits < 0, tf.float32))) / 2

        if self.hparams.model == "GAN":
            # basic gan
            self.d_loss_real = tf.losses.sigmoid_cross_entropy(
                tf.ones_like(self.real_logits),
                logits=self.real_logits,
                label_smoothing=0.2)
            self.d_loss_fake = tf.losses.sigmoid_cross_entropy(
                tf.zeros_like(self.fake_logits),
                logits=self.fake_logits,
                label_smoothing=0.2)
            self.d_loss = (self.d_loss_fake + self.d_loss_real) / 2.0
            self.g_loss = tf.losses.sigmoid_cross_entropy(
                tf.ones_like(self.fake_logits),
                logits=self.fake_logits,
                label_smoothing=0.2)
        elif self.hparams.model == "LSGAN":
            # lease square gan
            self.d_loss_real = tf.losses.mean_squared_error(
                tf.ones_like(self.real_logits), self.real_logits)
            self.d_loss_fake = tf.losses.mean_squared_error(
                tf.zeros_like(self.fake_logits), self.fake_logits)
            self.d_loss = (self.d_loss_fake + self.d_loss_real) / 2.0
            self.g_loss = tf.losses.mean_squared_error(
                tf.ones_like(self.fake_logits), self.fake_logits)
        elif self.hparams.model == "WGAN":
            self.d_loss_real = tf.reduce_mean(self.real_logits)
            self.d_loss_fake = tf.reduce_mean(self.fake_logits)
            self.d_loss = self.d_loss_fake - self.d_loss_real
            self.g_loss = -self.d_loss_fake
        elif self.hparams.model == "WGAN-GP":
            '''Gradient Penalty'''
            rand_alpha = tf.random_uniform(shape=[self.batch_size, 1, 1, 1],
                                           minval=0,
                                           maxval=1,
                                           name="rand_alpha")
            inter_imgs = self.real_imgs * rand_alpha + self.fake_imgs * (
                1 - rand_alpha)
            inter_logits = self.discriminator.discriminate(inter_imgs,
                                                           reuse=True)
            inter_grads = tf.gradients(inter_logits, inter_imgs)[0]
            slops = tf.sqrt(
                tf.reduce_sum(tf.square(inter_grads), axis=[1, 2, 3]))
            penalty = tf.reduce_mean(tf.square(slops - 1))

            self.d_loss_real = tf.reduce_mean(self.real_logits)
            self.d_loss_fake = tf.reduce_mean(self.fake_logits)
            self.d_loss = self.d_loss_fake - self.d_loss_real + self.hparams.penalty_coef * penalty
            self.g_loss = -self.d_loss_fake
        else:
            raise NotImplementedError

    def _add_optim(self):
        tvars = tf.trainable_variables()
        self.d_vars = [var for var in tvars if 'discriminator' in var.name]
        self.g_vars = [var for var in tvars if 'generator' in var.name]

        self.global_step = tf.Variable(0, trainable=False)
        self.d_optim = tf.train.AdamOptimizer(learning_rate=self.learning_rate, beta1=self.hparams.beta1).\
                        minimize(self.d_loss, var_list=self.d_vars, global_step=self.global_step)
        self.g_optim = tf.train.AdamOptimizer(learning_rate=self.learning_rate, beta1=self.hparams.beta1).\
                        minimize(self.g_loss, var_list=self.g_vars)

        if self.hparams.model == "WGAN":
            self.d_optim = tf.train.RMSPropOptimizer(5e-5).\
                            minimize(self.d_loss, var_list=self.d_vars, global_step=self.global_step)
            self.g_optim = tf.train.RMSPropOptimizer(5e-5).\
                            minimize(self.g_loss, var_list=self.g_vars)
            self.clip_min = self.hparams.clip_min
            self.clip_max = self.hparams.clip_max
            with tf.control_dependencies([self.d_optim]):
                #                 self.d_optim = tf.group(
                #                     *(tf.assign(var, tf.clip_by_value(var, self.clip_min, self.clip_max))
                #                       for var in tvars if ("discriminator/filter" in var.name) or
                #                       ("discriminator/dense_weight" in var.name)))
                self.d_optim = tf.group(*(tf.assign(
                    var, tf.clip_by_value(var, self.clip_min, self.clip_max))
                                          for var in self.d_vars))

    def _add_saver(self):
        # checkpoint 相关
        self.checkpoint_dir = os.path.abspath(
            os.path.join(self.hparams.checkpoint_dir, "checkpoints"))
        self.checkpoint_prefix = os.path.join(
            self.checkpoint_dir, "model_{}".format(self.hparams.model))
        if not os.path.exists(self.checkpoint_dir):
            os.makedirs(self.checkpoint_dir)
        self.saver = tf.train.Saver(tf.global_variables(),
                                    max_to_keep=self.hparams.max_to_keep)

    def train(self, sess):
        # loss summaries
        d_summary_op = tf.summary.merge([
            tf.summary.histogram("d_real_prob", tf.sigmoid(self.real_logits)),
            tf.summary.histogram("d_fake_prob", tf.sigmoid(self.fake_logits)),
            tf.summary.scalar("d_loss_fake", self.d_loss_fake),
            tf.summary.scalar("d_loss_real", self.d_loss_real),
            tf.summary.scalar("d_loss", self.d_loss)
        ],
                                        name="discriminator_summary")
        g_summary_op = tf.summary.merge([
            tf.summary.histogram("g_prob", tf.sigmoid(self.fake_logits)),
            tf.summary.scalar("g_loss", self.g_loss),
            tf.summary.image("gen_images", self.fake_imgs)
        ],
                                        name="generator_summary")

        self.summary_dir = os.path.abspath(
            os.path.join(self.hparams.checkpoint_dir, "summary"))
        summary_writer = tf.summary.FileWriter(self.summary_dir, sess.graph)

        image_helper = ImageHelper()

        sess.run(tf.global_variables_initializer())

        for num_epoch, num_batch, batch_images in image_helper.iter_images(
                dirname=self.hparams.data_dir,
                batch_size=self.batch_size,
                epoches=self.epoches):
            if (num_epoch == 0) and (num_batch < self.hparams.d_pretrain):
                # pre-train discriminator
                _, current_step, d_loss, d_accuarcy = sess.run(
                    [
                        self.d_optim, self.global_step, self.d_loss,
                        self.d_accuarcy
                    ],
                    feed_dict={
                        self.rand_noises:
                        np.random.normal(
                            size=[self.batch_size, self.noise_dim]),
                        self.real_imgs:
                        batch_images
                    })
                if current_step == self.hparams.d_pretrain:
                    tf.logging.info("==== pre-train ==== current_step:{}, d_loss:{}, d_accuarcy:{}"\
                                    .format(current_step, d_loss, d_accuarcy))
            else:
                # optimize discriminator
                _, current_step, d_loss, d_accuarcy = sess.run(
                    [
                        self.d_optim, self.global_step, self.d_loss,
                        self.d_accuarcy
                    ],
                    feed_dict={
                        self.rand_noises:
                        np.random.normal(
                            size=[self.batch_size, self.noise_dim]),
                        self.real_imgs:
                        batch_images
                    })

                # optimize generator
                if current_step % self.hparams.d_schedule == 0:
                    _, g_loss = sess.run(
                        [self.g_optim, self.g_loss],
                        feed_dict={
                            self.rand_noises:
                            np.random.normal(
                                size=[self.batch_size, self.noise_dim])
                        })

                # summary
                if current_step % self.hparams.log_interval == 0:
                    d_summary_str, g_summary_str = sess.run(
                        [d_summary_op, g_summary_op],
                        feed_dict={
                            self.rand_noises:
                            np.random.normal(
                                size=[self.batch_size, self.noise_dim]),
                            self.real_imgs:
                            batch_images
                        })
                    summary_writer.add_summary(d_summary_str, current_step)
                    summary_writer.add_summary(g_summary_str, current_step)

                    tf.logging.info("step:{}, d_loss:{}, d_accuarcy:{}, g_loss:{}"\
                                    .format(current_step, d_loss, d_accuarcy, g_loss))

            if (num_epoch > 0) and (num_batch == 0):
                # generate images per epoch
                tf.logging.info(
                    "epoch:{} === generate images and save checkpoint".format(
                        num_epoch))
                fake_imgs = sess.run(
                    self.fake_imgs,
                    feed_dict={
                        self.rand_noises:
                        np.random.normal(
                            size=[self.batch_size, self.noise_dim])
                    })
                image_helper.save_imgs(fake_imgs,
                                       img_name="{}/fake-{}".format(
                                           self.hparams.sample_dir, num_epoch))
                # save model per epoch
                self.saver.save(sess,
                                self.checkpoint_prefix,
                                global_step=num_epoch)

    def infer(self, sess):
        # 加载模型
        ckpt = tf.train.get_checkpoint_state(self.checkpoint_dir)
        if ckpt and tf.train.checkpoint_exists(ckpt.model_checkpoint_path):
            self.saver.restore(sess, ckpt.model_checkpoint_path)

        image_helper = ImageHelper()

        fake_imgs = sess.run(
            self.fake_imgs,
            feed_dict={
                self.rand_noises:
                np.random.normal(size=[self.batch_size, self.noise_dim])
            })
        img_name = "{}/infer-image".format(self.hparams.sample_dir)
        image_helper.save_imgs(fake_imgs, img_name=img_name)

        tf.logging.info(
            "====== generate images in file: {} ======".format(img_name))
Beispiel #6
0
class BasicGAN:
    
    def __init__(self, hparams):
        self.hparams = hparams
        self._build()

    def _build(self):
        self.noise_dim = self.hparams.noise_dim
        self.learning_rate = self.hparams.learning_rate
        self.epoches = self.hparams.epoches
        self.batch_size = self.hparams.batch_size
        
        self.generator = Generator(self.hparams)
        self.discriminator = Discriminator(self.hparams)

        self._add_placeholder()
        self._add_loss()
        self._add_optim()
        self._add_saver()


    def _add_placeholder(self):
        self.rand_noises = tf.placeholder(tf.float32, [self.batch_size, self.noise_dim], "rand_noises")
        self.real_imgs = tf.placeholder(tf.float32, [self.batch_size, 64, 64, 3], "real_imgs")
        self.tags = tf.placeholder(tf.float32, [self.batch_size, 22], "tags")
        self.wrong_tags = tf.placeholder(tf.float32, [self.batch_size, 22], "wrong_tags") 


    def _add_loss(self):
        images_filped = tf.map_fn(lambda img: tf.image.random_flip_left_right(img), self.real_imgs)
        angles = tf.random_uniform([self.batch_size], 
                                   minval=-15.0 * np.pi / 180.0, 
                                   maxval=15.0 * np.pi / 180.0)
        self.rotated_imgs = tf.contrib.image.rotate(images_filped, angles, interpolation='NEAREST')

        self.fake_imgs = self.generator.generate(self.rand_noises, self.tags)
        # fake images, right tag 的 logits
        self.fake_logits = self.discriminator.discriminate(self.fake_imgs, self.tags)
        # real images, wrong tag 的 logits !!! 新增的一种 loss
        self.wtag_logits = self.discriminator.discriminate(self.rotated_imgs, self.wrong_tags, reuse=True)
        # real images, right tag 的 logits
        self.real_logits = self.discriminator.discriminate(self.real_imgs, self.tags, reuse=True)

        self.d_accuarcy = (tf.reduce_mean(tf.cast(self.real_logits>0, tf.float32)) + 
                           tf.reduce_mean(tf.cast(self.fake_logits<0, tf.float32)) + 
                           tf.reduce_mean(tf.cast(self.wtag_logits<0, tf.float32))) / 3.0

#         self.d_accuarcy = (tf.reduce_mean(tf.cast(self.real_logits>0, tf.float32)) + 
#                            tf.reduce_mean(tf.cast(self.fake_logits<0, tf.float32))) / 2.0

        '''
        # basic gan
        self.d_loss_real = tf.losses.sigmoid_cross_entropy(tf.ones_like(self.real_logits), 
                                                           logits=self.real_logits, 
                                                           label_smoothing=0.2)
        self.d_loss_fake = tf.losses.sigmoid_cross_entropy(tf.zeros_like(self.fake_logits), 
                                                           logits=self.fake_logits,
                                                           label_smoothing=0.2)
        self.d_loss_wtag = tf.losses.sigmoid_cross_entropy(tf.zeros_like(self.wtag_logits), 
                                                           logits=self.wtag_logits, 
                                                           label_smoothing=0.2)
        self.d_loss = self.d_loss_real + (self.d_loss_fake + self.d_loss_wtag) / 2.0
        # self.d_loss = self.d_loss_real + self.d_loss_fake
        self.g_loss = tf.losses.sigmoid_cross_entropy(tf.ones_like(self.fake_logits),
                                                      logits=self.fake_logits,
                                                      label_smoothing=0.2)
        '''
        # Gradient Penalty
        rand_alpha = tf.random_uniform(shape=[self.batch_size, 1, 1, 1], 
                                       minval=0, maxval=1, name="rand_alpha")
        inter_imgs = self.real_imgs * rand_alpha + self.fake_imgs * (1-rand_alpha) 
        inter_logits = self.discriminator.discriminate(inter_imgs, self.tags, reuse=True)
        inter_grads = tf.gradients(inter_logits, inter_imgs)[0]
        slops = tf.sqrt(tf.reduce_sum(tf.square(inter_grads), axis=[1,2,3]))
        penalty = tf.reduce_mean(tf.square(slops - 1))

        self.d_loss_real = tf.reduce_mean(self.real_logits)
        self.d_loss_fake = tf.reduce_mean(self.fake_logits)
        self.d_loss_wtag = tf.reduce_mean(self.wtag_logits)
        self.d_loss = (self.d_loss_fake + self.d_loss_wtag) - self.d_loss_real + self.hparams.penalty_coef * penalty
        # self.d_loss = self.d_loss_fake - self.d_loss_real + self.hparams.penalty_coef * penalty
        self.g_loss = -self.d_loss_fake

        # self.d_loss_wtag = tf.reduce_mean(self.wtag_logits)
#         self.d_loss = (self.d_loss_fake + self.d_loss_wtag) * 0.5 - self.d_loss_real \
#                             + self.hparams.penalty_coef * gradient_penalty



    def _add_optim(self):
        tvars = tf.trainable_variables()
        self.d_vars = [var for var in tvars if 'discriminator' in var.name]
        self.g_vars = [var for var in tvars if 'generator' in var.name]

        self.global_step = tf.Variable(0, trainable=False)
        self.d_optim = tf.train.AdamOptimizer(learning_rate=self.learning_rate, beta1=self.hparams.beta1).\
                        minimize(self.d_loss, var_list=self.d_vars, global_step=self.global_step)
        self.g_optim = tf.train.AdamOptimizer(learning_rate=self.learning_rate, beta1=self.hparams.beta1).\
                        minimize(self.g_loss, var_list=self.g_vars)


    def _add_saver(self):
        # checkpoint 相关
        self.checkpoint_dir = os.path.abspath(os.path.join(self.hparams.checkpoint_dir, "checkpoints"))
        self.checkpoint_prefix = os.path.join(self.checkpoint_dir, "model_{}".format(self.hparams.model))
        if not os.path.exists(self.checkpoint_dir):
            os.makedirs(self.checkpoint_dir)
        self.saver = tf.train.Saver(tf.global_variables(), max_to_keep=self.hparams.max_to_keep)


    def train(self, sess):
        # loss summaries
        d_summary_op = tf.summary.merge([tf.summary.histogram("d_real_prob", tf.sigmoid(self.real_logits)),
                                         tf.summary.histogram("d_fake_prob", tf.sigmoid(self.fake_logits)),
                                         tf.summary.scalar("d_loss_fake", self.d_loss_fake), 
                                         tf.summary.scalar("d_loss_real", self.d_loss_real), 
                                         tf.summary.scalar("d_loss", self.d_loss)],
                                        name="discriminator_summary")
        g_summary_op = tf.summary.merge([tf.summary.histogram("g_prob", tf.sigmoid(self.fake_logits)),
                                         tf.summary.scalar("g_loss", self.g_loss),
                                         tf.summary.image("gen_images", self.fake_imgs)],
                                        name="generator_summary")

        self.summary_dir = os.path.abspath(os.path.join(self.hparams.checkpoint_dir, "summary"))
        summary_writer = tf.summary.FileWriter(self.summary_dir, sess.graph)

        image_helper = ImageHelper()
        
        sess.run(tf.global_variables_initializer())

        test_tags = image_helper.get_test_tags()
        for batch_id, batch_data in image_helper.iter_images(batch_size=self.batch_size, 
                                                             epoches=self.epoches):
            num_epoch, num_batch = batch_id
            batch_images, batch_tags, batch_wtags = batch_data
            if (num_epoch == 0) and (num_batch < self.hparams.d_pretrain):
                # pre-train discriminator
                _, current_step, d_loss, d_accuarcy = sess.run(
                    [self.d_optim, self.global_step, self.d_loss, self.d_accuarcy], 
                    feed_dict={
                        self.rand_noises: np.random.normal(size=[self.batch_size, self.noise_dim]),
                        self.real_imgs: batch_images,
                        self.tags: self.batch_tags,
                        self.wrong_tags: batch_wtags})
                if current_step == self.hparams.d_pretrain:
                    tf.logging.info("==== pre-train ==== current_step:{}, d_loss:{}, d_accuarcy:{}"\
                                    .format(current_step, d_loss, d_accuarcy))
            else:
                # optimize discriminator
                _, current_step, d_loss, d_accuarcy = sess.run(
                    [self.d_optim, self.global_step, self.d_loss, self.d_accuarcy], 
                    feed_dict={self.rand_noises: np.random.normal(size=[self.batch_size, self.noise_dim]),
                               self.real_imgs: batch_images,
                               self.tags: batch_tags,
                               self.wrong_tags: batch_wtags})
                # import IPython
                # IPython.embed()
                # optimize generator
                if current_step % self.hparams.d_schedule == 0:
                    _, g_loss = sess.run(
                        [self.g_optim, self.g_loss], 
                        feed_dict={self.rand_noises: np.random.normal(size=[self.batch_size, self.noise_dim]),
                                   self.tags: batch_tags})

                # summary
                if current_step % self.hparams.log_interval == 0:
#                     import IPython
#                     IPython.embed()
                    d_summary_str, g_summary_str = sess.run(
                        [d_summary_op, g_summary_op], 
                        feed_dict={self.rand_noises: np.random.normal(size=[self.batch_size, self.noise_dim]),
                                   self.real_imgs: batch_images,
                                   self.tags: batch_tags,
                                   self.wrong_tags: batch_wtags})
                    summary_writer.add_summary(d_summary_str, current_step)
                    summary_writer.add_summary(g_summary_str, current_step)

                    tf.logging.info("step:{}, d_loss:{}, g_loss:{}, d_accuarcy:{}"\
                                    .format(current_step, d_loss, g_loss, d_accuarcy))

            if (num_epoch > 0) and (num_batch == 0):
                # generate images per epoch
                tf.logging.info("epoch:{} === generate images and save checkpoint".format(num_epoch))
                fake_imgs = sess.run(
                    self.fake_imgs, 
                    feed_dict={self.rand_noises: np.random.normal(size=[self.batch_size, self.noise_dim]),
                               self.tags: test_tags})
                image_helper.save_imgs(fake_imgs, 
                                       img_name="{}/fake-{}".format(self.hparams.sample_dir, num_epoch))
                # save model per epoch
                self.saver.save(sess, self.checkpoint_prefix, global_step=num_epoch)


    def infer(self, sess):
        # 加载模型
        ckpt = tf.train.get_checkpoint_state(self.checkpoint_dir)
        if ckpt and tf.train.checkpoint_exists(ckpt.model_checkpoint_path):
            self.saver.restore(sess, ckpt.model_checkpoint_path)

        image_helper = ImageHelper()

        fake_imgs = sess.run(
            self.fake_imgs, 
            feed_dict={self.rand_noises: np.random.normal(size=[self.batch_size, self.noise_dim])})
        img_name = "{}/infer-image".format(self.hparams.sample_dir)
        image_helper.save_imgs(fake_imgs, 
                               img_name=img_name)

        tf.logging.info("====== generate images in file: {} ======".format(img_name))