def __init__(self, model, batcher, vocab, dqn=None): """Initialize decoder. Args: model: a Seq2SeqAttentionModel object. batcher: a Batcher object. vocab: Vocabulary object """ self._model = model self._model.build_graph() self._batcher = batcher self._vocab = vocab self._saver = tf.train.Saver( ) # we use this to load checkpoints for decoding self._sess = tf.Session(config=util.get_config()) if FLAGS.ac_training: self._dqn = dqn self._dqn_graph = tf.Graph() with self._dqn_graph.as_default(): self._dqn.build_graph() self._dqn_saver = tf.train.Saver( ) # we use this to load checkpoints for decoding self._dqn_sess = tf.Session(config=util.get_config()) _ = util.load_dqn_ckpt(self._dqn_saver, self._dqn_sess) # Load an initial checkpoint to use for decoding ckpt_path = util.load_ckpt(self._saver, self._sess, FLAGS.decode_from) if FLAGS.single_pass: # Make a descriptive decode directory name ckpt_name = "{}-ckpt-".format(FLAGS.decode_from) + ckpt_path.split( '-')[-1] # this is something of the form "ckpt-123456" self._decode_dir = os.path.join(FLAGS.log_root, get_decode_dir_name(ckpt_name)) else: # Generic decode dir name self._decode_dir = os.path.join(FLAGS.log_root, "decode") # Make the decode dir if necessary if not os.path.exists(self._decode_dir): os.mkdir(self._decode_dir) if FLAGS.single_pass: # Make the dirs to contain output written in the correct format for pyrouge self._rouge_ref_dir = os.path.join(self._decode_dir, "reference") if not os.path.exists(self._rouge_ref_dir): os.mkdir(self._rouge_ref_dir) self._rouge_dec_dir = os.path.join(self._decode_dir, "decoded") if not os.path.exists(self._rouge_dec_dir): os.mkdir(self._rouge_dec_dir)
def __init__(self, model, batcher, vocab, dqn = None): """Initialize decoder. Args: model: a Seq2SeqAttentionModel object. batcher: a Batcher object. vocab: Vocabulary object """ self._model = model self._model.build_graph() self._batcher = batcher self._vocab = vocab self._saver = tf.train.Saver() # we use this to load checkpoints for decoding self._sess = tf.Session(config=util.get_config()) if FLAGS.ac_training: self._dqn = dqn self._dqn_graph = tf.Graph() with self._dqn_graph.as_default(): self._dqn.build_graph() self._dqn_saver = tf.train.Saver() # we use this to load checkpoints for decoding self._dqn_sess = tf.Session(config=util.get_config()) _ = util.load_dqn_ckpt(self._dqn_saver, self._dqn_sess) # Load an initial checkpoint to use for decoding ckpt_path = util.load_ckpt(self._saver, self._sess, FLAGS.decode_from) if FLAGS.single_pass: # Make a descriptive decode directory name ckpt_name = "{}-ckpt-".format(FLAGS.decode_from) + ckpt_path.split('-')[ -1] # this is something of the form "ckpt-123456" self._decode_dir = os.path.join(FLAGS.log_root, get_decode_dir_name(ckpt_name)) else: # Generic decode dir name self._decode_dir = os.path.join(FLAGS.log_root, "decode") # Make the decode dir if necessary if not os.path.exists(self._decode_dir): os.mkdir(self._decode_dir) if FLAGS.single_pass: # Make the dirs to contain output written in the correct format for pyrouge self._rouge_ref_dir = os.path.join(self._decode_dir, "reference") if not os.path.exists(self._rouge_ref_dir): os.mkdir(self._rouge_ref_dir) self._rouge_dec_dir = os.path.join(self._decode_dir, "decoded") if not os.path.exists(self._rouge_dec_dir): os.mkdir(self._rouge_dec_dir)
def run_eval(self): """Repeatedly runs eval iterations, logging to screen and writing summaries. Saves the model with the best loss seen so far.""" self.model.build_graph() # build the graph saver = tf.train.Saver(max_to_keep=3) # we will keep 3 best checkpoints at a time sess = tf.Session(config=util.get_config()) if FLAGS.embedding: sess.run(tf.global_variables_initializer(),feed_dict={self.model.embedding_place:self.word_vector}) #eval_dir = os.path.join(FLAGS.log_root, "eval") # make a subdir of the root dir for eval data eval_dir = os.path.join(FLAGS.log_root, "eval_{}".format( "rouge" if FLAGS.rouge_based_eval else "loss")) # make a subdir of the root dir for eval data bestmodel_save_path = os.path.join(eval_dir, 'bestmodel') # this is where checkpoints of best models are saved self.summary_writer = tf.summary.FileWriter(eval_dir) if FLAGS.ac_training: tf.logging.info('DDQN building graph') t1 = time.time() dqn_graph = tf.Graph() with dqn_graph.as_default(): self.dqn.build_graph() # build dqn graph tf.logging.info('building current network took {} seconds'.format(time.time()-t1)) self.dqn_target.build_graph() # build dqn target graph tf.logging.info('building target network took {} seconds'.format(time.time()-t1)) dqn_saver = tf.train.Saver(max_to_keep=3) # keep 3 checkpoints at a time dqn_sess = tf.Session(config=util.get_config()) dqn_train_step = 0 replay_buffer = ReplayBuffer(self.dqn_hps) running_avg_loss = 0 # the eval job keeps a smoother, running average loss to tell it when to implement early stopping best_loss = self.restore_best_eval_model() # will hold the best loss achieved so far train_step = 0 decay = 0.99 while True: _ = util.load_ckpt(saver, sess) # load a new checkpoint if FLAGS.ac_training: _ = util.load_dqn_ckpt(dqn_saver, dqn_sess) # load a new checkpoint processed_batch = 0 avg_losses = [] greedy_rouges = [] sampled_rouges = [] # evaluate for 100 * batch_size before comparing the loss # we do this due to memory constraint, best to run eval on different machines with large batch size while processed_batch < FLAGS.eval_interval*FLAGS.batch_size: processed_batch += FLAGS.batch_size batch = full_batch = self.full_batcher.next_batch() if FLAGS.rl_training: partial_batch = self.partial_batcher.next_batch() batch = Batcher.merge_batches(full_batch, partial_batch) if batch.is_any_null(): print(partial_batch.original_abstracts_sents) print(full_batch.original_abstracts_sents) import ipdb ipdb.set_trace() print(np.concatenate((full_batch.original_abstracts_sents, partial_batch.original_abstracts_sents), axis=0)) raise Exception else: partial_batch = None if FLAGS.ac_training: t0 = time.time() transitions = self.model.collect_dqn_transitions(sess, batch, train_step, batch.max_art_oovs) # len(batch_size * k * max_dec_steps) tf.logging.info('Q values collection time: {}'.format(time.time()-t0)) with dqn_graph.as_default(): # if using true Q-value to train DQN network, # we do this as the pre-training for the DQN network to get better estimates batch_len = len(transitions) b = ReplayBuffer.create_batch(self.dqn_hps, transitions,len(transitions), use_state_prime = True, max_art_oovs = batch.max_art_oovs) b_prime = ReplayBuffer.create_batch(self.dqn_hps, transitions,len(transitions), use_state_prime = True, max_art_oovs = batch.max_art_oovs) dqn_results = self.dqn.run_test_steps(sess=dqn_sess, x= b._x, return_best_action=True) q_estimates = dqn_results['estimates'] # shape (len(transitions), vocab_size) dqn_best_action = dqn_results['best_action'] tf.logging.info('running test step on dqn_target') dqn_target_results = self.dqn_target.run_test_steps(dqn_sess, x= b_prime._x) q_vals_new_t = dqn_target_results['estimates'] # shape (len(transitions), vocab_size) # we need to expand the q_estimates to match the input batch max_art_oov q_estimates = np.concatenate([q_estimates,np.zeros((len(transitions),batch.max_art_oovs))],axis=-1) tf.logging.info('fixing the action q-estimates') for i, tr in enumerate(transitions): if tr.done: q_estimates[i][tr.action] = tr.reward else: q_estimates[i][tr.action] = tr.reward + FLAGS.gamma * q_vals_new_t[i][dqn_best_action[i]] if FLAGS.dqn_scheduled_sampling: tf.logging.info('scheduled sampling on q-estimates') q_estimates = self.scheduled_sampling(batch_len, FLAGS.sampling_probability, b._y_extended, q_estimates) if not FLAGS.calculate_true_q: # when we are not training DQN based on true Q-values # we need to update Q-values in our transitions based on this q_estimates we collected from DQN current network. for trans, q_val in zip(transitions,q_estimates): trans.q_values = q_val # each have the size vocab_extended q_estimates = np.reshape(q_estimates, [FLAGS.batch_size, FLAGS.k, FLAGS.max_dec_steps, -1]) # shape (batch_size, k, max_dec_steps, vocab_size_extended) tf.logging.info('run eval step on seq2seq model.') t0=time.time() results = self.model.run_eval_steps(sess, batch, train_step, q_estimates) t1=time.time() else: tf.logging.info('run eval step on seq2seq model.') t0=time.time() results = self.model.run_eval_steps(sess, batch, train_step) t1=time.time() tf.logging.info('experiment: {}'.format(FLAGS.exp_name)) tf.logging.info('processed_batch: {}, seconds for batch: {}'.format(processed_batch, t1-t0)) printer_helper = {} loss = printer_helper['pgen_loss']= results['pgen_loss'] printer_helper['rl_full_reward_sampled'] = np.mean(results['full_ssr']) printer_helper['rl_full_reward_greedy'] = np.mean(results['full_gsr']) printer_helper['rl_full_reward_diff'] = results['full_reward_diff'] if FLAGS.rl_training: loss = np.mean([results['full_rl_avg_logprobs'], results['partial_rl_avg_logprobs']]) printer_helper['shared_loss'] = results['shared_loss'] printer_helper['rl_full_loss'] = results['full_rl_loss'] printer_helper['rl_full_avg_logprobs'] = results['full_rl_avg_logprobs'] printer_helper['rl_partial_loss'] = results['partial_rl_loss'] printer_helper['rl_partial_avg_logprobs'] = results['partial_rl_avg_logprobs'] printer_helper['rl_partial_reward_sampled'] = results['partial_ssr'] printer_helper['rl_partial_reward_greedy'] = results['partial_gsr'] printer_helper['rl_partial_reward_diff'] = results['partial_reward_diff'] if FLAGS.coverage: loss = printer_helper['coverage_loss'] = results['coverage_loss'] if FLAGS.rl_training or FLAGS.ac_training: loss = printer_helper['rl_cov_total_loss'] = results['reinforce_cov_total_loss'] elif FLAGS.pointer_gen: loss = printer_helper['pointer_cov_total_loss'] = results['pointer_cov_total_loss'] if FLAGS.rl_training: greedy_rouges.append(np.mean([results['full_gsr'],results['partial_gsr']])) sampled_rouges.append(np.mean([results['full_ssr'],results['partial_ssr']])) else: greedy_rouges.append(np.mean(results['full_gsr'])) sampled_rouges.append(np.mean(results['full_ssr'])) for (k,v) in sorted(printer_helper.items(), key=lambda x: x[0]): if not np.isfinite(v): raise Exception("{} is not finite. Stopping.".format(k)) tf.logging.info('{}: {}\t'.format(k,v)) tf.logging.info('-------------------------------------------') time.sleep(2) # add summaries summaries = results['summaries'] train_step = results['global_step'] self.summary_writer.add_summary(summaries, train_step) # calculate running avg loss avg_losses.append(self.calc_running_avg_loss(np.asscalar(loss), running_avg_loss)) tf.logging.info('-------------------------------------------') time.sleep(2) running_avg_loss = np.mean(avg_losses) running_greedy_rouge = np.mean(greedy_rouges) running_sampled_rouge = np.mean(sampled_rouges) self.summary_writer.add_summary( tf.Summary(value=[tf.Summary.Value(tag="running_greedy_rouge", simple_value=running_greedy_rouge), ]), train_step) self.summary_writer.add_summary( tf.Summary(value=[tf.Summary.Value(tag="running_sampled_rouge", simple_value=running_sampled_rouge), ]), train_step) self.summary_writer.add_summary(tf.Summary( value=[tf.Summary.Value(tag='running_avg_loss/decay=%f' % (decay), simple_value=running_avg_loss), ]), train_step) tf.logging.info('==========================================') tf.logging.info('best_loss: {}\trunning_avg_loss: {}\t'.format(best_loss, running_avg_loss)) tf.logging.info('greedy rouges: {}\tsampled rouges: {}\t'.format(running_greedy_rouge, running_sampled_rouge)) tf.logging.info('==========================================') # If running_avg_loss is best so far, save this checkpoint (early stopping). # These checkpoints will appear as bestmodel-<iteration_number> in the eval dir if (best_loss is None) or (FLAGS.rouge_based_eval and running_greedy_rouge > best_loss) or ( not FLAGS.rouge_based_eval and running_avg_loss < best_loss): tf.logging.info('Found new best model with %.3f %s. Saving to %s', running_greedy_rouge if FLAGS.rouge_based_eval else running_avg_loss, "running_greedy_rouge" if FLAGS.rouge_based_eval else "running_avg_loss", bestmodel_save_path) saver.save(sess, bestmodel_save_path, global_step=train_step, latest_filename='checkpoint_best') best_loss = running_greedy_rouge if FLAGS.rouge_based_eval else running_avg_loss time.sleep(15) # flush the summary writer every so often if train_step % 100 == 0: self.summary_writer.flush()
def run_eval(self): """Repeatedly runs eval iterations, logging to screen and writing summaries. Saves the model with the best loss seen so far.""" self.model.build_graph() # build the graph saver = tf.train.Saver(max_to_keep=3) # we will keep 3 best checkpoints at a time sess = tf.Session(config=util.get_config()) if FLAGS.embedding: sess.run(tf.global_variables_initializer(),feed_dict={self.model.embedding_place:self.word_vector}) eval_dir = os.path.join(FLAGS.log_root, "eval") # make a subdir of the root dir for eval data bestmodel_save_path = os.path.join(eval_dir, 'bestmodel') # this is where checkpoints of best models are saved self.summary_writer = tf.summary.FileWriter(eval_dir) if FLAGS.ac_training: tf.logging.info('DDQN building graph') t1 = time.time() dqn_graph = tf.Graph() with dqn_graph.as_default(): self.dqn.build_graph() # build dqn graph tf.logging.info('building current network took {} seconds'.format(time.time()-t1)) self.dqn_target.build_graph() # build dqn target graph tf.logging.info('building target network took {} seconds'.format(time.time()-t1)) dqn_saver = tf.train.Saver(max_to_keep=3) # keep 3 checkpoints at a time dqn_sess = tf.Session(config=util.get_config()) dqn_train_step = 0 replay_buffer = ReplayBuffer(self.dqn_hps) running_avg_loss = 0 # the eval job keeps a smoother, running average loss to tell it when to implement early stopping best_loss = self.restore_best_eval_model() # will hold the best loss achieved so far train_step = 0 while True: _ = util.load_ckpt(saver, sess) # load a new checkpoint if FLAGS.ac_training: _ = util.load_dqn_ckpt(dqn_saver, dqn_sess) # load a new checkpoint processed_batch = 0 avg_losses = [] # evaluate for 100 * batch_size before comparing the loss # we do this due to memory constraint, best to run eval on different machines with large batch size while processed_batch < 100*FLAGS.batch_size: processed_batch += FLAGS.batch_size batch = self.batcher.next_batch() # get the next batch if FLAGS.ac_training: t0 = time.time() transitions = self.model.collect_dqn_transitions(sess, batch, train_step, batch.max_art_oovs) # len(batch_size * k * max_dec_steps) tf.logging.info('Q values collection time: {}'.format(time.time()-t0)) with dqn_graph.as_default(): # if using true Q-value to train DQN network, # we do this as the pre-training for the DQN network to get better estimates batch_len = len(transitions) b = ReplayBuffer.create_batch(self.dqn_hps, transitions,len(transitions), use_state_prime = True, max_art_oovs = batch.max_art_oovs) b_prime = ReplayBuffer.create_batch(self.dqn_hps, transitions,len(transitions), use_state_prime = True, max_art_oovs = batch.max_art_oovs) dqn_results = self.dqn.run_test_steps(sess=dqn_sess, x= b._x, return_best_action=True) q_estimates = dqn_results['estimates'] # shape (len(transitions), vocab_size) dqn_best_action = dqn_results['best_action'] tf.logging.info('running test step on dqn_target') dqn_target_results = self.dqn_target.run_test_steps(dqn_sess, x= b_prime._x) q_vals_new_t = dqn_target_results['estimates'] # shape (len(transitions), vocab_size) # we need to expand the q_estimates to match the input batch max_art_oov q_estimates = np.concatenate([q_estimates,np.zeros((len(transitions),batch.max_art_oovs))],axis=-1) tf.logging.info('fixing the action q-estimates') for i, tr in enumerate(transitions): if tr.done: q_estimates[i][tr.action] = tr.reward else: q_estimates[i][tr.action] = tr.reward + FLAGS.gamma * q_vals_new_t[i][dqn_best_action[i]] if FLAGS.dqn_scheduled_sampling: tf.logging.info('scheduled sampling on q-estimates') q_estimates = self.scheduled_sampling(batch_len, FLAGS.sampling_probability, b._y_extended, q_estimates) if not FLAGS.calculate_true_q: # when we are not training DQN based on true Q-values # we need to update Q-values in our transitions based on this q_estimates we collected from DQN current network. for trans, q_val in zip(transitions,q_estimates): trans.q_values = q_val # each have the size vocab_extended q_estimates = np.reshape(q_estimates, [FLAGS.batch_size, FLAGS.k, FLAGS.max_dec_steps, -1]) # shape (batch_size, k, max_dec_steps, vocab_size_extended) tf.logging.info('run eval step on seq2seq model.') t0=time.time() results = self.model.run_eval_step(sess, batch, train_step, q_estimates) t1=time.time() else: tf.logging.info('run eval step on seq2seq model.') t0=time.time() results = self.model.run_eval_step(sess, batch, train_step) t1=time.time() tf.logging.info('experiment: {}'.format(FLAGS.exp_name)) tf.logging.info('processed_batch: {}, seconds for batch: {}'.format(processed_batch, t1-t0)) printer_helper = {} loss = printer_helper['pgen_loss']= results['pgen_loss'] if FLAGS.coverage: printer_helper['coverage_loss'] = results['coverage_loss'] if FLAGS.rl_training or FLAGS.ac_training: printer_helper['rl_cov_total_loss']= results['reinforce_cov_total_loss'] loss = printer_helper['pointer_cov_total_loss'] = results['pointer_cov_total_loss'] if FLAGS.rl_training or FLAGS.ac_training: printer_helper['shared_loss'] = results['shared_loss'] printer_helper['rl_loss'] = results['rl_loss'] printer_helper['rl_avg_logprobs'] = results['rl_avg_logprobs'] if FLAGS.rl_training: printer_helper['sampled_r'] = np.mean(results['sampled_sentence_r_values']) printer_helper['greedy_r'] = np.mean(results['greedy_sentence_r_values']) printer_helper['r_diff'] = printer_helper['greedy_r'] - printer_helper['sampled_r'] if FLAGS.ac_training: printer_helper['dqn_loss'] = np.mean(self.avg_dqn_loss) if len(self.avg_dqn_loss) > 0 else 0 for (k,v) in printer_helper.items(): if not np.isfinite(v): raise Exception("{} is not finite. Stopping.".format(k)) tf.logging.info('{}: {}\t'.format(k,v)) # add summaries summaries = results['summaries'] train_step = results['global_step'] self.summary_writer.add_summary(summaries, train_step) # calculate running avg loss avg_losses.append(self.calc_running_avg_loss(np.asscalar(loss), running_avg_loss, train_step)) tf.logging.info('-------------------------------------------') running_avg_loss = np.mean(avg_losses) tf.logging.info('==========================================') tf.logging.info('best_loss: {}\trunning_avg_loss: {}\t'.format(best_loss, running_avg_loss)) tf.logging.info('==========================================') # If running_avg_loss is best so far, save this checkpoint (early stopping). # These checkpoints will appear as bestmodel-<iteration_number> in the eval dir if best_loss is None or running_avg_loss < best_loss: tf.logging.info('Found new best model with %.3f running_avg_loss. Saving to %s', running_avg_loss, bestmodel_save_path) saver.save(sess, bestmodel_save_path, global_step=train_step, latest_filename='checkpoint_best') best_loss = running_avg_loss # flush the summary writer every so often if train_step % 100 == 0: self.summary_writer.flush()