def run_training(): print('batch size', FLAGS.batch_size) summarizationModel = PointerNet(FLAGS, vocab) summarizationModel.build_graph() batcher = Batcher(FLAGS.data_path, vocab, FLAGS, single_pass=FLAGS.single_pass, decode_after=FLAGS.decode_after) val_batcher = Batcher(FLAGS.val_data_path, vocab, FLAGS, single_pass=FLAGS.single_pass, decode_after=FLAGS.decode_after) sess = tf.Session(config=get_config()) sess.run(tf.global_variables_initializer()) eval_max_reward = -float('inf') saver = tf.train.Saver(max_to_keep=10) if FLAGS.restore_path: print('loading params...') saver.restore(sess, FLAGS.restore_path) epoch = FLAGS.epoch step = 0 patient = FLAGS.patient while epoch > 0: batches = batcher.fill_batch_queue() print('load batch...') for batch in batches: print('start training...') step += 1 feed_dict = make_feed_dict(summarizationModel, batch) loss, _ = sess.run( [summarizationModel.loss, summarizationModel.train_op], feed_dict) print("epoch : {0}, step : {1}, loss : {2}".format( abs(epoch - FLAGS.epoch), step, loss)) if step % FLAGS.eval_step == 0: eval_reward = run_eval(summarizationModel, val_batcher, sess) print('eval reward ', eval_reward) if eval_max_reward < eval_reward: if not os.path.exists(FLAGS.checkpoint): os.mkdir(FLAGS.checkpoint) saver.save(sess, save_path=os.path.join( FLAGS.checkpoint, 'model_{0}_{1}.ckpt'.format( step, eval_reward))) eval_max_reward = eval_reward patient = FLAGS.patient print('eval max reward : {0}'.format(eval_max_reward)) if patient < 0: break if eval_max_reward - eval_reward > FLAGS.threshold: patient -= 1
def decode(test_path, rl): sess = tf.Session(config=get_config()) if FLAGS.beam == True: FLAGS.batch_size = FLAGS.beam_size FLAGS.max_dec_steps = 1 print('batch size ', FLAGS.batch_size) #if rl == False: summarizationModel = PointerNet(FLAGS, vocab) #elif rl==True: # if FLAGS.gamma > 0: # import rl_model_gamma # summarizationModel = rl_model_gamma.RLNet(FLAGS, vocab) # else: # import rl_model # summarizationModel = rl_model.RLNet(FLAGS, vocab) summarizationModel.build_graph() saver = tf.train.Saver() best_model = load_best_model(FLAGS.restore_path) print('best model : {0}'.format(best_model)) saver.restore(sess, save_path=best_model) counter = 0 batcher = Batcher(test_path, vocab, FLAGS, single_pass=FLAGS.single_pass, decode_after=FLAGS.decode_after) batches = batcher.fill_batch_queue( is_training=False) # 1 example repeated across batch print(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())) if FLAGS.beam == False: for batch in batches: article = batch.original_articles[0] original_abstract_sents = batch.original_abstracts_sents # list of strings #print('*****************start**************') best_hyps = beam_search.run_greedy_search(sess, summarizationModel, vocab, batch) output_ids = [[int(t) for t in best_hyp.tokens[1:]] for best_hyp in best_hyps] decoded_words = data.outputids2words_greedy( output_ids, vocab, (batch.art_oovs[0] if FLAGS.pointer_gen else None)) decoded_words = remove_stop_index(decoded_words, data) write_for_rouge_greedy( original_abstract_sents, decoded_words, article, counter, FLAGS.dec_path, FLAGS.ref_path, FLAGS.all_path ) # write ref summary and decoded summary to file, to eval with pyrouge later counter += FLAGS.batch_size # this is how many examples we've decoded print('counter ... ', counter) if counter % (5 * 64) == 0: print(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())) else: for batch in batches: article = batch.original_articles[0] original_abstract_sents = batch.original_abstracts_sents[ 0] # list of strings #print('*****************start**************') best_hyps = beam_search.run_beam_search(sess, summarizationModel, vocab, batch) #print('best hyp : {0}'.format(best_hyp)) output_ids = [int(t) for t in best_hyps.tokens[1:]] decoded_words = data.outputids2words_beam( output_ids, vocab, (batch.art_oovs[0] if FLAGS.pointer_gen else None)) try: fst_stop_idx = decoded_words.index( data.STOP_DECODING) # index of the (first) [STOP] symbol decoded_words = decoded_words[:fst_stop_idx] except ValueError: decoded_words = decoded_words #decoded_words = ' '.join(decoded_words) write_for_rouge_beam( original_abstract_sents, decoded_words, article, counter, FLAGS.dec_path, FLAGS.ref_path, FLAGS.all_path ) # write ref summary and decoded summary to file, to eval with pyrouge later counter += 1 # this is how many examples we've decoded print('counter ... ', counter) if counter % 100 == 0: print(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))
def run_rl_training(): summarizationModel = RLNet(FLAGS, vocab) summarizationModel.build_graph() batcher = Batcher(FLAGS.data_path, vocab, FLAGS, single_pass=FLAGS.single_pass, decode_after=FLAGS.decode_after) val_batcher = Batcher(FLAGS.val_data_path, vocab, FLAGS, single_pass=FLAGS.single_pass, decode_after=FLAGS.decode_after) sess = tf.Session(config=get_config()) saver = tf.train.Saver(max_to_keep=100) if FLAGS.restore_rl_path: saver.restore(sess, FLAGS.restore_rl_path) else: sess.run(tf.global_variables_initializer()) sess.run( loading_variable([v for v in tf.trainable_variables()], reader_params(load_best_model( FLAGS.restore_path)))) print('loading params...') epoch = FLAGS.epoch step = 0 eval_loss = float('inf') while epoch > 0: batches = batcher.fill_batch_queue() for batch in batches: step += 1 sampled_sentence_r_values = [] greedy_sentence_r_values = [] feed_dict = make_feed_dict(summarizationModel, batch) to_return = { 'sampled_sentences': summarizationModel.sampled_sentences, 'greedy_search_sentences': summarizationModel.greedy_search_sentences } ret_dict = sess.run(to_return, feed_dict) # calculate reward for sampled_sentence, greedy_search_sentence, target_sentence in zip( ret_dict['sampled_sentences'], ret_dict['greedy_search_sentences'], batch.target_batch): assert len(sampled_sentence[0]) == len(target_sentence) == len( greedy_search_sentence[0]) reference_sent = ' '.join([str(k) for k in target_sentence]) sampled_sent = ' '.join([str(k) for k in sampled_sentence[0]]) sampled_sentence_r_values.append( reward_function(reference_sent, sampled_sent)) greedy_sent = ' '.join( [str(k) for k in greedy_search_sentence[0]]) greedy_sentence_r_values.append( reward_function(reference_sent, greedy_sent)) to_return = { 'train_op': summarizationModel.train_op, 'pgen_loss': summarizationModel._pgen_loss, 'rl_loss': summarizationModel._rl_loss, 'loss': summarizationModel.loss } to_return['s_r'] = summarizationModel._sampled_sentence_r_values to_return['g_r'] = summarizationModel._greedy_sentence_r_values feed_dict[summarizationModel. _sampled_sentence_r_values] = sampled_sentence_r_values feed_dict[summarizationModel. _greedy_sentence_r_values] = greedy_sentence_r_values feed_dict[summarizationModel._eta] = 0.5 res = sess.run(to_return, feed_dict) print( 'step : {0},pgen_loss : {1}, rl_loss : {2}, loss : {3}, reward : {4}' .format(step, res['pgen_loss'], res['rl_loss'], res['loss'], np.sum(res['s_r'] - res['g_r']))) if step % FLAGS.eval_step == 0: eval_ = run_rl_eval(summarizationModel, val_batcher, sess, 0.5) if eval_ < eval_loss: if not os.path.exists(FLAGS.checkpoint): os.mkdir(FLAGS.checkpoint) saver.save(sess, save_path=os.path.join( FLAGS.checkpoint, 'model_{0}_{1}.ckpt'.format(step, eval_))) eval_loss = eval_ patient = FLAGS.patient print('eval loss : ', eval_loss) if patient < 0: break if eval_ - eval_loss > FLAGS.threshold: patient -= 1
def run_rl_training_gamma(FLAGS, vocab): summarizationModel = RLNet(FLAGS, vocab) summarizationModel.build_graph() batcher = Batcher(FLAGS.data_path, vocab, FLAGS, single_pass=FLAGS.single_pass, decode_after=FLAGS.decode_after) val_batcher = Batcher(FLAGS.val_data_path, vocab, FLAGS, single_pass=FLAGS.single_pass, decode_after=FLAGS.decode_after) sess = tf.Session(config=get_config()) saver = tf.train.Saver(max_to_keep=10) if FLAGS.restore_rl_path: print('restore rl model...') saver.restore(sess, FLAGS.restore_rl_path) else: sess.run(tf.global_variables_initializer()) sess.run( loading_variable([v for v in tf.trainable_variables()], reader_params(load_best_model(FLAGS.restore_path)))) print('loading params...') epoch = FLAGS.epoch step = 0 patient = FLAGS.patient eval_max_reward = -float('inf') while epoch > 0: batches = batcher.fill_batch_queue() for batch in batches: step += 1 sampled_sentence_r_values = [] greedy_sentence_r_values = [] feed_dict = make_feed_dict(summarizationModel, batch) to_return = { 'sampled_sentences': summarizationModel.sampled_sentences, } ret_dict = sess.run(to_return, feed_dict) Rs = [] # calculate reward for sampled_sentence, target_sentence in zip(ret_dict['sampled_sentences'],batch.target_batch): # print('sampled : ',sampled_sentence) # print('target : ', target_sentence) reward = compute_reward(sampled_sentence[0], target_sentence) R = 0 R_l = [] for r in reward[::-1]: R = r + FLAGS.gamma * R R_l.insert(0,R) #avg = np.mean(R_l) #R_l = list(map(lambda a:a-avg, R_l)) Rs.append(R_l) to_return = { 'train_op': summarizationModel.train_op, 'pgen_loss': summarizationModel._pgen_loss, 'rl_loss': summarizationModel._rl_loss, 'loss': summarizationModel.loss } to_return['reward'] = summarizationModel._reward feed_dict[summarizationModel._reward] = Rs feed_dict[summarizationModel._eta] = 0.1 res = sess.run(to_return, feed_dict) print('step : {0}, pgen_loss : {1}, rl_loss : {2}, loss : {3}, reward : {4}'.format(step,res['pgen_loss'], res['rl_loss'], res['loss'], np.mean(res['reward'],axis=0)[0] )) if step % FLAGS.eval_step == 0: #eval_ = run_rl_eval_gramma(summarizationModel, val_batcher, sess, 0.5) eval_reward = run_eval(summarizationModel, val_batcher, sess) print('eval reward ', eval_reward) if eval_reward > eval_max_reward: if not os.path.exists(FLAGS.checkpoint): os.mkdir(FLAGS.checkpoint) saver.save(sess, save_path=os.path.join(FLAGS.checkpoint, 'model_{0}_{1}.ckpt'.format(step, eval_reward))) eval_max_reward = eval_reward patient = FLAGS.patient print('eval max ward ', eval_max_reward) if patient < 0: break if eval_max_reward - eval_reward > FLAGS.threshold: patient -= 1