def get_iterator(data): """Return the data iterator.""" if FLAGS.data_set == 'ptb': iterator = ptb_loader.ptb_iterator(data, FLAGS.batch_size, FLAGS.sequence_length, FLAGS.epoch_size_override) elif FLAGS.data_set == 'imdb': iterator = imdb_loader.imdb_iterator(data, FLAGS.batch_size, FLAGS.sequence_length) return iterator
def pretrain_generator(sv, sess, model, data, log, id_to_word, data_ngram_counts, is_chief): """Pretrain the generator with classic language modeling training.""" print('\nPretraining generator for %d steps.' % FLAGS.gen_pretrain_steps) log.write( '\nPretraining generator for %d steps.\n' % FLAGS.gen_pretrain_steps) is_pretraining = True while is_pretraining: costs = 0. iters = 0 if FLAGS.data_set == 'ptb': iterator = ptb_loader.ptb_iterator(data, FLAGS.batch_size, FLAGS.sequence_length, FLAGS.epoch_size_override) elif FLAGS.data_set == 'imdb': iterator = imdb_loader.imdb_iterator(data, FLAGS.batch_size, FLAGS.sequence_length) for x, y, _ in iterator: # For pretraining with cross entropy loss, we have all tokens in the # forward sequence present (all True). model_utils.assign_percent_real(sess, model.percent_real_update, model.new_rate, 1.0) p = np.ones(shape=[FLAGS.batch_size, FLAGS.sequence_length], dtype=bool) pretrain_feed = {model.inputs: x, model.targets: y, model.present: p} [losses, cost_eval, _, step] = sess.run( [ model.fake_cross_entropy_losses, model.avg_log_perplexity, model.gen_pretrain_op, model.global_step ], feed_dict=pretrain_feed) costs += cost_eval iters += FLAGS.sequence_length # Calulate rolling perplexity. perplexity = np.exp(costs / iters) # Summaries. if is_chief and step % FLAGS.summaries_every == 0: # Graph summaries. summary_str = sess.run( model.merge_summaries_op, feed_dict=pretrain_feed) sv.SummaryComputed(sess, summary_str) # Additional summary. for n, data_ngram_count in data_ngram_counts.iteritems(): avg_percent_captured = evaluation_utils.sequence_ngram_evaluation( sess, model.fake_sequence, log, pretrain_feed, data_ngram_count, int(n)) summary_percent_str = tf.Summary(value=[ tf.Summary.Value( tag='general/%s-grams_percent_correct' % n, simple_value=avg_percent_captured) ]) sv.SummaryComputed(sess, summary_percent_str, global_step=step) summary_perplexity_str = tf.Summary(value=[ tf.Summary.Value(tag='general/perplexity', simple_value=perplexity) ]) sv.SummaryComputed(sess, summary_perplexity_str, global_step=step) # Printing and logging if is_chief and step % FLAGS.print_every == 0: print('global_step: %d' % step) print(' generator loss: %.3f' % np.mean(losses)) print(' perplexity: %.3f' % perplexity) log.write('global_step: %d\n' % step) log.write(' generator loss: %.3f\n' % np.mean(losses)) log.write(' perplexity: %.3f\n' % perplexity) for n, data_ngram_count in data_ngram_counts.iteritems(): avg_percent_captured = evaluation_utils.sequence_ngram_evaluation( sess, model.fake_sequence, log, pretrain_feed, data_ngram_count, int(n)) print(' percent of %s-grams captured: %.3f.\n' % (n, avg_percent_captured)) log.write(' percent of %s-grams captured: %.3f.\n\n' % (n, avg_percent_captured)) evaluation_utils.generate_logs(sess, model, log, id_to_word, pretrain_feed) if step >= FLAGS.gen_pretrain_steps: is_pretraining = False break return
def pretrain_discriminator(sv, sess, model, data, log, id_to_word, data_ngram_counts, is_chief): print('\nPretraining discriminator for %d steps.' % FLAGS.dis_pretrain_steps) log.write( '\nPretraining discriminator for %d steps.\n' % FLAGS.dis_pretrain_steps) is_pretraining = True while is_pretraining: cumulative_costs = 0. iters = 0 if FLAGS.data_set == 'ptb': iterator = ptb_loader.ptb_iterator(data, FLAGS.batch_size, FLAGS.sequence_length, FLAGS.epoch_size_override) elif FLAGS.data_set == 'imdb': iterator = imdb_loader.imdb_iterator(data, FLAGS.batch_size, FLAGS.sequence_length) for x, y, _ in iterator: is_present_rate = FLAGS.is_present_rate # is_present_rate = np.random.uniform(low=0.0, high=1.0) model_utils.assign_percent_real(sess, model.percent_real_update, model.new_rate, is_present_rate) # Randomly mask out tokens. p = model_utils.generate_mask() pretrain_feed = {model.inputs: x, model.targets: y, model.present: p} [_, dis_loss_eval, gen_log_perplexity_eval, step] = sess.run( [ model.dis_pretrain_op, model.dis_loss, model.avg_log_perplexity, model.global_step ], feed_dict=pretrain_feed) cumulative_costs += gen_log_perplexity_eval iters += 1 # Calulate rolling perplexity. perplexity = np.exp(cumulative_costs / iters) # Summaries. if is_chief and step % FLAGS.summaries_every == 0: # Graph summaries. summary_str = sess.run( model.merge_summaries_op, feed_dict=pretrain_feed) sv.SummaryComputed(sess, summary_str) # Additional summary. for n, data_ngram_count in data_ngram_counts.iteritems(): avg_percent_captured = evaluation_utils.sequence_ngram_evaluation( sess, model.fake_sequence, log, pretrain_feed, data_ngram_count, int(n)) summary_percent_str = tf.Summary(value=[ tf.Summary.Value( tag='general/%s-grams_percent_correct' % n, simple_value=avg_percent_captured) ]) sv.SummaryComputed(sess, summary_percent_str, global_step=step) summary_perplexity_str = tf.Summary(value=[ tf.Summary.Value(tag='general/perplexity', simple_value=perplexity) ]) sv.SummaryComputed(sess, summary_perplexity_str, global_step=step) # Printing and logging if is_chief and step % FLAGS.print_every == 0: print('global_step: %d' % step) print(' discriminator loss: %.3f' % dis_loss_eval) print(' perplexity: %.3f' % perplexity) log.write('global_step: %d\n' % step) log.write(' discriminator loss: %.3f\n' % dis_loss_eval) log.write(' perplexity: %.3f\n' % perplexity) for n, data_ngram_count in data_ngram_counts.iteritems(): avg_percent_captured = evaluation_utils.sequence_ngram_evaluation( sess, model.fake_sequence, log, pretrain_feed, data_ngram_count, int(n)) print(' percent of %s-grams captured: %.3f.\n' % (n, avg_percent_captured)) log.write(' percent of %s-grams captured: %.3f.\n\n' % (n, avg_percent_captured)) evaluation_utils.generate_logs(sess, model, log, id_to_word, pretrain_feed) if step >= FLAGS.dis_pretrain_steps + int(FLAGS.gen_pretrain_steps or 0): is_pretraining = False break return