Beispiel #1
0
def main(_):
    if not tf.gfile.Exists(FLAGS.train_log_dir):
        tf.gfile.MakeDirs(FLAGS.train_log_dir)

    with tf.name_scope('inputs'):
        with tf.device('/cpu:0'):
            images_vae, one_hot_labels, _ = provide_data('train', FLAGS.batch_size, FLAGS.dataset_dir, num_threads=4)
            images_gan = 2.0 * images_vae - 1.0

    my_vae = VAE("train", z_dim=64, data_tensor=images_vae)
    rec = my_vae.reconstruct(images_vae)

    vae_checkpoint_path = tf.train.latest_checkpoint(FLAGS.vae_checkpoint_folder)
    saver = tf.train.Saver()

    gan_model = tfgan.gan_model(
        generator_fn=networks.generator,
        discriminator_fn=networks.discriminator,
        real_data=images_gan,
        generator_inputs=[tf.random_normal(
            [FLAGS.batch_size, FLAGS.noise_dims]), tf.reshape(rec, [FLAGS.batch_size, 28, 28, 1])])

    tfgan.eval.add_gan_model_image_summaries(gan_model, FLAGS.grid_size, True)

    with tf.name_scope('loss'):

        gan_loss = tfgan.gan_loss(
            gan_model,
            gradient_penalty_weight=1.0,
            mutual_information_penalty_weight=0.0,
            add_summaries=True)
        # tfgan.eval.add_regularization_loss_summaries(gan_model)

    # Get the GANTrain ops using custom optimizers.
    with tf.name_scope('train'):
        gen_lr, dis_lr = (1e-3, 1e-4)
        train_ops = tfgan.gan_train_ops(
            gan_model,
            gan_loss,
            generator_optimizer=tf.train.AdamOptimizer(gen_lr, 0.5),
            discriminator_optimizer=tf.train.AdamOptimizer(dis_lr, 0.5),
            summarize_gradients=False,
            aggregation_method=tf.AggregationMethod.EXPERIMENTAL_ACCUMULATE_N)

    status_message = tf.string_join(
        ['Starting train step: ',
         tf.as_string(tf.train.get_or_create_global_step())],
        name='status_message')

    step_hooks = tfgan.get_sequential_train_hooks()(train_ops)
    hooks = [tf.train.StopAtStepHook(num_steps=FLAGS.max_number_of_steps),
             tf.train.LoggingTensorHook([status_message], every_n_iter=10)] + list(step_hooks)

    with tf.train.MonitoredTrainingSession(hooks=hooks,
                                           save_summaries_steps=500,
                                           checkpoint_dir=FLAGS.train_log_dir) as sess:
        saver.restore(sess, vae_checkpoint_path)
        loss = None
        while not sess.should_stop():
            loss = sess.run(train_ops.global_step_inc_op)
Beispiel #2
0
def get_model_and_loss(condition, real_image):
    gan_model = tfgan.gan_model(generator_fn=generator_fn,
                                discriminator_fn=discriminator_fn,
                                real_data=real_image,
                                generator_inputs=condition)
    gan_loss = tfgan.gan_loss(gan_model,
                              generator_loss_fn=generator_loss_fn,
                              discriminator_loss_fn=discriminator_loss_fn)

    return gan_model, gan_loss
Beispiel #3
0
def train_noestimator(features,
                      labels,
                      noise_dims=64,
                      batch_size=32,
                      num_steps=1200,
                      num_eval=20,
                      seed=0):
    """ Input features (images) and labels, noise vector dimension, batch size, seed for reproducibility """
    # Input training data and noise
    train_input_fn, train_input_hook = \
            _get_train_input_fn(features, labels, batch_size, noise_dims, seed)
    noise, next_image_batch = train_input_fn()

    # Define GAN model, loss, and optimizers
    model = tfgan.gan_model(generator_fn, discriminator_fn, next_image_batch,
                            noise)
    loss = tfgan.gan_loss(
        model,
        generator_loss_fn=tfgan.losses.wasserstein_generator_loss,
        discriminator_loss_fn=tfgan.losses.wasserstein_discriminator_loss,
        gradient_penalty_weight=1.0)
    generator_optimizer = tf.train.AdamOptimizer(0.001, beta1=0.5)
    discriminator_optimizer = tf.train.AdamOptimizer(0.0001, beta1=0.5)
    gan_train_ops = tfgan.gan_train_ops(model, loss, generator_optimizer,
                                        discriminator_optimizer)

    # We'll evaluate images during training to see how the generator improves
    with tf.variable_scope('Generator', reuse=True):
        predict_input_fn = _get_predict_input_fn(num_eval, noise_dims)
        eval_images = model.generator_fn(predict_input_fn(), is_training=False)

    # Train, outputting evaluation occasionally
    train_step_fn = tfgan.get_sequential_train_steps()
    global_step = tf.train.get_or_create_global_step()

    with tf.train.SingularMonitoredSession(hooks=[train_input_hook]) as sess:
        for i in range(num_steps + 1):
            cur_loss, _ = train_step_fn(sess,
                                        gan_train_ops,
                                        global_step,
                                        train_step_kwargs={})
            if i % 400 == 0:
                generated_images = sess.run(eval_images)
                print("Iteration", i, "- Loss:", cur_loss)
                show(generated_images)
Beispiel #4
0
def model_fn(features, labels, mode, params):
    if mode == tf.estimator.ModeKeys.PREDICT:
        raise NotImplementedError()
    else:
        # Pull images from input
        x = features['x']
        # Generate latent samples of same batch size as images
        n = tf.shape(x)[0]
        rnd = tf.random_normal(shape=(n, params.latent_units),
                               mean=0.,
                               stddev=1.,
                               dtype=tf.float32)
        # Build GAN Model
        gan_model = tfgan.gan_model(generator_fn=generator_fn,
                                    discriminator_fn=discriminator_fn,
                                    real_data=x,
                                    generator_inputs=rnd)
        gan_loss = tfgan.gan_loss(
            gan_model,
            generator_loss_fn=tfgan.losses.modified_generator_loss,
            discriminator_loss_fn=tfgan.losses.modified_discriminator_loss)

        if mode == tf.estimator.ModeKeys.TRAIN:
            generate_grid(gan_model, params)
            train_ops = tfgan.gan_train_ops(
                gan_model,
                gan_loss,
                generator_optimizer=tf.train.RMSPropOptimizer(params.gen_lr),
                discriminator_optimizer=tf.train.RMSPropOptimizer(
                    params.dis_lr))
            gan_hooks = tfgan.get_sequential_train_hooks(
                GANTrainSteps(params.generator_steps,
                              params.discriminator_steps))(train_ops)
            return tf.estimator.EstimatorSpec(
                mode=mode,
                loss=gan_loss.discriminator_loss,
                train_op=train_ops.global_step_inc_op,
                training_hooks=gan_hooks)
        else:
            eval_metric_ops = {}
            return tf.estimator.EstimatorSpec(mode=mode,
                                              loss=gan_loss.discriminator_loss,
                                              eval_metric_ops=eval_metric_ops)
Beispiel #5
0
def aegan_model(
        # Lambdas defining models.
        generator_fn,
        discriminator_fn,
        encoder_fn,
        # Real data and conditioning.
        real_data,
        generator_inputs,
        # Optional scopes.
        generator_scope='Generator',
        discriminator_scope='Discriminator',
        encoder_scope='Encoder',
        # Options.
        check_shapes=True):

    gan_model = tfgan.gan_model(generator_fn,
                                discriminator_fn,
                                real_data,
                                generator_inputs,
                                generator_scope=generator_scope,
                                discriminator_scope=discriminator_scope,
                                check_shapes=check_shapes)

    with tf.variable_scope(encoder_scope) as enc_scope:
        encoder_gen_outputs = encoder_fn(gan_model.generated_data)
    with tf.variable_scope(enc_scope, reuse=True):
        real_data = tf.convert_to_tensor(real_data)
        encoder_real_outputs = encoder_fn(real_data)

    encoder_variables = tf.trainable_variables(scope=encoder_scope)

    return AEGANModel(
        generator_inputs, gan_model.generated_data,
        gan_model.generator_variables, gan_model.generator_scope, generator_fn,
        real_data, gan_model.discriminator_real_outputs,
        gan_model.discriminator_gen_outputs, gan_model.discriminator_variables,
        gan_model.discriminator_scope, discriminator_fn, encoder_real_outputs,
        encoder_gen_outputs, encoder_variables, enc_scope, encoder_fn)
def build_gan_harness(image_input: tf.Tensor,
                      noise: tf.Tensor,
                      generator: tf.keras.Model,
                      discriminator: tf.keras.Model,
                      generator_learning_rate=0.01,
                      discriminator_learning_rate=0.01,
                      noise_format: str = 'SPHERE',
                      adversarial_training: str = 'WASSERSTEIN',
                      feature_matching: bool = False,
                      no_trainer: bool = False,
                      summarize_activations: bool = False) -> tuple:
    image_size = image_input.shape.as_list()[1]
    nchannels = image_input.shape.as_list()[3]
    print("Plain Generative Adversarial Network: {}x{}x{} images".format(
        image_size, image_size, nchannels))

    def _generator_fn(z):
        return generator([z], training=True)

    def _discriminator_fn(x, z):
        return discriminator([x, z], training=True)

    gan_model = tfgan.gan_model(
        _generator_fn,
        _discriminator_fn,
        image_input,
        noise,
        generator_scope='Generator',
        discriminator_scope='Discriminator',
        check_shapes=True)  # set to False for 2-level architectures

    sampled_x = gan_model.generated_data
    image_grid_summary(sampled_x, grid_size=3, name='generated_data')
    if summarize_activations:
        tf.contrib.layers.summarize_activations()
    tf.contrib.layers.summarize_collection(tf.GraphKeys.TRAINABLE_VARIABLES)

    loss = gan_loss_by_name(gan_model,
                            adversarial_training,
                            feature_matching=feature_matching,
                            add_summaries=True)

    if adversarial_training != 'WASSERSTEIN' and adversarial_training != 'RELATIVISTIC_AVG':
        disc_accuracy_gen = basic_accuracy(
            tf.zeros_like(gan_model.discriminator_gen_outputs),
            gan_model.discriminator_gen_outputs)
        disc_accuracy_real = basic_accuracy(
            tf.ones_like(gan_model.discriminator_real_outputs),
            gan_model.discriminator_real_outputs)
        disc_accuracy = (disc_accuracy_gen + disc_accuracy_real) * 0.5
        with tf.name_scope('Discriminator'):
            tf.summary.scalar('accuracy', disc_accuracy)

    if no_trainer:
        train_ops = None
    else:
        train_ops = tfgan.gan_train_ops(
            gan_model,
            loss,
            generator_optimizer=tf.train.AdamOptimizer(generator_learning_rate,
                                                       beta1=0.,
                                                       beta2=0.99),
            discriminator_optimizer=tf.train.AdamOptimizer(
                discriminator_learning_rate, beta1=0., beta2=0.99),
            summarize_gradients=True)
    return (gan_model, loss, train_ops)
Beispiel #7
0
def run_discgan():
    """ Constructs and trains the discriminative GAN consisting of
        Jerry and Diego.
    """
    # code follows the examples from
    # https://github.com/tensorflow/models/blob/master/research/gan/tutorial.ipynb

    # build the GAN model
    discgan = tfgan.gan_model(
        generator_fn=generator,
        discriminator_fn=adversary_conv(OUTPUT_SIZE),
        real_data=tf.random_uniform(shape=[BATCH_SIZE, OUTPUT_SIZE]),
        generator_inputs=get_input_tensor(BATCH_SIZE, MAX_VAL))
    # Build the GAN loss
    discgan_loss = tfgan.gan_loss(
        discgan,
        generator_loss_fn=tfgan.losses.least_squares_generator_loss,
        discriminator_loss_fn=tfgan.losses.least_squares_discriminator_loss)
    # Create the train ops, which calculate gradients and apply updates to weights.
    train_ops = tfgan.gan_train_ops(discgan,
                                    discgan_loss,
                                    generator_optimizer=GEN_OPT,
                                    discriminator_optimizer=OPP_OPT)
    # start TensorFlow session
    with tf.train.SingularMonitoredSession() as sess:
        pretrain_steps_fn = tfgan.get_sequential_train_steps(
            tfgan.GANTrainSteps(0, PRE_STEPS))
        train_steps_fn = tfgan.get_sequential_train_steps(
            tfgan.GANTrainSteps(1, ADV_MULT))
        global_step = tf.train.get_or_create_global_step()

        # pretrain discriminator
        print('\n\nPretraining ... ', end="", flush=True)
        try:
            pretrain_steps_fn(sess,
                              train_ops,
                              global_step,
                              train_step_kwargs={})
        except KeyboardInterrupt:
            pass
        print('[DONE]\n\n')

        # train both models
        losses_jerry = []
        losses_diego = []
        try:
            evaluate(sess, discgan.generated_data, discgan.generator_inputs, 0,
                     'jerry')

            for step in range(STEPS):
                train_steps_fn(sess,
                               train_ops,
                               global_step,
                               train_step_kwargs={})

                # if performed right number of steps, log
                if step % LOG_EVERY_N == 0:
                    sess.run([])
                    gen_l = discgan_loss.generator_loss.eval(session=sess)
                    disc_l = discgan_loss.discriminator_loss.eval(session=sess)

                    debug.print_step(step, gen_l, disc_l)
                    losses_jerry.append(gen_l)
                    losses_diego.append(disc_l)

        except KeyboardInterrupt:
            print('[INTERRUPTED BY USER] -- evaluating')

        # produce output
        files.write_to_file(losses_jerry, PLOT_DIR + '/jerry_loss.txt')
        files.write_to_file(losses_diego, PLOT_DIR + '/diego_loss.txt')
        evaluate(sess, discgan.generated_data, discgan.generator_inputs, 1,
                 'jerry')
Beispiel #8
0
            weights_regularizer=layers.l2_regularizer(weight_decay),
            biases_regularizer=layers.l2_regularizer(weight_decay)):
        net = layers.fully_connected(fragment, 64)
        net = layers.dropout(net, keep_prob=0.75)
        net = layers.fully_connected(net, 32)
        net = layers.fully_connected(net, 16, normalizer_fn=layers.batch_norm,activation_fn=tf.tanh)
        return layers.linear(net, 1, normalizer_fn=None,activation_fn=tf.tanh)

real_data_normed =  tf.divide(tf.convert_to_tensor(real, dtype=tf.float32), tf.constant(MAX_VAL, dtype=tf.float32))
chunk_queue = tf.train.slice_input_producer([real_data_normed])


# Build the generator and discriminator.
gan_model = tfgan.gan_model(
    generator_fn=generator_fn,  # you define
    discriminator_fn=discriminator_fn,  # you define
    real_data=chunk_queue,
    generator_inputs=noise)

gan_loss = tfgan.gan_loss(
    gan_model,
    generator_loss_fn=tfgan.losses.wasserstein_generator_loss,
    discriminator_loss_fn=tfgan.losses.wasserstein_discriminator_loss,
    gradient_penalty_weight=1.0)

l1_loss = tf.norm(gan_model.real_data - gan_model.generated_data, ord=1)

gan_loss = tfgan.losses.combine_adversarial_loss(gan_loss, gan_model, l1_loss, weight_factor=FLAGS.weight_factor)

train_ops = tfgan.gan_train_ops(gan_model,gan_loss,generator_optimizer=tf.train.AdamOptimizer(learning_rate=0.001, beta1=0.85, beta2=0.999, epsilon=1e-5),discriminator_optimizer=tf.train.AdamOptimizer(learning_rate=0.000001, beta1=0.85, beta2=0.999, epsilon=1e-5))
#train_ops.global_step_inc_op = tf.train.get_global_step().assign_add(1)
Beispiel #9
0
            biases_regularizer=layers.l2_regularizer(weight_decay)):
        net = layers.conv2d(img, 64, [4, 4], stride=2)
        net = tfgan.features.condition_tensor_from_onehot(net, one_hot_labels)
        net = layers.conv2d(net, 128, [4, 4], stride=2)
        net = layers.flatten(net)
        net = layers.fully_connected(net,
                                     1024,
                                     normalizer_fn=layers.batch_norm)

        return layers.linear(net, 1)


noise_dims = 64
conditional_gan_model = tfgan.gan_model(
    generator_fn=conditional_generator_fn,
    discriminator_fn=conditional_discriminator_fn,
    real_data=real_images,
    generator_inputs=(tf.random_normal([batch_size,
                                        noise_dims]), one_hot_labels))

# Sanity check that currently generated images are garbage.
cond_generated_data_to_visualize = tfgan.eval.image_reshaper(
    conditional_gan_model.generated_data[:20, ...], num_cols=10)
visualize_digits(cond_generated_data_to_visualize)

gan_loss = tfgan.gan_loss(conditional_gan_model, gradient_penalty_weight=1.0)

# Sanity check that we can evaluate our losses.
evaluate_tfgan_loss(gan_loss)

generator_optimizer = tf.train.AdamOptimizer(0.0009, beta1=0.5)
discriminator_optimizer = tf.train.AdamOptimizer(0.00009, beta1=0.5)
Beispiel #10
0
def model_fn(features, labels, mode, params):
    is_chief = not tf.get_variable_scope().reuse

    batch_size = tf.shape(labels)[0]
    noise = tf.random_normal([batch_size, FLAGS.emb_dim])
    noise = tf.nn.l2_normalize(noise, axis=1)
    gan_model = tfgan.gan_model(generator_fn=generator,
                                discriminator_fn=discriminator,
                                real_data=features[:, 1:],
                                generator_inputs=(noise, labels - 1),
                                check_shapes=False)
    if is_chief:
        for variable in tf.trainable_variables():
            tf.summary.histogram(variable.op.name, variable)
        tf.summary.histogram('logits/gen_logits',
                             gan_model.discriminator_gen_outputs[0])
        tf.summary.histogram('logits/real_logits',
                             gan_model.discriminator_real_outputs[0])

    def gen_loss_fn(gan_model, add_summaries):
        return 0

    def dis_loss_fn(gan_model, add_summaries):
        discriminator_real_outputs = gan_model.discriminator_real_outputs
        discriminator_gen_outputs = gan_model.discriminator_gen_outputs
        real_logits = tf.boolean_mask(discriminator_real_outputs[0],
                                      discriminator_real_outputs[1])
        gen_logits = tf.boolean_mask(discriminator_gen_outputs[0],
                                     discriminator_gen_outputs[1])
        return modified_discriminator_loss(real_logits,
                                           gen_logits,
                                           add_summaries=add_summaries)

    with tf.name_scope('losses'):
        gan_loss = tfgan.gan_loss(
            gan_model,
            generator_loss_fn=gen_loss_fn,
            discriminator_loss_fn=dis_loss_fn,
            gradient_penalty_weight=10 if FLAGS.wass else 0,
            add_summaries=is_chief)
        if is_chief:
            tfgan.eval.add_regularization_loss_summaries(gan_model)
    gan_loss = rl_loss(gan_model, gan_loss, add_summaries=is_chief)
    loss = gan_loss.generator_loss + gan_loss.discriminator_loss

    with tf.name_scope('train'):
        gen_opt = tf.train.AdamOptimizer(params.gen_lr, 0.5)
        dis_opt = tf.train.AdamOptimizer(params.dis_lr, 0.5)
        if params.multi_gpu:
            gen_opt = tf.contrib.estimator.TowerOptimizer(gen_opt)
            dis_opt = tf.contrib.estimator.TowerOptimizer(dis_opt)
        train_ops = tfgan.gan_train_ops(
            gan_model,
            gan_loss,
            generator_optimizer=gen_opt,
            discriminator_optimizer=dis_opt,
            transform_grads_fn=transform_grads_fn,
            summarize_gradients=is_chief,
            check_for_unused_update_ops=True,
            aggregation_method=tf.AggregationMethod.EXPERIMENTAL_ACCUMULATE_N)
        train_op = train_ops.global_step_inc_op
        train_hooks = get_sequential_train_hooks()(train_ops)

    if is_chief:
        with open('data/word_counts.txt', 'r') as f:
            dic = list(f)
            dic = [i.split()[0] for i in dic]
            dic.append('<unk>')
            dic = tf.convert_to_tensor(dic)
        sentence = crop_sentence(gan_model.generated_data[0][0], FLAGS.end_id)
        sentence = tf.gather(dic, sentence)
        real = crop_sentence(gan_model.real_data[0], FLAGS.end_id)
        real = tf.gather(dic, real)
        train_hooks.append(
            tf.train.LoggingTensorHook({
                'fake': sentence,
                'real': real
            },
                                       every_n_iter=100))
        tf.summary.text('fake', sentence)

    gen_var = tf.trainable_variables('Generator')
    dis_var = []
    dis_var.extend(tf.trainable_variables('Discriminator/rnn'))
    dis_var.extend(tf.trainable_variables('Discriminator/embedding'))
    saver = tf.train.Saver(gen_var + dis_var)

    def init_fn(scaffold, session):
        saver.restore(session, FLAGS.sae_ckpt)
        pass

    scaffold = tf.train.Scaffold(init_fn=init_fn)

    return tf.estimator.EstimatorSpec(mode=mode,
                                      loss=loss,
                                      train_op=train_op,
                                      scaffold=scaffold,
                                      training_hooks=train_hooks)
Beispiel #11
0
            weights_regularizer=layers.l2_regularizer(weight_decay),
            biases_regularizer=layers.l2_regularizer(weight_decay)):
        net = layers.conv2d(img, 64, [4, 4], stride=2)
        net = layers.conv2d(net, 128, [4, 4], stride=2)
        net = layers.flatten(net)
        with framework.arg_scope([layers.batch_norm], is_training=is_training):
            net = layers.fully_connected(net,
                                         1024,
                                         normalizer_fn=layers.batch_norm)
        return layers.linear(net, 1)


noise_dims = 64
gan_model = tfgan.gan_model(generator_fn,
                            discriminator_fn,
                            real_data=real_images,
                            generator_inputs=tf.random_normal(
                                [batch_size, noise_dims]))

# Sanity check that generated images before training are garbage.
check_generated_digits = tfgan.eval.image_reshaper(
    gan_model.generated_data[:20, ...], num_cols=10)
visualize_digits(check_generated_digits)

# We can use the minimax loss from the original paper.
vanilla_gan_loss = tfgan.gan_loss(
    gan_model,
    generator_loss_fn=tfgan.losses.minimax_generator_loss,
    discriminator_loss_fn=tfgan.losses.minimax_discriminator_loss)

# We can use the Wasserstein loss (https://arxiv.org/abs/1701.07875) with the
def model_fn(features, labels, mode, params):
    """The full unsupervised captioning model."""
    is_chief = not tf.get_variable_scope().reuse

    with slim.arg_scope(inception_v4.inception_v4_arg_scope()):
        net, _ = inception_v4.inception_v4(features['im'],
                                           None,
                                           is_training=False)
    net = tf.squeeze(net, [1, 2])
    inc_saver = tf.train.Saver(tf.global_variables('InceptionV4'))

    gan_model = tfgan.gan_model(generator_fn=generator,
                                discriminator_fn=discriminator,
                                real_data=labels['sentence'][:, 1:],
                                generator_inputs=(net, labels['len'] - 1),
                                check_shapes=False)

    if is_chief:
        for variable in tf.trainable_variables():
            tf.summary.histogram(variable.op.name, variable)
        tf.summary.histogram('logits/gen_logits',
                             gan_model.discriminator_gen_outputs[0])
        tf.summary.histogram('logits/real_logits',
                             gan_model.discriminator_real_outputs[0])

    def gen_loss_fn(gan_model, add_summaries):
        return 0

    def dis_loss_fn(gan_model, add_summaries):
        discriminator_real_outputs = gan_model.discriminator_real_outputs
        discriminator_gen_outputs = gan_model.discriminator_gen_outputs
        real_logits = tf.boolean_mask(discriminator_real_outputs[0],
                                      discriminator_real_outputs[1])
        gen_logits = tf.boolean_mask(discriminator_gen_outputs[0],
                                     discriminator_gen_outputs[1])
        return modified_discriminator_loss(real_logits,
                                           gen_logits,
                                           add_summaries=add_summaries)

    with tf.name_scope('losses'):
        pool_fn = functools.partial(tfgan.features.tensor_pool,
                                    pool_size=FLAGS.pool_size)
        gan_loss = tfgan.gan_loss(
            gan_model,
            generator_loss_fn=gen_loss_fn,
            discriminator_loss_fn=dis_loss_fn,
            gradient_penalty_weight=10 if FLAGS.wass else 0,
            tensor_pool_fn=pool_fn if FLAGS.use_pool else None,
            add_summaries=is_chief)
        if is_chief:
            tfgan.eval.add_regularization_loss_summaries(gan_model)
    gan_loss = rl_loss(gan_model,
                       gan_loss,
                       features['classes'],
                       features['scores'],
                       features['num'],
                       add_summaries=is_chief)
    sen_ae_loss = sentence_ae(gan_model, features, labels, is_chief)
    loss = gan_loss.generator_loss + gan_loss.discriminator_loss + sen_ae_loss
    gan_loss = gan_loss._replace(generator_loss=gan_loss.generator_loss +
                                 sen_ae_loss)

    with tf.name_scope('train'):
        gen_opt = tf.train.AdamOptimizer(params.gen_lr, 0.5)
        dis_opt = tf.train.AdamOptimizer(params.dis_lr, 0.5)
        if params.multi_gpu:
            gen_opt = tf.contrib.estimator.TowerOptimizer(gen_opt)
            dis_opt = tf.contrib.estimator.TowerOptimizer(dis_opt)
        train_ops = tfgan.gan_train_ops(
            gan_model,
            gan_loss,
            generator_optimizer=gen_opt,
            discriminator_optimizer=dis_opt,
            transform_grads_fn=transform_grads_fn,
            summarize_gradients=is_chief,
            check_for_unused_update_ops=not FLAGS.use_pool,
            aggregation_method=tf.AggregationMethod.EXPERIMENTAL_ACCUMULATE_N)
        train_op = train_ops.global_step_inc_op
        train_hooks = get_sequential_train_hooks()(train_ops)

    # Summary the generated caption on the fly.
    if is_chief:
        with open('data/word_counts.txt', 'r') as f:
            dic = list(f)
            dic = [i.split()[0] for i in dic]
            dic.append('<unk>')
            dic = tf.convert_to_tensor(dic)
        sentence = crop_sentence(gan_model.generated_data[0][0], FLAGS.end_id)
        sentence = tf.gather(dic, sentence)
        real = crop_sentence(gan_model.real_data[0], FLAGS.end_id)
        real = tf.gather(dic, real)
        train_hooks.append(
            tf.train.LoggingTensorHook({
                'fake': sentence,
                'real': real
            },
                                       every_n_iter=100))
        tf.summary.text('fake', sentence)
        tf.summary.image('im', features['im'][None, 0])

    gen_saver = tf.train.Saver(tf.trainable_variables('Generator'))
    dis_var = []
    dis_var.extend(tf.trainable_variables('Discriminator/rnn'))
    dis_var.extend(tf.trainable_variables('Discriminator/embedding'))
    dis_var.extend(tf.trainable_variables('Discriminator/fc'))
    dis_saver = tf.train.Saver(dis_var)

    def init_fn(scaffold, session):
        inc_saver.restore(session, FLAGS.inc_ckpt)
        if FLAGS.imcap_ckpt:
            gen_saver.restore(session, FLAGS.imcap_ckpt)
        if FLAGS.sae_ckpt:
            dis_saver.restore(session, FLAGS.sae_ckpt)

    scaffold = tf.train.Scaffold(init_fn=init_fn)

    return tf.estimator.EstimatorSpec(mode=mode,
                                      loss=loss,
                                      train_op=train_op,
                                      scaffold=scaffold,
                                      training_hooks=train_hooks)
Beispiel #13
0
def main(_):
    if not tf.gfile.Exists(FLAGS.train_log_dir):
        tf.gfile.MakeDirs(FLAGS.train_log_dir)

    with tf.name_scope('inputs'):
        with tf.device('/cpu:0'):
            images, one_hot_labels, _ = provide_data('train',
                                                     FLAGS.batch_size,
                                                     FLAGS.dataset_dir,
                                                     num_threads=4)
            images = 2.0 * images - 1.0

    gan_model = tfgan.gan_model(generator_fn=gan_networks.generator,
                                discriminator_fn=gan_networks.discriminator,
                                real_data=images,
                                generator_inputs=tf.random_normal(
                                    [FLAGS.batch_size, FLAGS.noise_dims]))

    tfgan.eval.add_gan_model_image_summaries(gan_model, FLAGS.grid_size, False)

    with tf.variable_scope('Generator', reuse=True):
        eval_images = gan_model.generator_fn(tf.random_normal(
            [FLAGS.num_images_eval, FLAGS.noise_dims]),
                                             is_training=False)

    # Calculate Inception score.
    tf.summary.scalar(
        "Inception score",
        util.mnist_score(eval_images, MNIST_CLASSIFIER_FROZEN_GRAPH))

    # Calculate Frechet Inception distance.
    with tf.device('/cpu:0'):
        real_images, labels, _ = provide_data('train', FLAGS.num_images_eval,
                                              FLAGS.dataset_dir)
    tf.summary.scalar(
        "Frechet distance",
        util.mnist_frechet_distance(real_images, eval_images,
                                    MNIST_CLASSIFIER_FROZEN_GRAPH))

    with tf.name_scope('loss'):
        gan_loss = tfgan.gan_loss(gan_model,
                                  gradient_penalty_weight=1.0,
                                  mutual_information_penalty_weight=0.0,
                                  add_summaries=True)
        # tfgan.eval.add_regularization_loss_summaries(gan_model)

    with tf.name_scope('train'):
        gen_lr, dis_lr = (1e-3, 1e-4)
        train_ops = tfgan.gan_train_ops(
            gan_model,
            gan_loss,
            generator_optimizer=tf.train.AdamOptimizer(gen_lr, 0.5),
            discriminator_optimizer=tf.train.AdamOptimizer(dis_lr, 0.5),
            summarize_gradients=False,
            aggregation_method=tf.AggregationMethod.EXPERIMENTAL_ACCUMULATE_N)

    status_message = tf.string_join([
        'Starting train step: ',
        tf.as_string(tf.train.get_or_create_global_step())
    ],
                                    name='status_message')

    step_hooks = tfgan.get_sequential_train_hooks()(train_ops)
    hooks = [
        tf.train.StopAtStepHook(num_steps=FLAGS.max_number_of_steps),
        tf.train.LoggingTensorHook([status_message], every_n_iter=10)
    ] + list(step_hooks)

    with tf.train.MonitoredTrainingSession(
            hooks=hooks,
            save_summaries_steps=500,
            checkpoint_dir=FLAGS.train_log_dir) as sess:
        loss = None
        while not sess.should_stop():
            loss = sess.run(train_ops.global_step_inc_op)