Exemple #1
0
def run_rl_eval_gramma(model, batcher, sess, eta):
    loss = 0
    batches = batcher.fill_batch_queue(is_training=False)
    for batch in batches:
        feed_dict = make_feed_dict(model, batch)
        feed_dict[model._eta] = eta
        eloss = sess.run(model._pgen_loss, feed_dict)
        loss += eloss
    return loss
Exemple #2
0
 def run_encoder(self, sess, batch):
     feed_dict = make_feed_dict(
         self, batch, just_enc=True)  # feed the batch into the placeholders
     (enc_states,
      dec_in_state) = sess.run([self._enc_states, self._dec_in_state],
                               feed_dict)  # run the encoder
     dec_in_state = tf.contrib.rnn.LSTMStateTuple(dec_in_state.c[0],
                                                  dec_in_state.h[0])
     return enc_states, dec_in_state
Exemple #3
0
def run_training():
    print('batch size', FLAGS.batch_size)
    summarizationModel = PointerNet(FLAGS, vocab)
    summarizationModel.build_graph()
    batcher = Batcher(FLAGS.data_path,
                      vocab,
                      FLAGS,
                      single_pass=FLAGS.single_pass,
                      decode_after=FLAGS.decode_after)
    val_batcher = Batcher(FLAGS.val_data_path,
                          vocab,
                          FLAGS,
                          single_pass=FLAGS.single_pass,
                          decode_after=FLAGS.decode_after)
    sess = tf.Session(config=get_config())
    sess.run(tf.global_variables_initializer())

    eval_max_reward = -float('inf')
    saver = tf.train.Saver(max_to_keep=10)
    if FLAGS.restore_path:
        print('loading params...')
        saver.restore(sess, FLAGS.restore_path)
    epoch = FLAGS.epoch
    step = 0
    patient = FLAGS.patient
    while epoch > 0:
        batches = batcher.fill_batch_queue()
        print('load batch...')
        for batch in batches:
            print('start training...')
            step += 1
            feed_dict = make_feed_dict(summarizationModel, batch)
            loss, _ = sess.run(
                [summarizationModel.loss, summarizationModel.train_op],
                feed_dict)
            print("epoch : {0}, step : {1}, loss : {2}".format(
                abs(epoch - FLAGS.epoch), step, loss))
            if step % FLAGS.eval_step == 0:
                eval_reward = run_eval(summarizationModel, val_batcher, sess)
                print('eval reward ', eval_reward)
                if eval_max_reward < eval_reward:
                    if not os.path.exists(FLAGS.checkpoint):
                        os.mkdir(FLAGS.checkpoint)
                    saver.save(sess,
                               save_path=os.path.join(
                                   FLAGS.checkpoint,
                                   'model_{0}_{1}.ckpt'.format(
                                       step, eval_reward)))
                    eval_max_reward = eval_reward
                    patient = FLAGS.patient
                print('eval max reward : {0}'.format(eval_max_reward))
                if patient < 0:
                    break

                if eval_max_reward - eval_reward > FLAGS.threshold:
                    patient -= 1
def run_eval(model, batcher, sess):
    print('eval start ......')
    reward = []
    batches = batcher.fill_batch_queue(is_training=False)
    for batch in batches:
        feed_dict = make_feed_dict(model, batch)
        #eloss = sess.run(model.loss, feed_dict)
        reward.extend(get_reward(sess, model, feed_dict, batch))
        #loss += eloss
    return np.mean(reward)
Exemple #5
0
def run_rl_eval(model, batcher, sess, eta):
    loss = 0
    batches = batcher.fill_batch_queue(is_training=False)
    for batch in batches:
        feed_dict = make_feed_dict(model, batch)
        feed_dict[model._eta] = eta
        if FLAGS.coverage:
            eloss = sess.run(
                model._pgen_loss + FLAGS.cov_loss_wt * model._coverage_loss,
                feed_dict)
        else:
            eloss = sess.run(model._pgen_loss, feed_dict)
        loss += eloss
    return loss
Exemple #6
0
def run_rl_training():

    summarizationModel = RLNet(FLAGS, vocab)
    summarizationModel.build_graph()
    batcher = Batcher(FLAGS.data_path,
                      vocab,
                      FLAGS,
                      single_pass=FLAGS.single_pass,
                      decode_after=FLAGS.decode_after)
    val_batcher = Batcher(FLAGS.val_data_path,
                          vocab,
                          FLAGS,
                          single_pass=FLAGS.single_pass,
                          decode_after=FLAGS.decode_after)
    sess = tf.Session(config=get_config())
    saver = tf.train.Saver(max_to_keep=100)
    if FLAGS.restore_rl_path:
        saver.restore(sess, FLAGS.restore_rl_path)
    else:
        sess.run(tf.global_variables_initializer())
        sess.run(
            loading_variable([v for v in tf.trainable_variables()],
                             reader_params(load_best_model(
                                 FLAGS.restore_path))))
    print('loading params...')
    epoch = FLAGS.epoch
    step = 0
    eval_loss = float('inf')
    while epoch > 0:
        batches = batcher.fill_batch_queue()
        for batch in batches:
            step += 1
            sampled_sentence_r_values = []
            greedy_sentence_r_values = []
            feed_dict = make_feed_dict(summarizationModel, batch)

            to_return = {
                'sampled_sentences':
                summarizationModel.sampled_sentences,
                'greedy_search_sentences':
                summarizationModel.greedy_search_sentences
            }
            ret_dict = sess.run(to_return, feed_dict)
            # calculate reward
            for sampled_sentence, greedy_search_sentence, target_sentence in zip(
                    ret_dict['sampled_sentences'],
                    ret_dict['greedy_search_sentences'], batch.target_batch):
                assert len(sampled_sentence[0]) == len(target_sentence) == len(
                    greedy_search_sentence[0])
                reference_sent = ' '.join([str(k) for k in target_sentence])
                sampled_sent = ' '.join([str(k) for k in sampled_sentence[0]])
                sampled_sentence_r_values.append(
                    reward_function(reference_sent, sampled_sent))
                greedy_sent = ' '.join(
                    [str(k) for k in greedy_search_sentence[0]])
                greedy_sentence_r_values.append(
                    reward_function(reference_sent, greedy_sent))

            to_return = {
                'train_op': summarizationModel.train_op,
                'pgen_loss': summarizationModel._pgen_loss,
                'rl_loss': summarizationModel._rl_loss,
                'loss': summarizationModel.loss
            }
            to_return['s_r'] = summarizationModel._sampled_sentence_r_values
            to_return['g_r'] = summarizationModel._greedy_sentence_r_values

            feed_dict[summarizationModel.
                      _sampled_sentence_r_values] = sampled_sentence_r_values
            feed_dict[summarizationModel.
                      _greedy_sentence_r_values] = greedy_sentence_r_values
            feed_dict[summarizationModel._eta] = 0.5
            res = sess.run(to_return, feed_dict)

            print(
                'step : {0},pgen_loss : {1}, rl_loss : {2}, loss : {3}, reward : {4}'
                .format(step, res['pgen_loss'], res['rl_loss'], res['loss'],
                        np.sum(res['s_r'] - res['g_r'])))
            if step % FLAGS.eval_step == 0:
                eval_ = run_rl_eval(summarizationModel, val_batcher, sess, 0.5)
                if eval_ < eval_loss:
                    if not os.path.exists(FLAGS.checkpoint):
                        os.mkdir(FLAGS.checkpoint)
                    saver.save(sess,
                               save_path=os.path.join(
                                   FLAGS.checkpoint,
                                   'model_{0}_{1}.ckpt'.format(step, eval_)))
                    eval_loss = eval_
                    patient = FLAGS.patient
                print('eval loss : ', eval_loss)
                if patient < 0:
                    break

                if eval_ - eval_loss > FLAGS.threshold:
                    patient -= 1
Exemple #7
0
def run_rl_training_gamma(FLAGS, vocab):
    summarizationModel = RLNet(FLAGS, vocab)
    summarizationModel.build_graph()
    batcher = Batcher(FLAGS.data_path, vocab, FLAGS, single_pass=FLAGS.single_pass, decode_after=FLAGS.decode_after)
    val_batcher = Batcher(FLAGS.val_data_path, vocab, FLAGS, single_pass=FLAGS.single_pass,
                          decode_after=FLAGS.decode_after)
    sess = tf.Session(config=get_config())
    saver = tf.train.Saver(max_to_keep=10)
    if FLAGS.restore_rl_path:
        print('restore rl model...')
        saver.restore(sess, FLAGS.restore_rl_path)
    else:
        sess.run(tf.global_variables_initializer())
        sess.run(
            loading_variable([v for v in tf.trainable_variables()], reader_params(load_best_model(FLAGS.restore_path))))
        print('loading params...')
    epoch = FLAGS.epoch
    step = 0
    patient = FLAGS.patient
    eval_max_reward = -float('inf')
    while epoch > 0:
        batches = batcher.fill_batch_queue()
        for batch in batches:
            step += 1
            sampled_sentence_r_values = []
            greedy_sentence_r_values = []
            feed_dict = make_feed_dict(summarizationModel, batch)

            to_return = {
                'sampled_sentences': summarizationModel.sampled_sentences,

            }
            ret_dict = sess.run(to_return, feed_dict)
            Rs = []
            # calculate reward
            for sampled_sentence, target_sentence in zip(ret_dict['sampled_sentences'],batch.target_batch):
              #  print('sampled : ',sampled_sentence)
              #  print('target : ', target_sentence)
                reward = compute_reward(sampled_sentence[0], target_sentence)
                R = 0
                R_l = []
                for r in reward[::-1]:
                    R = r + FLAGS.gamma * R
                    R_l.insert(0,R)
                #avg = np.mean(R_l)
                #R_l = list(map(lambda a:a-avg, R_l))
                Rs.append(R_l)
            to_return = {
                'train_op': summarizationModel.train_op,
                'pgen_loss': summarizationModel._pgen_loss,
                'rl_loss': summarizationModel._rl_loss,
                'loss': summarizationModel.loss
            }
            to_return['reward'] = summarizationModel._reward

            
            feed_dict[summarizationModel._reward] = Rs
            feed_dict[summarizationModel._eta] = 0.1
            res = sess.run(to_return, feed_dict)

            print('step : {0}, pgen_loss : {1}, rl_loss : {2}, loss : {3}, reward : {4}'.format(step,res['pgen_loss'], res['rl_loss'],
                                                                                    res['loss'],
                                                                                    np.mean(res['reward'],axis=0)[0]
                                                                                    ))
            if step % FLAGS.eval_step == 0:
                #eval_ = run_rl_eval_gramma(summarizationModel, val_batcher, sess, 0.5)
                eval_reward = run_eval(summarizationModel, val_batcher, sess)
                print('eval reward  ', eval_reward)
                if eval_reward > eval_max_reward:
                    if not os.path.exists(FLAGS.checkpoint): os.mkdir(FLAGS.checkpoint)
                    saver.save(sess, save_path=os.path.join(FLAGS.checkpoint, 'model_{0}_{1}.ckpt'.format(step, eval_reward)))
                    eval_max_reward = eval_reward
                    patient = FLAGS.patient
                print('eval max ward ', eval_max_reward)
                if patient < 0:
                    break

                if eval_max_reward - eval_reward > FLAGS.threshold:
                    patient -= 1