def generate_samples(hparams, data, id_to_word, log_dir, output_file): """"Generate samples. Args: hparams: Hyperparameters for the MaskGAN. data: Data to evaluate. id_to_word: Dictionary of indices to words. log_dir: Log directory. output_file: Output file for the samples. """ # Boolean indicating operational mode. is_training = False # Set a random seed to keep fixed mask. np.random.seed(0) with tf.Graph().as_default(): # Construct the model. model = train_mask_gan.create_MaskGAN(hparams, is_training) ## Retrieve the initial savers. init_savers = model_utils.retrieve_init_savers(hparams) ## Initial saver function to supervisor. init_fn = partial(model_utils.init_fn, init_savers) is_chief = FLAGS.task == 0 # Create the supervisor. It will take care of initialization, summaries, # checkpoints, and recovery. sv = tf.Supervisor(logdir=log_dir, is_chief=is_chief, saver=model.saver, global_step=model.global_step, recovery_wait_secs=30, summary_op=None, init_fn=init_fn) # Get an initialized, and possibly recovered session. Launch the # services: Checkpointing, Summaries, step counting. # # When multiple replicas of this program are running the services are # only launched by the 'chief' replica. with sv.managed_session(FLAGS.master, start_standard_services=False) as sess: # Generator statefulness over the epoch. [gen_initial_state_eval, fake_gen_initial_state_eval] = sess.run( [model.eval_initial_state, model.fake_gen_initial_state]) for n in xrange(FLAGS.number_epochs): print('Epoch number: %d' % n) # print('Percent done: %.2f' % float(n) / float(FLAGS.number_epochs)) iterator = get_iterator(data) for x, y, _ in iterator: if FLAGS.eval_language_model: is_present_rate = 0. else: is_present_rate = FLAGS.is_present_rate tf.logging.info('Evaluating on is_present_rate=%.3f.' % is_present_rate) model_utils.assign_percent_real(sess, model.percent_real_update, model.new_rate, is_present_rate) # Randomly mask out tokens. p = model_utils.generate_mask() eval_feed = { model.inputs: x, model.targets: y, model.present: p } if FLAGS.data_set == 'ptb': # Statefulness for *evaluation* Generator. for i, (c, h) in enumerate(model.eval_initial_state): eval_feed[c] = gen_initial_state_eval[i].c eval_feed[h] = gen_initial_state_eval[i].h # Statefulness for the Generator. for i, (c, h) in enumerate(model.fake_gen_initial_state): eval_feed[c] = fake_gen_initial_state_eval[i].c eval_feed[h] = fake_gen_initial_state_eval[i].h [gen_initial_state_eval, fake_gen_initial_state_eval, _] = sess.run([ model.eval_final_state, model.fake_gen_final_state, model.global_step ], feed_dict=eval_feed) generate_logs(sess, model, output_file, id_to_word, eval_feed) output_file.close() print('Closing output_file.') return
def generate_samples(hparams, data, id_to_word, log_dir, output_file): """"Generate samples. Args: hparams: Hyperparameters for the MaskGAN. data: Data to evaluate. id_to_word: Dictionary of indices to words. log_dir: Log directory. output_file: Output file for the samples. """ # Boolean indicating operational mode. is_training = False # Set a random seed to keep fixed mask. np.random.seed(0) with tf.Graph().as_default(): # Construct the model. model = train_mask_gan.create_MaskGAN(hparams, is_training) ## Retrieve the initial savers. init_savers = model_utils.retrieve_init_savers(hparams) ## Initial saver function to supervisor. init_fn = partial(model_utils.init_fn, init_savers) is_chief = FLAGS.task == 0 # Create the supervisor. It will take care of initialization, summaries, # checkpoints, and recovery. sv = tf.Supervisor( logdir=log_dir, is_chief=is_chief, saver=model.saver, global_step=model.global_step, recovery_wait_secs=30, summary_op=None, init_fn=init_fn) # Get an initialized, and possibly recovered session. Launch the # services: Checkpointing, Summaries, step counting. # # When multiple replicas of this program are running the services are # only launched by the 'chief' replica. with sv.managed_session( FLAGS.master, start_standard_services=False) as sess: # Generator statefulness over the epoch. [gen_initial_state_eval, fake_gen_initial_state_eval] = sess.run( [model.eval_initial_state, model.fake_gen_initial_state]) for n in xrange(FLAGS.number_epochs): print('Epoch number: %d' % n) # print('Percent done: %.2f' % float(n) / float(FLAGS.number_epochs)) iterator = get_iterator(data) for x, y, _ in iterator: if FLAGS.eval_language_model: is_present_rate = 0. else: is_present_rate = FLAGS.is_present_rate tf.logging.info( 'Evaluating on is_present_rate=%.3f.' % is_present_rate) model_utils.assign_percent_real(sess, model.percent_real_update, model.new_rate, is_present_rate) # Randomly mask out tokens. p = model_utils.generate_mask() eval_feed = {model.inputs: x, model.targets: y, model.present: p} if FLAGS.data_set == 'ptb': # Statefulness for *evaluation* Generator. for i, (c, h) in enumerate(model.eval_initial_state): eval_feed[c] = gen_initial_state_eval[i].c eval_feed[h] = gen_initial_state_eval[i].h # Statefulness for the Generator. for i, (c, h) in enumerate(model.fake_gen_initial_state): eval_feed[c] = fake_gen_initial_state_eval[i].c eval_feed[h] = fake_gen_initial_state_eval[i].h [gen_initial_state_eval, fake_gen_initial_state_eval, _] = sess.run( [ model.eval_final_state, model.fake_gen_final_state, model.global_step ], feed_dict=eval_feed) generate_logs(sess, model, output_file, id_to_word, eval_feed) output_file.close() print('Closing output_file.') return
def pretrain_discriminator(sv, sess, model, data, log, id_to_word, data_ngram_counts, is_chief): print('\nPretraining discriminator for %d steps.' % FLAGS.dis_pretrain_steps) log.write( '\nPretraining discriminator for %d steps.\n' % FLAGS.dis_pretrain_steps) is_pretraining = True while is_pretraining: cumulative_costs = 0. iters = 0 if FLAGS.data_set == 'ptb': iterator = ptb_loader.ptb_iterator(data, FLAGS.batch_size, FLAGS.sequence_length, FLAGS.epoch_size_override) elif FLAGS.data_set == 'imdb': iterator = imdb_loader.imdb_iterator(data, FLAGS.batch_size, FLAGS.sequence_length) for x, y, _ in iterator: is_present_rate = FLAGS.is_present_rate # is_present_rate = np.random.uniform(low=0.0, high=1.0) model_utils.assign_percent_real(sess, model.percent_real_update, model.new_rate, is_present_rate) # Randomly mask out tokens. p = model_utils.generate_mask() pretrain_feed = {model.inputs: x, model.targets: y, model.present: p} [_, dis_loss_eval, gen_log_perplexity_eval, step] = sess.run( [ model.dis_pretrain_op, model.dis_loss, model.avg_log_perplexity, model.global_step ], feed_dict=pretrain_feed) cumulative_costs += gen_log_perplexity_eval iters += 1 # Calulate rolling perplexity. perplexity = np.exp(cumulative_costs / iters) # Summaries. if is_chief and step % FLAGS.summaries_every == 0: # Graph summaries. summary_str = sess.run( model.merge_summaries_op, feed_dict=pretrain_feed) sv.SummaryComputed(sess, summary_str) # Additional summary. for n, data_ngram_count in data_ngram_counts.iteritems(): avg_percent_captured = evaluation_utils.sequence_ngram_evaluation( sess, model.fake_sequence, log, pretrain_feed, data_ngram_count, int(n)) summary_percent_str = tf.Summary(value=[ tf.Summary.Value( tag='general/%s-grams_percent_correct' % n, simple_value=avg_percent_captured) ]) sv.SummaryComputed(sess, summary_percent_str, global_step=step) summary_perplexity_str = tf.Summary(value=[ tf.Summary.Value(tag='general/perplexity', simple_value=perplexity) ]) sv.SummaryComputed(sess, summary_perplexity_str, global_step=step) # Printing and logging if is_chief and step % FLAGS.print_every == 0: print('global_step: %d' % step) print(' discriminator loss: %.3f' % dis_loss_eval) print(' perplexity: %.3f' % perplexity) log.write('global_step: %d\n' % step) log.write(' discriminator loss: %.3f\n' % dis_loss_eval) log.write(' perplexity: %.3f\n' % perplexity) for n, data_ngram_count in data_ngram_counts.iteritems(): avg_percent_captured = evaluation_utils.sequence_ngram_evaluation( sess, model.fake_sequence, log, pretrain_feed, data_ngram_count, int(n)) print(' percent of %s-grams captured: %.3f.\n' % (n, avg_percent_captured)) log.write(' percent of %s-grams captured: %.3f.\n\n' % (n, avg_percent_captured)) evaluation_utils.generate_logs(sess, model, log, id_to_word, pretrain_feed) if step >= FLAGS.dis_pretrain_steps + int(FLAGS.gen_pretrain_steps or 0): is_pretraining = False break return
def evaluate_once(data, sv, model, sess, train_dir, log, id_to_word, data_ngram_counts, eval_saver): """Evaluate model for a number of steps. Args: data: Dataset. sv: Supervisor. model: The GAN model we have just built. sess: A session to use. train_dir: Path to a directory containing checkpoints. log: Evaluation log for evaluation. id_to_word: Dictionary of indices to words. data_ngram_counts: Dictionary of hashed(n-gram tuples) to counts in the data_set. eval_saver: Evaluation saver.r. """ tf.logging.info('Evaluate Once.') # Load the last model checkpoint, or initialize the graph. model_save_path = tf.train.latest_checkpoint(train_dir) if not model_save_path: tf.logging.warning('No checkpoint yet in: %s', train_dir) return tf.logging.info('Starting eval of: %s' % model_save_path) tf.logging.info('Only restoring trainable variables.') eval_saver.restore(sess, model_save_path) # Run the requested number of evaluation steps avg_epoch_gen_loss, avg_epoch_dis_loss = [], [] cumulative_costs = 0. # Average percent captured for each of the n-grams. avg_percent_captured = {'2': 0., '3': 0., '4': 0.} # Set a random seed to keep fixed mask. np.random.seed(0) gen_iters = 0 # Generator statefulness over the epoch. # TODO(liamfedus): Check this. [gen_initial_state_eval, fake_gen_initial_state_eval] = sess.run( [model.eval_initial_state, model.fake_gen_initial_state]) if FLAGS.eval_language_model: is_present_rate = 0. tf.logging.info('Overriding is_present_rate=0. for evaluation.') print('Overriding is_present_rate=0. for evaluation.') iterator = get_iterator(data) for x, y, _ in iterator: if FLAGS.eval_language_model: is_present_rate = 0. else: is_present_rate = FLAGS.is_present_rate tf.logging.info('Evaluating on is_present_rate=%.3f.' % is_present_rate) model_utils.assign_percent_real(sess, model.percent_real_update, model.new_rate, is_present_rate) # Randomly mask out tokens. p = model_utils.generate_mask() eval_feed = {model.inputs: x, model.targets: y, model.present: p} if FLAGS.data_set == 'ptb': # Statefulness for *evaluation* Generator. for i, (c, h) in enumerate(model.eval_initial_state): eval_feed[c] = gen_initial_state_eval[i].c eval_feed[h] = gen_initial_state_eval[i].h # Statefulness for the Generator. for i, (c, h) in enumerate(model.fake_gen_initial_state): eval_feed[c] = fake_gen_initial_state_eval[i].c eval_feed[h] = fake_gen_initial_state_eval[i].h [ gen_log_perplexity_eval, dis_loss_eval, gen_loss_eval, gen_initial_state_eval, fake_gen_initial_state_eval, step ] = sess.run( [ model.avg_log_perplexity, model.dis_loss, model.gen_loss, model.eval_final_state, model.fake_gen_final_state, model.global_step ], feed_dict=eval_feed) for n, data_ngram_count in data_ngram_counts.items(): batch_percent_captured = evaluation_utils.sequence_ngram_evaluation( sess, model.fake_sequence, log, eval_feed, data_ngram_count, int(n)) avg_percent_captured[n] += batch_percent_captured cumulative_costs += gen_log_perplexity_eval avg_epoch_dis_loss.append(dis_loss_eval) avg_epoch_gen_loss.append(gen_loss_eval) gen_iters += 1 # Calulate rolling metrics. perplexity = np.exp(cumulative_costs / gen_iters) for n, _ in avg_percent_captured.items(): avg_percent_captured[n] /= gen_iters # Confirm perplexity is not infinite. if not np.isfinite(perplexity) or perplexity >= FLAGS.perplexity_threshold: print('Evaluation raising FloatingPointError.') raise FloatingPointError( 'Evaluation infinite perplexity: %.3f' % perplexity) ## Printing and logging. evaluation_utils.print_and_log_losses(log, step, is_present_rate, avg_epoch_dis_loss, avg_epoch_gen_loss) print(' perplexity: %.3f' % perplexity) log.write(' perplexity: %.3f\n' % perplexity) for n, n_gram_percent in avg_percent_captured.items(): n = int(n) print(' percent of %d-grams captured: %.3f.' % (n, n_gram_percent)) log.write(' percent of %d-grams captured: %.3f.\n' % (n, n_gram_percent)) samples = evaluation_utils.generate_logs(sess, model, log, id_to_word, eval_feed) ## Summaries. summary_str = sess.run(model.merge_summaries_op, feed_dict=eval_feed) sv.SummaryComputed(sess, summary_str) # Summary: text summary_str = sess.run(model.text_summary_op, {model.text_summary_placeholder: '\n\n'.join(samples)}) sv.SummaryComputed(sess, summary_str, global_step=step) # Summary: n-gram for n, n_gram_percent in avg_percent_captured.items(): n = int(n) summary_percent_str = tf.Summary(value=[ tf.Summary.Value( tag='general/%d-grams_percent_correct' % n, simple_value=n_gram_percent) ]) sv.SummaryComputed(sess, summary_percent_str, global_step=step) # Summary: geometric_avg geometric_avg = compute_geometric_average(avg_percent_captured) summary_geometric_avg_str = tf.Summary(value=[ tf.Summary.Value(tag='general/geometric_avg', simple_value=geometric_avg) ]) sv.SummaryComputed(sess, summary_geometric_avg_str, global_step=step) # Summary: arithmetic_avg arithmetic_avg = compute_arithmetic_average(avg_percent_captured) summary_arithmetic_avg_str = tf.Summary(value=[ tf.Summary.Value( tag='general/arithmetic_avg', simple_value=arithmetic_avg) ]) sv.SummaryComputed(sess, summary_arithmetic_avg_str, global_step=step) # Summary: perplexity summary_perplexity_str = tf.Summary(value=[ tf.Summary.Value(tag='general/perplexity', simple_value=perplexity) ]) sv.SummaryComputed(sess, summary_perplexity_str, global_step=step)
def train_model(hparams, data, log_dir, log, id_to_word, data_ngram_counts): """Train model. Args: hparams: Hyperparameters for the MaskGAN. data: Data to evaluate. log_dir: Directory to save checkpoints. log: Readable log for the experiment. id_to_word: Dictionary of indices to words. data_ngram_counts: Dictionary of hashed(n-gram tuples) to counts in the data_set. """ print('Training model.') tf.logging.info('Training model.') # Boolean indicating operational mode. is_training = True # Write all the information to the logs. log.write('hparams\n') log.write(str(hparams)) log.flush() is_chief = FLAGS.task == 0 with tf.Graph().as_default(): with tf.device(tf.train.replica_device_setter(FLAGS.ps_tasks)): container_name = '' with tf.container(container_name): # Construct the model. if FLAGS.num_rollouts == 1: model = create_MaskGAN(hparams, is_training) elif FLAGS.num_rollouts > 1: model = rollout.create_rollout_MaskGAN(hparams, is_training) else: raise ValueError print('\nTrainable Variables in Graph:') for v in tf.trainable_variables(): print(v) ## Retrieve the initial savers. init_savers = model_utils.retrieve_init_savers(hparams) ## Initial saver function to supervisor. init_fn = partial(model_utils.init_fn, init_savers) # Create the supervisor. It will take care of initialization, # summaries, checkpoints, and recovery. sv = tf.train.Supervisor( logdir=log_dir, is_chief=is_chief, saver=model.saver, global_step=model.global_step, save_model_secs=60, recovery_wait_secs=30, summary_op=None, init_fn=init_fn) # Get an initialized, and possibly recovered session. Launch the # services: Checkpointing, Summaries, step counting. # # When multiple replicas of this program are running the services are # only launched by the 'chief' replica. with sv.managed_session(FLAGS.master) as sess: ## Pretrain the generator. if FLAGS.gen_pretrain_steps: pretrain_mask_gan.pretrain_generator(sv, sess, model, data, log, id_to_word, data_ngram_counts, is_chief) ## Pretrain the discriminator. if FLAGS.dis_pretrain_steps: pretrain_mask_gan.pretrain_discriminator( sv, sess, model, data, log, id_to_word, data_ngram_counts, is_chief) # Initial indicators for printing and summarizing. print_step_division = -1 summary_step_division = -1 # Run iterative computation in a loop. while not sv.ShouldStop(): is_present_rate = FLAGS.is_present_rate if FLAGS.is_present_rate_decay is not None: is_present_rate *= (1. - FLAGS.is_present_rate_decay) model_utils.assign_percent_real(sess, model.percent_real_update, model.new_rate, is_present_rate) # GAN training. avg_epoch_gen_loss, avg_epoch_dis_loss = [], [] cumulative_costs = 0. gen_iters = 0 # Generator and Discriminator statefulness initial evaluation. # TODO(liamfedus): Throughout the code I am implicitly assuming # that the Generator and Discriminator are equal sized. [gen_initial_state_eval, fake_gen_initial_state_eval] = sess.run( [model.eval_initial_state, model.fake_gen_initial_state]) dis_initial_state_eval = fake_gen_initial_state_eval # Save zeros state to reset later. zeros_state = fake_gen_initial_state_eval ## Offset Discriminator. if FLAGS.ps_tasks == 0: dis_offset = 1 else: dis_offset = FLAGS.task * 1000 + 1 dis_iterator = get_iterator(data) for i in range(dis_offset): try: dis_x, dis_y, _ = next(dis_iterator) except StopIteration: dis_iterator = get_iterator(data) dis_initial_state_eval = zeros_state dis_x, dis_y, _ = next(dis_iterator) p = model_utils.generate_mask() # Construct the train feed. train_feed = { model.inputs: dis_x, model.targets: dis_y, model.present: p } if FLAGS.data_set == 'ptb': # Statefulness of the Generator being used for Discriminator. for i, (c, h) in enumerate(model.fake_gen_initial_state): train_feed[c] = dis_initial_state_eval[i].c train_feed[h] = dis_initial_state_eval[i].h # Determine the state had the Generator run over real data. We # use this state for the Discriminator. [dis_initial_state_eval] = sess.run( [model.fake_gen_final_state], train_feed) ## Training loop. iterator = get_iterator(data) gen_initial_state_eval = zeros_state if FLAGS.ps_tasks > 0: gen_offset = FLAGS.task * 1000 + 1 for i in range(gen_offset): try: next(iterator) except StopIteration: dis_iterator = get_iterator(data) dis_initial_state_eval = zeros_state next(dis_iterator) for x, y, _ in iterator: for _ in xrange(hparams.dis_train_iterations): try: dis_x, dis_y, _ = next(dis_iterator) except StopIteration: dis_iterator = get_iterator(data) dis_initial_state_eval = zeros_state dis_x, dis_y, _ = next(dis_iterator) if FLAGS.data_set == 'ptb': [dis_initial_state_eval] = sess.run( [model.fake_gen_initial_state]) p = model_utils.generate_mask() # Construct the train feed. train_feed = { model.inputs: dis_x, model.targets: dis_y, model.present: p } # Statefulness for the Discriminator. if FLAGS.data_set == 'ptb': for i, (c, h) in enumerate(model.fake_gen_initial_state): train_feed[c] = dis_initial_state_eval[i].c train_feed[h] = dis_initial_state_eval[i].h _, dis_loss_eval, step = sess.run( [model.dis_train_op, model.dis_loss, model.global_step], feed_dict=train_feed) # Determine the state had the Generator run over real data. # Use this state for the Discriminator. [dis_initial_state_eval] = sess.run( [model.fake_gen_final_state], train_feed) # Randomly mask out tokens. p = model_utils.generate_mask() # Construct the train feed. train_feed = {model.inputs: x, model.targets: y, model.present: p} # Statefulness for Generator. if FLAGS.data_set == 'ptb': tf.logging.info('Generator is stateful.') print('Generator is stateful.') # Statefulness for *evaluation* Generator. for i, (c, h) in enumerate(model.eval_initial_state): train_feed[c] = gen_initial_state_eval[i].c train_feed[h] = gen_initial_state_eval[i].h # Statefulness for Generator. for i, (c, h) in enumerate(model.fake_gen_initial_state): train_feed[c] = fake_gen_initial_state_eval[i].c train_feed[h] = fake_gen_initial_state_eval[i].h # Determine whether to decay learning rate. lr_decay = hparams.gen_learning_rate_decay**max( step + 1 - hparams.gen_full_learning_rate_steps, 0.0) # Assign learning rate. gen_learning_rate = hparams.gen_learning_rate * lr_decay model_utils.assign_learning_rate(sess, model.learning_rate_update, model.new_learning_rate, gen_learning_rate) [_, gen_loss_eval, gen_log_perplexity_eval, step] = sess.run( [ model.gen_train_op, model.gen_loss, model.avg_log_perplexity, model.global_step ], feed_dict=train_feed) cumulative_costs += gen_log_perplexity_eval gen_iters += 1 # Determine the state had the Generator run over real data. [gen_initial_state_eval, fake_gen_initial_state_eval] = sess.run( [model.eval_final_state, model.fake_gen_final_state], train_feed) avg_epoch_dis_loss.append(dis_loss_eval) avg_epoch_gen_loss.append(gen_loss_eval) ## Summaries. # Calulate rolling perplexity. perplexity = np.exp(cumulative_costs / gen_iters) if is_chief and (step / FLAGS.summaries_every > summary_step_division): summary_step_division = step / FLAGS.summaries_every # Confirm perplexity is not infinite. if (not np.isfinite(perplexity) or perplexity >= FLAGS.perplexity_threshold): print('Training raising FloatingPoinError.') raise FloatingPointError( 'Training infinite perplexity: %.3f' % perplexity) # Graph summaries. summary_str = sess.run( model.merge_summaries_op, feed_dict=train_feed) sv.SummaryComputed(sess, summary_str) # Summary: n-gram avg_percent_captured = {'2': 0., '3': 0., '4': 0.} for n, data_ngram_count in data_ngram_counts.items(): batch_percent_captured = evaluation_utils.sequence_ngram_evaluation( sess, model.fake_sequence, log, train_feed, data_ngram_count, int(n)) summary_percent_str = tf.Summary(value=[ tf.Summary.Value( tag='general/%s-grams_percent_correct' % n, simple_value=batch_percent_captured) ]) sv.SummaryComputed( sess, summary_percent_str, global_step=step) # Summary: geometric_avg geometric_avg = compute_geometric_average(avg_percent_captured) summary_geometric_avg_str = tf.Summary(value=[ tf.Summary.Value( tag='general/geometric_avg', simple_value=geometric_avg) ]) sv.SummaryComputed( sess, summary_geometric_avg_str, global_step=step) # Summary: arithmetic_avg arithmetic_avg = compute_arithmetic_average( avg_percent_captured) summary_arithmetic_avg_str = tf.Summary(value=[ tf.Summary.Value( tag='general/arithmetic_avg', simple_value=arithmetic_avg) ]) sv.SummaryComputed( sess, summary_arithmetic_avg_str, global_step=step) # Summary: perplexity summary_perplexity_str = tf.Summary(value=[ tf.Summary.Value( tag='general/perplexity', simple_value=perplexity) ]) sv.SummaryComputed( sess, summary_perplexity_str, global_step=step) ## Printing and logging if is_chief and (step / FLAGS.print_every > print_step_division): print_step_division = (step / FLAGS.print_every) print('global_step: %d' % step) print(' perplexity: %.3f' % perplexity) print(' gen_learning_rate: %.6f' % gen_learning_rate) log.write('global_step: %d\n' % step) log.write(' perplexity: %.3f\n' % perplexity) log.write(' gen_learning_rate: %.6f' % gen_learning_rate) # Average percent captured for each of the n-grams. avg_percent_captured = {'2': 0., '3': 0., '4': 0.} for n, data_ngram_count in data_ngram_counts.items(): batch_percent_captured = evaluation_utils.sequence_ngram_evaluation( sess, model.fake_sequence, log, train_feed, data_ngram_count, int(n)) avg_percent_captured[n] = batch_percent_captured print(' percent of %s-grams captured: %.3f.' % (n, batch_percent_captured)) log.write(' percent of %s-grams captured: %.3f.\n' % (n, batch_percent_captured)) geometric_avg = compute_geometric_average(avg_percent_captured) print(' geometric_avg: %.3f.' % geometric_avg) log.write(' geometric_avg: %.3f.' % geometric_avg) arithmetic_avg = compute_arithmetic_average( avg_percent_captured) print(' arithmetic_avg: %.3f.' % arithmetic_avg) log.write(' arithmetic_avg: %.3f.' % arithmetic_avg) evaluation_utils.print_and_log_losses( log, step, is_present_rate, avg_epoch_dis_loss, avg_epoch_gen_loss) if FLAGS.gen_training_strategy == 'reinforce': evaluation_utils.generate_RL_logs(sess, model, log, id_to_word, train_feed) else: evaluation_utils.generate_logs(sess, model, log, id_to_word, train_feed) log.flush() log.close()
def evaluate_once(data, sv, model, sess, train_dir, log, id_to_word, data_ngram_counts, eval_saver): """Evaluate model for a number of steps. Args: data: Dataset. sv: Supervisor. model: The GAN model we have just built. sess: A session to use. train_dir: Path to a directory containing checkpoints. log: Evaluation log for evaluation. id_to_word: Dictionary of indices to words. data_ngram_counts: Dictionary of hashed(n-gram tuples) to counts in the data_set. eval_saver: Evaluation saver.r. """ tf.logging.info('Evaluate Once.') # Load the last model checkpoint, or initialize the graph. model_save_path = tf.latest_checkpoint(train_dir) if not model_save_path: tf.logging.warning('No checkpoint yet in: %s', train_dir) return tf.logging.info('Starting eval of: %s' % model_save_path) tf.logging.info('Only restoring trainable variables.') eval_saver.restore(sess, model_save_path) # Run the requested number of evaluation steps avg_epoch_gen_loss, avg_epoch_dis_loss = [], [] cumulative_costs = 0. # Average percent captured for each of the n-grams. avg_percent_captured = {'2': 0., '3': 0., '4': 0.} # Set a random seed to keep fixed mask. np.random.seed(0) gen_iters = 0 # Generator statefulness over the epoch. # TODO(liamfedus): Check this. [gen_initial_state_eval, fake_gen_initial_state_eval] = sess.run( [model.eval_initial_state, model.fake_gen_initial_state]) if FLAGS.eval_language_model: is_present_rate = 0. tf.logging.info('Overriding is_present_rate=0. for evaluation.') print('Overriding is_present_rate=0. for evaluation.') iterator = get_iterator(data) for x, y, _ in iterator: if FLAGS.eval_language_model: is_present_rate = 0. else: is_present_rate = FLAGS.is_present_rate tf.logging.info('Evaluating on is_present_rate=%.3f.' % is_present_rate) model_utils.assign_percent_real(sess, model.percent_real_update, model.new_rate, is_present_rate) # Randomly mask out tokens. p = model_utils.generate_mask() eval_feed = {model.inputs: x, model.targets: y, model.present: p} if FLAGS.data_set == 'ptb': # Statefulness for *evaluation* Generator. for i, (c, h) in enumerate(model.eval_initial_state): eval_feed[c] = gen_initial_state_eval[i].c eval_feed[h] = gen_initial_state_eval[i].h # Statefulness for the Generator. for i, (c, h) in enumerate(model.fake_gen_initial_state): eval_feed[c] = fake_gen_initial_state_eval[i].c eval_feed[h] = fake_gen_initial_state_eval[i].h [ gen_log_perplexity_eval, dis_loss_eval, gen_loss_eval, gen_initial_state_eval, fake_gen_initial_state_eval, step ] = sess.run( [ model.avg_log_perplexity, model.dis_loss, model.gen_loss, model.eval_final_state, model.fake_gen_final_state, model.global_step ], feed_dict=eval_feed) for n, data_ngram_count in data_ngram_counts.iteritems(): batch_percent_captured = evaluation_utils.sequence_ngram_evaluation( sess, model.fake_sequence, log, eval_feed, data_ngram_count, int(n)) avg_percent_captured[n] += batch_percent_captured cumulative_costs += gen_log_perplexity_eval avg_epoch_dis_loss.append(dis_loss_eval) avg_epoch_gen_loss.append(gen_loss_eval) gen_iters += 1 # Calulate rolling metrics. perplexity = np.exp(cumulative_costs / gen_iters) for n, _ in avg_percent_captured.iteritems(): avg_percent_captured[n] /= gen_iters # Confirm perplexity is not infinite. if not np.isfinite(perplexity) or perplexity >= FLAGS.perplexity_threshold: print('Evaluation raising FloatingPointError.') raise FloatingPointError( 'Evaluation infinite perplexity: %.3f' % perplexity) ## Printing and logging. evaluation_utils.print_and_log_losses(log, step, is_present_rate, avg_epoch_dis_loss, avg_epoch_gen_loss) print(' perplexity: %.3f' % perplexity) log.write(' perplexity: %.3f\n' % perplexity) for n, n_gram_percent in avg_percent_captured.iteritems(): n = int(n) print(' percent of %d-grams captured: %.3f.' % (n, n_gram_percent)) log.write(' percent of %d-grams captured: %.3f.\n' % (n, n_gram_percent)) samples = evaluation_utils.generate_logs(sess, model, log, id_to_word, eval_feed) ## Summaries. summary_str = sess.run(model.merge_summaries_op, feed_dict=eval_feed) sv.SummaryComputed(sess, summary_str) # Summary: text summary_str = sess.run(model.text_summary_op, {model.text_summary_placeholder: '\n\n'.join(samples)}) sv.SummaryComputed(sess, summary_str, global_step=step) # Summary: n-gram for n, n_gram_percent in avg_percent_captured.iteritems(): n = int(n) summary_percent_str = tf.Summary(value=[ tf.Summary.Value( tag='general/%d-grams_percent_correct' % n, simple_value=n_gram_percent) ]) sv.SummaryComputed(sess, summary_percent_str, global_step=step) # Summary: geometric_avg geometric_avg = compute_geometric_average(avg_percent_captured) summary_geometric_avg_str = tf.Summary(value=[ tf.Summary.Value(tag='general/geometric_avg', simple_value=geometric_avg) ]) sv.SummaryComputed(sess, summary_geometric_avg_str, global_step=step) # Summary: arithmetic_avg arithmetic_avg = compute_arithmetic_average(avg_percent_captured) summary_arithmetic_avg_str = tf.Summary(value=[ tf.Summary.Value( tag='general/arithmetic_avg', simple_value=arithmetic_avg) ]) sv.SummaryComputed(sess, summary_arithmetic_avg_str, global_step=step) # Summary: perplexity summary_perplexity_str = tf.Summary(value=[ tf.Summary.Value(tag='general/perplexity', simple_value=perplexity) ]) sv.SummaryComputed(sess, summary_perplexity_str, global_step=step)
def train_model(hparams, data, log_dir, log, id_to_word, data_ngram_counts): """Train model. Args: hparams: Hyperparameters for the MaskGAN. data: Data to evaluate. log_dir: Directory to save checkpoints. log: Readable log for the experiment. id_to_word: Dictionary of indices to words. data_ngram_counts: Dictionary of hashed(n-gram tuples) to counts in the data_set. """ print('Training model.') tf.logging.info('Training model.') # Boolean indicating operational mode. is_training = True # Write all the information to the logs. log.write('hparams\n') log.write(str(hparams)) log.flush() is_chief = FLAGS.task == 0 with tf.Graph().as_default(): with tf.device(tf.train.replica_device_setter(FLAGS.ps_tasks)): container_name = '' with tf.container(container_name): # Construct the model. if FLAGS.num_rollouts == 1: model = create_MaskGAN(hparams, is_training) elif FLAGS.num_rollouts > 1: model = rollout.create_rollout_MaskGAN(hparams, is_training) else: raise ValueError print('\nTrainable Variables in Graph:') for v in tf.trainable_variables(): print(v) ## Retrieve the initial savers. init_savers = model_utils.retrieve_init_savers(hparams) ## Initial saver function to supervisor. init_fn = partial(model_utils.init_fn, init_savers) # Create the supervisor. It will take care of initialization, # summaries, checkpoints, and recovery. sv = tf.train.Supervisor( logdir=log_dir, is_chief=is_chief, saver=model.saver, global_step=model.global_step, save_model_secs=60, recovery_wait_secs=30, summary_op=None, init_fn=init_fn) # Get an initialized, and possibly recovered session. Launch the # services: Checkpointing, Summaries, step counting. # # When multiple replicas of this program are running the services are # only launched by the 'chief' replica. with sv.managed_session(FLAGS.master) as sess: ## Pretrain the generator. if FLAGS.gen_pretrain_steps: pretrain_mask_gan.pretrain_generator(sv, sess, model, data, log, id_to_word, data_ngram_counts, is_chief) ## Pretrain the discriminator. if FLAGS.dis_pretrain_steps: pretrain_mask_gan.pretrain_discriminator( sv, sess, model, data, log, id_to_word, data_ngram_counts, is_chief) # Initial indicators for printing and summarizing. print_step_division = -1 summary_step_division = -1 # Run iterative computation in a loop. while not sv.ShouldStop(): is_present_rate = FLAGS.is_present_rate if FLAGS.is_present_rate_decay is not None: is_present_rate *= (1. - FLAGS.is_present_rate_decay) model_utils.assign_percent_real(sess, model.percent_real_update, model.new_rate, is_present_rate) # GAN training. avg_epoch_gen_loss, avg_epoch_dis_loss = [], [] cumulative_costs = 0. gen_iters = 0 # Generator and Discriminator statefulness initial evaluation. # TODO(liamfedus): Throughout the code I am implicitly assuming # that the Generator and Discriminator are equal sized. [gen_initial_state_eval, fake_gen_initial_state_eval] = sess.run( [model.eval_initial_state, model.fake_gen_initial_state]) dis_initial_state_eval = fake_gen_initial_state_eval # Save zeros state to reset later. zeros_state = fake_gen_initial_state_eval ## Offset Discriminator. if FLAGS.ps_tasks == 0: dis_offset = 1 else: dis_offset = FLAGS.task * 1000 + 1 dis_iterator = get_iterator(data) for i in range(dis_offset): try: dis_x, dis_y, _ = next(dis_iterator) except StopIteration: dis_iterator = get_iterator(data) dis_initial_state_eval = zeros_state dis_x, dis_y, _ = next(dis_iterator) p = model_utils.generate_mask() # Construct the train feed. train_feed = { model.inputs: dis_x, model.targets: dis_y, model.present: p } if FLAGS.data_set == 'ptb': # Statefulness of the Generator being used for Discriminator. for i, (c, h) in enumerate(model.fake_gen_initial_state): train_feed[c] = dis_initial_state_eval[i].c train_feed[h] = dis_initial_state_eval[i].h # Determine the state had the Generator run over real data. We # use this state for the Discriminator. [dis_initial_state_eval] = sess.run( [model.fake_gen_final_state], train_feed) ## Training loop. iterator = get_iterator(data) gen_initial_state_eval = zeros_state if FLAGS.ps_tasks > 0: gen_offset = FLAGS.task * 1000 + 1 for i in range(gen_offset): try: next(iterator) except StopIteration: dis_iterator = get_iterator(data) dis_initial_state_eval = zeros_state next(dis_iterator) for x, y, _ in iterator: for _ in xrange(hparams.dis_train_iterations): try: dis_x, dis_y, _ = next(dis_iterator) except StopIteration: dis_iterator = get_iterator(data) dis_initial_state_eval = zeros_state dis_x, dis_y, _ = next(dis_iterator) if FLAGS.data_set == 'ptb': [dis_initial_state_eval] = sess.run( [model.fake_gen_initial_state]) p = model_utils.generate_mask() # Construct the train feed. train_feed = { model.inputs: dis_x, model.targets: dis_y, model.present: p } # Statefulness for the Discriminator. if FLAGS.data_set == 'ptb': for i, (c, h) in enumerate(model.fake_gen_initial_state): train_feed[c] = dis_initial_state_eval[i].c train_feed[h] = dis_initial_state_eval[i].h _, dis_loss_eval, step = sess.run( [model.dis_train_op, model.dis_loss, model.global_step], feed_dict=train_feed) # Determine the state had the Generator run over real data. # Use this state for the Discriminator. [dis_initial_state_eval] = sess.run( [model.fake_gen_final_state], train_feed) # Randomly mask out tokens. p = model_utils.generate_mask() # Construct the train feed. train_feed = {model.inputs: x, model.targets: y, model.present: p} # Statefulness for Generator. if FLAGS.data_set == 'ptb': tf.logging.info('Generator is stateful.') print('Generator is stateful.') # Statefulness for *evaluation* Generator. for i, (c, h) in enumerate(model.eval_initial_state): train_feed[c] = gen_initial_state_eval[i].c train_feed[h] = gen_initial_state_eval[i].h # Statefulness for Generator. for i, (c, h) in enumerate(model.fake_gen_initial_state): train_feed[c] = fake_gen_initial_state_eval[i].c train_feed[h] = fake_gen_initial_state_eval[i].h # Determine whether to decay learning rate. lr_decay = hparams.gen_learning_rate_decay**max( step + 1 - hparams.gen_full_learning_rate_steps, 0.0) # Assign learning rate. gen_learning_rate = hparams.gen_learning_rate * lr_decay model_utils.assign_learning_rate(sess, model.learning_rate_update, model.new_learning_rate, gen_learning_rate) [_, gen_loss_eval, gen_log_perplexity_eval, step] = sess.run( [ model.gen_train_op, model.gen_loss, model.avg_log_perplexity, model.global_step ], feed_dict=train_feed) cumulative_costs += gen_log_perplexity_eval gen_iters += 1 # Determine the state had the Generator run over real data. [gen_initial_state_eval, fake_gen_initial_state_eval] = sess.run( [model.eval_final_state, model.fake_gen_final_state], train_feed) avg_epoch_dis_loss.append(dis_loss_eval) avg_epoch_gen_loss.append(gen_loss_eval) ## Summaries. # Calulate rolling perplexity. perplexity = np.exp(cumulative_costs / gen_iters) if is_chief and (step / FLAGS.summaries_every > summary_step_division): summary_step_division = step / FLAGS.summaries_every # Confirm perplexity is not infinite. if (not np.isfinite(perplexity) or perplexity >= FLAGS.perplexity_threshold): print('Training raising FloatingPoinError.') raise FloatingPointError( 'Training infinite perplexity: %.3f' % perplexity) # Graph summaries. summary_str = sess.run( model.merge_summaries_op, feed_dict=train_feed) sv.SummaryComputed(sess, summary_str) # Summary: n-gram avg_percent_captured = {'2': 0., '3': 0., '4': 0.} for n, data_ngram_count in data_ngram_counts.iteritems(): batch_percent_captured = evaluation_utils.sequence_ngram_evaluation( sess, model.fake_sequence, log, train_feed, data_ngram_count, int(n)) summary_percent_str = tf.Summary(value=[ tf.Summary.Value( tag='general/%s-grams_percent_correct' % n, simple_value=batch_percent_captured) ]) sv.SummaryComputed( sess, summary_percent_str, global_step=step) # Summary: geometric_avg geometric_avg = compute_geometric_average(avg_percent_captured) summary_geometric_avg_str = tf.Summary(value=[ tf.Summary.Value( tag='general/geometric_avg', simple_value=geometric_avg) ]) sv.SummaryComputed( sess, summary_geometric_avg_str, global_step=step) # Summary: arithmetic_avg arithmetic_avg = compute_arithmetic_average( avg_percent_captured) summary_arithmetic_avg_str = tf.Summary(value=[ tf.Summary.Value( tag='general/arithmetic_avg', simple_value=arithmetic_avg) ]) sv.SummaryComputed( sess, summary_arithmetic_avg_str, global_step=step) # Summary: perplexity summary_perplexity_str = tf.Summary(value=[ tf.Summary.Value( tag='general/perplexity', simple_value=perplexity) ]) sv.SummaryComputed( sess, summary_perplexity_str, global_step=step) ## Printing and logging if is_chief and (step / FLAGS.print_every > print_step_division): print_step_division = (step / FLAGS.print_every) print('global_step: %d' % step) print(' perplexity: %.3f' % perplexity) print(' gen_learning_rate: %.6f' % gen_learning_rate) log.write('global_step: %d\n' % step) log.write(' perplexity: %.3f\n' % perplexity) log.write(' gen_learning_rate: %.6f' % gen_learning_rate) # Average percent captured for each of the n-grams. avg_percent_captured = {'2': 0., '3': 0., '4': 0.} for n, data_ngram_count in data_ngram_counts.iteritems(): batch_percent_captured = evaluation_utils.sequence_ngram_evaluation( sess, model.fake_sequence, log, train_feed, data_ngram_count, int(n)) avg_percent_captured[n] = batch_percent_captured print(' percent of %s-grams captured: %.3f.' % (n, batch_percent_captured)) log.write(' percent of %s-grams captured: %.3f.\n' % (n, batch_percent_captured)) geometric_avg = compute_geometric_average(avg_percent_captured) print(' geometric_avg: %.3f.' % geometric_avg) log.write(' geometric_avg: %.3f.' % geometric_avg) arithmetic_avg = compute_arithmetic_average( avg_percent_captured) print(' arithmetic_avg: %.3f.' % arithmetic_avg) log.write(' arithmetic_avg: %.3f.' % arithmetic_avg) evaluation_utils.print_and_log_losses( log, step, is_present_rate, avg_epoch_dis_loss, avg_epoch_gen_loss) if FLAGS.gen_training_strategy == 'reinforce': evaluation_utils.generate_RL_logs(sess, model, log, id_to_word, train_feed) else: evaluation_utils.generate_logs(sess, model, log, id_to_word, train_feed) log.flush() log.close()