Esempio n. 1
0
  def run_eval(self):
    """Repeatedly runs eval iterations, logging to screen and writing summaries. Saves the model with the best loss seen so far."""
    self.model.build_graph() # build the graph
    saver = tf.train.Saver(max_to_keep=3) # we will keep 3 best checkpoints at a time
    sess = tf.Session(config=util.get_config())

    if FLAGS.embedding:
      sess.run(tf.global_variables_initializer(),feed_dict={self.model.embedding_place:self.word_vector})
    #eval_dir = os.path.join(FLAGS.log_root, "eval") # make a subdir of the root dir for eval data
    eval_dir = os.path.join(FLAGS.log_root, "eval_{}".format(
      "rouge" if FLAGS.rouge_based_eval else "loss"))  # make a subdir of the root dir for eval data
    bestmodel_save_path = os.path.join(eval_dir, 'bestmodel') # this is where checkpoints of best models are saved
    self.summary_writer = tf.summary.FileWriter(eval_dir)

    if FLAGS.ac_training:
      tf.logging.info('DDQN building graph')
      t1 = time.time()
      dqn_graph = tf.Graph()
      with dqn_graph.as_default():
        self.dqn.build_graph() # build dqn graph
        tf.logging.info('building current network took {} seconds'.format(time.time()-t1))
        self.dqn_target.build_graph() # build dqn target graph
        tf.logging.info('building target network took {} seconds'.format(time.time()-t1))
        dqn_saver = tf.train.Saver(max_to_keep=3) # keep 3 checkpoints at a time
        dqn_sess = tf.Session(config=util.get_config())
      dqn_train_step = 0
      replay_buffer = ReplayBuffer(self.dqn_hps)

    running_avg_loss = 0 # the eval job keeps a smoother, running average loss to tell it when to implement early stopping
    best_loss = self.restore_best_eval_model()  # will hold the best loss achieved so far
    train_step = 0
    decay = 0.99

    while True:
      _ = util.load_ckpt(saver, sess) # load a new checkpoint
      if FLAGS.ac_training:
        _ = util.load_dqn_ckpt(dqn_saver, dqn_sess) # load a new checkpoint
      processed_batch = 0
      avg_losses = []
      greedy_rouges = []
      sampled_rouges = []
      # evaluate for 100 * batch_size before comparing the loss
      # we do this due to memory constraint, best to run eval on different machines with large batch size

      while processed_batch < FLAGS.eval_interval*FLAGS.batch_size:
        processed_batch += FLAGS.batch_size
        batch = full_batch = self.full_batcher.next_batch()
        if FLAGS.rl_training:
          partial_batch = self.partial_batcher.next_batch()
          batch = Batcher.merge_batches(full_batch, partial_batch)
        if batch.is_any_null():
          print(partial_batch.original_abstracts_sents)
          print(full_batch.original_abstracts_sents)
          import ipdb
          ipdb.set_trace()
          print(np.concatenate((full_batch.original_abstracts_sents, partial_batch.original_abstracts_sents), axis=0))
          raise Exception
        else:
          partial_batch = None
        if FLAGS.ac_training:
          t0 = time.time()
          transitions = self.model.collect_dqn_transitions(sess, batch, train_step, batch.max_art_oovs) # len(batch_size * k * max_dec_steps)
          tf.logging.info('Q values collection time: {}'.format(time.time()-t0))
          with dqn_graph.as_default():
            # if using true Q-value to train DQN network,
            # we do this as the pre-training for the DQN network to get better estimates
            batch_len = len(transitions)
            b = ReplayBuffer.create_batch(self.dqn_hps, transitions,len(transitions), use_state_prime = True, max_art_oovs = batch.max_art_oovs)
            b_prime = ReplayBuffer.create_batch(self.dqn_hps, transitions,len(transitions), use_state_prime = True, max_art_oovs = batch.max_art_oovs)
            dqn_results = self.dqn.run_test_steps(sess=dqn_sess, x= b._x, return_best_action=True)
            q_estimates = dqn_results['estimates'] # shape (len(transitions), vocab_size)
            dqn_best_action = dqn_results['best_action']

            tf.logging.info('running test step on dqn_target')
            dqn_target_results = self.dqn_target.run_test_steps(dqn_sess, x= b_prime._x)
            q_vals_new_t = dqn_target_results['estimates'] # shape (len(transitions), vocab_size)

            # we need to expand the q_estimates to match the input batch max_art_oov
            q_estimates = np.concatenate([q_estimates,np.zeros((len(transitions),batch.max_art_oovs))],axis=-1)

            tf.logging.info('fixing the action q-estimates')
            for i, tr in enumerate(transitions):
              if tr.done:
                q_estimates[i][tr.action] = tr.reward
              else:
                q_estimates[i][tr.action] = tr.reward + FLAGS.gamma * q_vals_new_t[i][dqn_best_action[i]]
            if FLAGS.dqn_scheduled_sampling:
              tf.logging.info('scheduled sampling on q-estimates')
              q_estimates = self.scheduled_sampling(batch_len, FLAGS.sampling_probability, b._y_extended, q_estimates)
            if not FLAGS.calculate_true_q:
              # when we are not training DQN based on true Q-values
              # we need to update Q-values in our transitions based on this q_estimates we collected from DQN current network.
              for trans, q_val in zip(transitions,q_estimates):
                trans.q_values = q_val # each have the size vocab_extended
            q_estimates = np.reshape(q_estimates, [FLAGS.batch_size, FLAGS.k, FLAGS.max_dec_steps, -1]) # shape (batch_size, k, max_dec_steps, vocab_size_extended)
          tf.logging.info('run eval step on seq2seq model.')
          t0=time.time()
          results = self.model.run_eval_steps(sess, batch, train_step, q_estimates)
          t1=time.time()
        else:
          tf.logging.info('run eval step on seq2seq model.')
          t0=time.time()
          results = self.model.run_eval_steps(sess, batch, train_step)
          t1=time.time()

        tf.logging.info('experiment: {}'.format(FLAGS.exp_name))
        tf.logging.info('processed_batch: {}, seconds for batch: {}'.format(processed_batch, t1-t0))

        printer_helper = {}
        loss = printer_helper['pgen_loss']= results['pgen_loss']
        printer_helper['rl_full_reward_sampled'] = np.mean(results['full_ssr'])
        printer_helper['rl_full_reward_greedy'] = np.mean(results['full_gsr'])
        printer_helper['rl_full_reward_diff'] = results['full_reward_diff']
        if FLAGS.rl_training:
          loss = np.mean([results['full_rl_avg_logprobs'], results['partial_rl_avg_logprobs']])
          printer_helper['shared_loss'] = results['shared_loss']
          printer_helper['rl_full_loss'] = results['full_rl_loss']
          printer_helper['rl_full_avg_logprobs'] = results['full_rl_avg_logprobs']
          printer_helper['rl_partial_loss'] = results['partial_rl_loss']
          printer_helper['rl_partial_avg_logprobs'] = results['partial_rl_avg_logprobs']
          printer_helper['rl_partial_reward_sampled'] = results['partial_ssr']
          printer_helper['rl_partial_reward_greedy'] = results['partial_gsr']
          printer_helper['rl_partial_reward_diff'] = results['partial_reward_diff']
        if FLAGS.coverage:
          loss = printer_helper['coverage_loss'] = results['coverage_loss']
          if FLAGS.rl_training or FLAGS.ac_training:
            loss = printer_helper['rl_cov_total_loss'] = results['reinforce_cov_total_loss']
          elif FLAGS.pointer_gen:
            loss = printer_helper['pointer_cov_total_loss'] = results['pointer_cov_total_loss']
        if FLAGS.rl_training:
          greedy_rouges.append(np.mean([results['full_gsr'],results['partial_gsr']]))
          sampled_rouges.append(np.mean([results['full_ssr'],results['partial_ssr']]))
        else:
          greedy_rouges.append(np.mean(results['full_gsr']))
          sampled_rouges.append(np.mean(results['full_ssr']))

        for (k,v) in sorted(printer_helper.items(), key=lambda x: x[0]):
          if not np.isfinite(v):
            raise Exception("{} is not finite. Stopping.".format(k))
          tf.logging.info('{}: {}\t'.format(k,v))
        tf.logging.info('-------------------------------------------')
        time.sleep(2)

        # add summaries
        summaries = results['summaries']
        train_step = results['global_step']
        self.summary_writer.add_summary(summaries, train_step)

        # calculate running avg loss
        avg_losses.append(self.calc_running_avg_loss(np.asscalar(loss), running_avg_loss))
        tf.logging.info('-------------------------------------------')
        time.sleep(2)

      running_avg_loss = np.mean(avg_losses)
      running_greedy_rouge = np.mean(greedy_rouges)
      running_sampled_rouge = np.mean(sampled_rouges)
      self.summary_writer.add_summary(
        tf.Summary(value=[tf.Summary.Value(tag="running_greedy_rouge", simple_value=running_greedy_rouge), ]),
        train_step)
      self.summary_writer.add_summary(
        tf.Summary(value=[tf.Summary.Value(tag="running_sampled_rouge", simple_value=running_sampled_rouge), ]),
        train_step)

      self.summary_writer.add_summary(tf.Summary(
          value=[tf.Summary.Value(tag='running_avg_loss/decay=%f' % (decay), simple_value=running_avg_loss), ]),
                                      train_step)

      tf.logging.info('==========================================')
      tf.logging.info('best_loss: {}\trunning_avg_loss: {}\t'.format(best_loss, running_avg_loss))
      tf.logging.info('greedy rouges: {}\tsampled rouges: {}\t'.format(running_greedy_rouge, running_sampled_rouge))
      tf.logging.info('==========================================')

      # If running_avg_loss is best so far, save this checkpoint (early stopping).
      # These checkpoints will appear as bestmodel-<iteration_number> in the eval dir
      if (best_loss is None) or (FLAGS.rouge_based_eval and running_greedy_rouge > best_loss) or (
              not FLAGS.rouge_based_eval and running_avg_loss < best_loss):
        tf.logging.info('Found new best model with %.3f %s. Saving to %s',
                        running_greedy_rouge if FLAGS.rouge_based_eval else running_avg_loss,
                        "running_greedy_rouge" if FLAGS.rouge_based_eval else "running_avg_loss",
                        bestmodel_save_path)
        saver.save(sess, bestmodel_save_path, global_step=train_step, latest_filename='checkpoint_best')
        best_loss = running_greedy_rouge if FLAGS.rouge_based_eval else running_avg_loss
        time.sleep(15)

      # flush the summary writer every so often
      if train_step % 100 == 0:
        self.summary_writer.flush()
Esempio n. 2
0
  def run_training(self):
    """Repeatedly runs training iterations, logging loss to screen and writing summaries"""
    tf.logging.info("Starting run_training")

    if FLAGS.debug: # start the tensorflow debugger
      self.sess = tf_debug.LocalCLIDebugWrapperSession(self.sess)
      self.sess.add_tensor_filter("has_inf_or_nan", tf_debug.has_inf_or_nan)

    self.train_step = 0
    if FLAGS.ac_training:
      # DDQN training is done asynchronously along with model training
      tf.logging.info('Starting DQN training thread...')
      self.dqn_train_step = 0
      self.thrd_dqn_training = Thread(target=self.dqn_training)
      self.thrd_dqn_training.daemon = True
      self.thrd_dqn_training.start()

      watcher = Thread(target=self.watch_threads)
      watcher.daemon = True
      watcher.start()
    # starting the main thread
    tf.logging.info('Starting Seq2Seq training...')
    while True: # repeats until interrupted
      batch = full_batch = self.full_batcher.next_batch()
      if FLAGS.rl_training:
        partial_batch = self.partial_batcher.next_batch()
        batch = Batcher.merge_batches(full_batch, partial_batch)
      if batch.is_any_null():
        print(partial_batch.original_abstracts_sents)
        print(full_batch.original_abstracts_sents)
        raise Exception
      t0=time.time()
      results = self.model.run_train_steps(self.sess, batch, self.train_step)
      t1=time.time()
      # get the summaries and iteration number so we can write summaries to tensorboard
      summaries = results['summaries'] # we will write these summaries to tensorboard using summary_writer
      self.train_step = results['global_step'] # we need this to update our running average loss
      self.epoch = (self.train_step * FLAGS.batch_size)/FLAGS.train_size
      tf.logging.info('seconds for training step {} in epoch {}: {}'.format(self.train_step, self.epoch, t1-t0))

      printer_helper = {}
      printer_helper['pgen_loss']= results['pgen_loss']
      if FLAGS.coverage:
        printer_helper['coverage_loss'] = results['coverage_loss']
        if FLAGS.rl_training or FLAGS.ac_training:
          printer_helper['rl_cov_total_loss']= results['reinforce_cov_total_loss']
        else:
          printer_helper['pointer_cov_total_loss'] = results['pointer_cov_total_loss']
      printer_helper['rl_full_reward_sampled'] = results['full_ssr']
      printer_helper['rl_full_reward_greedy'] = results['full_gsr']
      printer_helper['rl_full_reward_diff'] = results['full_reward_diff']
      if FLAGS.rl_training:
        printer_helper['rl_full_loss'] = results['full_rl_loss']
        printer_helper['rl_full_avg_logprobs'] = results['full_rl_avg_logprobs']
        printer_helper['shared_loss'] = results['shared_loss']
        printer_helper['rl_partial_loss'] = results['partial_rl_loss']
        printer_helper['rl_partial_avg_logprobs'] = results['partial_rl_avg_logprobs']
        printer_helper['rl_partial_reward_sampled'] = results['partial_ssr']
        printer_helper['rl_partial_reward_greedy'] = results['partial_gsr']
        printer_helper['rl_partial_reward_diff'] = results['partial_reward_diff']

      for (k,v) in sorted(printer_helper.items(), key=lambda x: x[0]):
        if not np.isfinite(v):
          raise Exception("{} is not finite. Stopping.".format(k))
        tf.logging.info('{}: {}\t'.format(k,v))
      tf.logging.info('-------------------------------------------')

      self.summary_writer.add_summary(summaries, self.train_step) # write the summaries
      if self.train_step % 100 == 0: # flush the summary writer every so often
        self.summary_writer.flush()
      if self.train_step > FLAGS.max_iter: break