def main(_): if not tf.gfile.Exists(FLAGS.train_log_dir): tf.gfile.MakeDirs(FLAGS.train_log_dir) with tf.name_scope('inputs'): with tf.device('/cpu:0'): images_vae, one_hot_labels, _ = provide_data('train', FLAGS.batch_size, FLAGS.dataset_dir, num_threads=4) images_gan = 2.0 * images_vae - 1.0 my_vae = VAE("train", z_dim=64, data_tensor=images_vae) rec = my_vae.reconstruct(images_vae) vae_checkpoint_path = tf.train.latest_checkpoint(FLAGS.vae_checkpoint_folder) saver = tf.train.Saver() gan_model = tfgan.gan_model( generator_fn=networks.generator, discriminator_fn=networks.discriminator, real_data=images_gan, generator_inputs=[tf.random_normal( [FLAGS.batch_size, FLAGS.noise_dims]), tf.reshape(rec, [FLAGS.batch_size, 28, 28, 1])]) tfgan.eval.add_gan_model_image_summaries(gan_model, FLAGS.grid_size, True) with tf.name_scope('loss'): gan_loss = tfgan.gan_loss( gan_model, gradient_penalty_weight=1.0, mutual_information_penalty_weight=0.0, add_summaries=True) # tfgan.eval.add_regularization_loss_summaries(gan_model) # Get the GANTrain ops using custom optimizers. with tf.name_scope('train'): gen_lr, dis_lr = (1e-3, 1e-4) train_ops = tfgan.gan_train_ops( gan_model, gan_loss, generator_optimizer=tf.train.AdamOptimizer(gen_lr, 0.5), discriminator_optimizer=tf.train.AdamOptimizer(dis_lr, 0.5), summarize_gradients=False, aggregation_method=tf.AggregationMethod.EXPERIMENTAL_ACCUMULATE_N) status_message = tf.string_join( ['Starting train step: ', tf.as_string(tf.train.get_or_create_global_step())], name='status_message') step_hooks = tfgan.get_sequential_train_hooks()(train_ops) hooks = [tf.train.StopAtStepHook(num_steps=FLAGS.max_number_of_steps), tf.train.LoggingTensorHook([status_message], every_n_iter=10)] + list(step_hooks) with tf.train.MonitoredTrainingSession(hooks=hooks, save_summaries_steps=500, checkpoint_dir=FLAGS.train_log_dir) as sess: saver.restore(sess, vae_checkpoint_path) loss = None while not sess.should_stop(): loss = sess.run(train_ops.global_step_inc_op)
def get_model_and_loss(condition, real_image): gan_model = tfgan.gan_model(generator_fn=generator_fn, discriminator_fn=discriminator_fn, real_data=real_image, generator_inputs=condition) gan_loss = tfgan.gan_loss(gan_model, generator_loss_fn=generator_loss_fn, discriminator_loss_fn=discriminator_loss_fn) return gan_model, gan_loss
def train_noestimator(features, labels, noise_dims=64, batch_size=32, num_steps=1200, num_eval=20, seed=0): """ Input features (images) and labels, noise vector dimension, batch size, seed for reproducibility """ # Input training data and noise train_input_fn, train_input_hook = \ _get_train_input_fn(features, labels, batch_size, noise_dims, seed) noise, next_image_batch = train_input_fn() # Define GAN model, loss, and optimizers model = tfgan.gan_model(generator_fn, discriminator_fn, next_image_batch, noise) loss = tfgan.gan_loss( model, generator_loss_fn=tfgan.losses.wasserstein_generator_loss, discriminator_loss_fn=tfgan.losses.wasserstein_discriminator_loss, gradient_penalty_weight=1.0) generator_optimizer = tf.train.AdamOptimizer(0.001, beta1=0.5) discriminator_optimizer = tf.train.AdamOptimizer(0.0001, beta1=0.5) gan_train_ops = tfgan.gan_train_ops(model, loss, generator_optimizer, discriminator_optimizer) # We'll evaluate images during training to see how the generator improves with tf.variable_scope('Generator', reuse=True): predict_input_fn = _get_predict_input_fn(num_eval, noise_dims) eval_images = model.generator_fn(predict_input_fn(), is_training=False) # Train, outputting evaluation occasionally train_step_fn = tfgan.get_sequential_train_steps() global_step = tf.train.get_or_create_global_step() with tf.train.SingularMonitoredSession(hooks=[train_input_hook]) as sess: for i in range(num_steps + 1): cur_loss, _ = train_step_fn(sess, gan_train_ops, global_step, train_step_kwargs={}) if i % 400 == 0: generated_images = sess.run(eval_images) print("Iteration", i, "- Loss:", cur_loss) show(generated_images)
def model_fn(features, labels, mode, params): if mode == tf.estimator.ModeKeys.PREDICT: raise NotImplementedError() else: # Pull images from input x = features['x'] # Generate latent samples of same batch size as images n = tf.shape(x)[0] rnd = tf.random_normal(shape=(n, params.latent_units), mean=0., stddev=1., dtype=tf.float32) # Build GAN Model gan_model = tfgan.gan_model(generator_fn=generator_fn, discriminator_fn=discriminator_fn, real_data=x, generator_inputs=rnd) gan_loss = tfgan.gan_loss( gan_model, generator_loss_fn=tfgan.losses.modified_generator_loss, discriminator_loss_fn=tfgan.losses.modified_discriminator_loss) if mode == tf.estimator.ModeKeys.TRAIN: generate_grid(gan_model, params) train_ops = tfgan.gan_train_ops( gan_model, gan_loss, generator_optimizer=tf.train.RMSPropOptimizer(params.gen_lr), discriminator_optimizer=tf.train.RMSPropOptimizer( params.dis_lr)) gan_hooks = tfgan.get_sequential_train_hooks( GANTrainSteps(params.generator_steps, params.discriminator_steps))(train_ops) return tf.estimator.EstimatorSpec( mode=mode, loss=gan_loss.discriminator_loss, train_op=train_ops.global_step_inc_op, training_hooks=gan_hooks) else: eval_metric_ops = {} return tf.estimator.EstimatorSpec(mode=mode, loss=gan_loss.discriminator_loss, eval_metric_ops=eval_metric_ops)
def aegan_model( # Lambdas defining models. generator_fn, discriminator_fn, encoder_fn, # Real data and conditioning. real_data, generator_inputs, # Optional scopes. generator_scope='Generator', discriminator_scope='Discriminator', encoder_scope='Encoder', # Options. check_shapes=True): gan_model = tfgan.gan_model(generator_fn, discriminator_fn, real_data, generator_inputs, generator_scope=generator_scope, discriminator_scope=discriminator_scope, check_shapes=check_shapes) with tf.variable_scope(encoder_scope) as enc_scope: encoder_gen_outputs = encoder_fn(gan_model.generated_data) with tf.variable_scope(enc_scope, reuse=True): real_data = tf.convert_to_tensor(real_data) encoder_real_outputs = encoder_fn(real_data) encoder_variables = tf.trainable_variables(scope=encoder_scope) return AEGANModel( generator_inputs, gan_model.generated_data, gan_model.generator_variables, gan_model.generator_scope, generator_fn, real_data, gan_model.discriminator_real_outputs, gan_model.discriminator_gen_outputs, gan_model.discriminator_variables, gan_model.discriminator_scope, discriminator_fn, encoder_real_outputs, encoder_gen_outputs, encoder_variables, enc_scope, encoder_fn)
def build_gan_harness(image_input: tf.Tensor, noise: tf.Tensor, generator: tf.keras.Model, discriminator: tf.keras.Model, generator_learning_rate=0.01, discriminator_learning_rate=0.01, noise_format: str = 'SPHERE', adversarial_training: str = 'WASSERSTEIN', feature_matching: bool = False, no_trainer: bool = False, summarize_activations: bool = False) -> tuple: image_size = image_input.shape.as_list()[1] nchannels = image_input.shape.as_list()[3] print("Plain Generative Adversarial Network: {}x{}x{} images".format( image_size, image_size, nchannels)) def _generator_fn(z): return generator([z], training=True) def _discriminator_fn(x, z): return discriminator([x, z], training=True) gan_model = tfgan.gan_model( _generator_fn, _discriminator_fn, image_input, noise, generator_scope='Generator', discriminator_scope='Discriminator', check_shapes=True) # set to False for 2-level architectures sampled_x = gan_model.generated_data image_grid_summary(sampled_x, grid_size=3, name='generated_data') if summarize_activations: tf.contrib.layers.summarize_activations() tf.contrib.layers.summarize_collection(tf.GraphKeys.TRAINABLE_VARIABLES) loss = gan_loss_by_name(gan_model, adversarial_training, feature_matching=feature_matching, add_summaries=True) if adversarial_training != 'WASSERSTEIN' and adversarial_training != 'RELATIVISTIC_AVG': disc_accuracy_gen = basic_accuracy( tf.zeros_like(gan_model.discriminator_gen_outputs), gan_model.discriminator_gen_outputs) disc_accuracy_real = basic_accuracy( tf.ones_like(gan_model.discriminator_real_outputs), gan_model.discriminator_real_outputs) disc_accuracy = (disc_accuracy_gen + disc_accuracy_real) * 0.5 with tf.name_scope('Discriminator'): tf.summary.scalar('accuracy', disc_accuracy) if no_trainer: train_ops = None else: train_ops = tfgan.gan_train_ops( gan_model, loss, generator_optimizer=tf.train.AdamOptimizer(generator_learning_rate, beta1=0., beta2=0.99), discriminator_optimizer=tf.train.AdamOptimizer( discriminator_learning_rate, beta1=0., beta2=0.99), summarize_gradients=True) return (gan_model, loss, train_ops)
def run_discgan(): """ Constructs and trains the discriminative GAN consisting of Jerry and Diego. """ # code follows the examples from # https://github.com/tensorflow/models/blob/master/research/gan/tutorial.ipynb # build the GAN model discgan = tfgan.gan_model( generator_fn=generator, discriminator_fn=adversary_conv(OUTPUT_SIZE), real_data=tf.random_uniform(shape=[BATCH_SIZE, OUTPUT_SIZE]), generator_inputs=get_input_tensor(BATCH_SIZE, MAX_VAL)) # Build the GAN loss discgan_loss = tfgan.gan_loss( discgan, generator_loss_fn=tfgan.losses.least_squares_generator_loss, discriminator_loss_fn=tfgan.losses.least_squares_discriminator_loss) # Create the train ops, which calculate gradients and apply updates to weights. train_ops = tfgan.gan_train_ops(discgan, discgan_loss, generator_optimizer=GEN_OPT, discriminator_optimizer=OPP_OPT) # start TensorFlow session with tf.train.SingularMonitoredSession() as sess: pretrain_steps_fn = tfgan.get_sequential_train_steps( tfgan.GANTrainSteps(0, PRE_STEPS)) train_steps_fn = tfgan.get_sequential_train_steps( tfgan.GANTrainSteps(1, ADV_MULT)) global_step = tf.train.get_or_create_global_step() # pretrain discriminator print('\n\nPretraining ... ', end="", flush=True) try: pretrain_steps_fn(sess, train_ops, global_step, train_step_kwargs={}) except KeyboardInterrupt: pass print('[DONE]\n\n') # train both models losses_jerry = [] losses_diego = [] try: evaluate(sess, discgan.generated_data, discgan.generator_inputs, 0, 'jerry') for step in range(STEPS): train_steps_fn(sess, train_ops, global_step, train_step_kwargs={}) # if performed right number of steps, log if step % LOG_EVERY_N == 0: sess.run([]) gen_l = discgan_loss.generator_loss.eval(session=sess) disc_l = discgan_loss.discriminator_loss.eval(session=sess) debug.print_step(step, gen_l, disc_l) losses_jerry.append(gen_l) losses_diego.append(disc_l) except KeyboardInterrupt: print('[INTERRUPTED BY USER] -- evaluating') # produce output files.write_to_file(losses_jerry, PLOT_DIR + '/jerry_loss.txt') files.write_to_file(losses_diego, PLOT_DIR + '/diego_loss.txt') evaluate(sess, discgan.generated_data, discgan.generator_inputs, 1, 'jerry')
weights_regularizer=layers.l2_regularizer(weight_decay), biases_regularizer=layers.l2_regularizer(weight_decay)): net = layers.fully_connected(fragment, 64) net = layers.dropout(net, keep_prob=0.75) net = layers.fully_connected(net, 32) net = layers.fully_connected(net, 16, normalizer_fn=layers.batch_norm,activation_fn=tf.tanh) return layers.linear(net, 1, normalizer_fn=None,activation_fn=tf.tanh) real_data_normed = tf.divide(tf.convert_to_tensor(real, dtype=tf.float32), tf.constant(MAX_VAL, dtype=tf.float32)) chunk_queue = tf.train.slice_input_producer([real_data_normed]) # Build the generator and discriminator. gan_model = tfgan.gan_model( generator_fn=generator_fn, # you define discriminator_fn=discriminator_fn, # you define real_data=chunk_queue, generator_inputs=noise) gan_loss = tfgan.gan_loss( gan_model, generator_loss_fn=tfgan.losses.wasserstein_generator_loss, discriminator_loss_fn=tfgan.losses.wasserstein_discriminator_loss, gradient_penalty_weight=1.0) l1_loss = tf.norm(gan_model.real_data - gan_model.generated_data, ord=1) gan_loss = tfgan.losses.combine_adversarial_loss(gan_loss, gan_model, l1_loss, weight_factor=FLAGS.weight_factor) train_ops = tfgan.gan_train_ops(gan_model,gan_loss,generator_optimizer=tf.train.AdamOptimizer(learning_rate=0.001, beta1=0.85, beta2=0.999, epsilon=1e-5),discriminator_optimizer=tf.train.AdamOptimizer(learning_rate=0.000001, beta1=0.85, beta2=0.999, epsilon=1e-5)) #train_ops.global_step_inc_op = tf.train.get_global_step().assign_add(1)
biases_regularizer=layers.l2_regularizer(weight_decay)): net = layers.conv2d(img, 64, [4, 4], stride=2) net = tfgan.features.condition_tensor_from_onehot(net, one_hot_labels) net = layers.conv2d(net, 128, [4, 4], stride=2) net = layers.flatten(net) net = layers.fully_connected(net, 1024, normalizer_fn=layers.batch_norm) return layers.linear(net, 1) noise_dims = 64 conditional_gan_model = tfgan.gan_model( generator_fn=conditional_generator_fn, discriminator_fn=conditional_discriminator_fn, real_data=real_images, generator_inputs=(tf.random_normal([batch_size, noise_dims]), one_hot_labels)) # Sanity check that currently generated images are garbage. cond_generated_data_to_visualize = tfgan.eval.image_reshaper( conditional_gan_model.generated_data[:20, ...], num_cols=10) visualize_digits(cond_generated_data_to_visualize) gan_loss = tfgan.gan_loss(conditional_gan_model, gradient_penalty_weight=1.0) # Sanity check that we can evaluate our losses. evaluate_tfgan_loss(gan_loss) generator_optimizer = tf.train.AdamOptimizer(0.0009, beta1=0.5) discriminator_optimizer = tf.train.AdamOptimizer(0.00009, beta1=0.5)
def model_fn(features, labels, mode, params): is_chief = not tf.get_variable_scope().reuse batch_size = tf.shape(labels)[0] noise = tf.random_normal([batch_size, FLAGS.emb_dim]) noise = tf.nn.l2_normalize(noise, axis=1) gan_model = tfgan.gan_model(generator_fn=generator, discriminator_fn=discriminator, real_data=features[:, 1:], generator_inputs=(noise, labels - 1), check_shapes=False) if is_chief: for variable in tf.trainable_variables(): tf.summary.histogram(variable.op.name, variable) tf.summary.histogram('logits/gen_logits', gan_model.discriminator_gen_outputs[0]) tf.summary.histogram('logits/real_logits', gan_model.discriminator_real_outputs[0]) def gen_loss_fn(gan_model, add_summaries): return 0 def dis_loss_fn(gan_model, add_summaries): discriminator_real_outputs = gan_model.discriminator_real_outputs discriminator_gen_outputs = gan_model.discriminator_gen_outputs real_logits = tf.boolean_mask(discriminator_real_outputs[0], discriminator_real_outputs[1]) gen_logits = tf.boolean_mask(discriminator_gen_outputs[0], discriminator_gen_outputs[1]) return modified_discriminator_loss(real_logits, gen_logits, add_summaries=add_summaries) with tf.name_scope('losses'): gan_loss = tfgan.gan_loss( gan_model, generator_loss_fn=gen_loss_fn, discriminator_loss_fn=dis_loss_fn, gradient_penalty_weight=10 if FLAGS.wass else 0, add_summaries=is_chief) if is_chief: tfgan.eval.add_regularization_loss_summaries(gan_model) gan_loss = rl_loss(gan_model, gan_loss, add_summaries=is_chief) loss = gan_loss.generator_loss + gan_loss.discriminator_loss with tf.name_scope('train'): gen_opt = tf.train.AdamOptimizer(params.gen_lr, 0.5) dis_opt = tf.train.AdamOptimizer(params.dis_lr, 0.5) if params.multi_gpu: gen_opt = tf.contrib.estimator.TowerOptimizer(gen_opt) dis_opt = tf.contrib.estimator.TowerOptimizer(dis_opt) train_ops = tfgan.gan_train_ops( gan_model, gan_loss, generator_optimizer=gen_opt, discriminator_optimizer=dis_opt, transform_grads_fn=transform_grads_fn, summarize_gradients=is_chief, check_for_unused_update_ops=True, aggregation_method=tf.AggregationMethod.EXPERIMENTAL_ACCUMULATE_N) train_op = train_ops.global_step_inc_op train_hooks = get_sequential_train_hooks()(train_ops) if is_chief: with open('data/word_counts.txt', 'r') as f: dic = list(f) dic = [i.split()[0] for i in dic] dic.append('<unk>') dic = tf.convert_to_tensor(dic) sentence = crop_sentence(gan_model.generated_data[0][0], FLAGS.end_id) sentence = tf.gather(dic, sentence) real = crop_sentence(gan_model.real_data[0], FLAGS.end_id) real = tf.gather(dic, real) train_hooks.append( tf.train.LoggingTensorHook({ 'fake': sentence, 'real': real }, every_n_iter=100)) tf.summary.text('fake', sentence) gen_var = tf.trainable_variables('Generator') dis_var = [] dis_var.extend(tf.trainable_variables('Discriminator/rnn')) dis_var.extend(tf.trainable_variables('Discriminator/embedding')) saver = tf.train.Saver(gen_var + dis_var) def init_fn(scaffold, session): saver.restore(session, FLAGS.sae_ckpt) pass scaffold = tf.train.Scaffold(init_fn=init_fn) return tf.estimator.EstimatorSpec(mode=mode, loss=loss, train_op=train_op, scaffold=scaffold, training_hooks=train_hooks)
weights_regularizer=layers.l2_regularizer(weight_decay), biases_regularizer=layers.l2_regularizer(weight_decay)): net = layers.conv2d(img, 64, [4, 4], stride=2) net = layers.conv2d(net, 128, [4, 4], stride=2) net = layers.flatten(net) with framework.arg_scope([layers.batch_norm], is_training=is_training): net = layers.fully_connected(net, 1024, normalizer_fn=layers.batch_norm) return layers.linear(net, 1) noise_dims = 64 gan_model = tfgan.gan_model(generator_fn, discriminator_fn, real_data=real_images, generator_inputs=tf.random_normal( [batch_size, noise_dims])) # Sanity check that generated images before training are garbage. check_generated_digits = tfgan.eval.image_reshaper( gan_model.generated_data[:20, ...], num_cols=10) visualize_digits(check_generated_digits) # We can use the minimax loss from the original paper. vanilla_gan_loss = tfgan.gan_loss( gan_model, generator_loss_fn=tfgan.losses.minimax_generator_loss, discriminator_loss_fn=tfgan.losses.minimax_discriminator_loss) # We can use the Wasserstein loss (https://arxiv.org/abs/1701.07875) with the
def model_fn(features, labels, mode, params): """The full unsupervised captioning model.""" is_chief = not tf.get_variable_scope().reuse with slim.arg_scope(inception_v4.inception_v4_arg_scope()): net, _ = inception_v4.inception_v4(features['im'], None, is_training=False) net = tf.squeeze(net, [1, 2]) inc_saver = tf.train.Saver(tf.global_variables('InceptionV4')) gan_model = tfgan.gan_model(generator_fn=generator, discriminator_fn=discriminator, real_data=labels['sentence'][:, 1:], generator_inputs=(net, labels['len'] - 1), check_shapes=False) if is_chief: for variable in tf.trainable_variables(): tf.summary.histogram(variable.op.name, variable) tf.summary.histogram('logits/gen_logits', gan_model.discriminator_gen_outputs[0]) tf.summary.histogram('logits/real_logits', gan_model.discriminator_real_outputs[0]) def gen_loss_fn(gan_model, add_summaries): return 0 def dis_loss_fn(gan_model, add_summaries): discriminator_real_outputs = gan_model.discriminator_real_outputs discriminator_gen_outputs = gan_model.discriminator_gen_outputs real_logits = tf.boolean_mask(discriminator_real_outputs[0], discriminator_real_outputs[1]) gen_logits = tf.boolean_mask(discriminator_gen_outputs[0], discriminator_gen_outputs[1]) return modified_discriminator_loss(real_logits, gen_logits, add_summaries=add_summaries) with tf.name_scope('losses'): pool_fn = functools.partial(tfgan.features.tensor_pool, pool_size=FLAGS.pool_size) gan_loss = tfgan.gan_loss( gan_model, generator_loss_fn=gen_loss_fn, discriminator_loss_fn=dis_loss_fn, gradient_penalty_weight=10 if FLAGS.wass else 0, tensor_pool_fn=pool_fn if FLAGS.use_pool else None, add_summaries=is_chief) if is_chief: tfgan.eval.add_regularization_loss_summaries(gan_model) gan_loss = rl_loss(gan_model, gan_loss, features['classes'], features['scores'], features['num'], add_summaries=is_chief) sen_ae_loss = sentence_ae(gan_model, features, labels, is_chief) loss = gan_loss.generator_loss + gan_loss.discriminator_loss + sen_ae_loss gan_loss = gan_loss._replace(generator_loss=gan_loss.generator_loss + sen_ae_loss) with tf.name_scope('train'): gen_opt = tf.train.AdamOptimizer(params.gen_lr, 0.5) dis_opt = tf.train.AdamOptimizer(params.dis_lr, 0.5) if params.multi_gpu: gen_opt = tf.contrib.estimator.TowerOptimizer(gen_opt) dis_opt = tf.contrib.estimator.TowerOptimizer(dis_opt) train_ops = tfgan.gan_train_ops( gan_model, gan_loss, generator_optimizer=gen_opt, discriminator_optimizer=dis_opt, transform_grads_fn=transform_grads_fn, summarize_gradients=is_chief, check_for_unused_update_ops=not FLAGS.use_pool, aggregation_method=tf.AggregationMethod.EXPERIMENTAL_ACCUMULATE_N) train_op = train_ops.global_step_inc_op train_hooks = get_sequential_train_hooks()(train_ops) # Summary the generated caption on the fly. if is_chief: with open('data/word_counts.txt', 'r') as f: dic = list(f) dic = [i.split()[0] for i in dic] dic.append('<unk>') dic = tf.convert_to_tensor(dic) sentence = crop_sentence(gan_model.generated_data[0][0], FLAGS.end_id) sentence = tf.gather(dic, sentence) real = crop_sentence(gan_model.real_data[0], FLAGS.end_id) real = tf.gather(dic, real) train_hooks.append( tf.train.LoggingTensorHook({ 'fake': sentence, 'real': real }, every_n_iter=100)) tf.summary.text('fake', sentence) tf.summary.image('im', features['im'][None, 0]) gen_saver = tf.train.Saver(tf.trainable_variables('Generator')) dis_var = [] dis_var.extend(tf.trainable_variables('Discriminator/rnn')) dis_var.extend(tf.trainable_variables('Discriminator/embedding')) dis_var.extend(tf.trainable_variables('Discriminator/fc')) dis_saver = tf.train.Saver(dis_var) def init_fn(scaffold, session): inc_saver.restore(session, FLAGS.inc_ckpt) if FLAGS.imcap_ckpt: gen_saver.restore(session, FLAGS.imcap_ckpt) if FLAGS.sae_ckpt: dis_saver.restore(session, FLAGS.sae_ckpt) scaffold = tf.train.Scaffold(init_fn=init_fn) return tf.estimator.EstimatorSpec(mode=mode, loss=loss, train_op=train_op, scaffold=scaffold, training_hooks=train_hooks)
def main(_): if not tf.gfile.Exists(FLAGS.train_log_dir): tf.gfile.MakeDirs(FLAGS.train_log_dir) with tf.name_scope('inputs'): with tf.device('/cpu:0'): images, one_hot_labels, _ = provide_data('train', FLAGS.batch_size, FLAGS.dataset_dir, num_threads=4) images = 2.0 * images - 1.0 gan_model = tfgan.gan_model(generator_fn=gan_networks.generator, discriminator_fn=gan_networks.discriminator, real_data=images, generator_inputs=tf.random_normal( [FLAGS.batch_size, FLAGS.noise_dims])) tfgan.eval.add_gan_model_image_summaries(gan_model, FLAGS.grid_size, False) with tf.variable_scope('Generator', reuse=True): eval_images = gan_model.generator_fn(tf.random_normal( [FLAGS.num_images_eval, FLAGS.noise_dims]), is_training=False) # Calculate Inception score. tf.summary.scalar( "Inception score", util.mnist_score(eval_images, MNIST_CLASSIFIER_FROZEN_GRAPH)) # Calculate Frechet Inception distance. with tf.device('/cpu:0'): real_images, labels, _ = provide_data('train', FLAGS.num_images_eval, FLAGS.dataset_dir) tf.summary.scalar( "Frechet distance", util.mnist_frechet_distance(real_images, eval_images, MNIST_CLASSIFIER_FROZEN_GRAPH)) with tf.name_scope('loss'): gan_loss = tfgan.gan_loss(gan_model, gradient_penalty_weight=1.0, mutual_information_penalty_weight=0.0, add_summaries=True) # tfgan.eval.add_regularization_loss_summaries(gan_model) with tf.name_scope('train'): gen_lr, dis_lr = (1e-3, 1e-4) train_ops = tfgan.gan_train_ops( gan_model, gan_loss, generator_optimizer=tf.train.AdamOptimizer(gen_lr, 0.5), discriminator_optimizer=tf.train.AdamOptimizer(dis_lr, 0.5), summarize_gradients=False, aggregation_method=tf.AggregationMethod.EXPERIMENTAL_ACCUMULATE_N) status_message = tf.string_join([ 'Starting train step: ', tf.as_string(tf.train.get_or_create_global_step()) ], name='status_message') step_hooks = tfgan.get_sequential_train_hooks()(train_ops) hooks = [ tf.train.StopAtStepHook(num_steps=FLAGS.max_number_of_steps), tf.train.LoggingTensorHook([status_message], every_n_iter=10) ] + list(step_hooks) with tf.train.MonitoredTrainingSession( hooks=hooks, save_summaries_steps=500, checkpoint_dir=FLAGS.train_log_dir) as sess: loss = None while not sess.should_stop(): loss = sess.run(train_ops.global_step_inc_op)