Example #1
0
  def run_eval(self):
    """Repeatedly runs eval iterations, logging to screen and writing summaries. Saves the model with the best loss seen so far."""
    self.model.build_graph() # build the graph
    saver = tf.train.Saver(max_to_keep=3) # we will keep 3 best checkpoints at a time
    sess = tf.Session(config=util.get_config())

    if FLAGS.embedding:
      sess.run(tf.global_variables_initializer(),feed_dict={self.model.embedding_place:self.word_vector})
    #eval_dir = os.path.join(FLAGS.log_root, "eval") # make a subdir of the root dir for eval data
    eval_dir = os.path.join(FLAGS.log_root, "eval_{}".format(
      "rouge" if FLAGS.rouge_based_eval else "loss"))  # make a subdir of the root dir for eval data
    bestmodel_save_path = os.path.join(eval_dir, 'bestmodel') # this is where checkpoints of best models are saved
    self.summary_writer = tf.summary.FileWriter(eval_dir)

    if FLAGS.ac_training:
      tf.logging.info('DDQN building graph')
      t1 = time.time()
      dqn_graph = tf.Graph()
      with dqn_graph.as_default():
        self.dqn.build_graph() # build dqn graph
        tf.logging.info('building current network took {} seconds'.format(time.time()-t1))
        self.dqn_target.build_graph() # build dqn target graph
        tf.logging.info('building target network took {} seconds'.format(time.time()-t1))
        dqn_saver = tf.train.Saver(max_to_keep=3) # keep 3 checkpoints at a time
        dqn_sess = tf.Session(config=util.get_config())
      dqn_train_step = 0
      replay_buffer = ReplayBuffer(self.dqn_hps)

    running_avg_loss = 0 # the eval job keeps a smoother, running average loss to tell it when to implement early stopping
    best_loss = self.restore_best_eval_model()  # will hold the best loss achieved so far
    train_step = 0
    decay = 0.99

    while True:
      _ = util.load_ckpt(saver, sess) # load a new checkpoint
      if FLAGS.ac_training:
        _ = util.load_dqn_ckpt(dqn_saver, dqn_sess) # load a new checkpoint
      processed_batch = 0
      avg_losses = []
      greedy_rouges = []
      sampled_rouges = []
      # evaluate for 100 * batch_size before comparing the loss
      # we do this due to memory constraint, best to run eval on different machines with large batch size

      while processed_batch < FLAGS.eval_interval*FLAGS.batch_size:
        processed_batch += FLAGS.batch_size
        batch = full_batch = self.full_batcher.next_batch()
        if FLAGS.rl_training:
          partial_batch = self.partial_batcher.next_batch()
          batch = Batcher.merge_batches(full_batch, partial_batch)
        if batch.is_any_null():
          print(partial_batch.original_abstracts_sents)
          print(full_batch.original_abstracts_sents)
          import ipdb
          ipdb.set_trace()
          print(np.concatenate((full_batch.original_abstracts_sents, partial_batch.original_abstracts_sents), axis=0))
          raise Exception
        else:
          partial_batch = None
        if FLAGS.ac_training:
          t0 = time.time()
          transitions = self.model.collect_dqn_transitions(sess, batch, train_step, batch.max_art_oovs) # len(batch_size * k * max_dec_steps)
          tf.logging.info('Q values collection time: {}'.format(time.time()-t0))
          with dqn_graph.as_default():
            # if using true Q-value to train DQN network,
            # we do this as the pre-training for the DQN network to get better estimates
            batch_len = len(transitions)
            b = ReplayBuffer.create_batch(self.dqn_hps, transitions,len(transitions), use_state_prime = True, max_art_oovs = batch.max_art_oovs)
            b_prime = ReplayBuffer.create_batch(self.dqn_hps, transitions,len(transitions), use_state_prime = True, max_art_oovs = batch.max_art_oovs)
            dqn_results = self.dqn.run_test_steps(sess=dqn_sess, x= b._x, return_best_action=True)
            q_estimates = dqn_results['estimates'] # shape (len(transitions), vocab_size)
            dqn_best_action = dqn_results['best_action']

            tf.logging.info('running test step on dqn_target')
            dqn_target_results = self.dqn_target.run_test_steps(dqn_sess, x= b_prime._x)
            q_vals_new_t = dqn_target_results['estimates'] # shape (len(transitions), vocab_size)

            # we need to expand the q_estimates to match the input batch max_art_oov
            q_estimates = np.concatenate([q_estimates,np.zeros((len(transitions),batch.max_art_oovs))],axis=-1)

            tf.logging.info('fixing the action q-estimates')
            for i, tr in enumerate(transitions):
              if tr.done:
                q_estimates[i][tr.action] = tr.reward
              else:
                q_estimates[i][tr.action] = tr.reward + FLAGS.gamma * q_vals_new_t[i][dqn_best_action[i]]
            if FLAGS.dqn_scheduled_sampling:
              tf.logging.info('scheduled sampling on q-estimates')
              q_estimates = self.scheduled_sampling(batch_len, FLAGS.sampling_probability, b._y_extended, q_estimates)
            if not FLAGS.calculate_true_q:
              # when we are not training DQN based on true Q-values
              # we need to update Q-values in our transitions based on this q_estimates we collected from DQN current network.
              for trans, q_val in zip(transitions,q_estimates):
                trans.q_values = q_val # each have the size vocab_extended
            q_estimates = np.reshape(q_estimates, [FLAGS.batch_size, FLAGS.k, FLAGS.max_dec_steps, -1]) # shape (batch_size, k, max_dec_steps, vocab_size_extended)
          tf.logging.info('run eval step on seq2seq model.')
          t0=time.time()
          results = self.model.run_eval_steps(sess, batch, train_step, q_estimates)
          t1=time.time()
        else:
          tf.logging.info('run eval step on seq2seq model.')
          t0=time.time()
          results = self.model.run_eval_steps(sess, batch, train_step)
          t1=time.time()

        tf.logging.info('experiment: {}'.format(FLAGS.exp_name))
        tf.logging.info('processed_batch: {}, seconds for batch: {}'.format(processed_batch, t1-t0))

        printer_helper = {}
        loss = printer_helper['pgen_loss']= results['pgen_loss']
        printer_helper['rl_full_reward_sampled'] = np.mean(results['full_ssr'])
        printer_helper['rl_full_reward_greedy'] = np.mean(results['full_gsr'])
        printer_helper['rl_full_reward_diff'] = results['full_reward_diff']
        if FLAGS.rl_training:
          loss = np.mean([results['full_rl_avg_logprobs'], results['partial_rl_avg_logprobs']])
          printer_helper['shared_loss'] = results['shared_loss']
          printer_helper['rl_full_loss'] = results['full_rl_loss']
          printer_helper['rl_full_avg_logprobs'] = results['full_rl_avg_logprobs']
          printer_helper['rl_partial_loss'] = results['partial_rl_loss']
          printer_helper['rl_partial_avg_logprobs'] = results['partial_rl_avg_logprobs']
          printer_helper['rl_partial_reward_sampled'] = results['partial_ssr']
          printer_helper['rl_partial_reward_greedy'] = results['partial_gsr']
          printer_helper['rl_partial_reward_diff'] = results['partial_reward_diff']
        if FLAGS.coverage:
          loss = printer_helper['coverage_loss'] = results['coverage_loss']
          if FLAGS.rl_training or FLAGS.ac_training:
            loss = printer_helper['rl_cov_total_loss'] = results['reinforce_cov_total_loss']
          elif FLAGS.pointer_gen:
            loss = printer_helper['pointer_cov_total_loss'] = results['pointer_cov_total_loss']
        if FLAGS.rl_training:
          greedy_rouges.append(np.mean([results['full_gsr'],results['partial_gsr']]))
          sampled_rouges.append(np.mean([results['full_ssr'],results['partial_ssr']]))
        else:
          greedy_rouges.append(np.mean(results['full_gsr']))
          sampled_rouges.append(np.mean(results['full_ssr']))

        for (k,v) in sorted(printer_helper.items(), key=lambda x: x[0]):
          if not np.isfinite(v):
            raise Exception("{} is not finite. Stopping.".format(k))
          tf.logging.info('{}: {}\t'.format(k,v))
        tf.logging.info('-------------------------------------------')
        time.sleep(2)

        # add summaries
        summaries = results['summaries']
        train_step = results['global_step']
        self.summary_writer.add_summary(summaries, train_step)

        # calculate running avg loss
        avg_losses.append(self.calc_running_avg_loss(np.asscalar(loss), running_avg_loss))
        tf.logging.info('-------------------------------------------')
        time.sleep(2)

      running_avg_loss = np.mean(avg_losses)
      running_greedy_rouge = np.mean(greedy_rouges)
      running_sampled_rouge = np.mean(sampled_rouges)
      self.summary_writer.add_summary(
        tf.Summary(value=[tf.Summary.Value(tag="running_greedy_rouge", simple_value=running_greedy_rouge), ]),
        train_step)
      self.summary_writer.add_summary(
        tf.Summary(value=[tf.Summary.Value(tag="running_sampled_rouge", simple_value=running_sampled_rouge), ]),
        train_step)

      self.summary_writer.add_summary(tf.Summary(
          value=[tf.Summary.Value(tag='running_avg_loss/decay=%f' % (decay), simple_value=running_avg_loss), ]),
                                      train_step)

      tf.logging.info('==========================================')
      tf.logging.info('best_loss: {}\trunning_avg_loss: {}\t'.format(best_loss, running_avg_loss))
      tf.logging.info('greedy rouges: {}\tsampled rouges: {}\t'.format(running_greedy_rouge, running_sampled_rouge))
      tf.logging.info('==========================================')

      # If running_avg_loss is best so far, save this checkpoint (early stopping).
      # These checkpoints will appear as bestmodel-<iteration_number> in the eval dir
      if (best_loss is None) or (FLAGS.rouge_based_eval and running_greedy_rouge > best_loss) or (
              not FLAGS.rouge_based_eval and running_avg_loss < best_loss):
        tf.logging.info('Found new best model with %.3f %s. Saving to %s',
                        running_greedy_rouge if FLAGS.rouge_based_eval else running_avg_loss,
                        "running_greedy_rouge" if FLAGS.rouge_based_eval else "running_avg_loss",
                        bestmodel_save_path)
        saver.save(sess, bestmodel_save_path, global_step=train_step, latest_filename='checkpoint_best')
        best_loss = running_greedy_rouge if FLAGS.rouge_based_eval else running_avg_loss
        time.sleep(15)

      # flush the summary writer every so often
      if train_step % 100 == 0:
        self.summary_writer.flush()
Example #2
0
  def run_eval(self):
    """Repeatedly runs eval iterations, logging to screen and writing summaries. Saves the model with the best loss seen so far."""
    self.model.build_graph() # build the graph
    saver = tf.train.Saver(max_to_keep=3) # we will keep 3 best checkpoints at a time
    sess = tf.Session(config=util.get_config())

    if FLAGS.embedding:
      sess.run(tf.global_variables_initializer(),feed_dict={self.model.embedding_place:self.word_vector})
    eval_dir = os.path.join(FLAGS.log_root, "eval") # make a subdir of the root dir for eval data
    bestmodel_save_path = os.path.join(eval_dir, 'bestmodel') # this is where checkpoints of best models are saved
    self.summary_writer = tf.summary.FileWriter(eval_dir)

    if FLAGS.ac_training:
      tf.logging.info('DDQN building graph')
      t1 = time.time()
      dqn_graph = tf.Graph()
      with dqn_graph.as_default():
        self.dqn.build_graph() # build dqn graph
        tf.logging.info('building current network took {} seconds'.format(time.time()-t1))
        self.dqn_target.build_graph() # build dqn target graph
        tf.logging.info('building target network took {} seconds'.format(time.time()-t1))
        dqn_saver = tf.train.Saver(max_to_keep=3) # keep 3 checkpoints at a time
        dqn_sess = tf.Session(config=util.get_config())
      dqn_train_step = 0
      replay_buffer = ReplayBuffer(self.dqn_hps)

    running_avg_loss = 0 # the eval job keeps a smoother, running average loss to tell it when to implement early stopping
    best_loss = self.restore_best_eval_model()  # will hold the best loss achieved so far
    train_step = 0

    while True:
      _ = util.load_ckpt(saver, sess) # load a new checkpoint
      if FLAGS.ac_training:
        _ = util.load_dqn_ckpt(dqn_saver, dqn_sess) # load a new checkpoint
      processed_batch = 0
      avg_losses = []
      # evaluate for 100 * batch_size before comparing the loss
      # we do this due to memory constraint, best to run eval on different machines with large batch size
      while processed_batch < 100*FLAGS.batch_size:
        processed_batch += FLAGS.batch_size
        batch = self.batcher.next_batch() # get the next batch
        if FLAGS.ac_training:
          t0 = time.time()
          transitions = self.model.collect_dqn_transitions(sess, batch, train_step, batch.max_art_oovs) # len(batch_size * k * max_dec_steps)
          tf.logging.info('Q values collection time: {}'.format(time.time()-t0))
          with dqn_graph.as_default():
            # if using true Q-value to train DQN network,
            # we do this as the pre-training for the DQN network to get better estimates
            batch_len = len(transitions)
            b = ReplayBuffer.create_batch(self.dqn_hps, transitions,len(transitions), use_state_prime = True, max_art_oovs = batch.max_art_oovs)
            b_prime = ReplayBuffer.create_batch(self.dqn_hps, transitions,len(transitions), use_state_prime = True, max_art_oovs = batch.max_art_oovs)
            dqn_results = self.dqn.run_test_steps(sess=dqn_sess, x= b._x, return_best_action=True)
            q_estimates = dqn_results['estimates'] # shape (len(transitions), vocab_size)
            dqn_best_action = dqn_results['best_action']

            tf.logging.info('running test step on dqn_target')
            dqn_target_results = self.dqn_target.run_test_steps(dqn_sess, x= b_prime._x)
            q_vals_new_t = dqn_target_results['estimates'] # shape (len(transitions), vocab_size)

            # we need to expand the q_estimates to match the input batch max_art_oov
            q_estimates = np.concatenate([q_estimates,np.zeros((len(transitions),batch.max_art_oovs))],axis=-1)

            tf.logging.info('fixing the action q-estimates')
            for i, tr in enumerate(transitions):
              if tr.done:
                q_estimates[i][tr.action] = tr.reward
              else:
                q_estimates[i][tr.action] = tr.reward + FLAGS.gamma * q_vals_new_t[i][dqn_best_action[i]]
            if FLAGS.dqn_scheduled_sampling:
              tf.logging.info('scheduled sampling on q-estimates')
              q_estimates = self.scheduled_sampling(batch_len, FLAGS.sampling_probability, b._y_extended, q_estimates)
            if not FLAGS.calculate_true_q:
              # when we are not training DQN based on true Q-values
              # we need to update Q-values in our transitions based on this q_estimates we collected from DQN current network.
              for trans, q_val in zip(transitions,q_estimates):
                trans.q_values = q_val # each have the size vocab_extended
            q_estimates = np.reshape(q_estimates, [FLAGS.batch_size, FLAGS.k, FLAGS.max_dec_steps, -1]) # shape (batch_size, k, max_dec_steps, vocab_size_extended)
          tf.logging.info('run eval step on seq2seq model.')
          t0=time.time()
          results = self.model.run_eval_step(sess, batch, train_step, q_estimates)
          t1=time.time()
        else:
          tf.logging.info('run eval step on seq2seq model.')
          t0=time.time()
          results = self.model.run_eval_step(sess, batch, train_step)
          t1=time.time()

        tf.logging.info('experiment: {}'.format(FLAGS.exp_name))
        tf.logging.info('processed_batch: {}, seconds for batch: {}'.format(processed_batch, t1-t0))

        printer_helper = {}
        loss = printer_helper['pgen_loss']= results['pgen_loss']
        if FLAGS.coverage:
          printer_helper['coverage_loss'] = results['coverage_loss']
          if FLAGS.rl_training or FLAGS.ac_training:
            printer_helper['rl_cov_total_loss']= results['reinforce_cov_total_loss']
          loss = printer_helper['pointer_cov_total_loss'] = results['pointer_cov_total_loss']
        if FLAGS.rl_training or FLAGS.ac_training:
          printer_helper['shared_loss'] = results['shared_loss']
          printer_helper['rl_loss'] = results['rl_loss']
          printer_helper['rl_avg_logprobs'] = results['rl_avg_logprobs']
        if FLAGS.rl_training:
          printer_helper['sampled_r'] = np.mean(results['sampled_sentence_r_values'])
          printer_helper['greedy_r'] = np.mean(results['greedy_sentence_r_values'])
          printer_helper['r_diff'] = printer_helper['greedy_r'] - printer_helper['sampled_r']
        if FLAGS.ac_training:
          printer_helper['dqn_loss'] = np.mean(self.avg_dqn_loss) if len(self.avg_dqn_loss) > 0 else 0

        for (k,v) in printer_helper.items():
          if not np.isfinite(v):
            raise Exception("{} is not finite. Stopping.".format(k))
          tf.logging.info('{}: {}\t'.format(k,v))

        # add summaries
        summaries = results['summaries']
        train_step = results['global_step']
        self.summary_writer.add_summary(summaries, train_step)

        # calculate running avg loss
        avg_losses.append(self.calc_running_avg_loss(np.asscalar(loss), running_avg_loss, train_step))
        tf.logging.info('-------------------------------------------')

      running_avg_loss = np.mean(avg_losses)
      tf.logging.info('==========================================')
      tf.logging.info('best_loss: {}\trunning_avg_loss: {}\t'.format(best_loss, running_avg_loss))
      tf.logging.info('==========================================')

      # If running_avg_loss is best so far, save this checkpoint (early stopping).
      # These checkpoints will appear as bestmodel-<iteration_number> in the eval dir
      if best_loss is None or running_avg_loss < best_loss:
        tf.logging.info('Found new best model with %.3f running_avg_loss. Saving to %s', running_avg_loss, bestmodel_save_path)
        saver.save(sess, bestmodel_save_path, global_step=train_step, latest_filename='checkpoint_best')
        best_loss = running_avg_loss

      # flush the summary writer every so often
      if train_step % 100 == 0:
        self.summary_writer.flush()
Example #3
0
  def run_training(self):
    """Repeatedly runs training iterations, logging loss to screen and writing summaries"""
    tf.logging.info("Starting run_training")

    if FLAGS.debug: # start the tensorflow debugger
      self.sess = tf_debug.LocalCLIDebugWrapperSession(self.sess)
      self.sess.add_tensor_filter("has_inf_or_nan", tf_debug.has_inf_or_nan)

    self.train_step = 0
    if FLAGS.ac_training:
      # DDQN training is done asynchronously along with model training
      tf.logging.info('Starting DQN training thread...')
      self.dqn_train_step = 0
      self.thrd_dqn_training = Thread(target=self.dqn_training)
      self.thrd_dqn_training.daemon = True
      self.thrd_dqn_training.start()

      watcher = Thread(target=self.watch_threads)
      watcher.daemon = True
      watcher.start()
    # starting the main thread
    tf.logging.info('Starting Seq2Seq training...')
    while True: # repeats until interrupted
      batch = self.batcher.next_batch()
      t0=time.time()
      if FLAGS.ac_training:
        # For DDQN, we first collect the model output to calculate the reward and Q-estimates
        # Then we fix the estimation either using our target network or using the true Q-values
        # This process will usually take time and we are working on improving it.
        transitions = self.model.collect_dqn_transitions(self.sess, batch, self.train_step, batch.max_art_oovs) # len(batch_size * k * max_dec_steps)
        tf.logging.info('Q-values collection time: {}'.format(time.time()-t0))
        # whenever we are working with the DDQN, we switch using DDQN graph rather than default graph
        with self.dqn_graph.as_default():
          batch_len = len(transitions)
          # we use current decoder state to predict q_estimates, use_state_prime = False
          b = ReplayBuffer.create_batch(self.dqn_hps, transitions,len(transitions), use_state_prime = False, max_art_oovs = batch.max_art_oovs)
          # we also get the next decoder state to correct the estimation, use_state_prime = True
          b_prime = ReplayBuffer.create_batch(self.dqn_hps, transitions,len(transitions), use_state_prime = True, max_art_oovs = batch.max_art_oovs)
          # use current DQN to estimate values from current decoder state
          dqn_results = self.dqn.run_test_steps(sess=self.dqn_sess, x= b._x, return_best_action=True)
          q_estimates = dqn_results['estimates'] # shape (len(transitions), vocab_size)
          dqn_best_action = dqn_results['best_action']
          #dqn_q_estimate_loss = dqn_results['loss']

          # use target DQN to estimate values for the next decoder state
          dqn_target_results = self.dqn_target.run_test_steps(self.dqn_sess, x= b_prime._x)
          q_vals_new_t = dqn_target_results['estimates'] # shape (len(transitions), vocab_size)

          # we need to expand the q_estimates to match the input batch max_art_oov
          # we use the q_estimate of UNK token for all the OOV tokens
          q_estimates = np.concatenate([q_estimates,
            np.reshape(q_estimates[:,0],[-1,1])*np.ones((len(transitions),batch.max_art_oovs))],axis=-1)
          # modify Q-estimates using the result collected from current and target DQN.
          # check algorithm 5 in the paper for more info: https://arxiv.org/pdf/1805.09461.pdf
          for i, tr in enumerate(transitions):
            if tr.done:
              q_estimates[i][tr.action] = tr.reward
            else:
              q_estimates[i][tr.action] = tr.reward + FLAGS.gamma * q_vals_new_t[i][dqn_best_action[i]]
          # use scheduled sampling to whether use true Q-values or DDQN estimation
          if FLAGS.dqn_scheduled_sampling:
            q_estimates = self.scheduled_sampling(batch_len, FLAGS.sampling_probability, b._y_extended, q_estimates)
          if not FLAGS.calculate_true_q:
            # when we are not training DDQN based on true Q-values,
            # we need to update Q-values in our transitions based on the q_estimates we collected from DQN current network.
            for trans, q_val in zip(transitions,q_estimates):
              trans.q_values = q_val # each have the size vocab_extended
          q_estimates = np.reshape(q_estimates, [FLAGS.batch_size, FLAGS.k, FLAGS.max_dec_steps, -1]) # shape (batch_size, k, max_dec_steps, vocab_size_extended)
        # Once we are done with modifying Q-values, we can use them to train the DDQN model.
        # In this paper, we use a priority experience buffer which always selects states with higher quality
        # to train the DDQN. The following line will add batch_size * max_dec_steps experiences to the replay buffer.
        # As mentioned before, the DDQN training is asynchronous. Therefore, once the related queues for DDQN training
        # are full, the DDQN will start the training.
        self.replay_buffer.add(transitions)
        # If dqn_pretrain flag is on, it means that we use a fixed Actor to only collect experiences for
        # DDQN pre-training
        if FLAGS.dqn_pretrain:
          tf.logging.info('RUNNNING DQN PRETRAIN: Adding data to relplay buffer only...')
          continue
        # if not, use the q_estimation to update the loss.
        results = self.model.run_train_steps(self.sess, batch, self.train_step, q_estimates)
      else:
          results = self.model.run_train_steps(self.sess, batch, self.train_step)
      t1=time.time()
      # get the summaries and iteration number so we can write summaries to tensorboard
      summaries = results['summaries'] # we will write these summaries to tensorboard using summary_writer
      self.train_step = results['global_step'] # we need this to update our running average loss
      tf.logging.info('seconds for training step {}: {}'.format(self.train_step, t1-t0))

      printer_helper = {}
      printer_helper['pgen_loss']= results['pgen_loss']
      if FLAGS.coverage:
        printer_helper['coverage_loss'] = results['coverage_loss']
        if FLAGS.rl_training or FLAGS.ac_training:
          printer_helper['rl_cov_total_loss']= results['reinforce_cov_total_loss']
        else:
          printer_helper['pointer_cov_total_loss'] = results['pointer_cov_total_loss']
      if FLAGS.rl_training or FLAGS.ac_training:
        printer_helper['shared_loss'] = results['shared_loss']
        printer_helper['rl_loss'] = results['rl_loss']
        printer_helper['rl_avg_logprobs'] = results['rl_avg_logprobs']
      if FLAGS.rl_training:
        printer_helper['sampled_r'] = np.mean(results['sampled_sentence_r_values'])
        printer_helper['greedy_r'] = np.mean(results['greedy_sentence_r_values'])
        printer_helper['r_diff'] = printer_helper['sampled_r'] - printer_helper['greedy_r']
      if FLAGS.ac_training:
        printer_helper['dqn_loss'] = np.mean(self.avg_dqn_loss) if len(self.avg_dqn_loss)>0 else 0

      for (k,v) in printer_helper.items():
        if not np.isfinite(v):
          raise Exception("{} is not finite. Stopping.".format(k))
        tf.logging.info('{}: {}\t'.format(k,v))
      tf.logging.info('-------------------------------------------')

      self.summary_writer.add_summary(summaries, self.train_step) # write the summaries
      if self.train_step % 100 == 0: # flush the summary writer every so often
        self.summary_writer.flush()
      if FLAGS.ac_training:
        self.dqn_summary_writer.flush()
      if self.train_step > FLAGS.max_iter: break
Example #4
0
  def run_training(self):
    """Repeatedly runs training iterations, logging loss to screen and writing summaries"""
    tf.logging.info("Starting run_training")

    if FLAGS.debug: # start the tensorflow debugger
      self.sess = tf_debug.LocalCLIDebugWrapperSession(self.sess)
      self.sess.add_tensor_filter("has_inf_or_nan", tf_debug.has_inf_or_nan)

    self.train_step = 0
    if FLAGS.ac_training:
      # DDQN training is done asynchronously along with model training
      tf.logging.info('Starting DQN training thread...')
      self.dqn_train_step = 0
      self.thrd_dqn_training = Thread(target=self.dqn_training)
      self.thrd_dqn_training.daemon = True
      self.thrd_dqn_training.start()

      watcher = Thread(target=self.watch_threads)
      watcher.daemon = True
      watcher.start()
    # starting the main thread
    tf.logging.info('Starting Seq2Seq training...')
    while True: # repeats until interrupted
      batch = self.batcher.next_batch()
      t0=time.time()
      if FLAGS.ac_training:
        # For DDQN, we first collect the model output to calculate the reward and Q-estimates
        # Then we fix the estimation either using our target network or using the true Q-values
        # This process will usually take time and we are working on improving it.
        transitions = self.model.collect_dqn_transitions(self.sess, batch, self.train_step, batch.max_art_oovs) # len(batch_size * k * max_dec_steps)
        tf.logging.info('Q-values collection time: {}'.format(time.time()-t0))
        # whenever we are working with the DDQN, we switch using DDQN graph rather than default graph
        with self.dqn_graph.as_default():
          batch_len = len(transitions)
          # we use current decoder state to predict q_estimates, use_state_prime = False
          b = ReplayBuffer.create_batch(self.dqn_hps, transitions,len(transitions), use_state_prime = False, max_art_oovs = batch.max_art_oovs)
          # we also get the next decoder state to correct the estimation, use_state_prime = True
          b_prime = ReplayBuffer.create_batch(self.dqn_hps, transitions,len(transitions), use_state_prime = True, max_art_oovs = batch.max_art_oovs)
          # use current DQN to estimate values from current decoder state
          dqn_results = self.dqn.run_test_steps(sess=self.dqn_sess, x= b._x, return_best_action=True)
          q_estimates = dqn_results['estimates'] # shape (len(transitions), vocab_size)
          dqn_best_action = dqn_results['best_action']
          #dqn_q_estimate_loss = dqn_results['loss']

          # use target DQN to estimate values for the next decoder state
          dqn_target_results = self.dqn_target.run_test_steps(self.dqn_sess, x= b_prime._x)
          q_vals_new_t = dqn_target_results['estimates'] # shape (len(transitions), vocab_size)

          # we need to expand the q_estimates to match the input batch max_art_oov
          # we use the q_estimate of UNK token for all the OOV tokens
          q_estimates = np.concatenate([q_estimates,
            np.reshape(q_estimates[:,0],[-1,1])*np.ones((len(transitions),batch.max_art_oovs))],axis=-1)
          # modify Q-estimates using the result collected from current and target DQN.
          # check algorithm 5 in the paper for more info: https://arxiv.org/pdf/1805.09461.pdf
          for i, tr in enumerate(transitions):
            if tr.done:
              q_estimates[i][tr.action] = tr.reward
            else:
              q_estimates[i][tr.action] = tr.reward + FLAGS.gamma * q_vals_new_t[i][dqn_best_action[i]]
          # use scheduled sampling to whether use true Q-values or DDQN estimation
          if FLAGS.dqn_scheduled_sampling:
            q_estimates = self.scheduled_sampling(batch_len, FLAGS.sampling_probability, b._y_extended, q_estimates)
          if not FLAGS.calculate_true_q:
            # when we are not training DDQN based on true Q-values,
            # we need to update Q-values in our transitions based on the q_estimates we collected from DQN current network.
            for trans, q_val in zip(transitions,q_estimates):
              trans.q_values = q_val # each have the size vocab_extended
          q_estimates = np.reshape(q_estimates, [FLAGS.batch_size, FLAGS.k, FLAGS.max_dec_steps, -1]) # shape (batch_size, k, max_dec_steps, vocab_size_extended)
        # Once we are done with modifying Q-values, we can use them to train the DDQN model.
        # In this paper, we use a priority experience buffer which always selects states with higher quality
        # to train the DDQN. The following line will add batch_size * max_dec_steps experiences to the replay buffer.
        # As mentioned before, the DDQN training is asynchronous. Therefore, once the related queues for DDQN training
        # are full, the DDQN will start the training.
        self.replay_buffer.add(transitions)
        # If dqn_pretrain flag is on, it means that we use a fixed Actor to only collect experiences for
        # DDQN pre-training
        if FLAGS.dqn_pretrain:
          tf.logging.info('RUNNNING DQN PRETRAIN: Adding data to relplay buffer only...')
          continue
        # if not, use the q_estimation to update the loss.
        results = self.model.run_train_steps(self.sess, batch, self.train_step, q_estimates)
      else:
          results = self.model.run_train_steps(self.sess, batch, self.train_step)
      t1=time.time()
      # get the summaries and iteration number so we can write summaries to tensorboard
      summaries = results['summaries'] # we will write these summaries to tensorboard using summary_writer
      self.train_step = results['global_step'] # we need this to update our running average loss
      tf.logging.info('seconds for training step {}: {}'.format(self.train_step, t1-t0))

      printer_helper = {}
      printer_helper['pgen_loss']= results['pgen_loss']
      if FLAGS.coverage:
        printer_helper['coverage_loss'] = results['coverage_loss']
        if FLAGS.rl_training or FLAGS.ac_training:
          printer_helper['rl_cov_total_loss']= results['reinforce_cov_total_loss']
        else:
          printer_helper['pointer_cov_total_loss'] = results['pointer_cov_total_loss']
      if FLAGS.rl_training or FLAGS.ac_training:
        printer_helper['shared_loss'] = results['shared_loss']
        printer_helper['rl_loss'] = results['rl_loss']
        printer_helper['rl_avg_logprobs'] = results['rl_avg_logprobs']
      if FLAGS.rl_training:
        printer_helper['sampled_r'] = np.mean(results['sampled_sentence_r_values'])
        printer_helper['greedy_r'] = np.mean(results['greedy_sentence_r_values'])
        printer_helper['r_diff'] = printer_helper['greedy_r'] - printer_helper['sampled_r']
      if FLAGS.ac_training:
        printer_helper['dqn_loss'] = np.mean(self.avg_dqn_loss) if len(self.avg_dqn_loss)>0 else 0

      for (k,v) in printer_helper.items():
        if not np.isfinite(v):
          raise Exception("{} is not finite. Stopping.".format(k))
        tf.logging.info('{}: {}\t'.format(k,v))
      tf.logging.info('-------------------------------------------')

      self.summary_writer.add_summary(summaries, self.train_step) # write the summaries
      if self.train_step % 100 == 0: # flush the summary writer every so often
        self.summary_writer.flush()
      if FLAGS.ac_training:
        self.dqn_summary_writer.flush()
      if self.train_step > FLAGS.max_iter: break
Example #5
0
    def run_training(self):
        """Repeatedly runs training iterations, logging loss to screen and writing summaries"""
        tf.logging.info("Starting run_training")

        if FLAGS.debug:  # start the tensorflow debugger
            self.sess = tf_debug.LocalCLIDebugWrapperSession(self.sess)
            self.sess.add_tensor_filter("has_inf_or_nan",
                                        tf_debug.has_inf_or_nan)

        self.train_step = 0
        if FLAGS.ac_training:
            tf.logging.info('Starting DQN training thread...')
            self.dqn_train_step = 0
            self.thrd_dqn_training = Thread(target=self.dqn_training)
            self.thrd_dqn_training.daemon = True
            self.thrd_dqn_training.start()

            watcher = Thread(target=self.watch_threads)
            watcher.daemon = True
            watcher.start()

        tf.logging.info('Starting Seq2Seq training...')
        while True:  # repeats until interrupted
            batch = self.batcher.next_batch()
            t0 = time.time()
            if FLAGS.ac_training:
                transitions = self.model.collect_dqn_transitions(
                    self.sess, batch, self.train_step,
                    batch.max_art_oovs)  # len(batch_size * k * max_dec_steps)
                tf.logging.info(
                    'Q-values collection time: {}'.format(time.time() - t0))
                with self.dqn_graph.as_default():
                    # if using true Q-value to train DQN network,
                    # we do this as the pre-training for the DQN network to get better estimates
                    batch_len = len(transitions)
                    b = ReplayBuffer.create_batch(
                        self.dqn_hps,
                        transitions,
                        len(transitions),
                        use_state_prime=True,
                        max_art_oovs=batch.max_art_oovs)
                    b_prime = ReplayBuffer.create_batch(
                        self.dqn_hps,
                        transitions,
                        len(transitions),
                        use_state_prime=True,
                        max_art_oovs=batch.max_art_oovs)
                    dqn_results = self.dqn.run_test_steps(
                        sess=self.dqn_sess, x=b._x, return_best_action=True)
                    q_estimates = dqn_results[
                        'estimates']  # shape (len(transitions), vocab_size)
                    dqn_best_action = dqn_results['best_action']
                    #dqn_q_estimate_loss = dqn_results['loss']

                    dqn_target_results = self.dqn_target.run_test_steps(
                        self.dqn_sess, x=b_prime._x)
                    q_vals_new_t = dqn_target_results[
                        'estimates']  # shape (len(transitions), vocab_size)

                    # we need to expand the q_estimates to match the input batch max_art_oov
                    q_estimates = np.concatenate([
                        q_estimates,
                        np.zeros((len(transitions), batch.max_art_oovs))
                    ],
                                                 axis=-1)

                    for i, tr in enumerate(transitions):
                        if tr.done:
                            q_estimates[i][tr.action] = tr.reward
                        else:
                            q_estimates[i][
                                tr.
                                action] = tr.reward + FLAGS.gamma * q_vals_new_t[
                                    i][dqn_best_action[i]]
                    if FLAGS.dqn_scheduled_sampling:
                        q_estimates = self.scheduled_sampling(
                            batch_len, FLAGS.sampling_probability,
                            b._y_extended, q_estimates)
                    if not FLAGS.calculate_true_q:
                        # when we are not training DQN based on true Q-values
                        # we need to update Q-values in our transitions based on this q_estimates we collected from DQN current network.
                        for trans, q_val in zip(transitions, q_estimates):
                            trans.q_values = q_val  # each have the size vocab_extended
                    q_estimates = np.reshape(
                        q_estimates,
                        [FLAGS.batch_size, FLAGS.k, FLAGS.max_dec_steps, -1]
                    )  # shape (batch_size, k, max_dec_steps, vocab_size_extended)
                self.replay_buffer.add(transitions)
                # if using estimated Q-value to train DQN network
                # this way, we only need to calculate the reward for a single action, not the whole vocabulary
                # we train the DQN network, based on q_estimate variable as the true Q-values
                # therefore, we need to update true Q-value of the transitions variable with this one for each transition
                # then finally add the updated transitions variable to replay buffer
                if FLAGS.dqn_pretrain:
                    tf.logging.info(
                        'RUNNNING DQN PRETRAIN: Adding data to relplay buffer only...'
                    )
                    continue
                results = self.model.run_train_steps(self.sess, batch,
                                                     self.train_step,
                                                     q_estimates)
            else:
                results = self.model.run_train_steps(self.sess, batch,
                                                     self.train_step)
            t1 = time.time()
            # get the summaries and iteration number so we can write summaries to tensorboard
            summaries = results[
                'summaries']  # we will write these summaries to tensorboard using summary_writer
            self.train_step = results[
                'global_step']  # we need this to update our running average loss
            tf.logging.info('seconds for training step {}: {}'.format(
                self.train_step, t1 - t0))

            printer_helper = {}
            printer_helper['pgen_loss'] = results['pgen_loss']
            if FLAGS.coverage:
                printer_helper['coverage_loss'] = results['coverage_loss']
                if FLAGS.rl_training or FLAGS.ac_training:
                    printer_helper['rl_cov_total_loss'] = results[
                        'reinforce_cov_total_loss']
                else:
                    printer_helper['pointer_cov_total_loss'] = results[
                        'pointer_cov_total_loss']
            if FLAGS.rl_training or FLAGS.ac_training:
                printer_helper['shared_loss'] = results['shared_loss']
                printer_helper['rl_loss'] = results['rl_loss']
                printer_helper['rl_avg_logprobs'] = results['rl_avg_logprobs']
            if FLAGS.rl_training:
                printer_helper['sampled_r'] = np.mean(
                    results['sampled_sentence_r_values'])
                printer_helper['greedy_r'] = np.mean(
                    results['greedy_sentence_r_values'])
                printer_helper['r_diff'] = printer_helper[
                    'sampled_r'] - printer_helper['greedy_r']
            if FLAGS.ac_training:
                printer_helper['dqn_loss'] = np.mean(
                    self.avg_dqn_loss) if len(self.avg_dqn_loss) > 0 else 0

            for (k, v) in printer_helper.items():
                if not np.isfinite(v):
                    raise Exception("{} is not finite. Stopping.".format(k))
                tf.logging.info('{}: {}\t'.format(k, v))
            tf.logging.info('-------------------------------------------')

            self.summary_writer.add_summary(
                summaries, self.train_step)  # write the summaries
            if self.train_step % 100 == 0:  # flush the summary writer every so often
                self.summary_writer.flush()
            if FLAGS.ac_training:
                self.dqn_summary_writer.flush()
            if self.train_step > FLAGS.max_iter: break