示例#1
0
def run_training():
    print('batch size', FLAGS.batch_size)
    summarizationModel = PointerNet(FLAGS, vocab)
    summarizationModel.build_graph()
    batcher = Batcher(FLAGS.data_path,
                      vocab,
                      FLAGS,
                      single_pass=FLAGS.single_pass,
                      decode_after=FLAGS.decode_after)
    val_batcher = Batcher(FLAGS.val_data_path,
                          vocab,
                          FLAGS,
                          single_pass=FLAGS.single_pass,
                          decode_after=FLAGS.decode_after)
    sess = tf.Session(config=get_config())
    sess.run(tf.global_variables_initializer())

    eval_max_reward = -float('inf')
    saver = tf.train.Saver(max_to_keep=10)
    if FLAGS.restore_path:
        print('loading params...')
        saver.restore(sess, FLAGS.restore_path)
    epoch = FLAGS.epoch
    step = 0
    patient = FLAGS.patient
    while epoch > 0:
        batches = batcher.fill_batch_queue()
        print('load batch...')
        for batch in batches:
            print('start training...')
            step += 1
            feed_dict = make_feed_dict(summarizationModel, batch)
            loss, _ = sess.run(
                [summarizationModel.loss, summarizationModel.train_op],
                feed_dict)
            print("epoch : {0}, step : {1}, loss : {2}".format(
                abs(epoch - FLAGS.epoch), step, loss))
            if step % FLAGS.eval_step == 0:
                eval_reward = run_eval(summarizationModel, val_batcher, sess)
                print('eval reward ', eval_reward)
                if eval_max_reward < eval_reward:
                    if not os.path.exists(FLAGS.checkpoint):
                        os.mkdir(FLAGS.checkpoint)
                    saver.save(sess,
                               save_path=os.path.join(
                                   FLAGS.checkpoint,
                                   'model_{0}_{1}.ckpt'.format(
                                       step, eval_reward)))
                    eval_max_reward = eval_reward
                    patient = FLAGS.patient
                print('eval max reward : {0}'.format(eval_max_reward))
                if patient < 0:
                    break

                if eval_max_reward - eval_reward > FLAGS.threshold:
                    patient -= 1
示例#2
0
def decode(test_path, rl):
    sess = tf.Session(config=get_config())
    if FLAGS.beam == True:
        FLAGS.batch_size = FLAGS.beam_size
    FLAGS.max_dec_steps = 1
    print('batch size ', FLAGS.batch_size)
    #if rl == False:
    summarizationModel = PointerNet(FLAGS, vocab)
    #elif rl==True:
    #    if FLAGS.gamma > 0:
    #        import rl_model_gamma
    #        summarizationModel = rl_model_gamma.RLNet(FLAGS, vocab)
    #    else:
    #        import rl_model
    #        summarizationModel = rl_model.RLNet(FLAGS, vocab)
    summarizationModel.build_graph()
    saver = tf.train.Saver()
    best_model = load_best_model(FLAGS.restore_path)
    print('best model : {0}'.format(best_model))
    saver.restore(sess, save_path=best_model)
    counter = 0
    batcher = Batcher(test_path,
                      vocab,
                      FLAGS,
                      single_pass=FLAGS.single_pass,
                      decode_after=FLAGS.decode_after)
    batches = batcher.fill_batch_queue(
        is_training=False)  # 1 example repeated across batch
    print(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))
    if FLAGS.beam == False:
        for batch in batches:
            article = batch.original_articles[0]
            original_abstract_sents = batch.original_abstracts_sents  # list of strings
            #print('*****************start**************')
            best_hyps = beam_search.run_greedy_search(sess, summarizationModel,
                                                      vocab, batch)
            output_ids = [[int(t) for t in best_hyp.tokens[1:]]
                          for best_hyp in best_hyps]
            decoded_words = data.outputids2words_greedy(
                output_ids, vocab,
                (batch.art_oovs[0] if FLAGS.pointer_gen else None))
            decoded_words = remove_stop_index(decoded_words, data)
            write_for_rouge_greedy(
                original_abstract_sents, decoded_words, article, counter,
                FLAGS.dec_path, FLAGS.ref_path, FLAGS.all_path
            )  # write ref summary and decoded summary to file, to eval with pyrouge later
            counter += FLAGS.batch_size  # this is how many examples we've decoded
            print('counter ... ', counter)
            if counter % (5 * 64) == 0:
                print(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))
    else:
        for batch in batches:
            article = batch.original_articles[0]
            original_abstract_sents = batch.original_abstracts_sents[
                0]  # list of strings
            #print('*****************start**************')
            best_hyps = beam_search.run_beam_search(sess, summarizationModel,
                                                    vocab, batch)
            #print('best hyp : {0}'.format(best_hyp))
            output_ids = [int(t) for t in best_hyps.tokens[1:]]
            decoded_words = data.outputids2words_beam(
                output_ids, vocab,
                (batch.art_oovs[0] if FLAGS.pointer_gen else None))
            try:
                fst_stop_idx = decoded_words.index(
                    data.STOP_DECODING)  # index of the (first) [STOP] symbol
                decoded_words = decoded_words[:fst_stop_idx]
            except ValueError:
                decoded_words = decoded_words
            #decoded_words = ' '.join(decoded_words)
            write_for_rouge_beam(
                original_abstract_sents, decoded_words, article, counter,
                FLAGS.dec_path, FLAGS.ref_path, FLAGS.all_path
            )  # write ref summary and decoded summary to file, to eval with pyrouge later
            counter += 1  # this is how many examples we've decoded
            print('counter ... ', counter)
            if counter % 100 == 0:
                print(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))
示例#3
0
def run_rl_training():

    summarizationModel = RLNet(FLAGS, vocab)
    summarizationModel.build_graph()
    batcher = Batcher(FLAGS.data_path,
                      vocab,
                      FLAGS,
                      single_pass=FLAGS.single_pass,
                      decode_after=FLAGS.decode_after)
    val_batcher = Batcher(FLAGS.val_data_path,
                          vocab,
                          FLAGS,
                          single_pass=FLAGS.single_pass,
                          decode_after=FLAGS.decode_after)
    sess = tf.Session(config=get_config())
    saver = tf.train.Saver(max_to_keep=100)
    if FLAGS.restore_rl_path:
        saver.restore(sess, FLAGS.restore_rl_path)
    else:
        sess.run(tf.global_variables_initializer())
        sess.run(
            loading_variable([v for v in tf.trainable_variables()],
                             reader_params(load_best_model(
                                 FLAGS.restore_path))))
    print('loading params...')
    epoch = FLAGS.epoch
    step = 0
    eval_loss = float('inf')
    while epoch > 0:
        batches = batcher.fill_batch_queue()
        for batch in batches:
            step += 1
            sampled_sentence_r_values = []
            greedy_sentence_r_values = []
            feed_dict = make_feed_dict(summarizationModel, batch)

            to_return = {
                'sampled_sentences':
                summarizationModel.sampled_sentences,
                'greedy_search_sentences':
                summarizationModel.greedy_search_sentences
            }
            ret_dict = sess.run(to_return, feed_dict)
            # calculate reward
            for sampled_sentence, greedy_search_sentence, target_sentence in zip(
                    ret_dict['sampled_sentences'],
                    ret_dict['greedy_search_sentences'], batch.target_batch):
                assert len(sampled_sentence[0]) == len(target_sentence) == len(
                    greedy_search_sentence[0])
                reference_sent = ' '.join([str(k) for k in target_sentence])
                sampled_sent = ' '.join([str(k) for k in sampled_sentence[0]])
                sampled_sentence_r_values.append(
                    reward_function(reference_sent, sampled_sent))
                greedy_sent = ' '.join(
                    [str(k) for k in greedy_search_sentence[0]])
                greedy_sentence_r_values.append(
                    reward_function(reference_sent, greedy_sent))

            to_return = {
                'train_op': summarizationModel.train_op,
                'pgen_loss': summarizationModel._pgen_loss,
                'rl_loss': summarizationModel._rl_loss,
                'loss': summarizationModel.loss
            }
            to_return['s_r'] = summarizationModel._sampled_sentence_r_values
            to_return['g_r'] = summarizationModel._greedy_sentence_r_values

            feed_dict[summarizationModel.
                      _sampled_sentence_r_values] = sampled_sentence_r_values
            feed_dict[summarizationModel.
                      _greedy_sentence_r_values] = greedy_sentence_r_values
            feed_dict[summarizationModel._eta] = 0.5
            res = sess.run(to_return, feed_dict)

            print(
                'step : {0},pgen_loss : {1}, rl_loss : {2}, loss : {3}, reward : {4}'
                .format(step, res['pgen_loss'], res['rl_loss'], res['loss'],
                        np.sum(res['s_r'] - res['g_r'])))
            if step % FLAGS.eval_step == 0:
                eval_ = run_rl_eval(summarizationModel, val_batcher, sess, 0.5)
                if eval_ < eval_loss:
                    if not os.path.exists(FLAGS.checkpoint):
                        os.mkdir(FLAGS.checkpoint)
                    saver.save(sess,
                               save_path=os.path.join(
                                   FLAGS.checkpoint,
                                   'model_{0}_{1}.ckpt'.format(step, eval_)))
                    eval_loss = eval_
                    patient = FLAGS.patient
                print('eval loss : ', eval_loss)
                if patient < 0:
                    break

                if eval_ - eval_loss > FLAGS.threshold:
                    patient -= 1
示例#4
0
def run_rl_training_gamma(FLAGS, vocab):
    summarizationModel = RLNet(FLAGS, vocab)
    summarizationModel.build_graph()
    batcher = Batcher(FLAGS.data_path, vocab, FLAGS, single_pass=FLAGS.single_pass, decode_after=FLAGS.decode_after)
    val_batcher = Batcher(FLAGS.val_data_path, vocab, FLAGS, single_pass=FLAGS.single_pass,
                          decode_after=FLAGS.decode_after)
    sess = tf.Session(config=get_config())
    saver = tf.train.Saver(max_to_keep=10)
    if FLAGS.restore_rl_path:
        print('restore rl model...')
        saver.restore(sess, FLAGS.restore_rl_path)
    else:
        sess.run(tf.global_variables_initializer())
        sess.run(
            loading_variable([v for v in tf.trainable_variables()], reader_params(load_best_model(FLAGS.restore_path))))
        print('loading params...')
    epoch = FLAGS.epoch
    step = 0
    patient = FLAGS.patient
    eval_max_reward = -float('inf')
    while epoch > 0:
        batches = batcher.fill_batch_queue()
        for batch in batches:
            step += 1
            sampled_sentence_r_values = []
            greedy_sentence_r_values = []
            feed_dict = make_feed_dict(summarizationModel, batch)

            to_return = {
                'sampled_sentences': summarizationModel.sampled_sentences,

            }
            ret_dict = sess.run(to_return, feed_dict)
            Rs = []
            # calculate reward
            for sampled_sentence, target_sentence in zip(ret_dict['sampled_sentences'],batch.target_batch):
              #  print('sampled : ',sampled_sentence)
              #  print('target : ', target_sentence)
                reward = compute_reward(sampled_sentence[0], target_sentence)
                R = 0
                R_l = []
                for r in reward[::-1]:
                    R = r + FLAGS.gamma * R
                    R_l.insert(0,R)
                #avg = np.mean(R_l)
                #R_l = list(map(lambda a:a-avg, R_l))
                Rs.append(R_l)
            to_return = {
                'train_op': summarizationModel.train_op,
                'pgen_loss': summarizationModel._pgen_loss,
                'rl_loss': summarizationModel._rl_loss,
                'loss': summarizationModel.loss
            }
            to_return['reward'] = summarizationModel._reward

            
            feed_dict[summarizationModel._reward] = Rs
            feed_dict[summarizationModel._eta] = 0.1
            res = sess.run(to_return, feed_dict)

            print('step : {0}, pgen_loss : {1}, rl_loss : {2}, loss : {3}, reward : {4}'.format(step,res['pgen_loss'], res['rl_loss'],
                                                                                    res['loss'],
                                                                                    np.mean(res['reward'],axis=0)[0]
                                                                                    ))
            if step % FLAGS.eval_step == 0:
                #eval_ = run_rl_eval_gramma(summarizationModel, val_batcher, sess, 0.5)
                eval_reward = run_eval(summarizationModel, val_batcher, sess)
                print('eval reward  ', eval_reward)
                if eval_reward > eval_max_reward:
                    if not os.path.exists(FLAGS.checkpoint): os.mkdir(FLAGS.checkpoint)
                    saver.save(sess, save_path=os.path.join(FLAGS.checkpoint, 'model_{0}_{1}.ckpt'.format(step, eval_reward)))
                    eval_max_reward = eval_reward
                    patient = FLAGS.patient
                print('eval max ward ', eval_max_reward)
                if patient < 0:
                    break

                if eval_max_reward - eval_reward > FLAGS.threshold:
                    patient -= 1