예제 #1
0
class BeamSearch(object):
    def __init__(self, model_file_path):
        model_name = os.path.basename(model_file_path)
        self._decode_dir = os.path.join(config.log_root,
                                        'decode_%s' % (model_name))
        self._rouge_ref_dir = os.path.join(self._decode_dir, 'rouge_ref')
        self._rouge_dec_dir = os.path.join(self._decode_dir, 'rouge_dec_dir')
        for p in [self._decode_dir, self._rouge_ref_dir, self._rouge_dec_dir]:
            if not os.path.exists(p):
                os.mkdir(p)

        self.vocab = Vocab(config.vocab_path, config.vocab_size)
        self.batcher = Batcher(config.decode_data_path,
                               self.vocab,
                               mode='decode',
                               batch_size=config.beam_size,
                               single_pass=True)
        time.sleep(15)

        self.model = Model(model_file_path, is_eval=True)

    def sort_beams(self, beams):
        return sorted(beams, key=lambda h: h.avg_log_prob, reverse=True)

    def decode(self):
        start = time.time()
        counter = 0
        batch = self.batcher.next_batch()
        while batch is not None:
            # Run beam search to get best Hypothesis
            best_summary = self.beam_search(batch)

            # Extract the output ids from the hypothesis and convert back to words
            output_ids = [int(t) for t in best_summary.tokens[1:]]
            decoded_words = data.outputids2words(
                output_ids, self.vocab,
                (batch.art_oovs[0] if config.pointer_gen else None))

            # Remove the [STOP] token from decoded_words, if necessary
            try:
                fst_stop_idx = decoded_words.index(data.STOP_DECODING)
                decoded_words = decoded_words[:fst_stop_idx]
            except ValueError:
                decoded_words = decoded_words

            original_abstract = batch.original_abstracts_sents[0]

            write_for_rouge(original_abstract, decoded_words, counter,
                            self._rouge_ref_dir, self._rouge_dec_dir)
            counter += 1
            if counter % 1000 == 0:
                print('%d example in %d sec' % (counter, time.time() - start))
                start = time.time()

            batch = self.batcher.next_batch()

        print("Decoder has finished reading dataset for single_pass.")
        print("Now starting ROUGE eval...")
        results_dict = rouge_eval(self._rouge_ref_dir, self._rouge_dec_dir)
        rouge_log(results_dict, self._decode_dir)

    def beam_search(self, batch):
        #batch should have only one example
        enc_batch, enc_padding_mask, enc_lens, enc_batch_extend_vocab, extra_zeros, c_t_0, coverage_t_0 = \
            get_input_from_batch(batch, use_cuda)

        encoder_outputs, encoder_feature, encoder_hidden = self.model.encoder(
            enc_batch, enc_lens)
        s_t_0 = self.model.reduce_state(encoder_hidden)

        dec_h, dec_c = s_t_0  # 1 x 2*hidden_size
        dec_h = dec_h.squeeze()
        dec_c = dec_c.squeeze()

        #decoder batch preparation, it has beam_size example initially everything is repeated
        beams = [
            Beam(tokens=[self.vocab.word2id(data.START_DECODING)],
                 log_probs=[0.0],
                 state=(dec_h[0], dec_c[0]),
                 context=c_t_0[0],
                 coverage=(coverage_t_0[0] if config.is_coverage else None))
            for _ in range(config.beam_size)
        ]
        results = []
        steps = 0
        while steps < config.max_dec_steps and len(results) < config.beam_size:
            latest_tokens = [h.latest_token for h in beams]
            latest_tokens = [t if t < self.vocab.size() else self.vocab.word2id(data.UNKNOWN_TOKEN) \
                             for t in latest_tokens]
            y_t_1 = Variable(torch.LongTensor(latest_tokens))
            if use_cuda:
                y_t_1 = y_t_1.cuda()
            all_state_h = []
            all_state_c = []

            all_context = []

            for h in beams:
                state_h, state_c = h.state
                all_state_h.append(state_h)
                all_state_c.append(state_c)

                all_context.append(h.context)

            s_t_1 = (torch.stack(all_state_h,
                                 0).unsqueeze(0), torch.stack(all_state_c,
                                                              0).unsqueeze(0))
            c_t_1 = torch.stack(all_context, 0)

            coverage_t_1 = None
            if config.is_coverage:
                all_coverage = []
                for h in beams:
                    all_coverage.append(h.coverage)
                coverage_t_1 = torch.stack(all_coverage, 0)

            final_dist, s_t, c_t, attn_dist, p_gen, coverage_t = self.model.decoder(
                y_t_1, s_t_1, encoder_outputs, encoder_feature,
                enc_padding_mask, c_t_1, extra_zeros, enc_batch_extend_vocab,
                coverage_t_1, steps)

            topk_log_probs, topk_ids = torch.topk(final_dist,
                                                  config.beam_size * 2)

            dec_h, dec_c = s_t
            dec_h = dec_h.squeeze()
            dec_c = dec_c.squeeze()

            all_beams = []
            num_orig_beams = 1 if steps == 0 else len(beams)
            for i in range(num_orig_beams):
                h = beams[i]
                state_i = (dec_h[i], dec_c[i])
                context_i = c_t[i]
                coverage_i = (coverage_t[i] if config.is_coverage else None)

                for j in range(config.beam_size *
                               2):  # for each of the top 2*beam_size hyps:
                    new_beam = h.extend(token=topk_ids[i, j].item(),
                                        log_prob=topk_log_probs[i, j].item(),
                                        state=state_i,
                                        context=context_i,
                                        coverage=coverage_i)
                    all_beams.append(new_beam)

            beams = []
            for h in self.sort_beams(all_beams):
                if h.latest_token == self.vocab.word2id(data.STOP_DECODING):
                    if steps >= config.min_dec_steps:
                        results.append(h)
                else:
                    beams.append(h)
                if len(beams) == config.beam_size or len(
                        results) == config.beam_size:
                    break

            steps += 1

        if len(results) == 0:
            results = beams

        beams_sorted = self.sort_beams(results)

        return beams_sorted[0]
예제 #2
0
def main(unused_argv):
    print("unused_argv: ", unused_argv)
    if len(unused_argv
           ) != 1:  # prints a message if you've entered flags incorrectly
        raise Exception("Problem with flags: %s" % unused_argv)

    tf.logging.set_verbosity(
        tf.logging.INFO)  # choose what level of logging you want
    tf.logging.info('Starting seq2seq_attention in %s mode...', (FLAGS.mode))

    # Change log_root to FLAGS.log_root/FLAGS.exp_name and create the dir if necessary
    FLAGS.log_root = os.path.join(FLAGS.log_root, FLAGS.exp_name)
    if not os.path.exists(FLAGS.log_root):
        if FLAGS.mode == "train":
            os.makedirs(FLAGS.log_root)
        else:
            raise Exception(
                "Logdir %s doesn't exist. Run in train mode to create it." %
                (FLAGS.log_root))
    print("FLAGS.vocab_size: ", FLAGS.vocab_size)
    vocab = Vocab(FLAGS.vocab_path, FLAGS.vocab_size)  # create a vocabulary
    print("vocab size: ", vocab.size())
    # If in decode mode, set batch_size = beam_size
    # Reason: in decode mode, we decode one example at a time.
    # On each step, we have beam_size-many hypotheses in the beam, so we need to make a batch of these hypotheses.
    if FLAGS.mode == 'decode':
        FLAGS.batch_size = FLAGS.beam_size

    # If single_pass=True, check we're in decode mode
    if FLAGS.single_pass and FLAGS.mode != 'decode':
        raise Exception(
            "The single_pass flag should only be True in decode mode")

    # Make a namedtuple hps, containing the values of the hyperparameters that the model needs
    hparam_list = [
        'mode', 'lr', 'adagrad_init_acc', 'rand_unif_init_mag',
        'trunc_norm_init_std', 'max_grad_norm', 'hidden_dim', 'emb_dim',
        'batch_size', 'max_dec_steps', 'max_enc_steps', 'coverage',
        'cov_loss_wt', 'pointer_gen', 'fine_tune', 'train_size', 'subred_size',
        'use_doc_vec', 'use_multi_attn', 'use_multi_pgen', 'use_multi_pvocab',
        'create_ckpt'
    ]
    hps_dict = {}
    for key, val in FLAGS.__flags.items():  # for each flag
        if key in hparam_list:  # if it's in the list
            hps_dict[key] = val  # add it to the dict
    hps = namedtuple("HParams", hps_dict.keys())(**hps_dict)

    # Create a batcher object that will create minibatches of data
    batcher = Batcher(FLAGS.data_path,
                      vocab,
                      hps,
                      single_pass=FLAGS.single_pass)

    tf.set_random_seed(111)  # a seed value for randomness

    #   return

    if hps.mode.value == 'train':
        print("creating model...")
        model = SummarizationModel(hps, vocab)

        # -------------------------------------
        if hps.create_ckpt.value:
            step = 0

            model.build_graph()
            print("get value")
            pretrained_ckpt = '/home/cs224u/pointer/log/pretrained_model_tf1.2.1/train/model-238410'
            reader = pywrap_tensorflow.NewCheckpointReader(pretrained_ckpt)
            var_to_shape_map = reader.get_variable_to_shape_map()
            value = {}
            for key in var_to_shape_map:
                value[key] = reader.get_tensor(key)

            print("assign op")
            assign_op = []
            if hps.use_multi_pvocab.value:
                new_key = [
                    "seq2seq/decoder/attention_decoder/AttnOutputProjection/Linear_0/Bias",
                    "seq2seq/decoder/attention_decoder/AttnOutputProjection/Linear_1/Bias"
                ]
                for v in tf.trainable_variables():
                    key = v.name.split(":")[0]
                    if key in new_key:
                        origin_key = "seq2seq/decoder/attention_decoder/AttnOutputProjection/Linear/" + key.split(
                            "/")[-1]
                        a_op = v.assign(tf.convert_to_tensor(
                            value[origin_key]))
                    else:
                        a_op = v.assign(tf.convert_to_tensor(value[key]))
                    # if key == "seq2seq/embedding/embedding":
                    # a_op = v.assign(tf.convert_to_tensor(value[key]))
                    assign_op.append(a_op)
            else:
                for v in tf.trainable_variables():
                    key = v.name.split(":")[0]
                    if key == "seq2seq/embedding/embedding":
                        a_op = v.assign(tf.convert_to_tensor(value[key]))
                        assign_op.append(a_op)
            # ratio = 1
            # for v in tf.trainable_variables():
            #   key = v.name.split(":")[0]
            #   # embedding (50000, 128) -> (50000, 32)

            #   if key == "seq2seq/embedding/embedding":
            #       print (key)
            #       print (value[key].shape)
            #       d1 = value[key].shape[1]
            #       a_op = v.assign(tf.convert_to_tensor(value[key][:,:d1//ratio]))
            #   # kernel (384, 1024) -> (96, 256)
            #   # w_reduce_c (512, 256) -> (128, 64)
            #   elif key == "seq2seq/encoder/bidirectional_rnn/fw/lstm_cell/kernel" or \
            #   key == "seq2seq/encoder/bidirectional_rnn/bw/lstm_cell/kernel" or \
            #   key == "seq2seq/reduce_final_st/w_reduce_c" or \
            #   key == "seq2seq/reduce_final_st/w_reduce_h" or \
            #   key == "seq2seq/decoder/attention_decoder/Linear/Matrix" or \
            #   key == "seq2seq/decoder/attention_decoder/lstm_cell/kernel" or \
            #   key == "seq2seq/decoder/attention_decoder/Attention/Linear/Matrix" or \
            #   key == "seq2seq/decoder/attention_decoder/AttnOutputProjection/Linear/Matrix":
            #       print (key)
            #       print (value[key].shape)
            #       d0, d1 = value[key].shape[0], value[key].shape[1]
            #       a_op = v.assign(tf.convert_to_tensor(value[key][:d0//ratio, :d1//ratio]))
            #   # bias (1024,) -> (256,)
            #   elif key == "seq2seq/encoder/bidirectional_rnn/fw/lstm_cell/bias" or \
            #   key == "seq2seq/encoder/bidirectional_rnn/bw/lstm_cell/bias" or \
            #   key == "seq2seq/reduce_final_st/bias_reduce_c" or \
            #   key == "seq2seq/reduce_final_st/bias_reduce_h" or \
            #   key == "seq2seq/decoder/attention_decoder/lstm_cell/bias" or \
            #   key == "seq2seq/decoder/attention_decoder/v" or \
            #   key == "seq2seq/decoder/attention_decoder/Attention/Linear/Bias" or \
            #   key == "seq2seq/decoder/attention_decoder/Linear/Bias" or \
            #   key == "seq2seq/decoder/attention_decoder/AttnOutputProjection/Linear/Bias":
            #       print (key)
            #       print (value[key].shape)
            #       d0 = value[key].shape[0]
            #       a_op = v.assign(tf.convert_to_tensor(value[key][:d0//ratio]))
            #   # W_h (1, 1, 512, 512) -> (1, 1, 128, 128)
            #   elif key == "seq2seq/decoder/attention_decoder/W_h":
            #       print (key)
            #       print (value[key].shape)
            #       d2, d3 = value[key].shape[2], value[key].shape[3]
            #       a_op = v.assign(tf.convert_to_tensor(value[key][:,:,:d2//ratio,:d3//ratio]))
            #   # Matrix (1152, 1) -> (288, 1)
            #   elif key == "seq2seq/decoder/attention_decoder/calculate_pgen/Linear/Matrix" or \
            #   key == "seq2seq/output_projection/w":
            #       print (key)
            #       print (value[key].shape)
            #       d0 = value[key].shape[0]
            #       a_op = v.assign(tf.convert_to_tensor(value[key][:d0//ratio,:]))
            #   # Bias (1,) -> (1,)
            #   elif key == "seq2seq/output_projection/v" or \
            #   key == "seq2seq/decoder/attention_decoder/calculate_pgen/Linear/Bias":
            #       print (key)
            #       print (value[key].shape)
            #       a_op = v.assign(tf.convert_to_tensor(value[key]))

            #   # multi_attn
            #   if hps.use_multi_attn.value:
            #     if key == "seq2seq/decoder/attention_decoder/attn_0/v" or \
            #     key == "seq2seq/decoder/attention_decoder/attn_1/v":
            #     # key == "seq2seq/decoder/attention_decoder/attn_2/v":
            #       k = "seq2seq/decoder/attention_decoder/v"
            #       print (key)
            #       print (value[k].shape)
            #       d0 = value[k].shape[0]
            #       a_op = v.assign(tf.convert_to_tensor(value[k][:d0//ratio]))
            #     if key == "seq2seq/decoder/attention_decoder/Attention/Linear_0/Bias" or \
            #     key == "seq2seq/decoder/attention_decoder/Attention/Linear_1/Bias":
            #     # key == "seq2seq/decoder/attention_decoder/Attention/Linear_2/Bias":
            #       k = "seq2seq/decoder/attention_decoder/Attention/Linear/Bias"
            #       print (key)
            #       print (value[k].shape)
            #       d0 = value[k].shape[0]
            #       a_op = v.assign(tf.convert_to_tensor(value[k][:d0//ratio]))
            #   elif hps.use_multi_pgen.value:
            #     if key == "seq2seq/decoder/attention_decoder/Linear_0/Bias" or \
            #     key == "seq2seq/decoder/attention_decoder/Linear_1/Bias":
            #     # key == "seq2seq/decoder/attention_decoder/Linear_2/Bias":
            #       k = "seq2seq/decoder/attention_decoder/Linear/Bias"
            #       print (key)
            #       print (value[k].shape)
            #       d0 = value[k].shape[0]
            #       a_op = v.assign(tf.convert_to_tensor(value[k][:d0//ratio]))
            #     if key == "seq2seq/decoder/attention_decoder/calculate_pgen/Linear_0/Bias" or \
            #     key == "seq2seq/decoder/attention_decoder/calculate_pgen/Linear_1/Bias":
            #     # key == "seq2seq/decoder/attention_decoder/calculate_pgen/Linear_2/Bias":
            #       k = "seq2seq/decoder/attention_decoder/calculate_pgen/Linear/Bias"
            #       print (key)
            #       print (value[k].shape)
            #       a_op = v.assign(tf.convert_to_tensor(value[k]))
            #   elif hps.use_multi_pvocab.value:
            #     if key == "seq2seq/decoder/attention_decoder/AttnOutputProjection/Linear_0/Bias" or \
            #     key == "seq2seq/decoder/attention_decoder/AttnOutputProjection/Linear_1/Bias":
            #     # key == "seq2seq/decoder/attention_decoder/AttnOutputProjection/Linear_2/Bias":
            #       k = "seq2seq/decoder/attention_decoder/AttnOutputProjection/Linear/Bias"
            #       print (key)
            #       print (value[k].shape)
            #       d0 = value[k].shape[0]
            #       a_op = v.assign(tf.convert_to_tensor(value[k][:d0//ratio]))

            #    assign_op.append(a_op)

            # Add an op to initialize the variables.
            init_op = tf.global_variables_initializer()
            # Add ops to save and restore all the variables.
            saver = tf.train.Saver()
            with tf.Session(config=util.get_config()) as sess:
                sess.run(init_op)
                # Do some work with the model.
                for a_op in assign_op:
                    a_op.op.run()

                for _ in range(0):
                    batch = batcher.next_batch()
                    results = model.run_train_step(sess, batch)

                # Save the variables to disk.
                if hps.use_multi_attn.value:
                    ckpt_tag = "multi_attn_2_attn_proj"
                elif hps.use_multi_pgen.value:
                    ckpt_tag = "multi_attn_2_pgen_proj"
                elif hps.use_multi_pvocab.value:
                    ckpt_tag = "big_multi_attn_2_pvocab_proj"
                else:
                    ckpt_tag = "pointer_proj"

                ckpt_to_save = '/home/cs224u/pointer/log/ckpt/' + ckpt_tag + '/model.ckpt-' + str(
                    step)
                save_path = saver.save(sess, ckpt_to_save)
                print("Model saved in path: %s" % save_path)

        # -------------------------------------
        else:
            setup_training(model, batcher, hps)

    elif hps.mode.value == 'eval':
        model = SummarizationModel(hps, vocab)
        run_eval(model, batcher, vocab)
    elif hps.mode.value == 'decode':
        decode_model_hps = hps  # This will be the hyperparameters for the decoder model
        decode_model_hps = hps._replace(
            max_dec_steps=1
        )  # The model is configured with max_dec_steps=1 because we only ever run one step of the decoder at a time (to do beam search). Note that the batcher is initialized with max_dec_steps equal to e.g. 100 because the batches need to contain the full summaries
        model = SummarizationModel(decode_model_hps, vocab)
        decoder = BeamSearchDecoder(model, batcher, vocab)
        decoder.decode(
        )  # decode indefinitely (unless single_pass=True, in which case deocde the dataset exactly once)
    else:
        raise ValueError("The 'mode' flag must be one of train/eval/decode")
예제 #3
0
class Seq2Seq(object):

  def calc_running_avg_loss(self, loss, running_avg_loss, step, decay=0.99):
    """Calculate the running average loss via exponential decay.
    This is used to implement early stopping w.r.t. a more smooth loss curve than the raw loss curve.

    Args:
      loss: loss on the most recent eval step
      running_avg_loss: running_avg_loss so far
      summary_writer: FileWriter object to write for tensorboard
      step: training iteration step
      decay: rate of exponential decay, a float between 0 and 1. Larger is smoother.

    Returns:
      running_avg_loss: new running average loss
    """
    if running_avg_loss == 0:  # on the first iteration just take the loss
      running_avg_loss = loss
    else:
      running_avg_loss = running_avg_loss * decay + (1 - decay) * loss
    running_avg_loss = min(running_avg_loss, 12)  # clip
    loss_sum = tf.Summary()
    tag_name = 'running_avg_loss/decay=%f' % (decay)
    loss_sum.value.add(tag=tag_name, simple_value=running_avg_loss)
    self.summary_writer.add_summary(loss_sum, step)
    tf.logging.info('running_avg_loss: %f', running_avg_loss)
    return running_avg_loss

  def restore_best_model(self):
    """Load bestmodel file from eval directory, add variables for adagrad, and save to train directory"""
    tf.logging.info("Restoring bestmodel for training...")

    # Initialize all vars in the model
    sess = tf.Session(config=util.get_config())
    print("Initializing all variables...")
    sess.run(tf.initialize_all_variables())

    # Restore the best model from eval dir
    saver = tf.train.Saver([v for v in tf.all_variables() if "Adagrad" not in v.name])
    print("Restoring all non-adagrad variables from best model in eval dir...")
    curr_ckpt = util.load_ckpt(saver, sess, "eval")
    print("Restored %s." % curr_ckpt)

    # Save this model to train dir and quit
    new_model_name = curr_ckpt.split("/")[-1].replace("bestmodel", "model")
    new_fname = os.path.join(FLAGS.log_root, "train", new_model_name)
    print("Saving model to %s..." % (new_fname))
    new_saver = tf.train.Saver() # this saver saves all variables that now exist, including Adagrad variables
    new_saver.save(sess, new_fname)
    print("Saved.")
    exit()

  def restore_best_eval_model(self):
    # load best evaluation loss so far
    best_loss = None
    best_step = None
    # goes through all event files and select the best loss achieved and return it
    event_files = sorted(glob('{}/eval/events*'.format(FLAGS.log_root)))
    for ef in event_files:
      try:
        for e in tf.train.summary_iterator(ef):
          for v in e.summary.value:
            step = e.step
            if 'running_avg_loss/decay' in v.tag:
              running_avg_loss = v.simple_value
              if best_loss is None or running_avg_loss < best_loss:
                best_loss = running_avg_loss
                best_step = step
      except:
        continue
    tf.logging.info('resotring best loss from the current logs: {}\tstep: {}'.format(best_loss, best_step))
    return best_loss

  def convert_to_coverage_model(self):
    """Load non-coverage checkpoint, add initialized extra variables for coverage, and save as new checkpoint"""
    tf.logging.info("converting non-coverage model to coverage model..")

    # initialize an entire coverage model from scratch
    sess = tf.Session(config=util.get_config())
    print("initializing everything...")
    sess.run(tf.global_variables_initializer())

    # load all non-coverage weights from checkpoint
    saver = tf.train.Saver([v for v in tf.global_variables() if "coverage" not in v.name and "Adagrad" not in v.name])
    print("restoring non-coverage variables...")
    curr_ckpt = util.load_ckpt(saver, sess)
    print("restored.")

    # save this model and quit
    new_fname = curr_ckpt + '_cov_init'
    print("saving model to %s..." % (new_fname))
    new_saver = tf.train.Saver() # this one will save all variables that now exist
    new_saver.save(sess, new_fname)
    print("saved.")
    exit()

  def convert_to_reinforce_model(self):
    """Load non-reinforce checkpoint, add initialized extra variables for reinforce, and save as new checkpoint"""
    tf.logging.info("converting non-reinforce model to reinforce model..")

    # initialize an entire reinforce model from scratch
    sess = tf.Session(config=util.get_config())
    print("initializing everything...")
    sess.run(tf.global_variables_initializer())

    # load all non-reinforce weights from checkpoint
    saver = tf.train.Saver([v for v in tf.global_variables() if "reinforce" not in v.name and "Adagrad" not in v.name])
    print("restoring non-reinforce variables...")
    curr_ckpt = util.load_ckpt(saver, sess)
    print("restored.")

    # save this model and quit
    new_fname = curr_ckpt + '_rl_init'
    print("saving model to %s..." % (new_fname))
    new_saver = tf.train.Saver() # this one will save all variables that now exist
    new_saver.save(sess, new_fname)
    print("saved.")
    exit()

  def setup_training(self):
    """Does setup before starting training (run_training)"""
    train_dir = os.path.join(FLAGS.log_root, "train")
    if not os.path.exists(train_dir): os.makedirs(train_dir)
    if FLAGS.ac_training:
      dqn_train_dir = os.path.join(FLAGS.log_root, "dqn", "train")
      if not os.path.exists(dqn_train_dir): os.makedirs(dqn_train_dir)
    #replaybuffer_pcl_path = os.path.join(FLAGS.log_root, "replaybuffer.pcl")
    #if not os.path.exists(dqn_target_train_dir): os.makedirs(dqn_target_train_dir)

    self.model.build_graph() # build the graph

    if FLAGS.convert_to_reinforce_model:
      assert (FLAGS.rl_training or FLAGS.ac_training), "To convert your pointer model to a reinforce model, run with convert_to_reinforce_model=True and either rl_training=True or ac_training=True"
      self.convert_to_reinforce_model()
    if FLAGS.convert_to_coverage_model:
      assert FLAGS.coverage, "To convert your non-coverage model to a coverage model, run with convert_to_coverage_model=True and coverage=True"
      self.convert_to_coverage_model()
    if FLAGS.restore_best_model:
      self.restore_best_model()
    saver = tf.train.Saver(max_to_keep=3) # keep 3 checkpoints at a time

    # Loads pre-trained word-embedding. By default the model learns the embedding.
    if FLAGS.embedding:
      self.vocab.LoadWordEmbedding(FLAGS.embedding, FLAGS.emb_dim)
      word_vector = self.vocab.getWordEmbedding()

    self.sv = tf.train.Supervisor(logdir=train_dir,
                       is_chief=True,
                       saver=saver,
                       summary_op=None,
                       save_summaries_secs=60, # save summaries for tensorboard every 60 secs
                       save_model_secs=60, # checkpoint every 60 secs
                       global_step=self.model.global_step,
                       init_feed_dict= {self.model.embedding_place:word_vector} if FLAGS.embedding else None
                       )
    self.summary_writer = self.sv.summary_writer
    self.sess = self.sv.prepare_or_wait_for_session(config=util.get_config())
    if FLAGS.ac_training:
      tf.logging.info('DDQN building graph')
      t1 = time.time()
      # We create a separate graph for DDQN
      self.dqn_graph = tf.Graph()
      with self.dqn_graph.as_default():
        self.dqn.build_graph() # build dqn graph
        tf.logging.info('building current network took {} seconds'.format(time.time()-t1))

        self.dqn_target.build_graph() # build dqn target graph
        tf.logging.info('building target network took {} seconds'.format(time.time()-t1))

        dqn_saver = tf.train.Saver(max_to_keep=3) # keep 3 checkpoints at a time
        self.dqn_sv = tf.train.Supervisor(logdir=dqn_train_dir,
                           is_chief=True,
                           saver=dqn_saver,
                           summary_op=None,
                           save_summaries_secs=60, # save summaries for tensorboard every 60 secs
                           save_model_secs=60, # checkpoint every 60 secs
                           global_step=self.dqn.global_step,
                           )
        self.dqn_summary_writer = self.dqn_sv.summary_writer
        self.dqn_sess = self.dqn_sv.prepare_or_wait_for_session(config=util.get_config())
      ''' #### TODO: try loading a previously saved replay buffer
      # right now this doesn't work due to running DQN on a thread
      if os.path.exists(replaybuffer_pcl_path):
        tf.logging.info('Loading Replay Buffer...')
        try:
          self.replay_buffer = pickle.load(open(replaybuffer_pcl_path, "rb"))
          tf.logging.info('Replay Buffer loaded...')
        except:
          tf.logging.info('Couldn\'t load Replay Buffer file...')
          self.replay_buffer = ReplayBuffer(self.dqn_hps)
      else:
        self.replay_buffer = ReplayBuffer(self.dqn_hps)
      tf.logging.info("Building DDQN took {} seconds".format(time.time()-t1))
      '''
      self.replay_buffer = ReplayBuffer(self.dqn_hps)
    tf.logging.info("Preparing or waiting for session...")
    tf.logging.info("Created session.")
    try:
      self.run_training() # this is an infinite loop until interrupted
    except (KeyboardInterrupt, SystemExit):
      tf.logging.info("Caught keyboard interrupt on worker. Stopping supervisor...")
      self.sv.stop()
      if FLAGS.ac_training:
        self.dqn_sv.stop()

  def run_training(self):
    """Repeatedly runs training iterations, logging loss to screen and writing summaries"""
    tf.logging.info("Starting run_training")

    if FLAGS.debug: # start the tensorflow debugger
      self.sess = tf_debug.LocalCLIDebugWrapperSession(self.sess)
      self.sess.add_tensor_filter("has_inf_or_nan", tf_debug.has_inf_or_nan)

    self.train_step = 0
    if FLAGS.ac_training:
      # DDQN training is done asynchronously along with model training
      tf.logging.info('Starting DQN training thread...')
      self.dqn_train_step = 0
      self.thrd_dqn_training = Thread(target=self.dqn_training)
      self.thrd_dqn_training.daemon = True
      self.thrd_dqn_training.start()

      watcher = Thread(target=self.watch_threads)
      watcher.daemon = True
      watcher.start()
    # starting the main thread
    tf.logging.info('Starting Seq2Seq training...')
    while True: # repeats until interrupted
      batch = self.batcher.next_batch()
      t0=time.time()
      if FLAGS.ac_training:
        # For DDQN, we first collect the model output to calculate the reward and Q-estimates
        # Then we fix the estimation either using our target network or using the true Q-values
        # This process will usually take time and we are working on improving it.
        transitions = self.model.collect_dqn_transitions(self.sess, batch, self.train_step, batch.max_art_oovs) # len(batch_size * k * max_dec_steps)
        tf.logging.info('Q-values collection time: {}'.format(time.time()-t0))
        # whenever we are working with the DDQN, we switch using DDQN graph rather than default graph
        with self.dqn_graph.as_default():
          batch_len = len(transitions)
          # we use current decoder state to predict q_estimates, use_state_prime = False
          b = ReplayBuffer.create_batch(self.dqn_hps, transitions,len(transitions), use_state_prime = False, max_art_oovs = batch.max_art_oovs)
          # we also get the next decoder state to correct the estimation, use_state_prime = True
          b_prime = ReplayBuffer.create_batch(self.dqn_hps, transitions,len(transitions), use_state_prime = True, max_art_oovs = batch.max_art_oovs)
          # use current DQN to estimate values from current decoder state
          dqn_results = self.dqn.run_test_steps(sess=self.dqn_sess, x= b._x, return_best_action=True)
          q_estimates = dqn_results['estimates'] # shape (len(transitions), vocab_size)
          dqn_best_action = dqn_results['best_action']
          #dqn_q_estimate_loss = dqn_results['loss']

          # use target DQN to estimate values for the next decoder state
          dqn_target_results = self.dqn_target.run_test_steps(self.dqn_sess, x= b_prime._x)
          q_vals_new_t = dqn_target_results['estimates'] # shape (len(transitions), vocab_size)

          # we need to expand the q_estimates to match the input batch max_art_oov
          # we use the q_estimate of UNK token for all the OOV tokens
          q_estimates = np.concatenate([q_estimates,
            np.reshape(q_estimates[:,0],[-1,1])*np.ones((len(transitions),batch.max_art_oovs))],axis=-1)
          # modify Q-estimates using the result collected from current and target DQN.
          # check algorithm 5 in the paper for more info: https://arxiv.org/pdf/1805.09461.pdf
          for i, tr in enumerate(transitions):
            if tr.done:
              q_estimates[i][tr.action] = tr.reward
            else:
              q_estimates[i][tr.action] = tr.reward + FLAGS.gamma * q_vals_new_t[i][dqn_best_action[i]]
          # use scheduled sampling to whether use true Q-values or DDQN estimation
          if FLAGS.dqn_scheduled_sampling:
            q_estimates = self.scheduled_sampling(batch_len, FLAGS.sampling_probability, b._y_extended, q_estimates)
          if not FLAGS.calculate_true_q:
            # when we are not training DDQN based on true Q-values,
            # we need to update Q-values in our transitions based on the q_estimates we collected from DQN current network.
            for trans, q_val in zip(transitions,q_estimates):
              trans.q_values = q_val # each have the size vocab_extended
          q_estimates = np.reshape(q_estimates, [FLAGS.batch_size, FLAGS.k, FLAGS.max_dec_steps, -1]) # shape (batch_size, k, max_dec_steps, vocab_size_extended)
        # Once we are done with modifying Q-values, we can use them to train the DDQN model.
        # In this paper, we use a priority experience buffer which always selects states with higher quality
        # to train the DDQN. The following line will add batch_size * max_dec_steps experiences to the replay buffer.
        # As mentioned before, the DDQN training is asynchronous. Therefore, once the related queues for DDQN training
        # are full, the DDQN will start the training.
        self.replay_buffer.add(transitions)
        # If dqn_pretrain flag is on, it means that we use a fixed Actor to only collect experiences for
        # DDQN pre-training
        if FLAGS.dqn_pretrain:
          tf.logging.info('RUNNNING DQN PRETRAIN: Adding data to relplay buffer only...')
          continue
        # if not, use the q_estimation to update the loss.
        results = self.model.run_train_steps(self.sess, batch, self.train_step, q_estimates)
      else:
          results = self.model.run_train_steps(self.sess, batch, self.train_step)
      t1=time.time()
      # get the summaries and iteration number so we can write summaries to tensorboard
      summaries = results['summaries'] # we will write these summaries to tensorboard using summary_writer
      self.train_step = results['global_step'] # we need this to update our running average loss
      tf.logging.info('seconds for training step {}: {}'.format(self.train_step, t1-t0))

      printer_helper = {}
      printer_helper['pgen_loss']= results['pgen_loss']
      if FLAGS.coverage:
        printer_helper['coverage_loss'] = results['coverage_loss']
        if FLAGS.rl_training or FLAGS.ac_training:
          printer_helper['rl_cov_total_loss']= results['reinforce_cov_total_loss']
        else:
          printer_helper['pointer_cov_total_loss'] = results['pointer_cov_total_loss']
      if FLAGS.rl_training or FLAGS.ac_training:
        printer_helper['shared_loss'] = results['shared_loss']
        printer_helper['rl_loss'] = results['rl_loss']
        printer_helper['rl_avg_logprobs'] = results['rl_avg_logprobs']
      if FLAGS.rl_training:
        printer_helper['sampled_r'] = np.mean(results['sampled_sentence_r_values'])
        printer_helper['greedy_r'] = np.mean(results['greedy_sentence_r_values'])
        printer_helper['r_diff'] = printer_helper['greedy_r'] - printer_helper['sampled_r']
      if FLAGS.ac_training:
        printer_helper['dqn_loss'] = np.mean(self.avg_dqn_loss) if len(self.avg_dqn_loss)>0 else 0

      for (k,v) in printer_helper.items():
        if not np.isfinite(v):
          raise Exception("{} is not finite. Stopping.".format(k))
        tf.logging.info('{}: {}\t'.format(k,v))
      tf.logging.info('-------------------------------------------')

      self.summary_writer.add_summary(summaries, self.train_step) # write the summaries
      if self.train_step % 100 == 0: # flush the summary writer every so often
        self.summary_writer.flush()
      if FLAGS.ac_training:
        self.dqn_summary_writer.flush()
      if self.train_step > FLAGS.max_iter: break

  def dqn_training(self):
    """ training the DDQN network."""
    try:
      while True:
        if self.dqn_train_step == FLAGS.dqn_pretrain_steps: raise SystemExit()
        _t = time.time()
        self.avg_dqn_loss = []
        avg_dqn_target_loss = []
        # Get a batch of size dqn_batch_size from replay buffer to train the model
        dqn_batch = self.replay_buffer.next_batch()
        if dqn_batch is None:
          tf.logging.info('replay buffer not loaded enough yet...')
          time.sleep(60)
          continue
        # Run train step for Current DQN model and collect the results
        dqn_results = self.dqn.run_train_steps(self.dqn_sess, dqn_batch)
        # Run test step for Target DQN model and collect the results and monitor the difference in loss between the two
        dqn_target_results = self.dqn_target.run_test_steps(self.dqn_sess, x=dqn_batch._x, y=dqn_batch._y, return_loss=True)
        self.dqn_train_step = dqn_results['global_step']
        self.dqn_summary_writer.add_summary(dqn_results['summaries'], self.dqn_train_step) # write the summaries
        self.avg_dqn_loss.append(dqn_results['loss'])
        avg_dqn_target_loss.append(dqn_target_results['loss'])
        self.dqn_train_step = self.dqn_train_step + 1
        tf.logging.info('seconds for training dqn model: {}'.format(time.time()-_t))
        # UPDATING TARGET DDQN NETWORK WITH CURRENT MODEL
        with self.dqn_graph.as_default():
          current_model_weights = self.dqn_sess.run([self.dqn.model_trainables])[0] # get weights of current model
          self.dqn_target.run_update_weights(self.dqn_sess, self.dqn_train_step, current_model_weights) # update target model weights with current model weights
        tf.logging.info('DQN loss at step {}: {}'.format(self.dqn_train_step, np.mean(self.avg_dqn_loss)))
        tf.logging.info('DQN Target loss at step {}: {}'.format(self.dqn_train_step, np.mean(avg_dqn_target_loss)))
        # sleeping is required if you want the keyboard interuption to work
        time.sleep(FLAGS.dqn_sleep_time)
    except (KeyboardInterrupt, SystemExit):
      tf.logging.info("Caught keyboard interrupt on worker. Stopping supervisor...")
      self.sv.stop()
      self.dqn_sv.stop()

  def watch_threads(self):
    """Watch example queue and batch queue threads and restart if dead."""
    while True:
      time.sleep(60)
      if not self.thrd_dqn_training.is_alive(): # if the thread is dead
        tf.logging.error('Found DQN Learning thread dead. Restarting.')
        self.thrd_dqn_training = Thread(target=self.dqn_training)
        self.thrd_dqn_training.daemon = True
        self.thrd_dqn_training.start()

  def run_eval(self):
    """Repeatedly runs eval iterations, logging to screen and writing summaries. Saves the model with the best loss seen so far."""
    self.model.build_graph() # build the graph
    saver = tf.train.Saver(max_to_keep=3) # we will keep 3 best checkpoints at a time
    sess = tf.Session(config=util.get_config())

    if FLAGS.embedding:
      sess.run(tf.global_variables_initializer(),feed_dict={self.model.embedding_place:self.word_vector})
    eval_dir = os.path.join(FLAGS.log_root, "eval") # make a subdir of the root dir for eval data
    bestmodel_save_path = os.path.join(eval_dir, 'bestmodel') # this is where checkpoints of best models are saved
    self.summary_writer = tf.summary.FileWriter(eval_dir)

    if FLAGS.ac_training:
      tf.logging.info('DDQN building graph')
      t1 = time.time()
      dqn_graph = tf.Graph()
      with dqn_graph.as_default():
        self.dqn.build_graph() # build dqn graph
        tf.logging.info('building current network took {} seconds'.format(time.time()-t1))
        self.dqn_target.build_graph() # build dqn target graph
        tf.logging.info('building target network took {} seconds'.format(time.time()-t1))
        dqn_saver = tf.train.Saver(max_to_keep=3) # keep 3 checkpoints at a time
        dqn_sess = tf.Session(config=util.get_config())
      dqn_train_step = 0
      replay_buffer = ReplayBuffer(self.dqn_hps)

    running_avg_loss = 0 # the eval job keeps a smoother, running average loss to tell it when to implement early stopping
    best_loss = self.restore_best_eval_model()  # will hold the best loss achieved so far
    train_step = 0

    while True:
      _ = util.load_ckpt(saver, sess) # load a new checkpoint
      if FLAGS.ac_training:
        _ = util.load_dqn_ckpt(dqn_saver, dqn_sess) # load a new checkpoint
      processed_batch = 0
      avg_losses = []
      # evaluate for 100 * batch_size before comparing the loss
      # we do this due to memory constraint, best to run eval on different machines with large batch size
      while processed_batch < 100*FLAGS.batch_size:
        processed_batch += FLAGS.batch_size
        batch = self.batcher.next_batch() # get the next batch
        if FLAGS.ac_training:
          t0 = time.time()
          transitions = self.model.collect_dqn_transitions(sess, batch, train_step, batch.max_art_oovs) # len(batch_size * k * max_dec_steps)
          tf.logging.info('Q values collection time: {}'.format(time.time()-t0))
          with dqn_graph.as_default():
            # if using true Q-value to train DQN network,
            # we do this as the pre-training for the DQN network to get better estimates
            batch_len = len(transitions)
            b = ReplayBuffer.create_batch(self.dqn_hps, transitions,len(transitions), use_state_prime = True, max_art_oovs = batch.max_art_oovs)
            b_prime = ReplayBuffer.create_batch(self.dqn_hps, transitions,len(transitions), use_state_prime = True, max_art_oovs = batch.max_art_oovs)
            dqn_results = self.dqn.run_test_steps(sess=dqn_sess, x= b._x, return_best_action=True)
            q_estimates = dqn_results['estimates'] # shape (len(transitions), vocab_size)
            dqn_best_action = dqn_results['best_action']

            tf.logging.info('running test step on dqn_target')
            dqn_target_results = self.dqn_target.run_test_steps(dqn_sess, x= b_prime._x)
            q_vals_new_t = dqn_target_results['estimates'] # shape (len(transitions), vocab_size)

            # we need to expand the q_estimates to match the input batch max_art_oov
            q_estimates = np.concatenate([q_estimates,np.zeros((len(transitions),batch.max_art_oovs))],axis=-1)

            tf.logging.info('fixing the action q-estimates')
            for i, tr in enumerate(transitions):
              if tr.done:
                q_estimates[i][tr.action] = tr.reward
              else:
                q_estimates[i][tr.action] = tr.reward + FLAGS.gamma * q_vals_new_t[i][dqn_best_action[i]]
            if FLAGS.dqn_scheduled_sampling:
              tf.logging.info('scheduled sampling on q-estimates')
              q_estimates = self.scheduled_sampling(batch_len, FLAGS.sampling_probability, b._y_extended, q_estimates)
            if not FLAGS.calculate_true_q:
              # when we are not training DQN based on true Q-values
              # we need to update Q-values in our transitions based on this q_estimates we collected from DQN current network.
              for trans, q_val in zip(transitions,q_estimates):
                trans.q_values = q_val # each have the size vocab_extended
            q_estimates = np.reshape(q_estimates, [FLAGS.batch_size, FLAGS.k, FLAGS.max_dec_steps, -1]) # shape (batch_size, k, max_dec_steps, vocab_size_extended)
          tf.logging.info('run eval step on seq2seq model.')
          t0=time.time()
          results = self.model.run_eval_step(sess, batch, train_step, q_estimates)
          t1=time.time()
        else:
          tf.logging.info('run eval step on seq2seq model.')
          t0=time.time()
          results = self.model.run_eval_step(sess, batch, train_step)
          t1=time.time()

        tf.logging.info('experiment: {}'.format(FLAGS.exp_name))
        tf.logging.info('processed_batch: {}, seconds for batch: {}'.format(processed_batch, t1-t0))

        printer_helper = {}
        loss = printer_helper['pgen_loss']= results['pgen_loss']
        if FLAGS.coverage:
          printer_helper['coverage_loss'] = results['coverage_loss']
          if FLAGS.rl_training or FLAGS.ac_training:
            printer_helper['rl_cov_total_loss']= results['reinforce_cov_total_loss']
          loss = printer_helper['pointer_cov_total_loss'] = results['pointer_cov_total_loss']
        if FLAGS.rl_training or FLAGS.ac_training:
          printer_helper['shared_loss'] = results['shared_loss']
          printer_helper['rl_loss'] = results['rl_loss']
          printer_helper['rl_avg_logprobs'] = results['rl_avg_logprobs']
        if FLAGS.rl_training:
          printer_helper['sampled_r'] = np.mean(results['sampled_sentence_r_values'])
          printer_helper['greedy_r'] = np.mean(results['greedy_sentence_r_values'])
          printer_helper['r_diff'] = printer_helper['greedy_r'] - printer_helper['sampled_r']
        if FLAGS.ac_training:
          printer_helper['dqn_loss'] = np.mean(self.avg_dqn_loss) if len(self.avg_dqn_loss) > 0 else 0

        for (k,v) in printer_helper.items():
          if not np.isfinite(v):
            raise Exception("{} is not finite. Stopping.".format(k))
          tf.logging.info('{}: {}\t'.format(k,v))

        # add summaries
        summaries = results['summaries']
        train_step = results['global_step']
        self.summary_writer.add_summary(summaries, train_step)

        # calculate running avg loss
        avg_losses.append(self.calc_running_avg_loss(np.asscalar(loss), running_avg_loss, train_step))
        tf.logging.info('-------------------------------------------')

      running_avg_loss = np.mean(avg_losses)
      tf.logging.info('==========================================')
      tf.logging.info('best_loss: {}\trunning_avg_loss: {}\t'.format(best_loss, running_avg_loss))
      tf.logging.info('==========================================')

      # If running_avg_loss is best so far, save this checkpoint (early stopping).
      # These checkpoints will appear as bestmodel-<iteration_number> in the eval dir
      if best_loss is None or running_avg_loss < best_loss:
        tf.logging.info('Found new best model with %.3f running_avg_loss. Saving to %s', running_avg_loss, bestmodel_save_path)
        saver.save(sess, bestmodel_save_path, global_step=train_step, latest_filename='checkpoint_best')
        best_loss = running_avg_loss

      # flush the summary writer every so often
      if train_step % 100 == 0:
        self.summary_writer.flush()
      #time.sleep(600) # run eval every 10 minute

  def main(self, unused_argv):
    if len(unused_argv) != 1: # prints a message if you've entered flags incorrectly
      raise Exception("Problem with flags: %s" % unused_argv)

    FLAGS.log_root = os.path.join(FLAGS.log_root, FLAGS.exp_name)
    tf.logging.set_verbosity(tf.logging.INFO) # choose what level of logging you want
    tf.logging.info('Starting seq2seq_attention in %s mode...', (FLAGS.mode))

    # Change log_root to FLAGS.log_root/FLAGS.exp_name and create the dir if necessary
    flags = getattr(FLAGS,"__flags")

    if not os.path.exists(FLAGS.log_root):
      if FLAGS.mode=="train":
        os.makedirs(FLAGS.log_root)
      else:
        raise Exception("Logdir %s doesn't exist. Run in train mode to create it." % (FLAGS.log_root))

    fw = open('{}/config.txt'.format(FLAGS.log_root), 'w')
    for k, v in flags.items():
      fw.write('{}\t{}\n'.format(k, v))
    fw.close()

    self.vocab = Vocab(FLAGS.vocab_path, FLAGS.vocab_size) # create a vocabulary

    # If in decode mode, set batch_size = beam_size
    # Reason: in decode mode, we decode one example at a time.
    # On each step, we have beam_size-many hypotheses in the beam, so we need to make a batch of these hypotheses.
    if FLAGS.mode == 'decode':
      FLAGS.batch_size = FLAGS.beam_size

    # If single_pass=True, check we're in decode mode
    if FLAGS.single_pass and FLAGS.mode!='decode':
      raise Exception("The single_pass flag should only be True in decode mode")

    # Make a namedtuple hps, containing the values of the hyperparameters that the model needs

    hparam_list = ['mode', 'lr', 'gpu_num',
    #'sampled_greedy_flag', 
    'gamma', 'eta', 
    'fixed_eta', 'reward_function', 'intradecoder', 
    'use_temporal_attention', 'ac_training','rl_training', 'matrix_attention', 'calculate_true_q',
    'enc_hidden_dim', 'dec_hidden_dim', 'k', 
    'scheduled_sampling', 'sampling_probability','fixed_sampling_probability',
    'alpha', 'hard_argmax', 'greedy_scheduled_sampling',
    'adagrad_init_acc', 'rand_unif_init_mag', 
    'trunc_norm_init_std', 'max_grad_norm', 
    'emb_dim', 'batch_size', 'max_dec_steps', 'max_enc_steps',
    'dqn_scheduled_sampling', 'dqn_sleep_time', 'E2EBackProp',
    'coverage', 'cov_loss_wt', 'pointer_gen']
    hps_dict = {}
    for key,val in flags.items(): # for each flag
      if key in hparam_list: # if it's in the list
        hps_dict[key] = val.value # add it to the dict
    if FLAGS.ac_training:
      hps_dict.update({'dqn_input_feature_len':(FLAGS.dec_hidden_dim)})
    self.hps = namedtuple("HParams", hps_dict.keys())(**hps_dict)
    # creating all the required parameters for DDQN model.
    if FLAGS.ac_training:
      hparam_list = ['lr', 'dqn_gpu_num', 
      'dqn_layers', 
      'dqn_replay_buffer_size', 
      'dqn_batch_size', 
      'dqn_target_update',
      'dueling_net',
      'dqn_polyak_averaging',
      'dqn_sleep_time',
      'dqn_scheduled_sampling',
      'max_grad_norm']
      hps_dict = {}
      for key,val in flags.items(): # for each flag
        if key in hparam_list: # if it's in the list
          hps_dict[key] = val.value # add it to the dict
      hps_dict.update({'dqn_input_feature_len':(FLAGS.dec_hidden_dim)})
      hps_dict.update({'vocab_size':self.vocab.size()})
      self.dqn_hps = namedtuple("HParams", hps_dict.keys())(**hps_dict)

    # Create a batcher object that will create minibatches of data
    self.batcher = Batcher(FLAGS.data_path, self.vocab, self.hps, single_pass=FLAGS.single_pass, decode_after=FLAGS.decode_after)

    tf.set_random_seed(111) # a seed value for randomness

    if self.hps.mode == 'train':
      print("creating model...")
      self.model = SummarizationModel(self.hps, self.vocab)
      if FLAGS.ac_training:
        # current DQN with paramters \Psi
        self.dqn = DQN(self.dqn_hps,'current')
        # target DQN with paramters \Psi^{\prime}
        self.dqn_target = DQN(self.dqn_hps,'target')
      self.setup_training()
    elif self.hps.mode == 'eval':
      self.model = SummarizationModel(self.hps, self.vocab)
      if FLAGS.ac_training:
        self.dqn = DQN(self.dqn_hps,'current')
        self.dqn_target = DQN(self.dqn_hps,'target')
      self.run_eval()
    elif self.hps.mode == 'decode':
      decode_model_hps = self.hps  # This will be the hyperparameters for the decoder model
      decode_model_hps = self.hps._replace(max_dec_steps=1) # The model is configured with max_dec_steps=1 because we only ever run one step of the decoder at a time (to do beam search). Note that the batcher is initialized with max_dec_steps equal to e.g. 100 because the batches need to contain the full summaries
      model = SummarizationModel(decode_model_hps, self.vocab)
      if FLAGS.ac_training:
        # We need our target DDQN network for collecting Q-estimation at each decoder step.
        dqn_target = DQN(self.dqn_hps,'target')
      else:
        dqn_target = None
      decoder = BeamSearchDecoder(model, self.batcher, self.vocab, dqn = dqn_target)
      decoder.decode() # decode indefinitely (unless single_pass=True, in which case deocde the dataset exactly once)
    else:
      raise ValueError("The 'mode' flag must be one of train/eval/decode")

  # Scheduled sampling used for either selecting true Q-estimates or the DDQN estimation
  # based on https://www.tensorflow.org/api_docs/python/tf/contrib/seq2seq/ScheduledEmbeddingTrainingHelper
  def scheduled_sampling(self, batch_size, sampling_probability, true, estimate):
    with variable_scope.variable_scope("ScheduledEmbedding"):
      # Return -1s where we do not sample, and sample_ids elsewhere
      select_sampler = bernoulli.Bernoulli(probs=sampling_probability, dtype=tf.bool)
      select_sample = select_sampler.sample(sample_shape=batch_size)
      sample_ids = array_ops.where(
                  select_sample,
                  tf.range(batch_size),
                  gen_array_ops.fill([batch_size], -1))
      where_sampling = math_ops.cast(
          array_ops.where(sample_ids > -1), tf.int32)
      where_not_sampling = math_ops.cast(
          array_ops.where(sample_ids <= -1), tf.int32)
      _estimate = array_ops.gather_nd(estimate, where_sampling)
      _true = array_ops.gather_nd(true, where_not_sampling)

      base_shape = array_ops.shape(true)
      result1 = array_ops.scatter_nd(indices=where_sampling, updates=_estimate, shape=base_shape)
      result2 = array_ops.scatter_nd(indices=where_not_sampling, updates=_true, shape=base_shape)
      result = result1 + result2
      return result1 + result2
예제 #4
0
class Seq2Seq(object):

  def calc_running_avg_loss(self, loss, running_avg_loss, step, decay=0.99):
    """Calculate the running average loss via exponential decay.
    This is used to implement early stopping w.r.t. a more smooth loss curve than the raw loss curve.

    Args:
      loss: loss on the most recent eval step
      running_avg_loss: running_avg_loss so far
      summary_writer: FileWriter object to write for tensorboard
      step: training iteration step
      decay: rate of exponential decay, a float between 0 and 1. Larger is smoother.

    Returns:
      running_avg_loss: new running average loss
    """
    if running_avg_loss == 0:  # on the first iteration just take the loss
      running_avg_loss = loss
    else:
      running_avg_loss = running_avg_loss * decay + (1 - decay) * loss
    running_avg_loss = min(running_avg_loss, 12)  # clip
    loss_sum = tf.Summary()
    tag_name = 'running_avg_loss/decay=%f' % (decay)
    loss_sum.value.add(tag=tag_name, simple_value=running_avg_loss)
    self.summary_writer.add_summary(loss_sum, step)
    tf.logging.info('running_avg_loss: %f', running_avg_loss)
    return running_avg_loss

  def restore_best_model(self):
    """Load bestmodel file from eval directory, add variables for adagrad, and save to train directory"""
    tf.logging.info("Restoring bestmodel for training...")

    # Initialize all vars in the model
    sess = tf.Session(config=util.get_config())
    print "Initializing all variables..."
    sess.run(tf.initialize_all_variables())

    # Restore the best model from eval dir
    saver = tf.train.Saver([v for v in tf.all_variables() if "Adagrad" not in v.name])
    print "Restoring all non-adagrad variables from best model in eval dir..."
    curr_ckpt = util.load_ckpt(saver, sess, "eval")
    print "Restored %s." % curr_ckpt

    # Save this model to train dir and quit
    new_model_name = curr_ckpt.split("/")[-1].replace("bestmodel", "model")
    new_fname = os.path.join(FLAGS.log_root, "train", new_model_name)
    print "Saving model to %s..." % (new_fname)
    new_saver = tf.train.Saver() # this saver saves all variables that now exist, including Adagrad variables
    new_saver.save(sess, new_fname)
    print "Saved."
    exit()

  def restore_best_eval_model(self):
    # load best evaluation loss so far
    best_loss = None
    best_step = None
    # goes through all event files and select the best loss achieved and return it
    event_files = sorted(glob('{}/eval/events*'.format(FLAGS.log_root)))
    for ef in event_files:
      try:
        for e in tf.train.summary_iterator(ef):
          for v in e.summary.value:
            step = e.step
            if 'running_avg_loss/decay' in v.tag:
              running_avg_loss = v.simple_value
              if best_loss is None or running_avg_loss < best_loss:
                best_loss = running_avg_loss
                best_step = step
      except:
        continue
    tf.logging.info('resotring best loss from the current logs: {}\tstep: {}'.format(best_loss, best_step))
    return best_loss

  def convert_to_coverage_model(self):
    """Load non-coverage checkpoint, add initialized extra variables for coverage, and save as new checkpoint"""
    tf.logging.info("converting non-coverage model to coverage model..")

    # initialize an entire coverage model from scratch
    sess = tf.Session(config=util.get_config())
    print "initializing everything..."
    sess.run(tf.global_variables_initializer())

    # load all non-coverage weights from checkpoint
    saver = tf.train.Saver([v for v in tf.global_variables() if "coverage" not in v.name and "Adagrad" not in v.name])
    print "restoring non-coverage variables..."
    curr_ckpt = util.load_ckpt(saver, sess)
    print "restored."

    # save this model and quit
    new_fname = curr_ckpt + '_cov_init'
    print "saving model to %s..." % (new_fname)
    new_saver = tf.train.Saver() # this one will save all variables that now exist
    new_saver.save(sess, new_fname)
    print "saved."
    exit()

  def convert_to_reinforce_model(self):
    """Load non-reinforce checkpoint, add initialized extra variables for reinforce, and save as new checkpoint"""
    tf.logging.info("converting non-reinforce model to reinforce model..")

    # initialize an entire reinforce model from scratch
    sess = tf.Session(config=util.get_config())
    print "initializing everything..."
    sess.run(tf.global_variables_initializer())

    # load all non-reinforce weights from checkpoint
    saver = tf.train.Saver([v for v in tf.global_variables() if "reinforce" not in v.name and "Adagrad" not in v.name])
    print "restoring non-reinforce variables..."
    curr_ckpt = util.load_ckpt(saver, sess)
    print "restored."

    # save this model and quit
    new_fname = curr_ckpt + '_rl_init'
    print "saving model to %s..." % (new_fname)
    new_saver = tf.train.Saver() # this one will save all variables that now exist
    new_saver.save(sess, new_fname)
    print "saved."
    exit()

  def setup_training(self):
    """Does setup before starting training (run_training)"""
    train_dir = os.path.join(FLAGS.log_root, "train")
    if not os.path.exists(train_dir): os.makedirs(train_dir)
    if FLAGS.ac_training:
      dqn_train_dir = os.path.join(FLAGS.log_root, "dqn", "train")
      if not os.path.exists(dqn_train_dir): os.makedirs(dqn_train_dir)
    #replaybuffer_pcl_path = os.path.join(FLAGS.log_root, "replaybuffer.pcl")
    #if not os.path.exists(dqn_target_train_dir): os.makedirs(dqn_target_train_dir)

    self.model.build_graph() # build the graph

    if FLAGS.convert_to_reinforce_model:
      assert (FLAGS.rl_training or FLAGS.ac_training), "To convert your pointer model to a reinforce model, run with convert_to_reinforce_model=True and either rl_training=True or ac_training=True"
      self.convert_to_reinforce_model()
    if FLAGS.convert_to_coverage_model:
      assert FLAGS.coverage, "To convert your non-coverage model to a coverage model, run with convert_to_coverage_model=True and coverage=True"
      self.convert_to_coverage_model()
    if FLAGS.restore_best_model:
      self.restore_best_model()
    saver = tf.train.Saver(max_to_keep=3) # keep 3 checkpoints at a time

    # Loads pre-trained word-embedding. By default the model learns the embedding.
    if FLAGS.embedding:
      self.vocab.LoadWordEmbedding(FLAGS.embedding, FLAGS.emb_dim)
      word_vector = self.vocab.getWordEmbedding()

    self.sv = tf.train.Supervisor(logdir=train_dir,
                       is_chief=True,
                       saver=saver,
                       summary_op=None,
                       save_summaries_secs=60, # save summaries for tensorboard every 60 secs
                       save_model_secs=60, # checkpoint every 60 secs
                       global_step=self.model.global_step,
                       init_feed_dict= {self.model.embedding_place:word_vector} if FLAGS.embedding else None
                       )
    self.summary_writer = self.sv.summary_writer
    self.sess = self.sv.prepare_or_wait_for_session(config=util.get_config())
    if FLAGS.ac_training:
      tf.logging.info('DDQN building graph')
      t1 = time.time()
      # We create a separate graph for DDQN
      self.dqn_graph = tf.Graph()
      with self.dqn_graph.as_default():
        self.dqn.build_graph() # build dqn graph
        tf.logging.info('building current network took {} seconds'.format(time.time()-t1))

        self.dqn_target.build_graph() # build dqn target graph
        tf.logging.info('building target network took {} seconds'.format(time.time()-t1))

        dqn_saver = tf.train.Saver(max_to_keep=3) # keep 3 checkpoints at a time
        self.dqn_sv = tf.train.Supervisor(logdir=dqn_train_dir,
                           is_chief=True,
                           saver=dqn_saver,
                           summary_op=None,
                           save_summaries_secs=60, # save summaries for tensorboard every 60 secs
                           save_model_secs=60, # checkpoint every 60 secs
                           global_step=self.dqn.global_step,
                           )
        self.dqn_summary_writer = self.dqn_sv.summary_writer
        self.dqn_sess = self.dqn_sv.prepare_or_wait_for_session(config=util.get_config())
      ''' #### TODO: try loading a previously saved replay buffer
      # right now this doesn't work due to running DQN on a thread
      if os.path.exists(replaybuffer_pcl_path):
        tf.logging.info('Loading Replay Buffer...')
        try:
          self.replay_buffer = pickle.load(open(replaybuffer_pcl_path, "rb"))
          tf.logging.info('Replay Buffer loaded...')
        except:
          tf.logging.info('Couldn\'t load Replay Buffer file...')
          self.replay_buffer = ReplayBuffer(self.dqn_hps)
      else:
        self.replay_buffer = ReplayBuffer(self.dqn_hps)
      tf.logging.info("Building DDQN took {} seconds".format(time.time()-t1))
      '''
      self.replay_buffer = ReplayBuffer(self.dqn_hps)
    tf.logging.info("Preparing or waiting for session...")
    tf.logging.info("Created session.")
    try:
      self.run_training() # this is an infinite loop until interrupted
    except (KeyboardInterrupt, SystemExit):
      tf.logging.info("Caught keyboard interrupt on worker. Stopping supervisor...")
      self.sv.stop()
      if FLAGS.ac_training:
        self.dqn_sv.stop()

  def run_training(self):
    """Repeatedly runs training iterations, logging loss to screen and writing summaries"""
    tf.logging.info("Starting run_training")

    if FLAGS.debug: # start the tensorflow debugger
      self.sess = tf_debug.LocalCLIDebugWrapperSession(self.sess)
      self.sess.add_tensor_filter("has_inf_or_nan", tf_debug.has_inf_or_nan)

    self.train_step = 0
    if FLAGS.ac_training:
      # DDQN training is done asynchronously along with model training
      tf.logging.info('Starting DQN training thread...')
      self.dqn_train_step = 0
      self.thrd_dqn_training = Thread(target=self.dqn_training)
      self.thrd_dqn_training.daemon = True
      self.thrd_dqn_training.start()

      watcher = Thread(target=self.watch_threads)
      watcher.daemon = True
      watcher.start()
    # starting the main thread
    tf.logging.info('Starting Seq2Seq training...')
    while True: # repeats until interrupted
      batch = self.batcher.next_batch()
      t0=time.time()
      if FLAGS.ac_training:
        # For DDQN, we first collect the model output to calculate the reward and Q-estimates
        # Then we fix the estimation either using our target network or using the true Q-values
        # This process will usually take time and we are working on improving it.
        transitions = self.model.collect_dqn_transitions(self.sess, batch, self.train_step, batch.max_art_oovs) # len(batch_size * k * max_dec_steps)
        tf.logging.info('Q-values collection time: {}'.format(time.time()-t0))
        # whenever we are working with the DDQN, we switch using DDQN graph rather than default graph
        with self.dqn_graph.as_default():
          batch_len = len(transitions)
          # we use current decoder state to predict q_estimates, use_state_prime = False
          b = ReplayBuffer.create_batch(self.dqn_hps, transitions,len(transitions), use_state_prime = False, max_art_oovs = batch.max_art_oovs)
          # we also get the next decoder state to correct the estimation, use_state_prime = True
          b_prime = ReplayBuffer.create_batch(self.dqn_hps, transitions,len(transitions), use_state_prime = True, max_art_oovs = batch.max_art_oovs)
          # use current DQN to estimate values from current decoder state
          dqn_results = self.dqn.run_test_steps(sess=self.dqn_sess, x= b._x, return_best_action=True)
          q_estimates = dqn_results['estimates'] # shape (len(transitions), vocab_size)
          dqn_best_action = dqn_results['best_action']
          #dqn_q_estimate_loss = dqn_results['loss']

          # use target DQN to estimate values for the next decoder state
          dqn_target_results = self.dqn_target.run_test_steps(self.dqn_sess, x= b_prime._x)
          q_vals_new_t = dqn_target_results['estimates'] # shape (len(transitions), vocab_size)

          # we need to expand the q_estimates to match the input batch max_art_oov
          # we use the q_estimate of UNK token for all the OOV tokens
          q_estimates = np.concatenate([q_estimates,
            np.reshape(q_estimates[:,0],[-1,1])*np.ones((len(transitions),batch.max_art_oovs))],axis=-1)
          # modify Q-estimates using the result collected from current and target DQN.
          # check algorithm 5 in the paper for more info: https://arxiv.org/pdf/1805.09461.pdf
          for i, tr in enumerate(transitions):
            if tr.done:
              q_estimates[i][tr.action] = tr.reward
            else:
              q_estimates[i][tr.action] = tr.reward + FLAGS.gamma * q_vals_new_t[i][dqn_best_action[i]]
          # use scheduled sampling to whether use true Q-values or DDQN estimation
          if FLAGS.dqn_scheduled_sampling:
            q_estimates = self.scheduled_sampling(batch_len, FLAGS.sampling_probability, b._y_extended, q_estimates)
          if not FLAGS.calculate_true_q:
            # when we are not training DDQN based on true Q-values,
            # we need to update Q-values in our transitions based on the q_estimates we collected from DQN current network.
            for trans, q_val in zip(transitions,q_estimates):
              trans.q_values = q_val # each have the size vocab_extended
          q_estimates = np.reshape(q_estimates, [FLAGS.batch_size, FLAGS.k, FLAGS.max_dec_steps, -1]) # shape (batch_size, k, max_dec_steps, vocab_size_extended)
        # Once we are done with modifying Q-values, we can use them to train the DDQN model.
        # In this paper, we use a priority experience buffer which always selects states with higher quality
        # to train the DDQN. The following line will add batch_size * max_dec_steps experiences to the replay buffer.
        # As mentioned before, the DDQN training is asynchronous. Therefore, once the related queues for DDQN training
        # are full, the DDQN will start the training.
        self.replay_buffer.add(transitions)
        # If dqn_pretrain flag is on, it means that we use a fixed Actor to only collect experiences for
        # DDQN pre-training
        if FLAGS.dqn_pretrain:
          tf.logging.info('RUNNNING DQN PRETRAIN: Adding data to relplay buffer only...')
          continue
        # if not, use the q_estimation to update the loss.
        results = self.model.run_train_steps(self.sess, batch, self.train_step, q_estimates)
      else:
          results = self.model.run_train_steps(self.sess, batch, self.train_step)
      t1=time.time()
      # get the summaries and iteration number so we can write summaries to tensorboard
      summaries = results['summaries'] # we will write these summaries to tensorboard using summary_writer
      self.train_step = results['global_step'] # we need this to update our running average loss
      tf.logging.info('seconds for training step {}: {}'.format(self.train_step, t1-t0))

      printer_helper = {}
      printer_helper['pgen_loss']= results['pgen_loss']
      if FLAGS.coverage:
        printer_helper['coverage_loss'] = results['coverage_loss']
        if FLAGS.rl_training or FLAGS.ac_training:
          printer_helper['rl_cov_total_loss']= results['reinforce_cov_total_loss']
        else:
          printer_helper['pointer_cov_total_loss'] = results['pointer_cov_total_loss']
      if FLAGS.rl_training or FLAGS.ac_training:
        printer_helper['shared_loss'] = results['shared_loss']
        printer_helper['rl_loss'] = results['rl_loss']
        printer_helper['rl_avg_logprobs'] = results['rl_avg_logprobs']
      if FLAGS.rl_training:
        printer_helper['sampled_r'] = np.mean(results['sampled_sentence_r_values'])
        printer_helper['greedy_r'] = np.mean(results['greedy_sentence_r_values'])
        printer_helper['r_diff'] = printer_helper['sampled_r'] - printer_helper['greedy_r']
      if FLAGS.ac_training:
        printer_helper['dqn_loss'] = np.mean(self.avg_dqn_loss) if len(self.avg_dqn_loss)>0 else 0

      for (k,v) in printer_helper.items():
        if not np.isfinite(v):
          raise Exception("{} is not finite. Stopping.".format(k))
        tf.logging.info('{}: {}\t'.format(k,v))
      tf.logging.info('-------------------------------------------')

      self.summary_writer.add_summary(summaries, self.train_step) # write the summaries
      if self.train_step % 100 == 0: # flush the summary writer every so often
        self.summary_writer.flush()
      if FLAGS.ac_training:
        self.dqn_summary_writer.flush()
      if self.train_step > FLAGS.max_iter: break

  def dqn_training(self):
    """ training the DDQN network."""
    try:
      while True:
        if self.dqn_train_step == FLAGS.dqn_pretrain_steps: raise SystemExit()
        _t = time.time()
        self.avg_dqn_loss = []
        avg_dqn_target_loss = []
        # Get a batch of size dqn_batch_size from replay buffer to train the model
        dqn_batch = self.replay_buffer.next_batch()
        if dqn_batch is None:
          tf.logging.info('replay buffer not loaded enough yet...')
          time.sleep(60)
          continue
        # Run train step for Current DQN model and collect the results
        dqn_results = self.dqn.run_train_steps(self.dqn_sess, dqn_batch)
        # Run test step for Target DQN model and collect the results and monitor the difference in loss between the two
        dqn_target_results = self.dqn_target.run_test_steps(self.dqn_sess, x=dqn_batch._x, y=dqn_batch._y, return_loss=True)
        self.dqn_train_step = dqn_results['global_step']
        self.dqn_summary_writer.add_summary(dqn_results['summaries'], self.dqn_train_step) # write the summaries
        self.avg_dqn_loss.append(dqn_results['loss'])
        avg_dqn_target_loss.append(dqn_target_results['loss'])
        self.dqn_train_step = self.dqn_train_step + 1
        tf.logging.info('seconds for training dqn model: {}'.format(time.time()-_t))
        # UPDATING TARGET DDQN NETWORK WITH CURRENT MODEL
        with self.dqn_graph.as_default():
          current_model_weights = self.dqn_sess.run([self.dqn.model_trainables])[0] # get weights of current model
          self.dqn_target.run_update_weights(self.dqn_sess, self.dqn_train_step, current_model_weights) # update target model weights with current model weights
        tf.logging.info('DQN loss at step {}: {}'.format(self.dqn_train_step, np.mean(self.avg_dqn_loss)))
        tf.logging.info('DQN Target loss at step {}: {}'.format(self.dqn_train_step, np.mean(avg_dqn_target_loss)))
        # sleeping is required if you want the keyboard interuption to work
        time.sleep(FLAGS.dqn_sleep_time)
    except (KeyboardInterrupt, SystemExit):
      tf.logging.info("Caught keyboard interrupt on worker. Stopping supervisor...")
      self.sv.stop()
      self.dqn_sv.stop()

  def watch_threads(self):
    """Watch example queue and batch queue threads and restart if dead."""
    while True:
      time.sleep(60)
      if not self.thrd_dqn_training.is_alive(): # if the thread is dead
        tf.logging.error('Found DQN Learning thread dead. Restarting.')
        self.thrd_dqn_training = Thread(target=self.dqn_training)
        self.thrd_dqn_training.daemon = True
        self.thrd_dqn_training.start()

  def run_eval(self):
    """Repeatedly runs eval iterations, logging to screen and writing summaries. Saves the model with the best loss seen so far."""
    self.model.build_graph() # build the graph
    saver = tf.train.Saver(max_to_keep=3) # we will keep 3 best checkpoints at a time
    sess = tf.Session(config=util.get_config())

    if FLAGS.embedding:
      sess.run(tf.global_variables_initializer(),feed_dict={self.model.embedding_place:self.word_vector})
    eval_dir = os.path.join(FLAGS.log_root, "eval") # make a subdir of the root dir for eval data
    bestmodel_save_path = os.path.join(eval_dir, 'bestmodel') # this is where checkpoints of best models are saved
    summary_writer = tf.summary.FileWriter(eval_dir)

    if FLAGS.ac_training:
      tf.logging.info('DDQN building graph')
      t1 = time.time()
      dqn_graph = tf.Graph()
      with dqn_graph.as_default():
        self.dqn.build_graph() # build dqn graph
        tf.logging.info('building current network took {} seconds'.format(time.time()-t1))
        self.dqn_target.build_graph() # build dqn target graph
        tf.logging.info('building target network took {} seconds'.format(time.time()-t1))
        dqn_saver = tf.train.Saver(max_to_keep=3) # keep 3 checkpoints at a time
        dqn_sess = tf.Session(config=util.get_config())
      dqn_train_step = 0
      replay_buffer = ReplayBuffer(self.dqn_hps)

    running_avg_loss = 0 # the eval job keeps a smoother, running average loss to tell it when to implement early stopping
    best_loss = self.restore_best_eval_model()  # will hold the best loss achieved so far
    train_step = 0

    while True:
      _ = util.load_ckpt(saver, sess) # load a new checkpoint
      if FLAGS.ac_training:
        _ = util.load_dqn_ckpt(dqn_saver, dqn_sess) # load a new checkpoint
      processed_batch = 0
      avg_losses = []
      # evaluate for 100 * batch_size before comparing the loss
      # we do this due to memory constraint, best to run eval on different machines with large batch size
      while processed_batch < 100*FLAGS.batch_size:
        processed_batch += FLAGS.batch_size
        batch = self.batcher.next_batch() # get the next batch
        if FLAGS.ac_training:
          t0 = time.time()
          transitions = self.model.collect_dqn_transitions(sess, batch, train_step, batch.max_art_oovs) # len(batch_size * k * max_dec_steps)
          tf.logging.info('Q values collection time: {}'.format(time.time()-t0))
          with dqn_graph.as_default():
            # if using true Q-value to train DQN network,
            # we do this as the pre-training for the DQN network to get better estimates
            batch_len = len(transitions)
            b = ReplayBuffer.create_batch(self.dqn_hps, transitions,len(transitions), use_state_prime = True, max_art_oovs = batch.max_art_oovs)
            b_prime = ReplayBuffer.create_batch(self.dqn_hps, transitions,len(transitions), use_state_prime = True, max_art_oovs = batch.max_art_oovs)
            dqn_results = self.dqn.run_test_steps(sess=dqn_sess, x= b._x, return_best_action=True)
            q_estimates = dqn_results['estimates'] # shape (len(transitions), vocab_size)
            dqn_best_action = dqn_results['best_action']

            tf.logging.info('running test step on dqn_target')
            dqn_target_results = self.dqn_target.run_test_steps(dqn_sess, x= b_prime._x)
            q_vals_new_t = dqn_target_results['estimates'] # shape (len(transitions), vocab_size)

            # we need to expand the q_estimates to match the input batch max_art_oov
            q_estimates = np.concatenate([q_estimates,np.zeros((len(transitions),batch.max_art_oovs))],axis=-1)

            tf.logging.info('fixing the action q-estimates')
            for i, tr in enumerate(transitions):
              if tr.done:
                q_estimates[i][tr.action] = tr.reward
              else:
                q_estimates[i][tr.action] = tr.reward + FLAGS.gamma * q_vals_new_t[i][dqn_best_action[i]]
            if FLAGS.dqn_scheduled_sampling:
              tf.logging.info('scheduled sampling on q-estimates')
              q_estimates = self.scheduled_sampling(batch_len, FLAGS.sampling_probability, b._y_extended, q_estimates)
            if not FLAGS.calculate_true_q:
              # when we are not training DQN based on true Q-values
              # we need to update Q-values in our transitions based on this q_estimates we collected from DQN current network.
              for trans, q_val in zip(transitions,q_estimates):
                trans.q_values = q_val # each have the size vocab_extended
            q_estimates = np.reshape(q_estimates, [FLAGS.batch_size, FLAGS.k, FLAGS.max_dec_steps, -1]) # shape (batch_size, k, max_dec_steps, vocab_size_extended)
          tf.logging.info('run eval step on seq2seq model.')
          t0=time.time()
          results = self.model.run_eval_step(sess, batch, train_step, q_estimates)
          t1=time.time()
        else:
          tf.logging.info('run eval step on seq2seq model.')
          t0=time.time()
          results = self.model.run_eval_step(sess, batch, train_step)
          t1=time.time()

        tf.logging.info('experiment: {}'.format(FLAGS.exp_name))
        tf.logging.info('processed_batch: {}, seconds for batch: {}'.format(processed_batch, t1-t0))

        printer_helper = {}
        loss = printer_helper['pgen_loss']= results['pgen_loss']
        if FLAGS.coverage:
          printer_helper['coverage_loss'] = results['coverage_loss']
          if FLAGS.rl_training or FLAGS.ac_training:
            loss = printer_helper['rl_cov_total_loss']= results['reinforce_cov_total_loss']
          else:
            loss = printer_helper['pointer_cov_total_loss'] = results['pointer_cov_total_loss']
        if FLAGS.rl_training or FLAGS.ac_training:
          printer_helper['shared_loss'] = results['shared_loss']
          printer_helper['rl_loss'] = results['rl_loss']
          printer_helper['rl_avg_logprobs'] = results['rl_avg_logprobs']

        for (k,v) in printer_helper.items():
          if not np.isfinite(v):
            raise Exception("{} is not finite. Stopping.".format(k))
          tf.logging.info('{}: {}\t'.format(k,v))

        # add summaries
        summaries = results['summaries']
        train_step = results['global_step']
        summary_writer.add_summary(summaries, train_step)

        # calculate running avg loss
        avg_losses.append(self.calc_running_avg_loss(np.asscalar(loss), running_avg_loss, summary_writer, train_step))
        tf.logging.info('-------------------------------------------')

      running_avg_loss = np.mean(avg_losses)
      tf.logging.info('==========================================')
      tf.logging.info('best_loss: {}\trunning_avg_loss: {}\t'.format(best_loss, running_avg_loss))
      tf.logging.info('==========================================')

      # If running_avg_loss is best so far, save this checkpoint (early stopping).
      # These checkpoints will appear as bestmodel-<iteration_number> in the eval dir
      if best_loss is None or running_avg_loss < best_loss:
        tf.logging.info('Found new best model with %.3f running_avg_loss. Saving to %s', running_avg_loss, bestmodel_save_path)
        saver.save(sess, bestmodel_save_path, global_step=train_step, latest_filename='checkpoint_best')
        best_loss = running_avg_loss

      # flush the summary writer every so often
      if train_step % 100 == 0:
        summary_writer.flush()
      #time.sleep(600) # run eval every 10 minute

  def main(self, unused_argv):
    if len(unused_argv) != 1: # prints a message if you've entered flags incorrectly
      raise Exception("Problem with flags: %s" % unused_argv)

    FLAGS.log_root = os.path.join(FLAGS.log_root, FLAGS.exp_name)
    tf.logging.set_verbosity(tf.logging.INFO) # choose what level of logging you want
    tf.logging.info('Starting seq2seq_attention in %s mode...', (FLAGS.mode))

    # Change log_root to FLAGS.log_root/FLAGS.exp_name and create the dir if necessary
    flags = getattr(FLAGS,"__flags")

    if not os.path.exists(FLAGS.log_root):
      if FLAGS.mode=="train":
        os.makedirs(FLAGS.log_root)
        fw = open('{}/config.txt'.format(FLAGS.log_root),'w')
        for k,v in flags.iteritems():
          fw.write('{}\t{}\n'.format(k,v))
        fw.close()
      else:
        raise Exception("Logdir %s doesn't exist. Run in train mode to create it." % (FLAGS.log_root))

    self.vocab = Vocab(FLAGS.vocab_path, FLAGS.vocab_size) # create a vocabulary

    # If in decode mode, set batch_size = beam_size
    # Reason: in decode mode, we decode one example at a time.
    # On each step, we have beam_size-many hypotheses in the beam, so we need to make a batch of these hypotheses.
    if FLAGS.mode == 'decode':
      FLAGS.batch_size = FLAGS.beam_size

    # If single_pass=True, check we're in decode mode
    if FLAGS.single_pass and FLAGS.mode!='decode':
      raise Exception("The single_pass flag should only be True in decode mode")

    # Make a namedtuple hps, containing the values of the hyperparameters that the model needs

    hparam_list = ['mode', 'lr', 'gpu_num',
    #'sampled_greedy_flag', 
    'gamma', 'eta', 
    'fixed_eta', 'reward_function', 'intradecoder', 
    'use_temporal_attention', 'ac_training','rl_training', 'matrix_attention', 'calculate_true_q',
    'enc_hidden_dim', 'dec_hidden_dim', 'k', 
    'scheduled_sampling', 'sampling_probability','fixed_sampling_probability',
    'alpha', 'hard_argmax', 'greedy_scheduled_sampling',
    'adagrad_init_acc', 'rand_unif_init_mag', 
    'trunc_norm_init_std', 'max_grad_norm', 
    'emb_dim', 'batch_size', 'max_dec_steps', 'max_enc_steps',
    'dqn_scheduled_sampling', 'dqn_sleep_time', 'E2EBackProp',
    'coverage', 'cov_loss_wt', 'pointer_gen']
    hps_dict = {}
    for key,val in flags.iteritems(): # for each flag
      if key in hparam_list: # if it's in the list
        hps_dict[key] = val # add it to the dict
    if FLAGS.ac_training:
      hps_dict.update({'dqn_input_feature_len':(FLAGS.dec_hidden_dim)})
    self.hps = namedtuple("HParams", hps_dict.keys())(**hps_dict)
    # creating all the required parameters for DDQN model.
    if FLAGS.ac_training:
      hparam_list = ['lr', 'dqn_gpu_num', 
      'dqn_layers', 
      'dqn_replay_buffer_size', 
      'dqn_batch_size', 
      'dqn_target_update',
      'dueling_net',
      'dqn_polyak_averaging',
      'dqn_sleep_time',
      'dqn_scheduled_sampling',
      'max_grad_norm']
      hps_dict = {}
      for key,val in flags.iteritems(): # for each flag
        if key in hparam_list: # if it's in the list
          hps_dict[key] = val # add it to the dict
      hps_dict.update({'dqn_input_feature_len':(FLAGS.dec_hidden_dim)})
      hps_dict.update({'vocab_size':self.vocab.size()})
      self.dqn_hps = namedtuple("HParams", hps_dict.keys())(**hps_dict)

    # Create a batcher object that will create minibatches of data
    self.batcher = Batcher(FLAGS.data_path, self.vocab, self.hps, single_pass=FLAGS.single_pass, decode_after=FLAGS.decode_after)

    tf.set_random_seed(111) # a seed value for randomness

    if self.hps.mode == 'train':
      print "creating model..."
      self.model = SummarizationModel(self.hps, self.vocab)
      if FLAGS.ac_training:
        # current DQN with paramters \Psi
        self.dqn = DQN(self.dqn_hps,'current')
        # target DQN with paramters \Psi^{\prime}
        self.dqn_target = DQN(self.dqn_hps,'target')
      self.setup_training()
    elif self.hps.mode == 'eval':
      self.model = SummarizationModel(self.hps, self.vocab)
      if FLAGS.ac_training:
        self.dqn = DQN(self.dqn_hps,'current')
        self.dqn_target = DQN(self.dqn_hps,'target')
      self.run_eval()
    elif self.hps.mode == 'decode':
      decode_model_hps = self.hps  # This will be the hyperparameters for the decoder model
      decode_model_hps = self.hps._replace(max_dec_steps=1) # The model is configured with max_dec_steps=1 because we only ever run one step of the decoder at a time (to do beam search). Note that the batcher is initialized with max_dec_steps equal to e.g. 100 because the batches need to contain the full summaries
      model = SummarizationModel(decode_model_hps, self.vocab)
      if FLAGS.ac_training:
        # We need our target DDQN network for collecting Q-estimation at each decoder step.
        dqn_target = DQN(self.dqn_hps,'target')
      else:
        dqn_target = None
      decoder = BeamSearchDecoder(model, self.batcher, self.vocab, dqn = dqn_target)
      decoder.decode() # decode indefinitely (unless single_pass=True, in which case deocde the dataset exactly once)
    else:
      raise ValueError("The 'mode' flag must be one of train/eval/decode")

  # Scheduled sampling used for either selecting true Q-estimates or the DDQN estimation
  # based on https://www.tensorflow.org/api_docs/python/tf/contrib/seq2seq/ScheduledEmbeddingTrainingHelper
  def scheduled_sampling(self, batch_size, sampling_probability, true, estimate):
    with variable_scope.variable_scope("ScheduledEmbedding"):
      # Return -1s where we do not sample, and sample_ids elsewhere
      select_sampler = bernoulli.Bernoulli(probs=sampling_probability, dtype=tf.bool)
      select_sample = select_sampler.sample(sample_shape=batch_size)
      sample_ids = array_ops.where(
                  select_sample,
                  tf.range(batch_size),
                  gen_array_ops.fill([batch_size], -1))
      where_sampling = math_ops.cast(
          array_ops.where(sample_ids > -1), tf.int32)
      where_not_sampling = math_ops.cast(
          array_ops.where(sample_ids <= -1), tf.int32)
      _estimate = array_ops.gather_nd(estimate, where_sampling)
      _true = array_ops.gather_nd(true, where_not_sampling)

      base_shape = array_ops.shape(true)
      result1 = array_ops.scatter_nd(indices=where_sampling, updates=_estimate, shape=base_shape)
      result2 = array_ops.scatter_nd(indices=where_not_sampling, updates=_true, shape=base_shape)
      result = result1 + result2
      return result1 + result2
예제 #5
0
class BeamSearch(object):
    def __init__(self, model, config, step):
        self.config = config
        self.model = model.to(device)

        self._decode_dir = os.path.join(config.log_root,
                                        'decode_S%s' % str(step))
        self._rouge_ref = os.path.join(self._decode_dir, 'rouge_ref')
        self._rouge_dec = os.path.join(self._decode_dir, 'rouge_dec')

        if not os.path.exists(self._decode_dir): os.mkdir(self._decode_dir)

        self.vocab = Vocab(config.vocab_file, config.vocab_size)
        self.test_data = CNNDMDataset('test', config.data_path, config,
                                      self.vocab)

    def sort_beams(self, beams):
        return sorted(beams, key=lambda h: h.avg_log_prob, reverse=True)

    @staticmethod
    def report_rouge(ref_path, dec_path):
        print("Now starting ROUGE eval...")
        files_rouge = FilesRouge(dec_path, ref_path)
        scores = files_rouge.get_scores(avg=True)
        logging(str(scores))

    #@staticmethod
    def get_summary(self, best_summary, batch):
        # Extract the output ids from the hypothesis and convert back to words
        output_ids = [int(t) for t in best_summary.tokens[1:]]
        decoded_words = output2words(
            output_ids, self.vocab,
            (batch.art_oovs[0] if self.config.pointer_gen else None))

        # Remove the [STOP] token from decoded_words, if necessary
        try:
            fst_stop_idx = decoded_words.index('<end>')
            decoded_words = decoded_words[:fst_stop_idx]
        except ValueError:
            decoded_words = decoded_words
        decoded_abstract = ' '.join(decoded_words)
        return decoded_abstract

    def decode(self):
        config = self.config
        start = time.time()
        counter = 0
        test_loader = DataLoader(
            self.test_data,
            batch_size=1,
            shuffle=False,
            collate_fn=Collate(beam_size=config.beam_size))

        ref = open(self._rouge_ref, 'w')
        dec = open(self._rouge_dec, 'w')

        for batch in test_loader:
            # Run beam search to get best Hypothesis
            best_summary = self.beam_search(batch)

            original_abstract = batch.original_abstract[0]
            decoded_abstract = self.get_summary(best_summary, batch)

            ref.write(original_abstract + '\n')
            dec.write(decoded_abstract + '\n')

            counter += 1
            if counter % 1000 == 0:
                print('%d example in %d sec' % (counter, time.time() - start))
                start = time.time()

        print("Decoder has finished reading dataset for single_pass.")
        ref.close()
        dec.close()
        self.report_rouge(self._rouge_ref, self._rouge_dec)

    def beam_search(self, batch):
        config = self.config
        #batch should have only one example
        enc_batch, enc_padding_mask, enc_lens, enc_batch_extend_vocab, extra_zeros, c_t_0, coverage_t_0 = \
            get_input_from_batch(batch, config, device)

        encoder_outputs, encoder_feature, encoder_hidden = self.model.encoder(
            enc_batch, enc_lens)
        s_t_0 = self.model.reduce_state(encoder_hidden)

        dec_h, dec_c = s_t_0  # 1 x 2*hidden_size

        dec_h = dec_h.squeeze()
        dec_c = dec_c.squeeze()

        #decoder batch preparation, it has beam_size example initially everything is repeated
        beams = [
            Beam(tokens=[self.vocab.word2id('<start>')],
                 log_probs=[0.0],
                 state=(dec_h[0], dec_c[0]),
                 context=c_t_0[0],
                 coverage=(coverage_t_0[0] if config.is_coverage else None))
            for _ in range(config.beam_size)
        ]
        results = []
        steps = 0
        while steps < config.max_dec_steps and len(results) < config.beam_size:
            latest_tokens = [h.latest_token for h in beams]
            latest_tokens = [t if t < self.vocab.size() else self.vocab.word2id('<unk>') \
                             for t in latest_tokens]
            y_t_1 = Variable(torch.tensor(latest_tokens)).to(device)

            all_state_h = []
            all_state_c = []

            all_context = []

            for h in beams:
                state_h, state_c = h.state
                all_state_h.append(state_h)
                all_state_c.append(state_c)
                all_context.append(h.context)

            s_t_1 = (torch.stack(all_state_h,
                                 0).unsqueeze(0), torch.stack(all_state_c,
                                                              0).unsqueeze(0))
            c_t_1 = torch.stack(all_context, 0)

            coverage_t_1 = None
            if config.is_coverage:
                all_coverage = []
                for h in beams:
                    all_coverage.append(h.coverage)
                coverage_t_1 = torch.stack(all_coverage, 0)

            final_dist, s_t, c_t, attn_dist, p_gen, coverage_t = self.model.decoder(
                y_t_1, s_t_1, encoder_outputs, encoder_feature,
                enc_padding_mask, c_t_1, extra_zeros, enc_batch_extend_vocab,
                coverage_t_1, steps)
            log_probs = torch.log(final_dist)
            topk_log_probs, topk_ids = torch.topk(log_probs,
                                                  config.beam_size * 2)

            dec_h, dec_c = s_t
            dec_h = dec_h.squeeze()
            dec_c = dec_c.squeeze()

            all_beams = []
            num_orig_beams = 1 if steps == 0 else len(beams)
            for i in range(num_orig_beams):
                h = beams[i]
                state_i = (dec_h[i], dec_c[i])
                context_i = c_t[i]
                coverage_i = (coverage_t[i] if config.is_coverage else None)

                for j in range(config.beam_size *
                               2):  # for each of the top 2*beam_size hyps:
                    new_beam = h.extend(token=topk_ids[i, j].item(),
                                        log_prob=topk_log_probs[i, j].item(),
                                        state=state_i,
                                        context=context_i,
                                        coverage=coverage_i)
                    all_beams.append(new_beam)

            beams = []
            for h in self.sort_beams(all_beams):
                if h.latest_token == self.vocab.word2id('<end>'):
                    if steps >= config.min_dec_steps:
                        results.append(h)
                else:
                    beams.append(h)
                if len(beams) == config.beam_size or len(
                        results) == config.beam_size:
                    break
            steps += 1

        if len(results) == 0:
            results = beams

        beams_sorted = self.sort_beams(results)

        return beams_sorted[0]
예제 #6
0
class BeamSearch(object):
    def __init__(self, model_file_path, data_path, data_class='val'):
        self.data_class = data_class
        if self.data_class not in  ['val', 'test']:
            print("data_class must be 'val' or 'test'.")
            raise ValueError

        # model_file_path e.g. --> ../log/{MODE NAME}/best_model/model_best_XXXXX
        model_name = os.path.basename(model_file_path)
        # log_root e.g. --> ../log/{MODE NAME}/
        log_root = os.path.dirname(os.path.dirname(model_file_path))
        # _decode_dir e.g. --> ../log/{MODE NAME}/decode_model_best_XXXXX/
        self._decode_dir = os.path.join(log_root, 'decode_%s' % (model_name))
        self._rouge_ref_dir = os.path.join(self._decode_dir, 'rouge_ref')
        self._rouge_dec_dir = os.path.join(self._decode_dir, 'rouge_dec_dir')
        self._result_path = os.path.join(self._decode_dir, 'result_%s_%s.txt' \
                                                        % (model_name, self.data_class))
        # remove result file if exist
        if os.path.isfile(self._result_path):
            os.remove(self._result_path)
        for p in [self._decode_dir, self._rouge_ref_dir, self._rouge_dec_dir]:
            if not os.path.exists(p):
                os.mkdir(p)

        self.vocab = Vocab(config.vocab_path, config.vocab_size)
        self.batcher = Batcher(data_path, self.vocab, mode='decode',
                               batch_size=config.beam_size, single_pass=True)
        time.sleep(5)

        self.model = Model(model_file_path, is_eval=True)


    def sort_beams(self, beams):
        return sorted(beams, key=lambda h: h.avg_log_prob, reverse=True)


    def beam_search(self, batch):
        # batch should have only one example
        enc_batch, enc_padding_mask, enc_lens, enc_batch_extend_vocab, extra_zeros, c_t_0, coverage_t_0 = \
            get_input_from_batch(batch, use_cuda)

        encoder_outputs, encoder_feature, encoder_hidden = self.model.encoder(enc_batch, enc_lens)
        s_t_0 = self.model.reduce_state(encoder_hidden)

        dec_h, dec_c = s_t_0 # 1 x 2H
        dec_h = dec_h.squeeze()
        dec_c = dec_c.squeeze()

        # decoder batch preparation, it has beam_size example initially everything is repeated
        beams = [Beam(tokens=[self.vocab.word2id(data.START_DECODING)],
                      log_probs=[0.0],
                      state=(dec_h[0], dec_c[0]),
                      context = c_t_0[0],
                      coverage=(coverage_t_0[0] if config.is_coverage else None))
                 for _ in range(config.beam_size)]
        results = []
        steps = 0
        while steps < config.max_dec_steps and len(results) < config.beam_size:
            latest_tokens = [h.latest_token for h in beams]
            latest_tokens = [t if t < self.vocab.size() else self.vocab.word2id(data.UNKNOWN_TOKEN) \
                             for t in latest_tokens]

            y_t_1 = Variable(torch.LongTensor(latest_tokens))
            if use_cuda:
                y_t_1 = y_t_1.cuda()
            all_state_h =[]
            all_state_c = []
            all_context = []

            for h in beams:
                state_h, state_c = h.state
                all_state_h.append(state_h)
                all_state_c.append(state_c)
                all_context.append(h.context)

            s_t_1 = (torch.stack(all_state_h, 0).unsqueeze(0), torch.stack(all_state_c, 0).unsqueeze(0))
            c_t_1 = torch.stack(all_context, 0)

            coverage_t_1 = None
            if config.is_coverage:
                all_coverage = []
                for h in beams:
                    all_coverage.append(h.coverage)
                coverage_t_1 = torch.stack(all_coverage, 0)

            final_dist, s_t, c_t, attn_dist, p_gen, coverage_t = self.model.decoder(y_t_1, s_t_1,
                                                        encoder_outputs, encoder_feature, enc_padding_mask, c_t_1,
                                                        extra_zeros, enc_batch_extend_vocab, coverage_t_1, steps)

            log_probs = torch.log(final_dist)
            topk_log_probs, topk_ids = torch.topk(log_probs, config.beam_size * 2)

            dec_h, dec_c = s_t
            dec_h = dec_h.squeeze()
            dec_c = dec_c.squeeze()

            all_beams = []
            num_orig_beams = 1 if steps == 0 else len(beams)
            for i in range(num_orig_beams):
                h = beams[i]
                state_i = (dec_h[i], dec_c[i])
                context_i = c_t[i]
                coverage_i = (coverage_t[i] if config.is_coverage else None)

                for j in range(config.beam_size * 2):  # for each of the top 2*beam_size hyps:
                    new_beam = h.extend(token=topk_ids[i, j].item(),
                                   log_prob=topk_log_probs[i, j].item(),
                                   state=state_i,
                                   context=context_i,
                                   coverage=coverage_i)
                    all_beams.append(new_beam)

            beams = []
            for h in self.sort_beams(all_beams):
                if h.latest_token == self.vocab.word2id(data.STOP_DECODING):
                    if steps >= config.min_dec_steps:
                        results.append(h)
                else:
                    beams.append(h)
                if len(beams) == config.beam_size or len(results) == config.beam_size:
                    break

            steps += 1

        if len(results) == 0:
            results = beams

        beams_sorted = self.sort_beams(results)
        return beams_sorted[0]
    
    
    def decode(self):
        start = time.time()
        counter = 0
        bleu_scores = []
        batch = self.batcher.next_batch()
        while batch is not None:
            # Run beam search to get best Hypothesis
            best_summary = self.beam_search(batch)

            # Extract the output ids from the hypothesis and convert back to words
            output_ids = [int(t) for t in best_summary.tokens[1:]]
            decoded_words = data.outputids2words(output_ids, self.vocab,
                                                 (batch.art_oovs[0] if config.pointer_gen else None))

            # Remove the [STOP] token from decoded_words, if necessary
            try:
                fst_stop_idx = decoded_words.index(data.STOP_DECODING)
                decoded_words = decoded_words[:fst_stop_idx]
            except ValueError:
                decoded_words = decoded_words

            original_articles = batch.original_articles[0]
            original_abstracts = batch.original_abstracts_sents[0]
            reference = original_abstracts[0].strip().split()
            bleu = nltk.translate.bleu_score.sentence_bleu([reference], decoded_words, weights = (0.5, 0.5))
            bleu_scores.append(bleu)

            # write_for_rouge(original_abstracts, decoded_words, counter,
            #                 self._rouge_ref_dir, self._rouge_dec_dir)

            write_for_result(original_articles, original_abstracts, decoded_words, \
                                                self._result_path, self.data_class)

            counter += 1
            if counter % 1000 == 0:
                print('%d example in %d sec'%(counter, time.time() - start))
                start = time.time()

            batch = self.batcher.next_batch()
        
        '''
        # uncomment this if you successfully install `pyrouge`
        print("Decoder has finished reading dataset for single_pass.")
        print("Now starting ROUGE eval...")
        results_dict = rouge_eval(self._rouge_ref_dir, self._rouge_dec_dir)
        rouge_log(results_dict, self._decode_dir)
        '''

        if self.data_class == 'val':
            print('Average BLEU score:', np.mean(bleu_scores))
            with open(self._result_path, "a") as f:
                print('Average BLEU score:', np.mean(bleu_scores), file=f)

    def get_processed_path(self):
        # ../log/{MODE NAME}/decode_model_best_XXXXX/result_model_best_2800_{data_class}.txt
        input_path = self._result_path
        temp = os.path.splitext(input_path)
        # ../log/{MODE NAME}/decode_model_best_XXXXX/result_model_best_2800_{data_class}_processed.txt
        output_path = temp[0] + "_processed" + temp[1]
        return input_path, output_path
예제 #7
0
def main():
    embedding_dict_file = os.path.join(os.path.dirname(hps.word_count_path),
                                       'emb_dict_50000.pkl')
    vocab = Vocab(hps.word_count_path, hps.glove_path, hps.embedding_dim,
                  hps.max_vocab_size, embedding_dict_file)
    train_file = os.path.join(hps.data_path, 'train_raw.json')
    dev_file = os.path.join(hps.data_path, 'dev_raw.json')  #'dev_raw.json')

    if (not os.path.exists(train_file)) \
    or (not os.path.exists(dev_file)):
        raise Exception(
            'train and dev data not exist in data_path, please check')

    if hps.save and not hps.exp_dir:
        raise Exception(
            'please specify exp_dir when you want to save experiment info')

    print(vars(hps))
    if hps.save:
        utils.save_hps(hps.exp_dir, hps)

    net = PointerNet(hps, vocab.emb_mat)
    net = net.cuda()

    model_parameters = list(filter(lambda p: p.requires_grad,
                                   net.parameters()))
    print('the number of parameters in model:',
          sum(p.numel() for p in model_parameters))
    optimizer = optim.Adam(model_parameters)

    train_data_batcher = Batcher(train_file, vocab, hps, hps.single_pass)
    dev_data_batcher = Batcher(dev_file, vocab, hps, hps.single_pass)

    if hps.reward_metric == 'bleu':
        reward = get_batch_bleu

    global_step = 0
    dev_loss_track = []
    min_dev_loss = math.inf
    for i in range(hps.num_epoch):
        epoch_loss_track = []
        train_data_batcher.setup()
        while True:
            start = time.time()
            try:
                batch = train_data_batcher.next_batch()
                #print('get next batch time:', time.time()-start)
            except StopIteration:
                # do evaluation here, if necessary, to save best model
                dev_data_batcher.setup()
                dev_loss = run_eval(dev_data_batcher, net)
                print(
                    "epoch {}: avg train loss: {:>10.4f}, dev_loss: {:>10.4f}".
                    format(i + 1,
                           sum(epoch_loss_track) / len(epoch_loss_track),
                           dev_loss))
                dev_loss_track.append(dev_loss)

                if i > hps.early_stopping_from:
                    last5devloss = dev_loss_track[i] + dev_loss_track[
                        i - 1] + dev_loss_track[i - 2] + dev_loss_track[
                            i - 3] + dev_loss_track[i - 4]
                    last10devloss = dev_loss_track[i - 5] + dev_loss_track[
                        i - 6] + dev_loss_track[i - 7] + dev_loss_track[
                            i - 8] + dev_loss_track[i - 9]
                    if hps.early_stopping_from and last5devloss >= last10devloss:
                        print("early stopping by dev_loss!")
                        sys.exit()

                if dev_loss < min_dev_loss:
                    min_dev_loss = dev_loss
                    if hps.save:
                        utils.save_model(hps.exp_dir, net, min_dev_loss)
                break

            paragraph_tensor = torch.tensor(batch.enc_batch,
                                            dtype=torch.int64,
                                            requires_grad=False).cuda()
            question_tensor = torch.tensor(batch.dec_batch,
                                           dtype=torch.int64,
                                           requires_grad=False).cuda()
            answer_position_tensor = torch.tensor(batch.ans_indices,
                                                  dtype=torch.int64,
                                                  requires_grad=False).cuda()
            target_tensor = torch.tensor(batch.target_batch,
                                         dtype=torch.int64,
                                         requires_grad=False).cuda()

            paragraph_batch_extend_vocab = None
            max_para_oovs = None
            if hps.pointer_gen:
                paragraph_batch_extend_vocab = torch.tensor(
                    batch.enc_batch_extend_vocab,
                    dtype=torch.int64,
                    requires_grad=False).cuda()
                max_para_oovs = batch.max_para_oovs

            optimizer.zero_grad()
            net.train()

            vocab_scores, vocab_dists, attn_dists, final_dists = net(
                paragraph_tensor, question_tensor, answer_position_tensor,
                paragraph_batch_extend_vocab, max_para_oovs)

            dec_padding_mask = torch.ne(target_tensor, 0).float().cuda()
            # for self-critic
            if hps.self_critic:
                greedy_seq = [
                    torch.argmax(dist, dim=1, keepdim=True)
                    for dist in final_dists
                ]  # each dist = [batch_size, vsize]
                greedy_seq_tensor = torch.cat(greedy_seq,
                                              dim=1)  # [batch_size, seq_len]

                sample_seq = []
                for dist in final_dists:
                    m = torch.distributions.categorical.Categorical(probs=dist)
                    sample_seq.append(m.sample())  # each is [batch_size,]
                sample_seq_tensor = torch.stack(sample_seq, dim=1)

                if hps.pointer_gen:
                    loss_per_step = []
                    for dist, sample_tgt in zip(final_dists, sample_seq):
                        # dist = [batch_size, extended_vsize]
                        probs = torch.gather(
                            dist, 1, sample_tgt.unsqueeze(1)).squeeze()
                        losses = -torch.log(probs)
                        loss_per_step.append(losses)  # a list of [batch_size,]
                    rl_loss = mask_and_avg(loss_per_step,
                                           dec_padding_mask,
                                           batch_average=False,
                                           step_average=False)
                    # this rl_loss = [batch_size, ]
                else:
                    # a list of dec_max_len (vocab_scores)
                    loss_batch_by_step = F.cross_entropy(
                        torch.stack(vocab_scores,
                                    dim=1).reshape(-1, vocab.size()),
                        sample_seq_tensor.reshape(-1),
                        size_average=False,
                        reduce=False)
                    # loss [batch_size*dec_max_len,]
                    mask_loss_batch_by_step = loss_batch_by_step * dec_padding_mask.reshape(
                        -1)
                    batch_size = vocab_scores[0].size(0)
                    rl_loss = torch.sum(mask_loss_batch_by_step.reshape(
                        batch_size, -1),
                                        dim=1)

                r1 = reward(target_tensor, greedy_seq_tensor)
                r2 = reward(target_tensor, sample_seq_tensor)
                reward_diff = r1 - r2
                final_rl_loss = reward_diff * rl_loss
                loss = torch.mean(final_rl_loss)
                print(
                    'r1: %.3f, r2: %.3f, reward_diff: %.3f, final rl loss: %.3f, loss batch mean: %.3f'
                    % (torch.max(r1).item(), torch.max(r2).item(),
                       torch.max(reward_diff).item(),
                       torch.max(final_rl_loss).item(), loss.item()))

            # for maximum likelihood
            if hps.maxium_likelihood:
                if hps.pointer_gen:
                    loss_per_step = []
                    for dec_step, dist in enumerate(final_dists):
                        # dist = [batch_size, extended_vsize]
                        targets = target_tensor[:, dec_step]
                        gold_probs = torch.gather(
                            dist, 1, targets.unsqueeze(1)).squeeze()
                        losses = -torch.log(gold_probs)
                        loss_per_step.append(losses)  # a list of [batch_size,]
                    loss = mask_and_avg(loss_per_step, dec_padding_mask)
                else:
                    # a list of dec_max_len (vocab_scores)
                    loss_batch_by_step = F.cross_entropy(
                        torch.stack(vocab_scores,
                                    dim=1).reshape(-1, vocab.size()),
                        target_tensor.reshape(-1),
                        size_average=False,
                        reduce=False)
                    # loss [batch_size*dec_max_len,]
                    loss = torch.sum(loss_batch_by_step *
                                     dec_padding_mask.reshape(-1)) / torch.sum(
                                         dec_padding_mask)

            epoch_loss_track.append(loss.item())
            global_step += 1

            loss.backward()
            nn.utils.clip_grad_norm_(net.parameters(), max_norm=hps.norm_limit)
            optimizer.step()
            #print('time one step:', time.time()-start)
            if (global_step == 1) or (global_step % hps.print_every == 0):
                print('Step {:>5}: ave loss: {:>10.4f}, speed: {:.1f} case/s'.
                      format(global_step,
                             sum(epoch_loss_track) / len(epoch_loss_track),
                             hps.batch_size / (time.time() - start)))
예제 #8
0
def main():
    vocab = Vocab(hps.word_count_path, hps.glove_path, hps.embedding_dim)

    net = PointerNet(hps, vocab.emb_mat)
    net = net.cuda()

    data_batcher = batcher(hps.data_path, vocab, hps, hps.single_pass)
    model_parameters = filter(lambda p: p.requires_grad, net.parameters())
    optimizer = optim.Adam(model_parameters)

    loss_track = []
    global_step = 0
    while True:
        start = time.time()
        batch = next(data_batcher)
        #batch = pickle.load(open('one_batch.pkl', 'rb'))
        paragraph_tensor = torch.tensor(batch.enc_batch, dtype=torch.int64, requires_grad=False).cuda()
        question_tensor = torch.tensor(batch.dec_batch, dtype=torch.int64, requires_grad=False).cuda()
        answer_position_tensor = torch.tensor(batch.ans_indices, dtype=torch.int64, requires_grad=False).cuda()
        target_tensor = torch.tensor(batch.target_batch, dtype=torch.int64, requires_grad=False).cuda()
        
        paragraph_batch_extend_vocab = None
        max_para_oovs = None
        if hps.pointer_gen:
            paragraph_batch_extend_vocab = torch.tensor(batch.enc_batch_extend_vocab, dtype=torch.int64, requires_grad=False).cuda()
            max_para_oovs = batch.max_para_oovs
        
        vocab_scores, vocab_dists, attn_dists, final_dists = net(paragraph_tensor, question_tensor, answer_position_tensor,
                                                                 paragraph_batch_extend_vocab, max_para_oovs)
        
        optimizer.zero_grad()
        dec_padding_mask = torch.ne(target_tensor, 0).float().cuda()
        if hps.pointer_gen:
            loss_per_step = []
            for dec_step, dist in enumerate(final_dists):
                # dist = [batch_size, extended_vsize]
                targets = target_tensor[:,dec_step]
                gold_probs = torch.gather(dist, 1, targets.unsqueeze(1)).squeeze()
                losses = -torch.log(gold_probs)
                loss_per_step.append(losses) # a list of [batch_size,]
            loss = mask_and_avg(loss_per_step, dec_padding_mask)
        else:
            # a list of dec_max_len (vocab_scores)
            loss_batch_by_step = F.cross_entropy(torch.cat(vocab_scores, dim=1).reshape(-1, vocab.size()), target_tensor.reshape(-1), size_average=False, reduce=False)
            # loss [batch_size*dec_max_len,]
            loss = torch.sum(loss_batch_by_step * dec_padding_mask.reshape(-1))/torch.sum(dec_padding_mask)
        loss_track.append(loss.item())
        global_step += 1

        loss.backward()
        optimizer.step()
        if global_step % hps.print_every == 0:
            print('Step {:>10}: ave loss: {:>10.4f}, speed: {:.4f}/case'.format(global_step, sum(loss_track)/len(loss_track), (time.time()-start)/hps.batch_size))
            loss_track = []