def rollout_generator(hparams, inputs, input_present, is_training, is_validating, reuse=None): """Define the Generator graph which does rollouts. G will now impute tokens that have been masked from the input seqeunce. """ rollouts = [] with tf.variable_scope('gen_rollout'): for n in xrange(FLAGS.num_rollouts): if n > 0: # TODO(liamfedus): Why is it necessary here to manually set reuse? reuse = True tf.get_variable_scope().reuse_variables() [sequence, logits, log_probs] = model_construction.create_generator(hparams, inputs, input_present, is_training, is_validating, reuse=reuse) rollouts.append([sequence, logits, log_probs]) # Length assertion. assert len(rollouts) == FLAGS.num_rollouts return rollouts
def rollout_generator(hparams, inputs, input_present, is_training, is_validating, reuse=None): """Define the Generator graph which does rollouts. G will now impute tokens that have been masked from the input seqeunce. """ rollouts = [] with tf.variable_scope('gen_rollout'): for n in xrange(FLAGS.num_rollouts): if n > 0: # TODO(liamfedus): Why is it necessary here to manually set reuse? reuse = True tf.get_variable_scope().reuse_variables() [sequence, logits, log_probs] = model_construction.create_generator( hparams, inputs, input_present, is_training, is_validating, reuse=reuse) rollouts.append([sequence, logits, log_probs]) # Length assertion. assert len(rollouts) == FLAGS.num_rollouts return rollouts
def create_rollout_MaskGAN(hparams, is_training): """Create the MaskGAN model. Args: hparams: Hyperparameters for the MaskGAN. is_training: Boolean indicating operational mode (train/inference). evaluated with a teacher forcing regime. Return: model: Namedtuple for specifying the MaskGAN.""" global_step = tf.Variable(0, name='global_step', trainable=False) new_learning_rate = tf.placeholder(tf.float32, [], name='new_learning_rate') learning_rate = tf.Variable(0.0, name='learning_rate', trainable=False) learning_rate_update = tf.assign(learning_rate, new_learning_rate) new_rate = tf.placeholder(tf.float32, [], name='new_rate') percent_real_var = tf.Variable(0.0, trainable=False) percent_real_update = tf.assign(percent_real_var, new_rate) ## Placeholders. inputs = tf.placeholder(tf.int32, shape=[FLAGS.batch_size, FLAGS.sequence_length]) present = tf.placeholder(tf.bool, shape=[FLAGS.batch_size, FLAGS.sequence_length]) inv_present = tf.placeholder( tf.bool, shape=[FLAGS.batch_size, FLAGS.sequence_length]) ## Rollout Generator. fwd_gen_rollouts = rollout_generator(hparams, inputs, present, is_training=is_training, is_validating=False) inv_gen_rollouts = rollout_generator(hparams, inputs, inv_present, is_training=is_training, is_validating=False, reuse=True) ## Rollout Discriminator. fwd_dis_rollouts = rollout_discriminator(hparams, fwd_gen_rollouts, is_training=is_training) inv_dis_rollouts = rollout_discriminator(hparams, inv_gen_rollouts, is_training=is_training, reuse=True) ## Discriminator Loss. [dis_loss, dis_loss_pred, dis_loss_inv_pred] = rollout_discriminator_loss(fwd_dis_rollouts, present, inv_dis_rollouts, inv_present) ## Average log-perplexity for only missing words. However, to do this, # the logits are still computed using teacher forcing, that is, the ground # truth tokens are fed in at each time point to be valid. # TODO(liamfedus): Fix the naming convention. with tf.variable_scope('gen_rollout'): _, fwd_eval_logits, _ = model_construction.create_generator( hparams, inputs, present, is_training=False, is_validating=True, reuse=True) avg_log_perplexity = model_losses.calculate_log_perplexity( fwd_eval_logits, inputs, present) ## Generator Loss. # 1. Cross Entropy losses on missing tokens. [fwd_cross_entropy_losses, inv_cross_entropy_losses ] = rollout_masked_cross_entropy_loss(inputs, present, inv_present, fwd_gen_rollouts, inv_gen_rollouts) # 2. GAN losses on missing tokens. [fwd_RL_loss, fwd_RL_statistics, fwd_averages_op] = rollout_reinforce_objective(hparams, fwd_gen_rollouts, fwd_dis_rollouts, present) [inv_RL_loss, inv_RL_statistics, inv_averages_op] = rollout_reinforce_objective(hparams, inv_gen_rollouts, inv_dis_rollouts, inv_present) # TODO(liamfedus): Generalize this to use all logs. [fwd_sequence, fwd_logits, fwd_log_probs] = fwd_gen_rollouts[-1] [inv_sequence, inv_logits, inv_log_probs] = inv_gen_rollouts[-1] # TODO(liamfedus): Generalize this to use all logs. fwd_predictions = fwd_dis_rollouts[-1] inv_predictions = inv_dis_rollouts[-1] # TODO(liamfedus): Generalize this to use all logs. [fwd_log_probs, fwd_rewards, fwd_advantages, fwd_baselines] = fwd_RL_statistics[-1] [inv_log_probs, inv_rewards, inv_advantages, inv_baselines] = inv_RL_statistics[-1] ## Pre-training. if FLAGS.gen_pretrain_steps: # TODO(liamfedus): Rewrite this. fwd_cross_entropy_loss = tf.reduce_mean(fwd_cross_entropy_losses) gen_pretrain_op = model_optimization.create_gen_pretrain_op( hparams, fwd_cross_entropy_loss, global_step) else: gen_pretrain_op = tf.no_op('gen_pretrain_no_op') if FLAGS.dis_pretrain_steps: dis_pretrain_op = model_optimization.create_dis_pretrain_op( hparams, dis_loss, global_step) else: dis_pretrain_op = tf.no_op('dis_pretrain_no_op') ## Generator Train Op. # 1. Cross-Entropy. if FLAGS.gen_training_strategy == 'cross_entropy': gen_loss = tf.reduce_mean(fwd_cross_entropy_losses + inv_cross_entropy_losses) / 2. [gen_train_op, gen_grads, gen_vars] = model_optimization.create_gen_train_op(hparams, learning_rate, gen_loss, global_step, mode='MINIMIZE') # 2. GAN (REINFORCE) elif FLAGS.gen_training_strategy == 'reinforce': gen_loss = (fwd_RL_loss + inv_RL_loss) / 2. [gen_train_op, gen_grads, gen_vars] = model_optimization.create_reinforce_gen_train_op( hparams, learning_rate, gen_loss, fwd_averages_op, inv_averages_op, global_step) else: raise NotImplementedError ## Discriminator Train Op. dis_train_op, dis_grads, dis_vars = model_optimization.create_dis_train_op( hparams, dis_loss, global_step) ## Summaries. with tf.name_scope('general'): tf.summary.scalar('percent_real', percent_real_var) tf.summary.scalar('learning_rate', learning_rate) with tf.name_scope('generator_losses'): tf.summary.scalar('gen_loss', tf.reduce_mean(gen_loss)) tf.summary.scalar('gen_loss_fwd_cross_entropy', tf.reduce_mean(fwd_cross_entropy_losses)) tf.summary.scalar('gen_loss_inv_cross_entropy', tf.reduce_mean(inv_cross_entropy_losses)) with tf.name_scope('REINFORCE'): with tf.name_scope('objective'): tf.summary.scalar('fwd_RL_loss', tf.reduce_mean(fwd_RL_loss)) tf.summary.scalar('inv_RL_loss', tf.reduce_mean(inv_RL_loss)) with tf.name_scope('rewards'): helper.variable_summaries(fwd_rewards, 'fwd_rewards') helper.variable_summaries(inv_rewards, 'inv_rewards') with tf.name_scope('advantages'): helper.variable_summaries(fwd_advantages, 'fwd_advantages') helper.variable_summaries(inv_advantages, 'inv_advantages') with tf.name_scope('baselines'): helper.variable_summaries(fwd_baselines, 'fwd_baselines') helper.variable_summaries(inv_baselines, 'inv_baselines') with tf.name_scope('log_probs'): helper.variable_summaries(fwd_log_probs, 'fwd_log_probs') helper.variable_summaries(inv_log_probs, 'inv_log_probs') with tf.name_scope('discriminator_losses'): tf.summary.scalar('dis_loss', dis_loss) tf.summary.scalar('dis_loss_fwd_sequence', dis_loss_pred) tf.summary.scalar('dis_loss_inv_sequence', dis_loss_inv_pred) with tf.name_scope('logits'): helper.variable_summaries(fwd_logits, 'fwd_logits') helper.variable_summaries(inv_logits, 'inv_logits') for v, g in zip(gen_vars, gen_grads): helper.variable_summaries(v, v.op.name) helper.variable_summaries(g, 'grad/' + v.op.name) for v, g in zip(dis_vars, dis_grads): helper.variable_summaries(v, v.op.name) helper.variable_summaries(g, 'grad/' + v.op.name) merge_summaries_op = tf.summary.merge_all() # Model saver. saver = tf.train.Saver(keep_checkpoint_every_n_hours=1, max_to_keep=5) # Named tuple that captures elements of the MaskGAN model. Model = collections.namedtuple('Model', [ 'inputs', 'present', 'inv_present', 'percent_real_update', 'new_rate', 'fwd_sequence', 'fwd_logits', 'fwd_rewards', 'fwd_advantages', 'fwd_log_probs', 'fwd_predictions', 'fwd_cross_entropy_losses', 'inv_sequence', 'inv_logits', 'inv_rewards', 'inv_advantages', 'inv_log_probs', 'inv_predictions', 'inv_cross_entropy_losses', 'avg_log_perplexity', 'dis_loss', 'gen_loss', 'dis_train_op', 'gen_train_op', 'gen_pretrain_op', 'dis_pretrain_op', 'merge_summaries_op', 'global_step', 'new_learning_rate', 'learning_rate_update', 'saver' ]) model = Model(inputs, present, inv_present, percent_real_update, new_rate, fwd_sequence, fwd_logits, fwd_rewards, fwd_advantages, fwd_log_probs, fwd_predictions, fwd_cross_entropy_losses, inv_sequence, inv_logits, inv_rewards, inv_advantages, inv_log_probs, inv_predictions, inv_cross_entropy_losses, avg_log_perplexity, dis_loss, gen_loss, dis_train_op, gen_train_op, gen_pretrain_op, dis_pretrain_op, merge_summaries_op, global_step, new_learning_rate, learning_rate_update, saver) return model
def create_MaskGAN(hparams, is_training): """Create the MaskGAN model. Args: hparams: Hyperparameters for the MaskGAN. is_training: Boolean indicating operational mode (train/inference). evaluated with a teacher forcing regime. Return: model: Namedtuple for specifying the MaskGAN. """ global_step = tf.Variable(0, name='global_step', trainable=False) new_learning_rate = tf.placeholder(tf.float32, [], name='new_learning_rate') learning_rate = tf.Variable(0.0, name='learning_rate', trainable=False) learning_rate_update = tf.assign(learning_rate, new_learning_rate) new_rate = tf.placeholder(tf.float32, [], name='new_rate') percent_real_var = tf.Variable(0.0, trainable=False) percent_real_update = tf.assign(percent_real_var, new_rate) ## Placeholders. inputs = tf.placeholder( tf.int32, shape=[FLAGS.batch_size, FLAGS.sequence_length]) targets = tf.placeholder( tf.int32, shape=[FLAGS.batch_size, FLAGS.sequence_length]) present = tf.placeholder( tf.bool, shape=[FLAGS.batch_size, FLAGS.sequence_length]) # TODO(adai): Placeholder for IMDB label. ## Real Sequence is the targets. real_sequence = targets ## Fakse Sequence from the Generator. # TODO(adai): Generator must have IMDB labels placeholder. (fake_sequence, fake_logits, fake_log_probs, fake_gen_initial_state, fake_gen_final_state, _) = model_construction.create_generator( hparams, inputs, targets, present, is_training=is_training, is_validating=False) (_, eval_logits, _, eval_initial_state, eval_final_state, _) = model_construction.create_generator( hparams, inputs, targets, present, is_training=False, is_validating=True, reuse=True) ## Discriminator. fake_predictions = model_construction.create_discriminator( hparams, fake_sequence, is_training=is_training, inputs=inputs, present=present) real_predictions = model_construction.create_discriminator( hparams, real_sequence, is_training=is_training, reuse=True, inputs=inputs, present=present) ## Critic. # The critic will be used to estimate the forward rewards to the Generator. if FLAGS.baseline_method == 'critic': est_state_values = model_construction.create_critic( hparams, fake_sequence, is_training=is_training) else: est_state_values = None ## Discriminator Loss. [dis_loss, dis_loss_fake, dis_loss_real] = model_losses.create_dis_loss( fake_predictions, real_predictions, present) ## Average log-perplexity for only missing words. However, to do this, # the logits are still computed using teacher forcing, that is, the ground # truth tokens are fed in at each time point to be valid. avg_log_perplexity = model_losses.calculate_log_perplexity( eval_logits, targets, present) ## Generator Objective. # 1. Cross Entropy losses on missing tokens. fake_cross_entropy_losses = model_losses.create_masked_cross_entropy_loss( targets, present, fake_logits) # 2. GAN REINFORCE losses. [ fake_RL_loss, fake_log_probs, fake_rewards, fake_advantages, fake_baselines, fake_averages_op, critic_loss, cumulative_rewards ] = model_losses.calculate_reinforce_objective( hparams, fake_log_probs, fake_predictions, present, est_state_values) ## Pre-training. if FLAGS.gen_pretrain_steps: raise NotImplementedError # # TODO(liamfedus): Rewrite this. # fwd_cross_entropy_loss = tf.reduce_mean(fwd_cross_entropy_losses) # gen_pretrain_op = model_optimization.create_gen_pretrain_op( # hparams, fwd_cross_entropy_loss, global_step) else: gen_pretrain_op = None if FLAGS.dis_pretrain_steps: dis_pretrain_op = model_optimization.create_dis_pretrain_op( hparams, dis_loss, global_step) else: dis_pretrain_op = None ## Generator Train Op. # 1. Cross-Entropy. if FLAGS.gen_training_strategy == 'cross_entropy': gen_loss = tf.reduce_mean(fake_cross_entropy_losses) [gen_train_op, gen_grads, gen_vars] = model_optimization.create_gen_train_op( hparams, learning_rate, gen_loss, global_step, mode='MINIMIZE') # 2. GAN (REINFORCE) elif FLAGS.gen_training_strategy == 'reinforce': gen_loss = fake_RL_loss [gen_train_op, gen_grads, gen_vars] = model_optimization.create_reinforce_gen_train_op( hparams, learning_rate, gen_loss, fake_averages_op, global_step) else: raise NotImplementedError ## Discriminator Train Op. dis_train_op, dis_grads, dis_vars = model_optimization.create_dis_train_op( hparams, dis_loss, global_step) ## Critic Train Op. if critic_loss is not None: [critic_train_op, _, _] = model_optimization.create_critic_train_op( hparams, critic_loss, global_step) dis_train_op = tf.group(dis_train_op, critic_train_op) ## Summaries. with tf.name_scope('general'): tf.summary.scalar('percent_real', percent_real_var) tf.summary.scalar('learning_rate', learning_rate) with tf.name_scope('generator_objectives'): tf.summary.scalar('gen_objective', tf.reduce_mean(gen_loss)) tf.summary.scalar('gen_loss_cross_entropy', tf.reduce_mean(fake_cross_entropy_losses)) with tf.name_scope('REINFORCE'): with tf.name_scope('objective'): tf.summary.scalar('fake_RL_loss', tf.reduce_mean(fake_RL_loss)) with tf.name_scope('rewards'): helper.variable_summaries(cumulative_rewards, 'rewards') with tf.name_scope('advantages'): helper.variable_summaries(fake_advantages, 'advantages') with tf.name_scope('baselines'): helper.variable_summaries(fake_baselines, 'baselines') with tf.name_scope('log_probs'): helper.variable_summaries(fake_log_probs, 'log_probs') with tf.name_scope('discriminator_losses'): tf.summary.scalar('dis_loss', dis_loss) tf.summary.scalar('dis_loss_fake_sequence', dis_loss_fake) tf.summary.scalar('dis_loss_prob_fake_sequence', tf.exp(-dis_loss_fake)) tf.summary.scalar('dis_loss_real_sequence', dis_loss_real) tf.summary.scalar('dis_loss_prob_real_sequence', tf.exp(-dis_loss_real)) if critic_loss is not None: with tf.name_scope('critic_losses'): tf.summary.scalar('critic_loss', critic_loss) with tf.name_scope('logits'): helper.variable_summaries(fake_logits, 'fake_logits') for v, g in zip(gen_vars, gen_grads): helper.variable_summaries(v, v.op.name) helper.variable_summaries(g, 'grad/' + v.op.name) for v, g in zip(dis_vars, dis_grads): helper.variable_summaries(v, v.op.name) helper.variable_summaries(g, 'grad/' + v.op.name) merge_summaries_op = tf.summary.merge_all() text_summary_placeholder = tf.placeholder(tf.string) text_summary_op = tf.summary.text('Samples', text_summary_placeholder) # Model saver. saver = tf.train.Saver(keep_checkpoint_every_n_hours=1, max_to_keep=5) # Named tuple that captures elements of the MaskGAN model. Model = collections.namedtuple('Model', [ 'inputs', 'targets', 'present', 'percent_real_update', 'new_rate', 'fake_sequence', 'fake_logits', 'fake_rewards', 'fake_baselines', 'fake_advantages', 'fake_log_probs', 'fake_predictions', 'real_predictions', 'fake_cross_entropy_losses', 'fake_gen_initial_state', 'fake_gen_final_state', 'eval_initial_state', 'eval_final_state', 'avg_log_perplexity', 'dis_loss', 'gen_loss', 'critic_loss', 'cumulative_rewards', 'dis_train_op', 'gen_train_op', 'gen_pretrain_op', 'dis_pretrain_op', 'merge_summaries_op', 'global_step', 'new_learning_rate', 'learning_rate_update', 'saver', 'text_summary_op', 'text_summary_placeholder' ]) model = Model( inputs, targets, present, percent_real_update, new_rate, fake_sequence, fake_logits, fake_rewards, fake_baselines, fake_advantages, fake_log_probs, fake_predictions, real_predictions, fake_cross_entropy_losses, fake_gen_initial_state, fake_gen_final_state, eval_initial_state, eval_final_state, avg_log_perplexity, dis_loss, gen_loss, critic_loss, cumulative_rewards, dis_train_op, gen_train_op, gen_pretrain_op, dis_pretrain_op, merge_summaries_op, global_step, new_learning_rate, learning_rate_update, saver, text_summary_op, text_summary_placeholder) return model
def create_rollout_MaskGAN(hparams, is_training): """Create the MaskGAN model. Args: hparams: Hyperparameters for the MaskGAN. is_training: Boolean indicating operational mode (train/inference). evaluated with a teacher forcing regime. Return: model: Namedtuple for specifying the MaskGAN.""" global_step = tf.Variable(0, name='global_step', trainable=False) new_learning_rate = tf.placeholder(tf.float32, [], name='new_learning_rate') learning_rate = tf.Variable(0.0, name='learning_rate', trainable=False) learning_rate_update = tf.assign(learning_rate, new_learning_rate) new_rate = tf.placeholder(tf.float32, [], name='new_rate') percent_real_var = tf.Variable(0.0, trainable=False) percent_real_update = tf.assign(percent_real_var, new_rate) ## Placeholders. inputs = tf.placeholder( tf.int32, shape=[FLAGS.batch_size, FLAGS.sequence_length]) present = tf.placeholder( tf.bool, shape=[FLAGS.batch_size, FLAGS.sequence_length]) inv_present = tf.placeholder( tf.bool, shape=[FLAGS.batch_size, FLAGS.sequence_length]) ## Rollout Generator. fwd_gen_rollouts = rollout_generator( hparams, inputs, present, is_training=is_training, is_validating=False) inv_gen_rollouts = rollout_generator( hparams, inputs, inv_present, is_training=is_training, is_validating=False, reuse=True) ## Rollout Discriminator. fwd_dis_rollouts = rollout_discriminator( hparams, fwd_gen_rollouts, is_training=is_training) inv_dis_rollouts = rollout_discriminator( hparams, inv_gen_rollouts, is_training=is_training, reuse=True) ## Discriminator Loss. [dis_loss, dis_loss_pred, dis_loss_inv_pred] = rollout_discriminator_loss( fwd_dis_rollouts, present, inv_dis_rollouts, inv_present) ## Average log-perplexity for only missing words. However, to do this, # the logits are still computed using teacher forcing, that is, the ground # truth tokens are fed in at each time point to be valid. # TODO(liamfedus): Fix the naming convention. with tf.variable_scope('gen_rollout'): _, fwd_eval_logits, _ = model_construction.create_generator( hparams, inputs, present, is_training=False, is_validating=True, reuse=True) avg_log_perplexity = model_losses.calculate_log_perplexity( fwd_eval_logits, inputs, present) ## Generator Loss. # 1. Cross Entropy losses on missing tokens. [fwd_cross_entropy_losses, inv_cross_entropy_losses] = rollout_masked_cross_entropy_loss( inputs, present, inv_present, fwd_gen_rollouts, inv_gen_rollouts) # 2. GAN losses on missing tokens. [fwd_RL_loss, fwd_RL_statistics, fwd_averages_op] = rollout_reinforce_objective( hparams, fwd_gen_rollouts, fwd_dis_rollouts, present) [inv_RL_loss, inv_RL_statistics, inv_averages_op] = rollout_reinforce_objective( hparams, inv_gen_rollouts, inv_dis_rollouts, inv_present) # TODO(liamfedus): Generalize this to use all logs. [fwd_sequence, fwd_logits, fwd_log_probs] = fwd_gen_rollouts[-1] [inv_sequence, inv_logits, inv_log_probs] = inv_gen_rollouts[-1] # TODO(liamfedus): Generalize this to use all logs. fwd_predictions = fwd_dis_rollouts[-1] inv_predictions = inv_dis_rollouts[-1] # TODO(liamfedus): Generalize this to use all logs. [fwd_log_probs, fwd_rewards, fwd_advantages, fwd_baselines] = fwd_RL_statistics[-1] [inv_log_probs, inv_rewards, inv_advantages, inv_baselines] = inv_RL_statistics[-1] ## Pre-training. if FLAGS.gen_pretrain_steps: # TODO(liamfedus): Rewrite this. fwd_cross_entropy_loss = tf.reduce_mean(fwd_cross_entropy_losses) gen_pretrain_op = model_optimization.create_gen_pretrain_op( hparams, fwd_cross_entropy_loss, global_step) else: gen_pretrain_op = tf.no_op('gen_pretrain_no_op') if FLAGS.dis_pretrain_steps: dis_pretrain_op = model_optimization.create_dis_pretrain_op( hparams, dis_loss, global_step) else: dis_pretrain_op = tf.no_op('dis_pretrain_no_op') ## Generator Train Op. # 1. Cross-Entropy. if FLAGS.gen_training_strategy == 'cross_entropy': gen_loss = tf.reduce_mean( fwd_cross_entropy_losses + inv_cross_entropy_losses) / 2. [gen_train_op, gen_grads, gen_vars] = model_optimization.create_gen_train_op( hparams, learning_rate, gen_loss, global_step, mode='MINIMIZE') # 2. GAN (REINFORCE) elif FLAGS.gen_training_strategy == 'reinforce': gen_loss = (fwd_RL_loss + inv_RL_loss) / 2. [gen_train_op, gen_grads, gen_vars] = model_optimization.create_reinforce_gen_train_op( hparams, learning_rate, gen_loss, fwd_averages_op, inv_averages_op, global_step) else: raise NotImplementedError ## Discriminator Train Op. dis_train_op, dis_grads, dis_vars = model_optimization.create_dis_train_op( hparams, dis_loss, global_step) ## Summaries. with tf.name_scope('general'): tf.summary.scalar('percent_real', percent_real_var) tf.summary.scalar('learning_rate', learning_rate) with tf.name_scope('generator_losses'): tf.summary.scalar('gen_loss', tf.reduce_mean(gen_loss)) tf.summary.scalar('gen_loss_fwd_cross_entropy', tf.reduce_mean(fwd_cross_entropy_losses)) tf.summary.scalar('gen_loss_inv_cross_entropy', tf.reduce_mean(inv_cross_entropy_losses)) with tf.name_scope('REINFORCE'): with tf.name_scope('objective'): tf.summary.scalar('fwd_RL_loss', tf.reduce_mean(fwd_RL_loss)) tf.summary.scalar('inv_RL_loss', tf.reduce_mean(inv_RL_loss)) with tf.name_scope('rewards'): helper.variable_summaries(fwd_rewards, 'fwd_rewards') helper.variable_summaries(inv_rewards, 'inv_rewards') with tf.name_scope('advantages'): helper.variable_summaries(fwd_advantages, 'fwd_advantages') helper.variable_summaries(inv_advantages, 'inv_advantages') with tf.name_scope('baselines'): helper.variable_summaries(fwd_baselines, 'fwd_baselines') helper.variable_summaries(inv_baselines, 'inv_baselines') with tf.name_scope('log_probs'): helper.variable_summaries(fwd_log_probs, 'fwd_log_probs') helper.variable_summaries(inv_log_probs, 'inv_log_probs') with tf.name_scope('discriminator_losses'): tf.summary.scalar('dis_loss', dis_loss) tf.summary.scalar('dis_loss_fwd_sequence', dis_loss_pred) tf.summary.scalar('dis_loss_inv_sequence', dis_loss_inv_pred) with tf.name_scope('logits'): helper.variable_summaries(fwd_logits, 'fwd_logits') helper.variable_summaries(inv_logits, 'inv_logits') for v, g in zip(gen_vars, gen_grads): helper.variable_summaries(v, v.op.name) helper.variable_summaries(g, 'grad/' + v.op.name) for v, g in zip(dis_vars, dis_grads): helper.variable_summaries(v, v.op.name) helper.variable_summaries(g, 'grad/' + v.op.name) merge_summaries_op = tf.summary.merge_all() # Model saver. saver = tf.train.Saver(keep_checkpoint_every_n_hours=1, max_to_keep=5) # Named tuple that captures elements of the MaskGAN model. Model = collections.namedtuple('Model', [ 'inputs', 'present', 'inv_present', 'percent_real_update', 'new_rate', 'fwd_sequence', 'fwd_logits', 'fwd_rewards', 'fwd_advantages', 'fwd_log_probs', 'fwd_predictions', 'fwd_cross_entropy_losses', 'inv_sequence', 'inv_logits', 'inv_rewards', 'inv_advantages', 'inv_log_probs', 'inv_predictions', 'inv_cross_entropy_losses', 'avg_log_perplexity', 'dis_loss', 'gen_loss', 'dis_train_op', 'gen_train_op', 'gen_pretrain_op', 'dis_pretrain_op', 'merge_summaries_op', 'global_step', 'new_learning_rate', 'learning_rate_update', 'saver' ]) model = Model( inputs, present, inv_present, percent_real_update, new_rate, fwd_sequence, fwd_logits, fwd_rewards, fwd_advantages, fwd_log_probs, fwd_predictions, fwd_cross_entropy_losses, inv_sequence, inv_logits, inv_rewards, inv_advantages, inv_log_probs, inv_predictions, inv_cross_entropy_losses, avg_log_perplexity, dis_loss, gen_loss, dis_train_op, gen_train_op, gen_pretrain_op, dis_pretrain_op, merge_summaries_op, global_step, new_learning_rate, learning_rate_update, saver) return model