Beispiel #1
0
def main(_):
    if not tf.gfile.Exists(FLAGS.train_log_dir):
        tf.gfile.MakeDirs(FLAGS.train_log_dir)

    with tf.name_scope('inputs'):
        with tf.device('/cpu:0'):
            images_vae, one_hot_labels, _ = provide_data('train', FLAGS.batch_size, FLAGS.dataset_dir, num_threads=4)
            images_gan = 2.0 * images_vae - 1.0

    my_vae = VAE("train", z_dim=64, data_tensor=images_vae)
    rec = my_vae.reconstruct(images_vae)

    vae_checkpoint_path = tf.train.latest_checkpoint(FLAGS.vae_checkpoint_folder)
    saver = tf.train.Saver()

    gan_model = tfgan.gan_model(
        generator_fn=networks.generator,
        discriminator_fn=networks.discriminator,
        real_data=images_gan,
        generator_inputs=[tf.random_normal(
            [FLAGS.batch_size, FLAGS.noise_dims]), tf.reshape(rec, [FLAGS.batch_size, 28, 28, 1])])

    tfgan.eval.add_gan_model_image_summaries(gan_model, FLAGS.grid_size, True)

    with tf.name_scope('loss'):

        gan_loss = tfgan.gan_loss(
            gan_model,
            gradient_penalty_weight=1.0,
            mutual_information_penalty_weight=0.0,
            add_summaries=True)
        # tfgan.eval.add_regularization_loss_summaries(gan_model)

    # Get the GANTrain ops using custom optimizers.
    with tf.name_scope('train'):
        gen_lr, dis_lr = (1e-3, 1e-4)
        train_ops = tfgan.gan_train_ops(
            gan_model,
            gan_loss,
            generator_optimizer=tf.train.AdamOptimizer(gen_lr, 0.5),
            discriminator_optimizer=tf.train.AdamOptimizer(dis_lr, 0.5),
            summarize_gradients=False,
            aggregation_method=tf.AggregationMethod.EXPERIMENTAL_ACCUMULATE_N)

    status_message = tf.string_join(
        ['Starting train step: ',
         tf.as_string(tf.train.get_or_create_global_step())],
        name='status_message')

    step_hooks = tfgan.get_sequential_train_hooks()(train_ops)
    hooks = [tf.train.StopAtStepHook(num_steps=FLAGS.max_number_of_steps),
             tf.train.LoggingTensorHook([status_message], every_n_iter=10)] + list(step_hooks)

    with tf.train.MonitoredTrainingSession(hooks=hooks,
                                           save_summaries_steps=500,
                                           checkpoint_dir=FLAGS.train_log_dir) as sess:
        saver.restore(sess, vae_checkpoint_path)
        loss = None
        while not sess.should_stop():
            loss = sess.run(train_ops.global_step_inc_op)
def gan_loss(gan_model: tfgan.GANModel,
             generator_loss_fn=tfgan.losses.modified_generator_loss,
             discriminator_loss_fn=tfgan.losses.modified_discriminator_loss,
             gradient_penalty_weight=None,
             gradient_penalty_epsilon=1e-10,
             gradient_penalty_target=1.0,
             feature_matching=False,
             add_summaries=False):
    """ Create A GAN loss set, with support for feature matching.
    Args:
        bigan_model: the model
        feature_matching: Whether to add a feature matching loss to the encoder
      and generator.
    """
    gan_loss = tfgan.gan_loss(gan_model,
                              generator_loss_fn=generator_loss_fn,
                              discriminator_loss_fn=discriminator_loss_fn,
                              gradient_penalty_weight=gradient_penalty_weight,
                              gradient_penalty_target=1.0,
                              add_summaries=add_summaries)

    if feature_matching:
        fm_loss = feature_matching_loss(
            scope=gan_model.discriminator_scope.name)
        if add_summaries:
            tf.summary.scalar("feature_matching_loss", fm_loss)
        # or combine the original adversarial loss with FM
        gen_loss = gan_loss.generator_loss + fm_loss
        disc_loss = gan_loss.discriminator_loss
        gan_loss = tfgan.GANLoss(gen_loss, disc_loss)

    return gan_loss
def construct_gan_loss(training_params, gan_model_dict):
  gan_model = gan_model_dict['gan_model']
  gan_loss = tfgan.gan_loss(
      gan_model,
      generator_loss_fn=eval(training_params.generator_params.loss_fn),
      discriminator_loss_fn=eval(
        training_params.discriminator_params.loss_fn),
      gradient_penalty_weight=(
        training_params.discriminator_params.gradient_penalty_weight),
      add_summaries=True,
  )

  if training_params.discriminator_params.eps_drift:
    gan_loss = _add_drift_loss(gan_loss, gan_model_dict, training_params)

  if (training_params.generator_params.consistency_loss
      or training_params.generator_params.consistency_loss_msssim):
    gan_loss = _add_consistency_loss(gan_loss, gan_model_dict, training_params)

  if training_params.infogan_cont_weight is not None:
    gan_loss = _add_continuous_mutual_information_penalty(
      gan_loss, gan_model_dict, training_params)

  if training_params.infogan_cat_weight is not None:
    gan_loss = _add_categorical_mutual_information_penalty(
      gan_loss, gan_model_dict, training_params)

  return gan_loss
Beispiel #4
0
def get_model_and_loss(condition, real_image):
    gan_model = tfgan.gan_model(generator_fn=generator_fn,
                                discriminator_fn=discriminator_fn,
                                real_data=real_image,
                                generator_inputs=condition)
    gan_loss = tfgan.gan_loss(gan_model,
                              generator_loss_fn=generator_loss_fn,
                              discriminator_loss_fn=discriminator_loss_fn)

    return gan_model, gan_loss
Beispiel #5
0
def train_noestimator(features,
                      labels,
                      noise_dims=64,
                      batch_size=32,
                      num_steps=1200,
                      num_eval=20,
                      seed=0):
    """ Input features (images) and labels, noise vector dimension, batch size, seed for reproducibility """
    # Input training data and noise
    train_input_fn, train_input_hook = \
            _get_train_input_fn(features, labels, batch_size, noise_dims, seed)
    noise, next_image_batch = train_input_fn()

    # Define GAN model, loss, and optimizers
    model = tfgan.gan_model(generator_fn, discriminator_fn, next_image_batch,
                            noise)
    loss = tfgan.gan_loss(
        model,
        generator_loss_fn=tfgan.losses.wasserstein_generator_loss,
        discriminator_loss_fn=tfgan.losses.wasserstein_discriminator_loss,
        gradient_penalty_weight=1.0)
    generator_optimizer = tf.train.AdamOptimizer(0.001, beta1=0.5)
    discriminator_optimizer = tf.train.AdamOptimizer(0.0001, beta1=0.5)
    gan_train_ops = tfgan.gan_train_ops(model, loss, generator_optimizer,
                                        discriminator_optimizer)

    # We'll evaluate images during training to see how the generator improves
    with tf.variable_scope('Generator', reuse=True):
        predict_input_fn = _get_predict_input_fn(num_eval, noise_dims)
        eval_images = model.generator_fn(predict_input_fn(), is_training=False)

    # Train, outputting evaluation occasionally
    train_step_fn = tfgan.get_sequential_train_steps()
    global_step = tf.train.get_or_create_global_step()

    with tf.train.SingularMonitoredSession(hooks=[train_input_hook]) as sess:
        for i in range(num_steps + 1):
            cur_loss, _ = train_step_fn(sess,
                                        gan_train_ops,
                                        global_step,
                                        train_step_kwargs={})
            if i % 400 == 0:
                generated_images = sess.run(eval_images)
                print("Iteration", i, "- Loss:", cur_loss)
                show(generated_images)
Beispiel #6
0
def model_fn(features, labels, mode, params):
    if mode == tf.estimator.ModeKeys.PREDICT:
        raise NotImplementedError()
    else:
        # Pull images from input
        x = features['x']
        # Generate latent samples of same batch size as images
        n = tf.shape(x)[0]
        rnd = tf.random_normal(shape=(n, params.latent_units),
                               mean=0.,
                               stddev=1.,
                               dtype=tf.float32)
        # Build GAN Model
        gan_model = tfgan.gan_model(generator_fn=generator_fn,
                                    discriminator_fn=discriminator_fn,
                                    real_data=x,
                                    generator_inputs=rnd)
        gan_loss = tfgan.gan_loss(
            gan_model,
            generator_loss_fn=tfgan.losses.modified_generator_loss,
            discriminator_loss_fn=tfgan.losses.modified_discriminator_loss)

        if mode == tf.estimator.ModeKeys.TRAIN:
            generate_grid(gan_model, params)
            train_ops = tfgan.gan_train_ops(
                gan_model,
                gan_loss,
                generator_optimizer=tf.train.RMSPropOptimizer(params.gen_lr),
                discriminator_optimizer=tf.train.RMSPropOptimizer(
                    params.dis_lr))
            gan_hooks = tfgan.get_sequential_train_hooks(
                GANTrainSteps(params.generator_steps,
                              params.discriminator_steps))(train_ops)
            return tf.estimator.EstimatorSpec(
                mode=mode,
                loss=gan_loss.discriminator_loss,
                train_op=train_ops.global_step_inc_op,
                training_hooks=gan_hooks)
        else:
            eval_metric_ops = {}
            return tf.estimator.EstimatorSpec(mode=mode,
                                              loss=gan_loss.discriminator_loss,
                                              eval_metric_ops=eval_metric_ops)
Beispiel #7
0
def run_discgan():
    """ Constructs and trains the discriminative GAN consisting of
        Jerry and Diego.
    """
    # code follows the examples from
    # https://github.com/tensorflow/models/blob/master/research/gan/tutorial.ipynb

    # build the GAN model
    discgan = tfgan.gan_model(
        generator_fn=generator,
        discriminator_fn=adversary_conv(OUTPUT_SIZE),
        real_data=tf.random_uniform(shape=[BATCH_SIZE, OUTPUT_SIZE]),
        generator_inputs=get_input_tensor(BATCH_SIZE, MAX_VAL))
    # Build the GAN loss
    discgan_loss = tfgan.gan_loss(
        discgan,
        generator_loss_fn=tfgan.losses.least_squares_generator_loss,
        discriminator_loss_fn=tfgan.losses.least_squares_discriminator_loss)
    # Create the train ops, which calculate gradients and apply updates to weights.
    train_ops = tfgan.gan_train_ops(discgan,
                                    discgan_loss,
                                    generator_optimizer=GEN_OPT,
                                    discriminator_optimizer=OPP_OPT)
    # start TensorFlow session
    with tf.train.SingularMonitoredSession() as sess:
        pretrain_steps_fn = tfgan.get_sequential_train_steps(
            tfgan.GANTrainSteps(0, PRE_STEPS))
        train_steps_fn = tfgan.get_sequential_train_steps(
            tfgan.GANTrainSteps(1, ADV_MULT))
        global_step = tf.train.get_or_create_global_step()

        # pretrain discriminator
        print('\n\nPretraining ... ', end="", flush=True)
        try:
            pretrain_steps_fn(sess,
                              train_ops,
                              global_step,
                              train_step_kwargs={})
        except KeyboardInterrupt:
            pass
        print('[DONE]\n\n')

        # train both models
        losses_jerry = []
        losses_diego = []
        try:
            evaluate(sess, discgan.generated_data, discgan.generator_inputs, 0,
                     'jerry')

            for step in range(STEPS):
                train_steps_fn(sess,
                               train_ops,
                               global_step,
                               train_step_kwargs={})

                # if performed right number of steps, log
                if step % LOG_EVERY_N == 0:
                    sess.run([])
                    gen_l = discgan_loss.generator_loss.eval(session=sess)
                    disc_l = discgan_loss.discriminator_loss.eval(session=sess)

                    debug.print_step(step, gen_l, disc_l)
                    losses_jerry.append(gen_l)
                    losses_diego.append(disc_l)

        except KeyboardInterrupt:
            print('[INTERRUPTED BY USER] -- evaluating')

        # produce output
        files.write_to_file(losses_jerry, PLOT_DIR + '/jerry_loss.txt')
        files.write_to_file(losses_diego, PLOT_DIR + '/diego_loss.txt')
        evaluate(sess, discgan.generated_data, discgan.generator_inputs, 1,
                 'jerry')
Beispiel #8
0
        return layers.linear(net, 1, normalizer_fn=None,activation_fn=tf.tanh)

real_data_normed =  tf.divide(tf.convert_to_tensor(real, dtype=tf.float32), tf.constant(MAX_VAL, dtype=tf.float32))
chunk_queue = tf.train.slice_input_producer([real_data_normed])


# Build the generator and discriminator.
gan_model = tfgan.gan_model(
    generator_fn=generator_fn,  # you define
    discriminator_fn=discriminator_fn,  # you define
    real_data=chunk_queue,
    generator_inputs=noise)

gan_loss = tfgan.gan_loss(
    gan_model,
    generator_loss_fn=tfgan.losses.wasserstein_generator_loss,
    discriminator_loss_fn=tfgan.losses.wasserstein_discriminator_loss,
    gradient_penalty_weight=1.0)

l1_loss = tf.norm(gan_model.real_data - gan_model.generated_data, ord=1)

gan_loss = tfgan.losses.combine_adversarial_loss(gan_loss, gan_model, l1_loss, weight_factor=FLAGS.weight_factor)

train_ops = tfgan.gan_train_ops(gan_model,gan_loss,generator_optimizer=tf.train.AdamOptimizer(learning_rate=0.001, beta1=0.85, beta2=0.999, epsilon=1e-5),discriminator_optimizer=tf.train.AdamOptimizer(learning_rate=0.000001, beta1=0.85, beta2=0.999, epsilon=1e-5))
#train_ops.global_step_inc_op = tf.train.get_global_step().assign_add(1)


#store_output_and_check_loss(gan_loss, gan_model.generated_data, gan_model.real_data, num_of_samples=3, prefix='gen',logdir=log_folder)

global_step_tensor = tf.Variable(1, trainable=False, name='global_step')
global_step = tf.train.get_or_create_global_step()
Beispiel #9
0
generator_fn = functools.partial(infogan_generator, categorical_dim=cat_dim)
discriminator_fn = functools.partial(
    infogan_discriminator, categorical_dim=cat_dim,
    continuous_dim=cont_dim)
unstructured_inputs, structured_inputs = util.get_infogan_noise(
    batch_size, cat_dim, cont_dim, noise_dims)

infogan_model = tfgan.infogan_model(
    generator_fn=generator_fn,
    discriminator_fn=discriminator_fn,
    real_data=real_images,
    unstructured_generator_inputs=unstructured_inputs,
    structured_generator_inputs=structured_inputs)

infogan_loss = tfgan.gan_loss(
    infogan_model,
    gradient_penalty_weight=1.0,
    mutual_information_penalty_weight=1.0)

# Sanity check that we can evaluate our losses.
evaluate_tfgan_loss(infogan_loss)


generator_optimizer = tf.train.AdamOptimizer(0.001, beta1=0.5)
discriminator_optimizer = tf.train.AdamOptimizer(0.00009, beta1=0.5)
gan_train_ops = tfgan.gan_train_ops(
    infogan_model,
    infogan_loss,
    generator_optimizer,
    discriminator_optimizer)

# Set up images to evaluate MNIST score.
Beispiel #10
0

noise_dims = 64
conditional_gan_model = tfgan.gan_model(
    generator_fn=conditional_generator_fn,
    discriminator_fn=conditional_discriminator_fn,
    real_data=real_images,
    generator_inputs=(tf.random_normal([batch_size,
                                        noise_dims]), one_hot_labels))

# Sanity check that currently generated images are garbage.
cond_generated_data_to_visualize = tfgan.eval.image_reshaper(
    conditional_gan_model.generated_data[:20, ...], num_cols=10)
visualize_digits(cond_generated_data_to_visualize)

gan_loss = tfgan.gan_loss(conditional_gan_model, gradient_penalty_weight=1.0)

# Sanity check that we can evaluate our losses.
evaluate_tfgan_loss(gan_loss)

generator_optimizer = tf.train.AdamOptimizer(0.0009, beta1=0.5)
discriminator_optimizer = tf.train.AdamOptimizer(0.00009, beta1=0.5)
gan_train_ops = tfgan.gan_train_ops(conditional_gan_model, gan_loss,
                                    generator_optimizer,
                                    discriminator_optimizer)

# Set up class-conditional visualization. We feed class labels to the generator
# so that the the first column is `0`, the second column is `1`, etc.
images_to_eval = 500
assert images_to_eval % 10 == 0
Beispiel #11
0
def model_fn(features, labels, mode, params):
    is_chief = not tf.get_variable_scope().reuse

    batch_size = tf.shape(labels)[0]
    noise = tf.random_normal([batch_size, FLAGS.emb_dim])
    noise = tf.nn.l2_normalize(noise, axis=1)
    gan_model = tfgan.gan_model(generator_fn=generator,
                                discriminator_fn=discriminator,
                                real_data=features[:, 1:],
                                generator_inputs=(noise, labels - 1),
                                check_shapes=False)
    if is_chief:
        for variable in tf.trainable_variables():
            tf.summary.histogram(variable.op.name, variable)
        tf.summary.histogram('logits/gen_logits',
                             gan_model.discriminator_gen_outputs[0])
        tf.summary.histogram('logits/real_logits',
                             gan_model.discriminator_real_outputs[0])

    def gen_loss_fn(gan_model, add_summaries):
        return 0

    def dis_loss_fn(gan_model, add_summaries):
        discriminator_real_outputs = gan_model.discriminator_real_outputs
        discriminator_gen_outputs = gan_model.discriminator_gen_outputs
        real_logits = tf.boolean_mask(discriminator_real_outputs[0],
                                      discriminator_real_outputs[1])
        gen_logits = tf.boolean_mask(discriminator_gen_outputs[0],
                                     discriminator_gen_outputs[1])
        return modified_discriminator_loss(real_logits,
                                           gen_logits,
                                           add_summaries=add_summaries)

    with tf.name_scope('losses'):
        gan_loss = tfgan.gan_loss(
            gan_model,
            generator_loss_fn=gen_loss_fn,
            discriminator_loss_fn=dis_loss_fn,
            gradient_penalty_weight=10 if FLAGS.wass else 0,
            add_summaries=is_chief)
        if is_chief:
            tfgan.eval.add_regularization_loss_summaries(gan_model)
    gan_loss = rl_loss(gan_model, gan_loss, add_summaries=is_chief)
    loss = gan_loss.generator_loss + gan_loss.discriminator_loss

    with tf.name_scope('train'):
        gen_opt = tf.train.AdamOptimizer(params.gen_lr, 0.5)
        dis_opt = tf.train.AdamOptimizer(params.dis_lr, 0.5)
        if params.multi_gpu:
            gen_opt = tf.contrib.estimator.TowerOptimizer(gen_opt)
            dis_opt = tf.contrib.estimator.TowerOptimizer(dis_opt)
        train_ops = tfgan.gan_train_ops(
            gan_model,
            gan_loss,
            generator_optimizer=gen_opt,
            discriminator_optimizer=dis_opt,
            transform_grads_fn=transform_grads_fn,
            summarize_gradients=is_chief,
            check_for_unused_update_ops=True,
            aggregation_method=tf.AggregationMethod.EXPERIMENTAL_ACCUMULATE_N)
        train_op = train_ops.global_step_inc_op
        train_hooks = get_sequential_train_hooks()(train_ops)

    if is_chief:
        with open('data/word_counts.txt', 'r') as f:
            dic = list(f)
            dic = [i.split()[0] for i in dic]
            dic.append('<unk>')
            dic = tf.convert_to_tensor(dic)
        sentence = crop_sentence(gan_model.generated_data[0][0], FLAGS.end_id)
        sentence = tf.gather(dic, sentence)
        real = crop_sentence(gan_model.real_data[0], FLAGS.end_id)
        real = tf.gather(dic, real)
        train_hooks.append(
            tf.train.LoggingTensorHook({
                'fake': sentence,
                'real': real
            },
                                       every_n_iter=100))
        tf.summary.text('fake', sentence)

    gen_var = tf.trainable_variables('Generator')
    dis_var = []
    dis_var.extend(tf.trainable_variables('Discriminator/rnn'))
    dis_var.extend(tf.trainable_variables('Discriminator/embedding'))
    saver = tf.train.Saver(gen_var + dis_var)

    def init_fn(scaffold, session):
        saver.restore(session, FLAGS.sae_ckpt)
        pass

    scaffold = tf.train.Scaffold(init_fn=init_fn)

    return tf.estimator.EstimatorSpec(mode=mode,
                                      loss=loss,
                                      train_op=train_op,
                                      scaffold=scaffold,
                                      training_hooks=train_hooks)
Beispiel #12
0
noise_dims = 64
gan_model = tfgan.gan_model(generator_fn,
                            discriminator_fn,
                            real_data=real_images,
                            generator_inputs=tf.random_normal(
                                [batch_size, noise_dims]))

# Sanity check that generated images before training are garbage.
check_generated_digits = tfgan.eval.image_reshaper(
    gan_model.generated_data[:20, ...], num_cols=10)
visualize_digits(check_generated_digits)

# We can use the minimax loss from the original paper.
vanilla_gan_loss = tfgan.gan_loss(
    gan_model,
    generator_loss_fn=tfgan.losses.minimax_generator_loss,
    discriminator_loss_fn=tfgan.losses.minimax_discriminator_loss)

# We can use the Wasserstein loss (https://arxiv.org/abs/1701.07875) with the
# gradient penalty from the improved Wasserstein loss paper
# (https://arxiv.org/abs/1704.00028).
improved_wgan_loss = tfgan.gan_loss(
    gan_model,
    # We make the loss explicit for demonstration, even though the default is
    # Wasserstein loss.
    generator_loss_fn=tfgan.losses.wasserstein_generator_loss,
    discriminator_loss_fn=tfgan.losses.wasserstein_discriminator_loss,
    gradient_penalty_weight=1.0)


# We can also define custom losses to use with the rest of the TFGAN framework.
def model_fn(features, labels, mode, params):
    """The full unsupervised captioning model."""
    is_chief = not tf.get_variable_scope().reuse

    with slim.arg_scope(inception_v4.inception_v4_arg_scope()):
        net, _ = inception_v4.inception_v4(features['im'],
                                           None,
                                           is_training=False)
    net = tf.squeeze(net, [1, 2])
    inc_saver = tf.train.Saver(tf.global_variables('InceptionV4'))

    gan_model = tfgan.gan_model(generator_fn=generator,
                                discriminator_fn=discriminator,
                                real_data=labels['sentence'][:, 1:],
                                generator_inputs=(net, labels['len'] - 1),
                                check_shapes=False)

    if is_chief:
        for variable in tf.trainable_variables():
            tf.summary.histogram(variable.op.name, variable)
        tf.summary.histogram('logits/gen_logits',
                             gan_model.discriminator_gen_outputs[0])
        tf.summary.histogram('logits/real_logits',
                             gan_model.discriminator_real_outputs[0])

    def gen_loss_fn(gan_model, add_summaries):
        return 0

    def dis_loss_fn(gan_model, add_summaries):
        discriminator_real_outputs = gan_model.discriminator_real_outputs
        discriminator_gen_outputs = gan_model.discriminator_gen_outputs
        real_logits = tf.boolean_mask(discriminator_real_outputs[0],
                                      discriminator_real_outputs[1])
        gen_logits = tf.boolean_mask(discriminator_gen_outputs[0],
                                     discriminator_gen_outputs[1])
        return modified_discriminator_loss(real_logits,
                                           gen_logits,
                                           add_summaries=add_summaries)

    with tf.name_scope('losses'):
        pool_fn = functools.partial(tfgan.features.tensor_pool,
                                    pool_size=FLAGS.pool_size)
        gan_loss = tfgan.gan_loss(
            gan_model,
            generator_loss_fn=gen_loss_fn,
            discriminator_loss_fn=dis_loss_fn,
            gradient_penalty_weight=10 if FLAGS.wass else 0,
            tensor_pool_fn=pool_fn if FLAGS.use_pool else None,
            add_summaries=is_chief)
        if is_chief:
            tfgan.eval.add_regularization_loss_summaries(gan_model)
    gan_loss = rl_loss(gan_model,
                       gan_loss,
                       features['classes'],
                       features['scores'],
                       features['num'],
                       add_summaries=is_chief)
    sen_ae_loss = sentence_ae(gan_model, features, labels, is_chief)
    loss = gan_loss.generator_loss + gan_loss.discriminator_loss + sen_ae_loss
    gan_loss = gan_loss._replace(generator_loss=gan_loss.generator_loss +
                                 sen_ae_loss)

    with tf.name_scope('train'):
        gen_opt = tf.train.AdamOptimizer(params.gen_lr, 0.5)
        dis_opt = tf.train.AdamOptimizer(params.dis_lr, 0.5)
        if params.multi_gpu:
            gen_opt = tf.contrib.estimator.TowerOptimizer(gen_opt)
            dis_opt = tf.contrib.estimator.TowerOptimizer(dis_opt)
        train_ops = tfgan.gan_train_ops(
            gan_model,
            gan_loss,
            generator_optimizer=gen_opt,
            discriminator_optimizer=dis_opt,
            transform_grads_fn=transform_grads_fn,
            summarize_gradients=is_chief,
            check_for_unused_update_ops=not FLAGS.use_pool,
            aggregation_method=tf.AggregationMethod.EXPERIMENTAL_ACCUMULATE_N)
        train_op = train_ops.global_step_inc_op
        train_hooks = get_sequential_train_hooks()(train_ops)

    # Summary the generated caption on the fly.
    if is_chief:
        with open('data/word_counts.txt', 'r') as f:
            dic = list(f)
            dic = [i.split()[0] for i in dic]
            dic.append('<unk>')
            dic = tf.convert_to_tensor(dic)
        sentence = crop_sentence(gan_model.generated_data[0][0], FLAGS.end_id)
        sentence = tf.gather(dic, sentence)
        real = crop_sentence(gan_model.real_data[0], FLAGS.end_id)
        real = tf.gather(dic, real)
        train_hooks.append(
            tf.train.LoggingTensorHook({
                'fake': sentence,
                'real': real
            },
                                       every_n_iter=100))
        tf.summary.text('fake', sentence)
        tf.summary.image('im', features['im'][None, 0])

    gen_saver = tf.train.Saver(tf.trainable_variables('Generator'))
    dis_var = []
    dis_var.extend(tf.trainable_variables('Discriminator/rnn'))
    dis_var.extend(tf.trainable_variables('Discriminator/embedding'))
    dis_var.extend(tf.trainable_variables('Discriminator/fc'))
    dis_saver = tf.train.Saver(dis_var)

    def init_fn(scaffold, session):
        inc_saver.restore(session, FLAGS.inc_ckpt)
        if FLAGS.imcap_ckpt:
            gen_saver.restore(session, FLAGS.imcap_ckpt)
        if FLAGS.sae_ckpt:
            dis_saver.restore(session, FLAGS.sae_ckpt)

    scaffold = tf.train.Scaffold(init_fn=init_fn)

    return tf.estimator.EstimatorSpec(mode=mode,
                                      loss=loss,
                                      train_op=train_op,
                                      scaffold=scaffold,
                                      training_hooks=train_hooks)
Beispiel #14
0
def main(_):
    if not tf.gfile.Exists(FLAGS.train_log_dir):
        tf.gfile.MakeDirs(FLAGS.train_log_dir)

    with tf.name_scope('inputs'):
        with tf.device('/cpu:0'):
            images, one_hot_labels, _ = provide_data('train',
                                                     FLAGS.batch_size,
                                                     FLAGS.dataset_dir,
                                                     num_threads=4)
            images = 2.0 * images - 1.0

    gan_model = tfgan.gan_model(generator_fn=gan_networks.generator,
                                discriminator_fn=gan_networks.discriminator,
                                real_data=images,
                                generator_inputs=tf.random_normal(
                                    [FLAGS.batch_size, FLAGS.noise_dims]))

    tfgan.eval.add_gan_model_image_summaries(gan_model, FLAGS.grid_size, False)

    with tf.variable_scope('Generator', reuse=True):
        eval_images = gan_model.generator_fn(tf.random_normal(
            [FLAGS.num_images_eval, FLAGS.noise_dims]),
                                             is_training=False)

    # Calculate Inception score.
    tf.summary.scalar(
        "Inception score",
        util.mnist_score(eval_images, MNIST_CLASSIFIER_FROZEN_GRAPH))

    # Calculate Frechet Inception distance.
    with tf.device('/cpu:0'):
        real_images, labels, _ = provide_data('train', FLAGS.num_images_eval,
                                              FLAGS.dataset_dir)
    tf.summary.scalar(
        "Frechet distance",
        util.mnist_frechet_distance(real_images, eval_images,
                                    MNIST_CLASSIFIER_FROZEN_GRAPH))

    with tf.name_scope('loss'):
        gan_loss = tfgan.gan_loss(gan_model,
                                  gradient_penalty_weight=1.0,
                                  mutual_information_penalty_weight=0.0,
                                  add_summaries=True)
        # tfgan.eval.add_regularization_loss_summaries(gan_model)

    with tf.name_scope('train'):
        gen_lr, dis_lr = (1e-3, 1e-4)
        train_ops = tfgan.gan_train_ops(
            gan_model,
            gan_loss,
            generator_optimizer=tf.train.AdamOptimizer(gen_lr, 0.5),
            discriminator_optimizer=tf.train.AdamOptimizer(dis_lr, 0.5),
            summarize_gradients=False,
            aggregation_method=tf.AggregationMethod.EXPERIMENTAL_ACCUMULATE_N)

    status_message = tf.string_join([
        'Starting train step: ',
        tf.as_string(tf.train.get_or_create_global_step())
    ],
                                    name='status_message')

    step_hooks = tfgan.get_sequential_train_hooks()(train_ops)
    hooks = [
        tf.train.StopAtStepHook(num_steps=FLAGS.max_number_of_steps),
        tf.train.LoggingTensorHook([status_message], every_n_iter=10)
    ] + list(step_hooks)

    with tf.train.MonitoredTrainingSession(
            hooks=hooks,
            save_summaries_steps=500,
            checkpoint_dir=FLAGS.train_log_dir) as sess:
        loss = None
        while not sess.should_stop():
            loss = sess.run(train_ops.global_step_inc_op)