def reinforce_loss(probs, doc, max_num_of_sents=3, max_num_of_chars=-1, std_rouge=False, rouge_metric="all"): # sample sentences probs = torch.clamp(probs, 1e-6, 1 - 1e-6) probs_numpy = probs.data.cpu().numpy() probs_numpy = np.reshape(probs_numpy, len(probs_numpy)) max_num_of_sents = min( len(probs_numpy), max_num_of_sents) # max of sents# in doc and sents# in summary rl_baseline_summary_index, _ = return_summary_index( probs_numpy, "greedy", max_num_of_sents) rl_baseline_reward = from_summary_index_compute_rouge( doc, rl_baseline_summary_index, std_rouge=std_rouge, rouge_metric=rouge_metric, max_num_of_chars=max_num_of_chars) lead3_reward = from_summary_index_compute_rouge(doc, range(max_num_of_sents), std_rouge=std_rouge, rouge_metric=rouge_metric) return rl_baseline_reward, lead3_reward
def generate_reward(self, summary_index_list, max_num_of_chars=-1): #print("std_rouge:" ,self.std_rouge) reward = from_summary_index_compute_rouge(self.doc, summary_index_list, std_rouge=self.std_rouge, rouge_metric=self.rouge_metric, max_num_of_chars=max_num_of_chars) return reward
def generate_reward(self, summary_index_list, max_num_of_bytes=-1): reward = from_summary_index_compute_rouge( self.doc, summary_index_list, std_rouge=self.std_rouge, rouge_metric=self.rouge_metric, max_num_of_bytes=max_num_of_bytes) return reward
def validate(self, probs, doc, max_num_of_sents=3): """ :return: training_loss_of_the current example """ self.update_data_instance(probs, doc, max_num_of_sents) summary_index_list, _ = self.generate_index_list_and_loss("greedy") reward_tuple = from_summary_index_compute_rouge(self.doc, summary_index_list, std_rouge=self.std_rouge, rouge_metric="all") return reward_tuple
def extractive_training(args, vocab): writer = SummaryWriter('../log') print(args) print("generating config") config = Config( vocab_size=vocab.embedding.shape[0], embedding_dim=vocab.embedding.shape[1], position_size=500, position_dim=50, word_input_size=100, sent_input_size=2 * args.hidden, word_GRU_hidden_units=args.hidden, sent_GRU_hidden_units=args.hidden, pretrained_embedding=vocab.embedding, word2id=vocab.w2i, id2word=vocab.i2w, dropout=args.dropout, pooling_way=args.pooling_way, num_layers = args.num_layers, num_directions = args.num_directions, fixed_length=args.fixed_length, num_filters=args.num_filters, filter_sizes=args.filter_sizes, batch_size=args.batch_size, novelty=args.novelty, ) model_name = ".".join(("../model/"+str(args.ext_model), "termination_", str(args.terminated_way), "pooling_", str(args.pooling_way), "max_sent", str(args.oracle_length), "min_sents", str(args.min_num_of_sents), "rl_m",str(args.rl_baseline_method), "oracle_l", str(args.oracle_length), "bsz", str(args.batch_size), "rl_loss", str(args.rl_loss_method), "hidden", str(args.hidden), "dropout", str(args.dropout), 'ext')) print(model_name) log_name = ".".join(("../log/"+str(args.ext_model), "termination_", str(args.terminated_way), "pooling_", str(args.pooling_way), "max_sent", str(args.oracle_length), "min_sents", str(args.min_num_of_sents), "rl_m",str(args.rl_baseline_method), "oracle_l", str(args.oracle_length), "bsz", str(args.batch_size), "rl_loss", str(args.rl_loss_method), "hidden", str(args.hidden), "dropout", str(args.dropout), 'log')) print("init data loader and RL learner") data_loader = PickleReader(args.data_dir) # init statistics reward_list = [] best_eval_reward = 0. model_save_name = model_name if args.fine_tune: model_save_name = model_name + ".fine_tune" log_name = log_name + ".fine_tune" args.std_rouge = True print("fine_tune model with std_rouge, args.std_rouge changed to %s" % args.std_rouge) print('init extractive model') extract_net = model.SHE(config).cuda() reinforce = ReinforceReward(terminated_way=args.terminated_way, std_rouge=args.std_rouge, rouge_metric=args.rouge_metric, b=args.batch_size, rl_baseline_method=args.rl_baseline_method, loss_method=1) extract_net.cuda() logging.basicConfig(filename='%s' % log_name, level=logging.INFO, format='%(asctime)s [INFO] %(message)s') if args.load_ext: print("loading existing model%s" % model_name) extract_net = torch.load(model_name, map_location=lambda storage, loc: storage) extract_net.cuda() print("finish loading and evaluate model %s" % model_name) # evaluate.ext_model_eval(extract_net, vocab, args, eval_data="test") best_eval_reward, _ = evaluate.ext_model_eval(extract_net, vocab, args, "val") # Loss and Optimizer optimizer_ext = torch.optim.Adam(extract_net.parameters(), lr=args.lr, betas=(0., 0.999)) print("starting training") n_step = 100 error_counter = 0 for epoch in range(args.epochs_ext): train_iter = data_loader.chunked_data_reader("train", data_quota=args.train_example_quota) step_in_epoch = 0 for dataset in train_iter: # for step, docs in enumerate(BatchDataLoader(dataset, shuffle=True, batch_size=args.batch_size )): for step, docs in enumerate(BatchDataLoader(dataset, shuffle=True)): try: extract_net.train() # if True: step_in_epoch += 1 # for i in range(1): # how many times a single data gets updated before proceeding doc = docs[0] doc.content = tokens_to_sentences(doc.content) doc.summary = tokens_to_sentences(doc.summary) if len(doc.content) == 0 or len(doc.summary) == 0: continue if len(doc.content) <3: summary_index_list = range(min(len(doc.content),3)) loss = 0 reward = from_summary_index_compute_rouge(doc, summary_index_list, std_rouge=args.std_rouge, rouge_metric=args.rouge_metric, max_num_of_bytes=args.length_limit) else: if args.oracle_length == -1: # use true oracle length oracle_summary_sent_num = len(doc.summary) else: oracle_summary_sent_num = args.oracle_length x = prepare_data(doc, vocab) if min(x.shape) == 0: continue sents = Variable(torch.from_numpy(x)).cuda() outputs = extract_net(sents) if args.prt_inf and np.random.randint(0, 1000) == 0: prt = True else: prt = False loss, reward = reinforce.train(outputs, doc, min_num_of_sents=args.min_num_of_sents, max_num_of_sents=oracle_summary_sent_num, max_num_of_bytes=args.length_limit, prt=prt) if prt: print('Probabilities: ', outputs.squeeze().data.cpu().numpy()) print('-' * 80) reward_list.append(reward) if isinstance(loss, Variable): loss.backward() if step % 1 == 0: torch.nn.utils.clip_grad_norm(extract_net.parameters(), 1) # gradient clipping optimizer_ext.step() optimizer_ext.zero_grad() # print('Epoch %d Step %d Reward %.4f'%(epoch,step_in_epoch,reward)) logging.info('Epoch %d Step %d Reward %.4f' % (epoch, step_in_epoch, reward)) except Exception as e: error_counter += 1 print(e) if (step_in_epoch) % n_step == 0 and step_in_epoch != 0: print('Epoch ' + str(epoch) + ' Step ' + str(step_in_epoch) + ' reward: ' + str(np.mean(reward_list))) print('error_count: ',error_counter) mean_loss = np.mean(reward_list) writer.add_scalar('Train/SHE', mean_loss, step_in_epoch) reward_list = [] if (step_in_epoch) % 2000 == 0 and step_in_epoch != 0: print("doing evaluation") extract_net.eval() eval_reward, lead3_reward = evaluate.ext_model_eval(extract_net, vocab, args, "val") if eval_reward > best_eval_reward: best_eval_reward = eval_reward print("saving model %s with eval_reward:" % model_save_name, eval_reward, "leadreward", lead3_reward) torch.save(extract_net, model_name) writer.add_scalar('val/SHE', eval_reward, step_in_epoch) f = open('log/learning_curve','a') f.write(str(eval_reward)+'\t'+str(lead3_reward)+'\n') f.close() print('epoch ' + str(epoch) + ' reward in validation: ' + str(eval_reward) + ' lead3: ' + str(lead3_reward)) print('Error Counter: ',error_counter) return extract_net