Пример #1
0
def reinforce_loss(probs,
                   doc,
                   max_num_of_sents=3,
                   max_num_of_chars=-1,
                   std_rouge=False,
                   rouge_metric="all"):
    # sample sentences
    probs = torch.clamp(probs, 1e-6, 1 - 1e-6)
    probs_numpy = probs.data.cpu().numpy()
    probs_numpy = np.reshape(probs_numpy, len(probs_numpy))
    max_num_of_sents = min(
        len(probs_numpy),
        max_num_of_sents)  # max of sents# in doc and sents# in summary

    rl_baseline_summary_index, _ = return_summary_index(
        probs_numpy, "greedy", max_num_of_sents)
    rl_baseline_reward = from_summary_index_compute_rouge(
        doc,
        rl_baseline_summary_index,
        std_rouge=std_rouge,
        rouge_metric=rouge_metric,
        max_num_of_chars=max_num_of_chars)

    lead3_reward = from_summary_index_compute_rouge(doc,
                                                    range(max_num_of_sents),
                                                    std_rouge=std_rouge,
                                                    rouge_metric=rouge_metric)

    return rl_baseline_reward, lead3_reward
Пример #2
0
 def generate_reward(self, summary_index_list, max_num_of_chars=-1):
     #print("std_rouge:" ,self.std_rouge)
     reward = from_summary_index_compute_rouge(self.doc, summary_index_list,
                                               std_rouge=self.std_rouge,
                                               rouge_metric=self.rouge_metric,
                                               max_num_of_chars=max_num_of_chars)
     return reward
Пример #3
0
 def generate_reward(self, summary_index_list, max_num_of_bytes=-1):
     reward = from_summary_index_compute_rouge(
         self.doc,
         summary_index_list,
         std_rouge=self.std_rouge,
         rouge_metric=self.rouge_metric,
         max_num_of_bytes=max_num_of_bytes)
     return reward
Пример #4
0
 def validate(self, probs, doc, max_num_of_sents=3):
     """
     :return: training_loss_of_the current example
     """
     self.update_data_instance(probs, doc, max_num_of_sents)
     summary_index_list, _ = self.generate_index_list_and_loss("greedy")
     reward_tuple = from_summary_index_compute_rouge(self.doc, summary_index_list,
                                                     std_rouge=self.std_rouge, rouge_metric="all")
     return reward_tuple
Пример #5
0
def extractive_training(args, vocab):
    writer = SummaryWriter('../log')
    print(args)
    print("generating config")
    config = Config(
        vocab_size=vocab.embedding.shape[0],
        embedding_dim=vocab.embedding.shape[1],
        position_size=500,
        position_dim=50,
        word_input_size=100,
        sent_input_size=2 * args.hidden,
        word_GRU_hidden_units=args.hidden,
        sent_GRU_hidden_units=args.hidden,
        pretrained_embedding=vocab.embedding,
        word2id=vocab.w2i,
        id2word=vocab.i2w,
        dropout=args.dropout,
        pooling_way=args.pooling_way,
        num_layers = args.num_layers,
        num_directions = args.num_directions,
        fixed_length=args.fixed_length,
        num_filters=args.num_filters,
        filter_sizes=args.filter_sizes,
        batch_size=args.batch_size,
        novelty=args.novelty,
    )
    model_name = ".".join(("../model/"+str(args.ext_model),
                         "termination_", str(args.terminated_way),
                         "pooling_", str(args.pooling_way),
                         "max_sent", str(args.oracle_length),
                         "min_sents", str(args.min_num_of_sents),
                         "rl_m",str(args.rl_baseline_method), 
                         "oracle_l", str(args.oracle_length),
                         "bsz", str(args.batch_size), 
                         "rl_loss", str(args.rl_loss_method),
                         "hidden", str(args.hidden),
                         "dropout", str(args.dropout),
                         'ext'))
    print(model_name)

    log_name = ".".join(("../log/"+str(args.ext_model),
                         "termination_", str(args.terminated_way),
                         "pooling_", str(args.pooling_way),
                         "max_sent", str(args.oracle_length),
                         "min_sents", str(args.min_num_of_sents),
                         "rl_m",str(args.rl_baseline_method), 
                         "oracle_l", str(args.oracle_length),
                         "bsz", str(args.batch_size), 
                         "rl_loss", str(args.rl_loss_method),
                         "hidden", str(args.hidden),
                         "dropout", str(args.dropout),
                         'log'))

    print("init data loader and RL learner")
    data_loader = PickleReader(args.data_dir)

    # init statistics
    reward_list = []
    best_eval_reward = 0.
    model_save_name = model_name

    if args.fine_tune:
        model_save_name = model_name + ".fine_tune"
        log_name = log_name + ".fine_tune"
        args.std_rouge = True
        print("fine_tune model with std_rouge, args.std_rouge changed to %s" % args.std_rouge)

    print('init extractive model')

    extract_net = model.SHE(config).cuda()
    reinforce = ReinforceReward(terminated_way=args.terminated_way, std_rouge=args.std_rouge, rouge_metric=args.rouge_metric,
                                    b=args.batch_size, rl_baseline_method=args.rl_baseline_method,
                                    loss_method=1)
    extract_net.cuda()


    logging.basicConfig(filename='%s' % log_name,
                        level=logging.INFO, format='%(asctime)s [INFO] %(message)s')
    if args.load_ext:
        print("loading existing model%s" % model_name)
        extract_net = torch.load(model_name, map_location=lambda storage, loc: storage)
        extract_net.cuda()
        print("finish loading and evaluate model %s" % model_name)
        # evaluate.ext_model_eval(extract_net, vocab, args, eval_data="test")
        best_eval_reward, _ = evaluate.ext_model_eval(extract_net, vocab, args, "val")

    # Loss and Optimizer
    optimizer_ext = torch.optim.Adam(extract_net.parameters(), lr=args.lr, betas=(0., 0.999))

    print("starting training")
    n_step = 100
    error_counter = 0

    for epoch in range(args.epochs_ext):
        train_iter = data_loader.chunked_data_reader("train", data_quota=args.train_example_quota)
        step_in_epoch = 0
        for dataset in train_iter:
            # for step, docs in enumerate(BatchDataLoader(dataset, shuffle=True, batch_size=args.batch_size )):
            for step, docs in enumerate(BatchDataLoader(dataset, shuffle=True)):
                try:
                    extract_net.train()
                    # if True:
                    step_in_epoch += 1
                    # for i in range(1):  # how many times a single data gets updated before proceeding
                    doc = docs[0]
                    doc.content = tokens_to_sentences(doc.content)
                    doc.summary = tokens_to_sentences(doc.summary)

                    if len(doc.content) == 0 or len(doc.summary) == 0:
                        continue

                    if len(doc.content) <3:
                        summary_index_list = range(min(len(doc.content),3))
                        loss = 0
                        reward = from_summary_index_compute_rouge(doc, summary_index_list,
                                                            std_rouge=args.std_rouge,
                                                            rouge_metric=args.rouge_metric,
                                                            max_num_of_bytes=args.length_limit)                    
                    
                    
                    
                    else:
                        if args.oracle_length == -1:  # use true oracle length
                            oracle_summary_sent_num = len(doc.summary)
                        else:
                            oracle_summary_sent_num = args.oracle_length

                        x = prepare_data(doc, vocab)
                        if min(x.shape) == 0:
                            continue
                        sents = Variable(torch.from_numpy(x)).cuda()

                        outputs = extract_net(sents)

                        if args.prt_inf and np.random.randint(0, 1000) == 0:
                            prt = True
                        else:
                            prt = False
                        loss, reward = reinforce.train(outputs, doc,
                                                min_num_of_sents=args.min_num_of_sents,
                                                max_num_of_sents=oracle_summary_sent_num,
                                                max_num_of_bytes=args.length_limit,
                                                prt=prt)
                        if prt:
                            print('Probabilities: ', outputs.squeeze().data.cpu().numpy())
                            print('-' * 80)
                    reward_list.append(reward)

                    if isinstance(loss, Variable):
                        loss.backward()

                    if step % 1 == 0:
                        torch.nn.utils.clip_grad_norm(extract_net.parameters(), 1)  # gradient clipping
                        optimizer_ext.step()
                        optimizer_ext.zero_grad()
                    # print('Epoch %d Step %d Reward %.4f'%(epoch,step_in_epoch,reward))
                    logging.info('Epoch %d Step %d Reward %.4f' % (epoch, step_in_epoch, reward))

                except Exception as e:
                    error_counter += 1
                    print(e)

                if (step_in_epoch) % n_step == 0 and step_in_epoch != 0:
                    print('Epoch ' + str(epoch) + ' Step ' + str(step_in_epoch) +
                       ' reward: ' + str(np.mean(reward_list)))
                    print('error_count: ',error_counter)
                    mean_loss = np.mean(reward_list)
                    writer.add_scalar('Train/SHE', mean_loss, step_in_epoch)
                    reward_list = []

                if (step_in_epoch) % 2000 == 0 and step_in_epoch != 0:
                    print("doing evaluation")
                    extract_net.eval()
                    eval_reward, lead3_reward = evaluate.ext_model_eval(extract_net, vocab, args, "val")
                    if eval_reward > best_eval_reward:
                        best_eval_reward = eval_reward
                        print("saving model %s with eval_reward:" % model_save_name, eval_reward, "leadreward",
                              lead3_reward)
                        torch.save(extract_net, model_name)
                    writer.add_scalar('val/SHE', eval_reward, step_in_epoch)
                    f = open('log/learning_curve','a')
                    f.write(str(eval_reward)+'\t'+str(lead3_reward)+'\n')
                    f.close()
                    print('epoch ' + str(epoch) + ' reward in validation: '
                          + str(eval_reward) +  ' lead3: ' + str(lead3_reward))
                    print('Error Counter: ',error_counter)
        

    return extract_net