예제 #1
0
def compute_reward_old(pred_str,
                       pred_sent_list,
                       trg_str,
                       trg_sent_list,
                       reward_type,
                       regularization_factor=0.0,
                       regularization_type=0,
                       entropy=None):

    if regularization_type == 1:
        raise ValueError("Not implemented.")
    elif regularization_type == 2:
        regularization = entropy
    else:
        regularization = 0.0

    if reward_type == 0:
        tmp_reward = 0.3 * compute_rouge_n(
            pred_str, trg_str, n=1, mode='f') + 0.2 * compute_rouge_n(
                pred_str, trg_str, n=2, mode='f') + 0.5 * compute_rouge_l_summ(
                    pred_sent_list, trg_sent_list, mode='f')
    elif reward_type == 1:
        tmp_reward = compute_rouge_l_summ(pred_sent_list,
                                          trg_sent_list,
                                          mode='f')
    else:
        raise ValueError

    # Add the regularization term to the reward only if regularization type != 0
    if regularization_type == 0 or regularization_factor == 0:
        reward = tmp_reward
    else:
        reward = (1 - regularization_factor
                  ) * tmp_reward + regularization_factor * regularization
    return reward
def get_extract_label(art_sents, abs_sents):
    """ greedily match summary sentences to article sentences"""
    extracted = []
    scores = []
    indices = list(range(len(art_sents)))
    for abst in abs_sents:

        rouges = list(map(compute_rouge_l(reference=abst, mode='f'),
                          art_sents))
        rouge1 = list(
            map(compute_rouge_n(reference=abst, n=1, mode='f'), art_sents))
        rouge2 = list(
            map(compute_rouge_n(reference=abst, n=2, mode='f'), art_sents))

        # ext = max(indices, key=lambda i: (rouges[i]))
        ext = max(indices,
                  key=lambda i: (rouges[i] + rouge1[i] + rouge2[i]) / 3)
        indices.remove(ext)
        extracted.append(ext)
        scores.append(rouges[ext])

        if not indices:
            break

    return extracted, scores
예제 #3
0
def xsum_mixed_rouge_reward(pred_str_list, pred_sent_2d_list, trg_str_list,
                            trg_sent_2d_list, batch_size, device):
    #reward = np.zeros(batch_size)
    reward = []
    for idx, (pred_str, pred_sent_list, trg_str, trg_sent_list) in enumerate(
            zip(pred_str_list, pred_sent_2d_list, trg_str_list,
                trg_sent_2d_list)):
        #reward[idx] = 0.2 * compute_rouge_n(pred_str, trg_str, n=1, mode='f') + 0.5 * compute_rouge_n(pred_str, trg_str, n=2, mode='f') + 0.3 * compute_rouge_l_summ(pred_sent_list, trg_sent_list, mode='f')
        reward.append(
            0.2 * compute_rouge_n(pred_str, trg_str, n=1, mode='f') +
            0.5 * compute_rouge_n(pred_str, trg_str, n=2, mode='f') + 0.3 *
            compute_rouge_l_summ(pred_sent_list, trg_sent_list, mode='f'))
    return torch.FloatTensor(reward).to(device)  # tensor: [batch_size]
예제 #4
0
def train(args):
    if not exists(args.path):
        os.makedirs(args.path)

    # make net
    agent, agent_vocab, abstractor, net_args = configure_net(
        args.abs_dir, args.ext_dir, args.cuda)

    # configure training setting
    assert args.stop > 0
    train_params = configure_training(
        'adam', args.lr, args.clip, args.decay, args.batch,
        args.gamma, args.reward, args.stop, 'rouge-1'
    )
    train_batcher, val_batcher = build_batchers(args.batch)
    # TODO different reward
    reward_fn = compute_rouge_l
    stop_reward_fn = compute_rouge_n(n=1)

    # save abstractor binary
    if args.abs_dir is not None:
        abs_ckpt = {}
        abs_ckpt['state_dict'] = load_best_ckpt(args.abs_dir)
        abs_vocab = pkl.load(open(join(args.abs_dir, 'vocab.pkl'), 'rb'))
        abs_dir = join(args.path, 'abstractor')
        os.makedirs(join(abs_dir, 'ckpt'))
        with open(join(abs_dir, 'meta.json'), 'w') as f:
            json.dump(net_args['abstractor'], f, indent=4)
        torch.save(abs_ckpt, join(abs_dir, 'ckpt/ckpt-0-0'))
        with open(join(abs_dir, 'vocab.pkl'), 'wb') as f:
            pkl.dump(abs_vocab, f)
    # save configuration
    meta = {}
    meta['net']           = 'rnn-ext_abs_rl'
    meta['net_args']      = net_args
    meta['train_params']  = train_params
    with open(join(args.path, 'meta.json'), 'w') as f:
        json.dump(meta, f, indent=4)
    with open(join(args.path, 'agent_vocab.pkl'), 'wb') as f:
        pkl.dump(agent_vocab, f)

    # prepare trainer
    grad_fn = get_grad_fn(agent, args.clip)
    optimizer = optim.Adam(agent.parameters(), **train_params['optimizer'][1])
    scheduler = ReduceLROnPlateau(optimizer, 'max', verbose=True,
                                  factor=args.decay, min_lr=0,
                                  patience=args.lr_p)

    pipeline = A2CPipeline(meta['net'], agent, abstractor,
                           train_batcher, val_batcher,
                           optimizer, grad_fn,
                           reward_fn, args.gamma,
                           stop_reward_fn, args.stop)
    trainer = BasicTrainer(pipeline, args.path,
                           args.ckpt_freq, args.patience, scheduler,
                           val_mode='score')

    print('start training with the following hyper-parameters:')
    print(meta)
    trainer.train()
예제 #5
0
def main(args):

    # Define directories of train set and decoded summaries
    DATA_DIR = os.environ['DATA']
    train_data_dir = os.path.join(DATA_DIR, 'train')
    dec_dir = args.decoded_data_dir

    # Get reference files list
    ref_files_df = pd.read_csv(os.path.join(DATA_DIR, args.bucket_file_path))
    ref_files = ref_files_df['filename']

    # Get rouge scores list between reference and generated texts
    rouge_scores_list = []
    for ind, act_file_name in enumerate(ref_files):

        with open(join(train_data_dir, act_file_name)) as f:
            js = json.loads(f.read())
            abstract = '\n'.join(js['abstract'])

        dec_file_name = str(ind) + '.dec'
        with open(join(dec_dir, dec_file_name)) as f:
            generation = f.read()

        rouge_score = compute_rouge_n(generation.split(), abstract.split())
        rouge_scores_list.append(rouge_score)

    df = pd.DataFrame()
    df['filename'] = ref_files
    df['rouge_score'] = rouge_scores_list
    df.to_csv(join('./rouge_score_files', args.rouge_scores_file_name))
예제 #6
0
파일: rl.py 프로젝트: ShawnXiha/fast_abs_rl
def a2c_validate(agent, abstractor, loader):
    agent.eval()
    start = time()
    print('start running validation...', end='')
    avg_reward = 0
    i = 0
    with torch.no_grad():
        for art_batch, abs_batch in loader:
            ext_sents = []
            ext_inds = []
            for raw_arts in art_batch:
                indices = agent(raw_arts)
                ext_inds += [(len(ext_sents), len(indices)-1)]
                ext_sents += [raw_arts[idx.item()]
                              for idx in indices if idx.item() < len(raw_arts)]
            all_summs = abstractor(ext_sents)
            for (j, n), abs_sents in zip(ext_inds, abs_batch):
                summs = all_summs[j:j+n]
                # python ROUGE-1 (not official evaluation)
                avg_reward += compute_rouge_n(list(concat(summs)),
                                              list(concat(abs_sents)), n=1)
                i += 1
    avg_reward /= (i/100)
    print('finished in {}! avg reward: {:.2f}'.format(
        timedelta(seconds=int(time()-start)), avg_reward))
    return {'reward': avg_reward}
예제 #7
0
def a2c_validate(agent, abstractor, loader):
    agent.eval()
    start = time()
    print('start running validation...', end='')
    avg_reward = 0
    i = 0
    with torch.no_grad():
        for art_batch, topic_batch, abs_batch in loader:
            ext_sents = []
            ext_inds = []
            for raw_arts, topic in zip(art_batch, topic_batch):
                indices = agent(raw_arts, topic)
                ext_inds += [(len(ext_sents), len(indices) - 1)]
                ext_sents += [
                    raw_arts[idx.item()] for idx in indices
                    if idx.item() < len(raw_arts)
                ]
            all_summs = abstractor(ext_sents)
            for (j, n), abs_sents in zip(ext_inds, abs_batch):
                summs = all_summs[j:j + n]
                # python ROUGE-1 (not official evaluation)
                avg_reward += compute_rouge_n(list(concat(summs)),
                                              list(concat(abs_sents)),
                                              n=1)
                i += 1
    avg_reward /= (i / 100)
    print('finished in {}! avg reward: {:.2f}'.format(
        timedelta(seconds=int(time() - start)), avg_reward))
    return {'reward': avg_reward}
예제 #8
0
def main(args):

    dec_dir = args.decoded_data_dir
    act_dir = args.actual_data_dir

    decoded_files = os.listdir(dec_dir)
    ref_files = os.listdir(act_dir)
    file_ids = [dec_file.split('.')[0] for dec_file in decoded_files]

    rouge_scores_list = []
    for ind, file_id in enumerate(file_ids):
        act_file_name = '.'.join([file_id, 'json'])
        with open(join(act_dir, act_file_name)) as f:
            js = json.loads(f.read())
            abstract = '\n'.join(js['abstract'])

        dec_file_name = decoded_files[ind]
        with open(join(dec_dir, dec_file_name)) as f:
            generation = f.read()

        rouge_score = compute_rouge_n(generation.split(), abstract.split())
        rouge_scores_list.append(rouge_score)

    df = pd.DataFrame()
    df['file_id'] = file_ids
    df['rouge_score'] = rouge_scores_list
    df.to_csv(join('./rouge_score_files', args.rouge_scores_file_name))
예제 #9
0
def a2c_validate(agent, abstractor, loader):
    agent.eval()
    start = time()
    print('start running validation...', end='')
    avg_reward = 0
    i = 0
    with torch.no_grad():
        for art_batch, abs_batch, extract in loader:
            greedy_inputs = []
            for idx, raw_arts in enumerate(art_batch):
                greedy, sample, log_probs = agent(raw_arts,
                                                  sample_time=1,
                                                  validate=True)
                sample = sample[0]
                log_probs = log_probs[0]
                greedy_sents = [raw_arts[ind] for ind in greedy]
                greedy_sents = [word for sent in greedy_sents for word in sent]
                #print(greedy_sents)
                #greedy_sents = list(concat(greedy_sents))
                greedy_sents = []
                ext_sent = []
                for ids in greedy:
                    if ids < len(raw_arts):
                        if ids == 0:
                            if ext_sent:
                                greedy_sents.append(ext_sent)
                            ext_sent = []
                        else:
                            ext_sent += raw_arts[ids]
                if greedy[-1] != 0 and ext_sent:
                    greedy_sents.append(ext_sent)
                #print(greedy_sents)
                #exit()
                greedy_inputs.append(greedy_sents)
            greedy_abstracts = []
            for abs_src in greedy_inputs:
                with torch.no_grad():
                    greedy_outputs = abstractor(abs_src)
                #greedy_abstract = []
                #for greedy_sents in greedy_outputs:
                #    greedy_sents = sent_tokenize(' '.join(greedy_sents))
                #    greedy_sents = [sent.strip().split(' ') for sent in greedy_sents]
                #    greedy_abstract += greedy_sents
                greedy_abstract = list(concat(greedy_outputs))
                greedy_abstracts.append(greedy_abstract)
            for idx, greedy_sents in enumerate(greedy_abstracts):
                abss = abs_batch[idx]
                bs = compute_rouge_n(greedy_sents, list(concat(abss)))
                avg_reward += bs
                i += 1
                #print(i)
                #print(avg_reward)
                #exit()
    avg_reward /= (i / 100)
    print('finished in {}! avg reward: {:.2f}'.format(
        timedelta(seconds=int(time() - start)), avg_reward))
    return {'reward': avg_reward}
예제 #10
0
def a2c_validate(agent, abstractor, loader):
    agent.eval()
    start = time()
    print('start running validation...', end='')
    avg_reward = 0
    i = 0
    with torch.no_grad():
        for art_batch, abs_batch, sent_batch in loader:
            print(i)
            ext_sents = []
            ext_inds = []
            masks = []
            dirty = []
            for raw_arts, sent_labels in zip(art_batch, sent_batch):
                indices = agent(raw_arts, sent_labels)
                ext_inds += [(len(ext_sents), len(indices) - 1)]
                assert indices[-1][-1].item() == len(raw_arts) + 1
                tmp_stop = indices[-1][-1].item()
                tmp_truncate = tmp_stop - 1
                str_arts = list(map(lambda x: ' '.join(x), raw_arts))
                for idx in indices:
                    t, m = rl_edu_to_sentence(str_arts, idx)
                    if t == []:
                        assert len(idx) == 1
                        id = idx[0].item()
                        if id == tmp_truncate:
                            dirty.append(len(ext_sents))
                            ext_sents.append(label)
                            masks.append(label_mask)
                    else:
                        if idx[-1].item() != tmp_stop:
                            ext_sents.append(t)
                            masks.append(m)
            all_summs = abstractor(ext_sents, masks)
            for d in dirty:
                all_summs[d] = []
            for (j, n), abs_sents in zip(ext_inds, abs_batch):
                summs = all_summs[j:j + n]
                # python ROUGE-1 (not official evaluation)
                avg_reward += compute_rouge_n(list(concat(summs)),
                                              list(concat(abs_sents)),
                                              n=1)
                i += 1
                if i % 100 == 1:
                    print(avg_reward / i, i)
                '''
                with open('./compare/rl/' + str(i - 1) + '.dec', 'w') as f:
                    for s in summs:
                        s = ' '.join(s)
                        f.write(s + '\n')
                '''
            #if i > 1000:
            #    break
    avg_reward /= (i / 100)
    print('finished in {}! avg reward: {:.2f}'.format(
        timedelta(seconds=int(time() - start)), avg_reward))
    return {'reward': avg_reward}
예제 #11
0
def a2c_validate(agent, abstractor, loader):
    agent.eval()
    start = time()
    print('start running validation...', end='')
    avg_reward = 0
    i = 0
    with torch.no_grad():
        for art_batch, abs_batch, ext_batch in loader:
            ext_sents = []
            ext_inds = []
            sent_acts = []
            for raw_arts in art_batch:
                (indices, _), actions = agent(raw_arts)
                ext_inds += [(len(ext_sents), len(indices) - 1)]
                ext_sents += [
                    raw_arts[idx.item()] for idx in indices
                    if idx.item() < len(raw_arts)
                ]

                sent_acts += [
                    actions[j] for j, idx in enumerate(indices)
                    if idx.item() < len(raw_arts)
                ]

            assert len(ext_sents) == len(sent_acts)

            all_summs = []
            need_abs_sents = [
                ext_sents[iters] for iters, act in enumerate(sent_acts)
                if act == 0
            ]
            if len(need_abs_sents) > 0:
                turn_abs_sents = abstractor(need_abs_sents)

            for nums, action in enumerate(sent_acts):
                if action == 0:
                    all_summs += turn_abs_sents.pop(0)
                else:
                    all_summs += ext_sents[nums]

            for (j, n), abs_sents in zip(ext_inds, abs_batch):
                summs = all_summs[j:j + n]
                # python ROUGE-1 (not official evaluation)
                avg_reward += compute_rouge_n(list(concat(summs)),
                                              list(concat(abs_sents)),
                                              n=1)
                i += 1
    avg_reward /= (i / 100)
    print('finished in {}! avg reward: {:.2f}'.format(
        timedelta(seconds=int(time() - start)), avg_reward))
    return {'reward': avg_reward}
예제 #12
0
def get_extract_label(art_sents, abs_sents):
    """ greedily match summary sentences to article sentences"""
    extracted = []
    scores = []
    new_art_sents, composed = my_compose(art_sents)
    indices = list(range(len(new_art_sents)))

    for abst in abs_sents:
        rouges = list(map(compute_rouge_n(reference=abst, mode='f'),
                          new_art_sents))
        ext = max(indices, key=lambda i: rouges[i])
        extracted.append(composed[ext])
        scores.append(rouges[ext])
    #print(extracted, scores)
    return extracted, scores
예제 #13
0
def a2c_validate(agent, abstractor, loader):
    agent.eval()
    start = time()
    print('start running validation...', end='')
    avg_reward = 0
    i = 0
    with torch.no_grad():
        for art_batch, abs_batch, _ in loader:
            #print(art_batch, abs_batch, _)
            ext_sents = []
            ext_inds = []
            new_indices = []
            for raw_arts in art_batch:
                indices = agent(raw_arts)
                extL = len(ext_sents)
                #ext_sents += [raw_arts[idx.item()]
                #              for idx in indices if idx.item() < len(raw_arts)]
                inds_ = []
                ext_sent = []
                for idx in indices:
                    if idx.item() < len(raw_arts):
                        if idx.item() == 0:
                            if ext_sent:
                                ext_sents.append(ext_sent)
                            ext_sent = []
                            inds_.append(1)
                        else:
                            ext_sent += raw_arts[idx.item()]
                if indices[-1].item() != 0 and ext_sent:
                    ext_sents.append(ext_sent)
                    inds_.append(1)
                new_indices.append(inds_)
                indxL = len(new_indices) - 1
                ext_inds += [(extL, indxL)]
            all_summs = abstractor(ext_sents)
            for (j, n), abs_sents in zip(ext_inds, abs_batch):
                summs = all_summs[j:j + n]
                # python ROUGE-1 (not official evaluation)
                avg_reward += compute_rouge_n(list(concat(summs)),
                                              list(concat(abs_sents)),
                                              n=1)
                i += 1
    avg_reward /= (i / 100)
    print('finished in {}! avg reward: {:.2f}'.format(
        timedelta(seconds=int(time() - start)), avg_reward))
    return {'reward': avg_reward}
예제 #14
0
def sc_validate(agent, abstractor, loader, entity=False, bert=False):
    agent.eval()
    start = time()
    print('start running validation...', end='')
    avg_reward = 0
    i = 0
    with torch.no_grad():
        for art_batch, abs_batch in loader:
            greedy_inputs = []
            for idx, raw_arts in enumerate(art_batch):
                greedy, sample, log_probs = agent(raw_arts,
                                                  sample_time=1,
                                                  validate=True)
                if entity or bert:
                    raw_arts = raw_arts[0]
                # sample = sample[0]
                # log_probs = log_probs[0]
                greedy_sents = [raw_arts[ind] for ind in greedy]
                #greedy_sents = list(concat(greedy_sents))
                greedy_sents = [word for sent in greedy_sents for word in sent]
                greedy_inputs.append(greedy_sents)
            with torch.no_grad():
                greedy_outputs = abstractor(greedy_inputs)
            greedy_abstracts = []
            for greedy_sents in greedy_outputs:
                greedy_sents = sent_tokenize(' '.join(greedy_sents))
                greedy_sents = [
                    sent.strip().split(' ') for sent in greedy_sents
                ]
                greedy_abstracts.append(greedy_sents)
            for idx, greedy_sents in enumerate(greedy_abstracts):
                abss = abs_batch[idx]
                bs = compute_rouge_n(list(concat(greedy_sents)),
                                     list(concat(abss)))
                avg_reward += bs
                i += 1
    avg_reward /= (i / 100)
    print('finished in {}! avg reward: {:.2f}'.format(
        timedelta(seconds=int(time() - start)), avg_reward))
    return {'reward': avg_reward}
예제 #15
0
파일: rl.py 프로젝트: aniket03/fast_abs_rl
def a2c_validate(agent, abstractor, loader):
    agent.eval()
    start = time()
    print('start running validation...', end='')
    avg_reward = 0
    i = 0

    # IF using BERTScore
    # reward_fn = compute_bertscore_wo_baseline_rescaling
    # reward_fn.metric = datasets.load_metric('bertscore')

    with torch.no_grad():
        for art_batch, abs_batch in loader:
            ext_sents = []
            ext_inds = []
            for raw_arts in art_batch:
                indices = agent(raw_arts)
                ext_inds += [(len(ext_sents), len(indices) - 1)]
                ext_sents += [
                    raw_arts[idx.item()] for idx in indices
                    if idx.item() < len(raw_arts)
                ]
            all_summs = abstractor(ext_sents)
            for (j, n), abs_sents in zip(ext_inds, abs_batch):
                summs = all_summs[j:j + n]
                # python ROUGE-1 (not official evaluation)
                avg_reward += compute_rouge_n(list(concat(summs)),
                                              list(concat(abs_sents)),
                                              n=1)
                # IF using BERT score
                # avg_reward += reward_fn(' '.join(concat(summs)), ' '.join(concat(abs_sents)))
                i += 1
    avg_reward /= (i / 100)

    # IF using BERTScore
    # del reward_fn

    print('finished in {}! avg reward: {:.2f}'.format(
        timedelta(seconds=int(time() - start)), avg_reward))
    return {'reward': avg_reward}
예제 #16
0
    def coll(tokenizer, batch):
        blank = '[unused0]'
        questions, context, _ids, abstract = list(
            filter(bool, list(zip(*batch))))
        system = context
        # print('q:', questions)
        # print('c:', context)
        if len(abstract) > 0:
            rouges = {
                'RLr':
                compute_rouge_l_summ(
                    [_c.lower().split(' ') for _c in context[0]],
                    abstract[0],
                    mode='r'),
                'RLf':
                compute_rouge_l_summ(
                    [_c.lower().split(' ') for _c in context[0]],
                    abstract[0],
                    mode='f'),
                'R1':
                compute_rouge_n(' '.join(context[0]).split(' '),
                                list(concat(abstract[0])),
                                mode='f',
                                n=1),
                'R2':
                compute_rouge_n(' '.join(context[0]).split(' '),
                                list(concat(abstract[0])),
                                mode='f',
                                n=2)
            }
        if len(questions[0]) != 0:
            questions = questions[0]
            context = context[0]
            context = ' '.join(context)

            choicess = [[
                question['answer'], question['choice1'], question['choice2'],
                question['choice3']
            ] for question in questions]
            questions = [
                question['question'].replace('<\\blank>', blank)
                for question in questions
            ]
            questions = [[
                tokenizer.tokenize(qp.lower()) for qp in question.split(blank)
            ] for question in questions]
            new_questions = []
            for question in questions:
                new_q = ['[CLS]']
                for q in question:
                    new_q += q + [blank]
                new_q.pop()
                new_questions.append(new_q)
            questions = new_questions
            contexts = [['[SEP]'] + tokenizer.tokenize(context.lower())
                        for _ in range(len(questions))]
            choicess = [[[tokenizer.tokenize(c.lower()) for c in choice]
                         for choice in choices] for choices in choicess]
            choicess = [[['[SEP]'] + choice[0] + ['[SEP]'] +
                         choice[1] if len(choice) == 2 else ['[SEP]'] +
                         choice[0] for choice in choices]
                        for choices in choicess]
            _inputs = [[
                tokenizer.convert_tokens_to_ids(
                    (question + context + choice)[:MAX_LEN])
                for choice in choices
            ]
                       for question, context, choices in zip(
                           questions, contexts, choicess)]
            _inputs = pad_batch_tensorize_3d(_inputs, pad=0, cuda=False)
        else:
            _inputs = []

        return (_inputs, rouges, system, abstract)
예제 #17
0
def a2c_train_step(agent,
                   target_agent,
                   abstractor,
                   loader,
                   opt,
                   grad_fn,
                   gamma=0.99,
                   reward_fn=compute_rouge_l,
                   stop_reward_fn=compute_rouge_n(n=1),
                   stop_coeff=1.0):
    opt.zero_grad()

    def length_penalty(l, eps=0.7):
        return (5 + len(l) - 1)**eps / (5 + 1)**eps

    indices, probs, baselines = [], [], []
    target_indices, target_probs, target_baselines = [], [], []
    act, act_probs, act_baselines = [], [], []
    pure_act = []
    ext_sents = []  # 所有句子
    abstract = []
    pres_or_rewr = []
    act_precision = []
    target_probs = []

    true_indices, false_indices = [], []
    recall, precision = [], []
    art_batch, abs_batch, ext_indices = next(loader)

    for n, (raw_arts, raw_abs,
            ext_index) in enumerate(zip(art_batch, abs_batch, ext_indices)):

        with torch.no_grad():
            (target_inds, target_ms), target_bs, (
                target_act_inds,
                target_act_ms), target_act_bs = target_agent(raw_arts)

        (inds, ms), bs, (act_inds,
                         act_ms), act_bs = agent(raw_arts,
                                                 observations=target_inds)

        baselines.append(bs)
        indices.append(inds)
        probs.append(ms)

        target_probs.append(target_ms)

        act.append(act_inds)
        act_probs.append(act_ms)
        act_baselines.append(act_bs)
        pure_act += [a.detach() for a in act_inds[:-1]]

        def precision_recall():
            """ sentence-level precision/recall """
            true_positive_extracts = []
            false_extracts = []
            # acc_pre = []
            for k, index in enumerate(inds[:-1]):
                if index.item() in ext_index:
                    true_positive_extracts.append(index.detach().item())
                else:
                    false_extracts.append(index.detach().item())

            each_recall = (len(true_positive_extracts)) / (len(ext_index))
            each_precision = (len(true_positive_extracts)) / (
                len(inds) - 1) if len(inds) > 1 else 0
            recall.append(each_recall)
            precision.append(each_precision)
            true_indices.append(true_positive_extracts)
            false_indices.append(false_extracts)

        precision_recall()
        """ 存取原句 """
        current_sentence = [
            raw_arts[idx.item()] for idx in inds if idx.item() < len(raw_arts)
        ]
        ext_sents += current_sentence

    assert len(indices) == len(act)
    assert len(probs) == len(act_probs)
    assert len(baselines) == len(act_baselines)
    """ 呼叫改寫模型 """
    with torch.no_grad():
        rewrite_sentences = abstractor(ext_sents)
    """ 經過abs agent後,重組句子 """
    results = [
        rewrite_sentences[n] if bins.item() == 0 else ext_sents[n]
        for n, bins in enumerate(pure_act)
    ]
    """ 原本選句子該得到的 reward 計算  """
    i = 0
    rewards = []
    avg_reward = 0
    avg_reward_abs_rew = 0
    avg_reward_abs_pre = 0
    rewrite_rewards, preserve_rewards = [], []

    for inds, act_inds, abss, labeled in zip(indices, act, abs_batch,
                                             ext_indices):
        """ 全部都做 總得分 """
        """ 原先extractor agent取的原始句子 """
        rs = ([
            reward_fn(ext_sents[i + j], abss[j])
            for j in range(min(len(inds) - 1, len(abss)))
        ] + [0 for _ in range(max(0,
                                  len(inds) - 1 - len(abss)))] +
              [
                  stop_coeff *
                  stop_reward_fn(list(concat(ext_sents[i:i + len(inds) - 1])),
                                 list(concat(abss)))
              ])
        # rs = ([reward_fn(sum(ext_sents[i:i+j+1],[]), sum(abss[:j+1],[])) for j in range(min(len(inds)-1, len(abss)))]
        #       + [0 for _ in range(max(0, len(inds)-1-len(abss)))]
        #       + [stop_coeff*stop_reward_fn(
        #           list(concat(ext_sents[i:i+len(inds)-1])),
        #           list(concat(abss)))])
        """ 單句 """
        def single_compared_MC_TD(choice):

            # before_rewrite = ([stop_reward_fn(ext_sents[i+j], abss[j])
            #                     for j in range(min(len(inds)-1, len(abss)))]
            #                 + [0 for _ in range(max(0, len(inds)-1-len(abss)))]
            #                 + [0])
            # after_rewrite = ([stop_reward_fn(rewrite_sentences[i+j], abss[j])
            #                     for j in range(min(len(inds)-1, len(abss)))]
            #                 + [1 for _ in range(max(0, len(inds)-1-len(abss)))]
            #                 + [0])
            remain_sents = min(len(inds) - 1, len(abss))
            if choice == 'TD':
                previous_results = [[]] + results[i:i + remain_sents]
                """  abstractor agent 做完 rewrite 後 能得到的分數 """
                rew_rs = (
                    [
                        reward_fn(results[i + j], abss[j]) -
                        reward_fn(previous_results[j], abss[j])
                        if act_inds[j].item() == 0 else 0
                        for j in range(min(len(inds) - 1, len(abss)))
                    ]
                    # + [reward_fn(sum(results[i:i+remain_sents+j+1],[]), sum(abss[:-1],[])) - reward_fn(sum(results[i:i+remain_sents+j],[]), sum(abss[:-1],[]))
                    #     if act_inds[j].item()==0 else 0 for j in range(max(0, len(inds)-1-len(abss)))])
                    + [0 for _ in range(max(0,
                                            len(inds) - 1 - len(abss)))])
                # + [stop_coeff*stop_reward_fn(
                # list(concat(
                # [results[x] for x in range(i, i+len(inds)-1)
                # if act_inds[x-i].item()==0])),
                # list(concat(abss)))])
                rew_rs += [sum(rew_rs)]
                """  abstractor agent 做完 preserve 後 能得到的分數 """
                pres_rs = (
                    [
                        reward_fn(results[i + j], abss[j]) -
                        reward_fn(previous_results[j], abss[j])
                        if act_inds[j].item() == 1 else 0
                        for j in range(min(len(inds) - 1, len(abss)))
                    ]
                    # + [reward_fn(sum(results[i:i+remain_sents+j+1],[]), sum(abss[:-1],[])) - reward_fn(sum(results[i:i+remain_sents+j],[]), sum(abss[:-1],[]))
                    # if act_inds[j].item()==1 else 0 for j in range(max(0, len(inds)-1-len(abss)))])
                    + [0 for _ in range(max(0,
                                            len(inds) - 1 - len(abss)))])
                # + [stop_coeff*stop_reward_fn(
                # list(concat(
                # [results[x] for x in range(i, i+len(inds)-1)
                # if act_inds[x-i].item()==1])),
                # list(concat(abss)))])
                pres_rs += [sum(pres_rs)]

            elif choice == 'MC':
                """  abstractor agent 做完 rewrite 後 能得到的分數 """
                rew_rs = (
                    [
                        reward_fn(results[i + j], abss[j])
                        if act_inds[j].item() == 0 else 0
                        for j in range(min(len(inds) - 1, len(abss)))
                    ]
                    # + [reward_fn(results[i+remain_sents+j], sum(abss[:-1],[]))
                    # if act_inds[j].item()==0 else 0 for j in range(max(0, len(inds)-1-len(abss)))])
                    + [0 for _ in range(max(0,
                                            len(inds) - 1 - len(abss)))])
                # + [0 for _ in range(max(0, len(inds)-1-len(abss)))])
                rew_rs += [sum(rew_rs)]
                """  abstractor agent 做完 preserve 後 能得到的分數 """
                pres_rs = (
                    [
                        reward_fn(results[i + j], abss[j])
                        if act_inds[j].item() == 1 else 0
                        for j in range(min(len(inds) - 1, len(abss)))
                    ]
                    # + [reward_fn(results[i+remain_sents+j], sum(abss[:-1],[]))
                    # if act_inds[j].item()==1 else 0 for j in range(max(0, len(inds)-1-len(abss)))])
                    + [0 for _ in range(max(0,
                                            len(inds) - 1 - len(abss)))])
                pres_rs += [sum(pres_rs)]
            return rew_rs, pres_rs

        rew_rs, pres_rs = single_compared_MC_TD(choice='TD')
        """ 累計  """

        def accumulated_compared_MC_TD(choice):
            remain_sents = min(len(inds) - 1, len(abss))
            if choice == 'MC':
                """  abstractor agent 做完 rewrite 後 能得到的分數 """
                rew_rs = ([
                    reward_fn(results[i + j], sum(abss[:j + 1], []))
                    if act_inds[j].item() == 0 else 0
                    for j in range(min(len(inds) - 1, len(abss)))
                ] + [
                    reward_fn(results[i + remain_sents + j], sum(
                        abss[:-1], [])) if act_inds[j].item() == 0 else 0
                    for j in range(max(0,
                                       len(inds) - 1 - len(abss)))
                ])
                # + [0 for _ in range(max(0, len(inds)-1-len(abss)))])
                rew_rs += [sum(rew_rs)]
                """  abstractor agent 做完 preserve 後 能得到的分數 """
                pres_rs = ([
                    reward_fn(results[i + j], sum(abss[:j + 1], []))
                    if act_inds[j].item() == 1 else 0
                    for j in range(min(len(inds) - 1, len(abss)))
                ] + [
                    reward_fn(results[i + remain_sents + j], sum(
                        abss[:-1], [])) if act_inds[j].item() == 1 else 0
                    for j in range(max(0,
                                       len(inds) - 1 - len(abss)))
                ])
                #  + [0 for _ in range(max(0, len(inds)-1-len(abss)))])
                pres_rs += [sum(pres_rs)]
            elif choice == 'TD':
                """  abstractor agent 做完 rewrite 後 能得到的分數 """
                rew_rs = ([
                    reward_fn(sum(results[i:i + j +
                                          1], []), sum(abss[:j + 1], [])) -
                    reward_fn(sum(results[i:i + j], []), sum(abss[:j], []))
                    if act_inds[j].item() == 0 else 0
                    for j in range(min(len(inds) - 1, len(abss)))
                ] + [
                    reward_fn(sum(results[i:i + remain_sents + j +
                                          1], []), sum(abss[:-1], [])) -
                    reward_fn(sum(results[i:i + remain_sents +
                                          j], []), sum(abss[:-1], []))
                    if act_inds[j].item() == 0 else 0
                    for j in range(max(0,
                                       len(inds) - 1 - len(abss)))
                ])
                # + [0 for _ in range(max(0, len(inds)-1-len(abss)))])
                rew_rs += [sum(rew_rs)]
                """  abstractor agent 做完 preserve 後 能得到的分數 """
                pres_rs = ([
                    reward_fn(sum(results[i:i + j +
                                          1], []), sum(abss[:j + 1], [])) -
                    reward_fn(sum(results[i:i + j], []), sum(abss[:j], []))
                    if act_inds[j].item() == 1 else 0
                    for j in range(min(len(inds) - 1, len(abss)))
                ] + [
                    reward_fn(sum(results[i:i + remain_sents + j +
                                          1], []), sum(abss[:-1], [])) -
                    reward_fn(sum(results[i:i + remain_sents +
                                          j], []), sum(abss[:-1], []))
                    if act_inds[j].item() == 1 else 0
                    for j in range(max(0,
                                       len(inds) - 1 - len(abss)))
                ])
                pres_rs += [sum(pres_rs)]
            return rew_rs, pres_rs

        # rew_rs, pres_rs = accumulated_compared_MC_TD(choice='MC')

        assert len(rs) == len(inds)
        avg_reward += rs[-1] / stop_coeff
        avg_reward_abs_rew += rew_rs[-1] / stop_coeff
        avg_reward_abs_pre += pres_rs[-1] / stop_coeff
        i += len(inds) - 1

        # compute discounted rewards
        R = 0
        disc_rs = []
        for r in rs[::-1]:
            R = r + gamma * R
            disc_rs.insert(0, R)
        rewards += disc_rs
        rewrite_rewards += rew_rs
        preserve_rewards += pres_rs

    indices = list(concat(indices))
    probs = list(concat(probs))
    baselines = list(concat(baselines))

    target_probs = list(concat(target_probs))

    act = list(concat(act))
    act_probs = list(concat(act_probs))
    act_baselines = list(concat(act_baselines))

    # act_precision = list(concat(act_precision))
    # act_precision = torch.Tensor(act_precision).to(act_baselines[0].device)
    """ 三個隱動作 reward 計算 """
    rewrite_rewards = torch.Tensor(rewrite_rewards).to(act_baselines[0].device)
    rewrite_rewards = (rewrite_rewards - rewrite_rewards.mean()) / (
        rewrite_rewards.std() + float(np.finfo(np.float32).eps))

    preserve_rewards = torch.Tensor(preserve_rewards).to(
        act_baselines[0].device)
    preserve_rewards = (preserve_rewards - preserve_rewards.mean()) / (
        preserve_rewards.std() + float(np.finfo(np.float32).eps))

    complex_rewards = torch.stack([rewrite_rewards, preserve_rewards], dim=-1)

    # standardize rewards
    reward = torch.Tensor(rewards).to(baselines[0].device)
    reward = (reward - reward.mean()) / (reward.std() +
                                         float(np.finfo(np.float32).eps))

    baseline = torch.cat(baselines).squeeze()
    act_baselines = torch.cat(act_baselines).squeeze().view(-1, 2)

    assert len(indices) == len(probs) == len(reward) == len(baseline)
    assert len(act) == len(act_probs) == len(act_baselines)

    avg_advantage = 0
    losses = []
    entropy_target = []
    entropy_value = []
    ratios = []

    for action, p, tp, r, b in zip(indices, probs, target_probs, reward,
                                   baseline):

        # ratio = torch.exp(p.log_prob(action) - tp.log_prob(action))
        # ratios.append(ratio)
        entropy_target.append(tp.entropy().detach().item())
        entropy_value.append(p.entropy().detach().item())

        advantage = r - b
        avg_advantage += advantage
        current_avg_advantage = advantage / len(indices)
        losses.append(-p.log_prob(action) * current_avg_advantage)

    abs_losses_rew = []
    abs_losses_pre = []

    for action, p, r, b in zip(act, act_probs, complex_rewards, act_baselines):
        advantage = r - b
        current_avg_advantage = advantage / len(act)
        # Rewrite
        abs_losses_rew.append(-p.log_prob(action) *
                              current_avg_advantage[0])  # divide by T*B
        # Preserve
        abs_losses_pre.append(-p.log_prob(action) *
                              current_avg_advantage[1])  # divide by T*B

    critic_loss = F.mse_loss(baseline, reward)
    critic_loss_re = F.mse_loss(act_baselines, complex_rewards)

    # backprop and update
    autograd.backward(tensors=[critic_loss] + losses,
                      grad_tensors=[torch.ones(1).to(critic_loss.device)] *
                      (1 + len(losses)),
                      retain_graph=True)

    autograd.backward(tensors=[critic_loss_re] + abs_losses_rew +
                      abs_losses_pre,
                      grad_tensors=[torch.ones(1).to(critic_loss.device)] *
                      (1 + len(abs_losses_rew) + len(abs_losses_pre)))

    opt.step()
    target_agent.load_state_dict(agent.state_dict())
    ## clear cache
    torch.cuda.empty_cache()

    grad_log = grad_fn()
    log_dict = {}
    log_dict.update(grad_log)
    log_dict['entropy'] = sum(entropy_value) / len(entropy_value)
    log_dict['entropy_target'] = sum(entropy_target) / len(entropy_target)
    # log_dict['ratios'] = sum(ratios)/len(ratios)

    log_dict['reward'] = avg_reward / len(art_batch)
    log_dict['abs_reward_rew'] = avg_reward_abs_rew / len(art_batch)
    log_dict['abs_reward_pre'] = avg_reward_abs_pre / len(art_batch)

    log_dict['advantage'] = avg_advantage.item() / len(indices)
    log_dict['act_mse'] = critic_loss_re.item()
    log_dict['recall'] = sum(recall) / len(recall)
    log_dict['precision'] = sum(precision) / len(precision)

    assert not math.isnan(log_dict['grad_norm'])

    return log_dict
예제 #18
0
파일: rl.py 프로젝트: ShawnXiha/fast_abs_rl
def a2c_train_step(agent, abstractor, loader, opt, grad_fn,
                   gamma=0.99, reward_fn=compute_rouge_l,
                   stop_reward_fn=compute_rouge_n(n=1), stop_coeff=1.0):
    opt.zero_grad()
    indices = []
    probs = []
    baselines = []
    ext_sents = []
    art_batch, abs_batch = next(loader)
    for raw_arts in art_batch:
        (inds, ms), bs = agent(raw_arts)
        baselines.append(bs)
        indices.append(inds)
        probs.append(ms)
        ext_sents += [raw_arts[idx.item()]
                      for idx in inds if idx.item() < len(raw_arts)]
    with torch.no_grad():
        summaries = abstractor(ext_sents)
    i = 0
    rewards = []
    avg_reward = 0
    for inds, abss in zip(indices, abs_batch):
        rs = ([reward_fn(summaries[i+j], abss[j])
              for j in range(min(len(inds)-1, len(abss)))]
              + [0 for _ in range(max(0, len(inds)-1-len(abss)))]
              + [stop_coeff*stop_reward_fn(
                  list(concat(summaries[i:i+len(inds)-1])),
                  list(concat(abss)))])
        assert len(rs) == len(inds)
        avg_reward += rs[-1]/stop_coeff
        i += len(inds)-1
        # compute discounted rewards
        R = 0
        disc_rs = []
        for r in rs[::-1]:
            R = r + gamma * R
            disc_rs.insert(0, R)
        rewards += disc_rs
    indices = list(concat(indices))
    probs = list(concat(probs))
    baselines = list(concat(baselines))
    # standardize rewards
    reward = torch.Tensor(rewards).to(baselines[0].get_device())
    reward = (reward - reward.mean()) / (
        reward.std() + float(np.finfo(np.float32).eps))
    baseline = torch.cat(baselines).squeeze()
    avg_advantage = 0
    losses = []
    for action, p, r, b in zip(indices, probs, reward, baseline):
        advantage = r - b
        avg_advantage += advantage
        losses.append(-p.log_prob(action)
                      * (advantage/len(indices))) # divide by T*B
    critic_loss = F.mse_loss(baseline, reward)
    # backprop and update
    autograd.backward(
        [critic_loss] + losses,
        [torch.ones(1).to(critic_loss.get_device())]*(1+len(losses))
    )
    grad_log = grad_fn()
    opt.step()
    log_dict = {}
    log_dict.update(grad_log)
    log_dict['reward'] = avg_reward/len(art_batch)
    log_dict['advantage'] = avg_advantage.item()/len(indices)
    log_dict['mse'] = critic_loss.item()
    assert not math.isnan(log_dict['grad_norm'])
    return log_dict
예제 #19
0
def a2c_train_step(agent,
                   abstractor,
                   loader,
                   opt,
                   grad_fn,
                   gamma=0.99,
                   reward_fn=compute_rouge_l,
                   stop_reward_fn=compute_rouge_n(n=1),
                   stop_coeff=1.0):
    opt.zero_grad()
    indices = []
    probs = []
    baselines = []
    ext_sents = []
    masks = []
    art_batch, abs_batch, sent_label_batch = next(loader)
    leng = []
    avg_leng = []
    dirty = []
    time1 = time()
    for raw_arts, sent_labels in zip(art_batch, sent_label_batch):
        (inds, ms), bs = agent(raw_arts, sent_labels)
        assert inds[-1][-1].item() == len(raw_arts) + 1
        baselines.append(bs)
        indices.append(inds)
        probs.append(ms)
        try:
            avg_leng.append(
                sum(list(map(lambda x: len(x) - 1, inds[:-1]))) /
                (len(inds) - 1))
        except:
            pass
        leng.append(len(inds) - 1)
        tmp_stop = inds[-1][-1].item()
        tmp_truncate = tmp_stop - 1
        str_arts = list(map(lambda x: ' '.join(x), raw_arts))
        for idx in inds:
            t, m = rl_edu_to_sentence(str_arts, idx)
            assert len(t) == len(m)
            if t == []:
                assert len(idx) == 1
                id = idx[0].item()
                if id == tmp_truncate:
                    dirty.append(len(ext_sents))
                    ext_sents.append(label)
                    masks.append(label_mask)
            else:
                if idx[-1].item() != tmp_stop:
                    ext_sents.append(t)
                    masks.append(m)
    print('长度:', leng)
    leng = list(map(lambda x: x[0] - len(x[1]), zip(leng, abs_batch)))
    avg_dis = sum(leng) / len(leng)
    print('平均差距:', avg_dis)
    print('平均edu数量:', sum(avg_leng) / len(avg_leng))
    time2 = time()
    with torch.no_grad():
        summaries = abstractor(ext_sents, masks)
        for d in dirty:
            summaries[d] = []
    time3 = time()
    i = 0
    rewards = []
    avg_reward = 0
    for inds, abss in zip(indices, abs_batch):
        rs_abs = []
        rs_len = []
        abs_num = min(len(inds) - 1, len(abss))
        for j in range(abs_num):
            rs_abs.append(reward_fn(summaries[i + j], abss[j]))
            rs_len.append(len(inds[j]))
        rs_zero = []
        for j in range(max(0, len(inds) - 1 - len(abss))):
            rs_zero += [0] * len(inds[j + abs_num])
        rs_zero += [0] * (len(inds[-1]) - 1)
        rs_final = stop_coeff * stop_reward_fn(
            list(concat(summaries[i:i + len(inds) - 1])), list(concat(abss)))
        avg_reward += rs_final / stop_coeff
        i += len(inds) - 1
        disc_rs = [rs_final]
        R = rs_final
        for _ in rs_zero:
            R = R * gamma
            disc_rs.append(R)
        for r, leng in zip(rs_abs, rs_len):
            R = r + R * gamma
            disc_rs += [R] * leng
        disc_rs = list(reversed(disc_rs))
        assert len(disc_rs) == sum(list(map(lambda x: len(x), inds)))
        rewards += disc_rs

    indices = list(concat(list(concat(indices))))
    probs = list(concat(probs))
    baselines = list(concat(baselines))
    # standardize rewards
    reward = torch.Tensor(rewards).to(baselines[0].device)
    assert len(reward) == len(probs) and len(baselines) == len(probs) and len(
        baselines) == len(indices)

    reward = (reward - reward.mean()) / (reward.std() +
                                         float(np.finfo(np.float32).eps))

    baseline = torch.cat(baselines).squeeze()
    avg_advantage = 0
    losses = []
    for action, p, r, b in zip(indices, probs, reward, baseline):
        advantage = r - b
        avg_advantage += advantage
        losses.append(-(p.log_prob(action)) *
                      (advantage / len(indices)))  # divide by T*B

    critic_loss = F.mse_loss(baseline, reward)
    time4 = time()
    # backprop and update
    autograd.backward([critic_loss] + losses,
                      [torch.ones(1).to(critic_loss.device)] *
                      (1 + len(losses)))
    grad_log = grad_fn()
    opt.step()
    log_dict = {}
    log_dict.update(grad_log)
    log_dict['reward'] = avg_reward / len(art_batch)
    #print(avg_reward)
    log_dict['advantage'] = avg_advantage.item() / len(indices)
    log_dict['mse'] = critic_loss.item()
    assert not math.isnan(log_dict['grad_norm'])
    time5 = time()
    print(time2 - time1, time3 - time2, time4 - time3, time5 - time4)
    return log_dict
def decode(save_path, model_dir, split, batch_size,
           beam_size, diverse, max_len, cuda):
    start = time()
    # setup model
    with open(join(model_dir, 'meta.json')) as f:
        meta = json.loads(f.read())
    if meta['net_args']['abstractor'] is None:
        # NOTE: if no abstractor is provided then
        #       the whole model would be extractive summarization
        assert beam_size == 1
        abstractor = identity
    else:
        if beam_size == 1:
            abstractor = Abstractor(join(model_dir, 'abstractor'),
                                    max_len, cuda)
        else:
            abstractor = BeamAbstractor(join(model_dir, 'abstractor'),
                                        max_len, cuda)
    
    extractor = RLExtractor(model_dir, cuda=cuda)

    # setup loader
    def coll(batch):
        articles, abstract, extracted = unzip(batch)
        articles = list(filter(bool, articles))
        abstract = list(filter(bool, abstract))
        extracted =  list(filter(bool, extracted))
        return articles, abstract, extracted

    dataset = DecodeDataset(split)
    n_data = len(dataset[0]) # article sentence
    loader = DataLoader(
        dataset, batch_size=batch_size, shuffle=False, num_workers=4,
        collate_fn=coll
    )
    # prepare save paths and logs
    if os.path.exists(join(save_path, 'output')):
        pass
    else:
        os.makedirs(join(save_path, 'output'))
    dec_log = {}
    dec_log['abstractor'] = meta['net_args']['abstractor']
    dec_log['extractor'] = meta['net_args']['extractor']
    dec_log['rl'] = True
    dec_log['split'] = split
    dec_log['beam'] = beam_size
    dec_log['diverse'] = diverse

    with open(join(save_path, 'log.json'), 'w') as f:
        json.dump(dec_log, f, indent=4)
    
    file_path = os.path.join(save_path, 'Attention')
    act_path = os.path.join(save_path, 'Actions')

    header = "index,rouge_score1,rouge_score2,"+\
    "rouge_scorel,dec_sent_nums,abs_sent_nums,doc_sent_nums,doc_words_nums,"+\
    "ext_words_nums, abs_words_nums, diff,"+\
    "recall, precision, less_rewrite, preserve_action, rewrite_action, each_actions,"+\
    "top3AsAns, top3AsGold, any_top2AsAns, any_top2AsGold,true_rewrite,true_preserve\n"


    if not os.path.exists(file_path):
        print('create dir:{}'.format(file_path))
        os.makedirs(file_path)

    if not os.path.exists(act_path):
        print('create dir:{}'.format(act_path))
        os.makedirs(act_path)

    with open(join(save_path,'_statisticsDecode.log.csv'),'w') as w:
        w.write(header)  
        
    # Decoding
    i = 0
    with torch.no_grad():
        for i_debug, (raw_article_batch, raw_abstract_batch, extracted_batch) in enumerate(loader):
            tokenized_article_batch = map(tokenize(None), raw_article_batch)
            tokenized_abstract_batch = map(tokenize(None), raw_abstract_batch)
            token_nums_batch = list(map(token_nums(None), raw_article_batch))

            ext_nums = []
            ext_arts = []
            ext_inds = []
            rewrite_less_rouge = []
            dec_outs_act = []
            ext_acts = []
            abs_collections = []
            ext_collections = []

            # 抽句子
            for ind, (raw_art_sents, abs_sents) in enumerate(zip(tokenized_article_batch ,tokenized_abstract_batch)):

                (ext, (state, act_dists)), act = extractor(raw_art_sents)  # exclude EOE
                extracted_state = state[extracted_batch[ind]]
                attn = torch.softmax(state.mm(extracted_state.transpose(1,0)),dim=-1)
                # (_, abs_state), _ = extractor(abs_sents)  # exclude EOE
                
                def plot_actDist(actons, nums):
                    print('indiex: {} distribution ...'.format(nums))
                    # Write MDP State Attention weight matrix   
                    file_name = os.path.join(act_path, '{}.attention.pdf'.format(nums))
                    pdf_pages = PdfPages(file_name)
                    plot_attention(actons.cpu().numpy(), name='{}-th artcle'.format(nums),
                        X_label=list(range(len(raw_art_sents))), Y_label=list(range(len(ext))),
                        dirpath=save_path, pdf_page=pdf_pages,action=True)
                    pdf_pages.close()
                # plot_actDist(torch.stack(act_dists, dim=0), nums=ind+i)

                def plot_attn():
                    print('indiex: {} write_attention_pdf ...'.format(i + ind))
                    # Write MDP State Attention weight matrix   
                    file_name = os.path.join(file_path, '{}.attention.pdf'.format(i+ind))
                    pdf_pages = PdfPages(file_name)
                    plot_attention(attn.cpu().numpy(), name='{}-th artcle'.format(i+ind),
                        X_label=extracted_batch[ind],Y_label=list(range(len(raw_art_sents))),
                        dirpath=save_path, pdf_page=pdf_pages) 
                    pdf_pages.close()
                # plot_attn()

                ext = ext[:-1]
                act = act[:-1]

                if not ext:
                    # use top-5 if nothing is extracted
                    # in some rare cases rnn-ext does not extract at all
                    ext = list(range(5))[:len(raw_art_sents)]
                    act = list([1]*5)[:len(raw_art_sents)]
                else:
                    ext = [i.item() for i in ext]
                    act = [i.item() for i in act]

                ext_nums.append(ext)

                ext_inds += [(len(ext_arts), len(ext))] # [(0,5),(5,7),(7,3),...]
                ext_arts += [raw_art_sents[k] for k in ext]
                ext_acts += [k for k in act]

                # 計算累計的句子
                ext_collections += [sum(ext_arts[ext_inds[-1][0]:ext_inds[-1][0]+k+1],[]) for k in range(ext_inds[-1][1])]

                abs_collections += [sum(abs_sents[:k+1],[]) if k<len(abs_sents) 
                                        else sum(abs_sents[0:len(abs_sents)],[]) 
                                        for k in range(ext_inds[-1][1])]

            if beam_size > 1: # do n times abstract
                all_beams = abstractor(ext_arts, beam_size, diverse)
                dec_outs = rerank_mp(all_beams, ext_inds)

                dec_collections = [[sum(dec_outs[pos[0]:pos[0]+k+1],[]) for k in range(pos[1])] for pos in ext_inds]
                dec_collections = [x for sublist in dec_collections for x in sublist]
                for index, chooser in enumerate(ext_acts):
                    if chooser == 0:
                        dec_outs_act += [dec_outs[index]]
                    else:
                        dec_outs_act += [ext_arts[index]]

                assert len(ext_collections)==len(dec_collections)==len(abs_collections)
                for ext, dec, abss, act in zip(ext_collections, dec_collections, abs_collections, ext_acts):
                    # for each sent in extracted digest
                    # All abstract mapping
                    rouge_before_rewriten = compute_rouge_n(ext, abss, n=1)
                    rouge_after_rewriten = compute_rouge_n(dec, abss, n=1)
                    diff_ins = rouge_before_rewriten - rouge_after_rewriten
                    rewrite_less_rouge.append(diff_ins)
            
            else: # do 1st abstract
                dec_outs = abstractor(ext_arts)
                dec_collections = [[sum(dec_outs[pos[0]:pos[0]+k+1],[]) for k in range(pos[1])] for pos in ext_inds]
                dec_collections = [x for sublist in dec_collections for x in sublist]
                for index, chooser in enumerate(ext_acts):
                    if chooser == 0:
                        dec_outs_act += [dec_outs[index]]
                    else:
                        dec_outs_act += [ext_arts[index]]
                # dec_outs_act = dec_outs
                # dec_outs_act = ext_arts
                assert len(ext_collections)==len(dec_collections)==len(abs_collections)
                for ext, dec, abss, act in zip(ext_collections, dec_collections, abs_collections, ext_acts):
                    # for each sent in extracted digest
                    # All abstract mapping
                    rouge_before_rewriten = compute_rouge_n(ext, abss, n=1)
                    rouge_after_rewriten = compute_rouge_n(dec, abss, n=1)
                    diff_ins = rouge_before_rewriten - rouge_after_rewriten
                    rewrite_less_rouge.append(diff_ins)

            assert i == batch_size*i_debug

            for iters, (j, n) in enumerate(ext_inds):        
                
                do_right_rewrite = sum([1 for rouge, action in zip(rewrite_less_rouge[j:j+n], ext_acts[j:j+n]) if rouge<0 and action==0])
                do_right_preserve = sum([1 for rouge, action in zip(rewrite_less_rouge[j:j+n], ext_acts[j:j+n]) if rouge>=0 and action==1])
                
                decoded_words_nums = [len(dec) for dec in dec_outs_act[j:j+n]]
                ext_words_nums = [token_nums_batch[iters][x] for x in range(len(token_nums_batch[iters])) if x in ext_nums[iters]]

                # 皆取extracted label 
                # decoded_sents = [raw_article_batch[iters][x] for x in extracted_batch[iters]]         
                # 統計數據 [START]
                decoded_sents = [' '.join(dec) for dec in dec_outs_act[j:j+n]]
                rouge_score1 = compute_rouge_n(' '.join(decoded_sents),' '.join(raw_abstract_batch[iters]),n=1)
                rouge_score2 = compute_rouge_n(' '.join(decoded_sents),' '.join(raw_abstract_batch[iters]),n=2)
                rouge_scorel = compute_rouge_l(' '.join(decoded_sents),' '.join(raw_abstract_batch[iters]))
                
                dec_sent_nums = len(decoded_sents)
                abs_sent_nums = len(raw_abstract_batch[iters])
                doc_sent_nums = len(raw_article_batch[iters])
                
                doc_words_nums = sum(token_nums_batch[iters])
                ext_words_nums = sum(ext_words_nums)
                abs_words_nums = sum(decoded_words_nums)

                label_recall = len(set(ext_nums[iters]) & set(extracted_batch[iters])) / len(extracted_batch[iters])
                label_precision = len(set(ext_nums[iters]) & set(extracted_batch[iters])) / len(ext_nums[iters])
                less_rewrite = rewrite_less_rouge[j+n-1]
                dec_one_action_num = sum(ext_acts[j:j+n])
                dec_zero_action_num = n - dec_one_action_num

                ext_indices = '_'.join([str(i) for i in ext_nums[iters]])
                
                top3 = set([0,1,2]) <= set(ext_nums[iters])
                top3_gold = set([0,1,2]) <= set(extracted_batch[iters])
                
                # Any Top 2 
                top2 = set([0,1]) <= set(ext_nums[iters]) or set([1,2]) <= set(ext_nums[iters]) or set([0,2]) <= set(ext_nums[iters])
                top2_gold = set([0,1]) <= set(extracted_batch[iters]) or set([1,2]) <= set(extracted_batch[iters]) or set([0,2]) <= set(extracted_batch[iters])
                
                with open(join(save_path,'_statisticsDecode.log.csv'),'a') as w:
                    w.write('{},{},{},{},{},{},{},{},{},{},{},{},{},{},{},{},{},{},{},{},{},{},{}\n'.format(i,rouge_score1,
                     rouge_score2, rouge_scorel, dec_sent_nums,
                      abs_sent_nums, doc_sent_nums, doc_words_nums,
                      ext_words_nums,abs_words_nums,(ext_words_nums - abs_words_nums),
                      label_recall, label_precision,
                      less_rewrite, dec_one_action_num, dec_zero_action_num, 
                      ext_indices, top3, top3_gold, top2, top2_gold,do_right_rewrite,do_right_preserve))
                # 統計數據 END

                with open(join(save_path, 'output/{}.dec'.format(i)),
                          'w') as f:
                    decoded_sents = [i for i in decoded_sents if i!='']
                    if len(decoded_sents) > 0:
                        f.write(make_html_safe('\n'.join(decoded_sents)))
                    else:
                        f.write('')

                i += 1
                print('{}/{} ({:.2f}%) decoded in {} seconds\r'.format(
                    i, n_data, i/n_data*100,
                    timedelta(seconds=int(time()-start))
                ), end='')
            
    print()
예제 #21
0
def rl_validate(net,
                val_batches,
                reward_func=None,
                reward_coef=0.01,
                local_coh_func=None,
                local_coh_coef=0.005,
                bert=False):
    print('running validation ... ', end='')
    if bert:
        tokenizer = net._bert_model._tokenizer
        end = tokenizer.encoder[tokenizer._eos_token]
        unk = tokenizer.encoder[tokenizer._unk_token]

    def argmax(arr, keys):
        return arr[max(range(len(arr)), key=lambda i: keys[i].item())]

    def sum_id2word(raw_article_sents, decs, attns):
        if bert:
            dec_sents = []
            for i, raw_words in enumerate(raw_article_sents):
                dec = []
                for id_, attn in zip(decs, attns):
                    if id_[i] == end:
                        break
                    elif id_[i] == unk:
                        dec.append(argmax(raw_words, attn[i]))
                    else:
                        dec.append(id2word[id_[i].item()])
                dec_sents.append(dec)
        else:
            dec_sents = []
            for i, raw_words in enumerate(raw_article_sents):
                dec = []
                for id_, attn in zip(decs, attns):
                    if id_[i] == END:
                        break
                    elif id_[i] == UNK:
                        dec.append(argmax(raw_words, attn[i]))
                    else:
                        dec.append(id2word[id_[i].item()])
                dec_sents.append(dec)
        return dec_sents

    net.eval()
    start = time()
    i = 0
    score = 0
    score_reward = 0
    score_local_coh = 0
    score_r2 = 0
    score_r1 = 0
    bl_r2 = []
    bl_r1 = []
    with torch.no_grad():
        for fw_args, bw_args in val_batches:
            raw_articles = bw_args[0]
            id2word = bw_args[1]
            raw_targets = bw_args[2]
            if reward_func is not None:
                questions = bw_args[3]
            greedies, greedy_attns = net.greedy(*fw_args)
            greedy_sents = sum_id2word(raw_articles, greedies, greedy_attns)
            bl_scores = []
            if reward_func is not None:
                bl_coh_inputs = []
            if local_coh_func is not None:
                bl_local_coh_scores = []
            for baseline, target in zip(greedy_sents, raw_targets):
                if bert:
                    text = ''.join(baseline)
                    baseline = bytearray([
                        tokenizer.byte_decoder[c] for c in text
                    ]).decode('utf-8', errors=tokenizer.errors)
                    baseline = baseline.split(' ')
                    text = ''.join(target)
                    target = bytearray([
                        tokenizer.byte_decoder[c] for c in text
                    ]).decode('utf-8', errors=tokenizer.errors)
                    target = target.split(' ')

                bss = sent_tokenize(' '.join(baseline))
                if reward_func is not None:
                    bl_coh_inputs.append(bss)

                bss = [bs.split(' ') for bs in bss]
                tgs = sent_tokenize(' '.join(target))
                tgs = [tg.split(' ') for tg in tgs]
                bl_score = compute_rouge_l_summ(bss, tgs)
                bl_r1.append(
                    compute_rouge_n(list(concat(bss)), list(concat(tgs)), n=1))
                bl_r2.append(
                    compute_rouge_n(list(concat(bss)), list(concat(tgs)), n=2))
                bl_scores.append(bl_score)
            bl_scores = torch.tensor(bl_scores,
                                     dtype=torch.float32,
                                     device=greedy_attns[0].device)

            if reward_func is not None:
                bl_reward_scores = reward_func.score(questions, bl_coh_inputs)
                bl_reward_scores = torch.tensor(bl_reward_scores,
                                                dtype=torch.float32,
                                                device=greedy_attns[0].device)
                score_reward += bl_reward_scores.mean().item() * 100

            reward = bl_scores.mean().item()
            i += 1
            score += reward * 100
            score_r2 += torch.tensor(
                bl_r2, dtype=torch.float32,
                device=greedy_attns[0].device).mean().item() * 100
            score_r1 += torch.tensor(
                bl_r1, dtype=torch.float32,
                device=greedy_attns[0].device).mean().item() * 100

    val_score = score / i
    score_r2 = score_r2 / i
    score_r1 = score_r1 / i
    if reward_func is not None:
        val_reward_score = score_reward / i
    else:
        val_reward_score = 0
    val_local_coh_score = 0
    print(
        'validation finished in {}                                    '.format(
            timedelta(seconds=int(time() - start))))
    print('validation reward: {:.4f} ... '.format(val_score))
    print('validation r2: {:.4f} ... '.format(score_r2))
    print('validation r1: {:.4f} ... '.format(score_r1))
    if reward_func is not None:
        print('validation reward: {:.4f} ... '.format(val_reward_score))
    if local_coh_func is not None:
        val_local_coh_score = score_local_coh / i
        print('validation {} reward: {:.4f} ... '.format(
            local_coh_func.__name__, val_local_coh_score))
    print('n_data:', i)
    return {'score': val_score, 'score_reward:': val_reward_score}
예제 #22
0
def train(args):
    if not exists(args.path):
        os.makedirs(args.path)

    # make net
    if args.docgraph or args.paragraph:
        agent, agent_vocab, abstractor, net_args = configure_net_graph(
            args.abs_dir, args.ext_dir, args.cuda, args.docgraph,
            args.paragraph)
    else:
        agent, agent_vocab, abstractor, net_args = configure_net(
            args.abs_dir, args.ext_dir, args.cuda, True, False, args.rl_dir)

    if args.bert_stride > 0:
        assert args.bert_stride == agent._bert_stride
    # configure training setting
    assert args.stop > 0
    train_params = configure_training('adam', args.lr, args.clip, args.decay,
                                      args.batch, args.gamma, args.reward,
                                      args.stop, 'rouge-1')

    if args.docgraph or args.paragraph:
        if args.bert:
            train_batcher, val_batcher = build_batchers_graph_bert(
                args.batch, args.key, args.adj_type, args.max_bert_word,
                args.docgraph, args.paragraph)
        else:
            train_batcher, val_batcher = build_batchers_graph(
                args.batch, args.key, args.adj_type, args.gold_key,
                args.docgraph, args.paragraph)
    elif args.bert:
        train_batcher, val_batcher = build_batchers_bert(
            args.batch, args.bert_sent, args.bert_stride, args.max_bert_word)
    else:
        train_batcher, val_batcher = build_batchers(args.batch)
    # TODO different reward
    if args.reward == 'rouge-l':
        reward_fn = compute_rouge_l
    elif args.reward == 'rouge-1':
        reward_fn = compute_rouge_n(n=1)
    elif args.reward == 'rouge-2':
        reward_fn = compute_rouge_n(n=2)
    elif args.reward == 'rouge-l-s':
        reward_fn = compute_rouge_l_summ
    else:
        raise Exception('Not prepared reward')
    stop_reward_fn = compute_rouge_n(n=1)

    # save abstractor binary
    if args.abs_dir is not None:
        abs_ckpt = {}
        abs_ckpt['state_dict'] = load_best_ckpt(args.abs_dir, reverse=True)
        abs_vocab = pkl.load(open(join(args.abs_dir, 'vocab.pkl'), 'rb'))
        abs_dir = join(args.path, 'abstractor')
        os.makedirs(join(abs_dir, 'ckpt'))
        with open(join(abs_dir, 'meta.json'), 'w') as f:
            json.dump(net_args['abstractor'], f, indent=4)
        torch.save(abs_ckpt, join(abs_dir, 'ckpt/ckpt-0-0'))
        with open(join(abs_dir, 'vocab.pkl'), 'wb') as f:
            pkl.dump(abs_vocab, f)
        # save configuration
    meta = {}
    meta['net'] = 'rnn-ext_abs_rl'
    meta['net_args'] = net_args
    meta['train_params'] = train_params
    with open(join(args.path, 'meta.json'), 'w') as f:
        json.dump(meta, f, indent=4)
    with open(join(args.path, 'agent_vocab.pkl'), 'wb') as f:
        pkl.dump(agent_vocab, f)

    # prepare trainer
    grad_fn = get_grad_fn(agent, args.clip)
    optimizer = optim.Adam(agent.parameters(), **train_params['optimizer'][1])
    scheduler = ReduceLROnPlateau(optimizer,
                                  'max',
                                  verbose=True,
                                  factor=args.decay,
                                  min_lr=1e-5,
                                  patience=args.lr_p)

    if args.docgraph or args.paragraph:
        entity = True
    else:
        entity = False
    pipeline = SCPipeline(meta['net'], agent, abstractor, train_batcher,
                          val_batcher, optimizer, grad_fn, reward_fn, entity,
                          args.bert)

    trainer = BasicTrainer(pipeline,
                           args.path,
                           args.ckpt_freq,
                           args.patience,
                           scheduler,
                           val_mode='score')

    print('start training with the following hyper-parameters:')
    print(meta)
    trainer.train()
예제 #23
0
def a2c_train_step(agent,
                   abstractor,
                   loader,
                   opt,
                   grad_fn,
                   gamma=0.99,
                   reward_fn=compute_rouge_l,
                   stop_reward_fn=compute_rouge_n(n=1),
                   stop_coeff=1.0):
    opt.zero_grad()
    indices = []
    probs = []
    baselines = []
    ext_sents, ext_sents_topic, ext_sents_probs, summ_indices = [], [], [], []
    #art_batch, topic_batch, abs_batch, topic_label_batch = next(loader)
    art_batch, topic_batch, abs_batch = next(loader)
    for raw_arts, topic in zip(art_batch, topic_batch):
        (inds, ms), bs = agent(raw_arts, topic)
        baselines.append(bs)
        indices.append(inds)
        probs.append(ms)
        num_sents = len(raw_arts)
        ext_sents += [
            raw_arts[idx.item()] for idx in inds if idx.item() < num_sents
        ]
        #ext_sents_topic += [topic[idx.item()] for idx in inds if idx.item()<num_sents]
        ext_sents_probs += [
            m.probs[0][idx] for idx, m in zip(inds, ms)
            if idx.item() < num_sents
        ]
        summ_indices += [[
            idx.item() for idx in inds if idx.item() < num_sents
        ]]
    if ext_sents == []:
        #print('Reach the end')
        return None
    with torch.no_grad():
        summaries = abstractor(ext_sents)

    # Expand the topic dis of the reference
    #topic_label_expand = []
    #for label, inds  in zip(topic_label_batch, summ_indices):
    #    num_ext = len(inds)
    #    topic_label_expand += [label]*num_ext
    """
    legal_ext_num = len(ext_sents)
    #topic_label_expand = torch.stack(reconstruct_topic_dis(topic_label_expand, legal_ext_num)).squeeze(1).cuda()
    ext_sents_topic    = torch.stack(reconstruct_topic_dis_rl(ext_sents_topic, legal_ext_num)).squeeze(1).cuda()
    ext_sents_probs    = torch.stack(ext_sents_probs).squeeze(1).cuda()
    #rs_topic = 1 - ext_sents_probs * abs(topic_label_expand - ext_sents_topic).sum(dim=1) # by L1 distance
    #rs_topic = 0.01*(ext_sents_probs * abs(topic_label_expand * ext_sents_topic).sum(dim=1)) # by inner product
    rs_topic = 0.05*ext_sents_probs * (ext_sents_topic).sum(dim=1) # by inner product if m2
    rs_topic_avg = rs_topic.sum()/rs_topic.size(0)
    rs_topic_list = rs_topic.cpu().tolist()
 
    """

    ## collect the generated headline
    summaries_collect = [None] * len(summ_indices)
    cnt = 0
    i = 0
    for summ_inds in summ_indices:
        try:
            if len(summ_inds) != 0:
                summaries_collect[cnt] = summaries[i]
                i = i + len(summ_inds)
            else:
                summaries_collect[cnt] = [" "]
        except:
            pdb.set_trace()
        cnt += 1

    add_cls_loss = True
    if add_cls_loss == True:
        ## collect the gnerated headline
        summaries_collect = [None] * len(summ_indices)
        cnt = 0
        i = 0
        for summ_inds in summ_indices:
            try:
                if len(summ_inds) != 0:
                    summaries_collect[cnt] = summaries[i]
                    i = i + len(summ_inds)
                else:
                    summaries_collect[cnt] = [" "]
            except:
                pdb.set_trace()
            cnt += 1
        with open(
                '/home/yunzhu/Headline/FASum/PORLHG_v3/model/classifier/TEXT.Field',
                'rb') as f:
            #with open('/home/yunzhu/Headline/FASum/PORLHG_v3/model/classifier/TEXT_abs.Field', 'rb') as f:
            TEXT = dill.load(f)
        vocab_size = len(TEXT.vocab)
        word_embeddings = TEXT.vocab.vectors
        prediction = get_gen_score(summaries_collect, TEXT, vocab_size,
                                   word_embeddings)
        cls_score = (prediction[:, 1].sum() / prediction.size(0)).item()
    else:
        TEXT = None
        vocab_size = None
        word_embeddings = None
    #prediction = get_gen_score(summaries_collect, TEXT, vocab_size, word_embeddings)
    #cls_score = prediction[:,1].sum().item()

    i, i_topic = 0, 0
    rewards = []
    avg_reward = 0
    cnt = 0
    for inds, abss in zip(indices, abs_batch):
        if add_cls_loss == True:
            try:
                #cls_r = prediction[cnt,1].item()*0.2
                cls_r = prediction[cnt, 1].item()
            except:
                print('[Info] In rl.py code, cls_r=0')
                cls_r = 0
        else:
            cls_r = 0
        #rs_topic_ = (rs_topic[:len(inds)].sum()/len(inds)).item()
        rs_topic_ = 0
        rs = ([
            reward_fn(summaries[i + j], abss[j]) + cls_r + rs_topic_
            for j in range(min(len(inds) - 1, len(abss)))
        ] + [0 for _ in range(max(0,
                                  len(inds) - 1 - len(abss)))] +
              [
                  stop_coeff *
                  (stop_reward_fn(list(concat(summaries[i:i + len(inds) - 1])),
                                  list(concat(abss))))
              ])
        #list(concat(abss)))+cls_r)])  # + cls_r

        #try:
        #rs[:len(inds)-1] = np.add(rs[:len(inds)-1], rs_topic_list[i_topic:i_topic+len(inds)-1])
        #    rs[:-1] = np.add(rs[:-1], rs_topic_list[i:i+len(inds)-1])

        #except:
        #   pdb.set_trace()

        assert len(rs) == len(inds)
        avg_reward += rs[-1] / stop_coeff
        i += len(inds) - 1
        i_topic += len(inds)
        # compute discounted rewards
        R = 0
        disc_rs = []
        for r in rs[::-1]:
            R = r + gamma * R
            disc_rs.insert(0, R)
        rewards += disc_rs
        cnt += 1

    indices = list(concat(indices))
    probs = list(concat(probs))
    baselines = list(concat(baselines))
    # standardize rewards
    reward = torch.Tensor(rewards).to(baselines[0].device)
    reward = (reward - reward.mean()) / (reward.std() +
                                         float(np.finfo(np.float32).eps))
    baseline = torch.cat(baselines).squeeze()
    avg_advantage = 0
    losses = []
    for action, prob, r, b in zip(indices, probs, reward, baseline):
        advantage = r - b
        avg_advantage += advantage
        losses.append(-prob.log_prob(action) *
                      (advantage / len(indices)))  # divide by T*B
    critic_loss = F.mse_loss(baseline, reward)
    # backprop and update
    autograd.backward([critic_loss.unsqueeze(0)] + losses,
                      [torch.ones(1).to('cuda')] * (1 + len(losses))
                      #[torch.ones(1).to(critic_loss.device)]*(1+len(losses))
                      )

    grad_log = grad_fn()
    opt.step()
    log_dict = {}
    log_dict.update(grad_log)
    log_dict['reward'] = avg_reward / len(art_batch)
    log_dict['advantage'] = avg_advantage.item() / len(indices)
    log_dict['mse'] = critic_loss.item()
    #log_dict['rs_topic'] = rs_topic_avg.item()
    if add_cls_loss == True:
        log_dict['cls_score'] = cls_score
    assert not math.isnan(log_dict['grad_norm'])
    return log_dict
예제 #24
0
def a2c_train_step(agent,
                   abstractor,
                   loader,
                   opt,
                   grad_fn,
                   gamma=0.99,
                   reward_fn=compute_rouge_l,
                   stop_reward_fn=compute_rouge_n(n=1),
                   stop_coeff=1.0):
    opt.zero_grad()
    indices = []
    probs = []
    baselines = []
    ext_sents = []
    art_batch, abs_batch = next(loader)
    for raw_arts in art_batch:
        (inds, ms), bs = agent(raw_arts)
        baselines.append(bs)
        indices.append(inds)
        probs.append(ms)
        ext_sents += [
            raw_arts[idx.item()] for idx in inds if idx.item() < len(raw_arts)
        ]
    with torch.no_grad():
        summaries = abstractor(ext_sents)
    i = 0
    rewards = []
    avg_reward = 0
    for inds, abss in zip(indices, abs_batch):
        rs = ([
            reward_fn(summaries[i + j], abss[j])
            for j in range(min(len(inds) - 1, len(abss)))
        ] + [0 for _ in range(max(0,
                                  len(inds) - 1 - len(abss)))] +
              [
                  stop_coeff *
                  stop_reward_fn(list(concat(summaries[i:i + len(inds) - 1])),
                                 list(concat(abss)))
              ])
        assert len(rs) == len(inds)
        avg_reward += rs[-1] / stop_coeff
        i += len(inds) - 1
        # compute discounted rewards
        R = 0
        disc_rs = []
        for r in rs[::-1]:
            R = r + gamma * R
            disc_rs.insert(0, R)
        rewards += disc_rs
    indices = list(concat(indices))
    probs = list(concat(probs))
    baselines = list(concat(baselines))
    # standardize rewards
    reward = torch.Tensor(rewards).to(baselines[0].device)
    reward = (reward - reward.mean()) / (reward.std() +
                                         float(np.finfo(np.float32).eps))
    baseline = torch.cat(baselines).squeeze()
    avg_advantage = 0
    losses = []
    for action, p, r, b in zip(indices, probs, reward, baseline):
        advantage = r - b
        avg_advantage += advantage
        losses.append(-p.log_prob(action) *
                      (advantage / len(indices)))  # divide by T*B
    critic_loss = F.mse_loss(baseline, reward)
    # backprop and update
    autograd.backward([critic_loss] + losses,
                      [torch.ones(1).to(critic_loss.device)] *
                      (1 + len(losses)))
    grad_log = grad_fn()
    opt.step()
    log_dict = {}
    log_dict.update(grad_log)
    log_dict['reward'] = avg_reward / len(art_batch)
    log_dict['advantage'] = avg_advantage.item() / len(indices)
    log_dict['mse'] = critic_loss.item()
    assert not math.isnan(log_dict['grad_norm'])
    return log_dict
예제 #25
0
def a2c_train_step(agent,
                   abstractor,
                   loader,
                   opt,
                   grad_fn,
                   gamma=0.99,
                   reward_fn=compute_rouge_l,
                   stop_reward_fn=compute_rouge_n(n=1),
                   stop_coeff=1.0):
    sample_time = 1
    time_variant = True
    gamma = 0.95
    opt.zero_grad()
    art_batch, abs_batch, extracts = next(loader)
    all_loss = []
    reward = 0
    advantage = 0
    i = 0
    greedy_inputs = []
    sample_inputs = []
    sample_log_probs = []
    for idx, raw_arts in enumerate(art_batch):
        #print(len(raw_arts))
        #print(len(raw_arts[0]))
        greedy, samples, all_log_probs = agent(raw_arts,
                                               sample_time=sample_time)
        #print(greedy, samples, all_log_probs)
        #print(len(greedy), len(samples), len(all_log_probs))
        #exit()
        if time_variant:
            bs = []
            abss = abs_batch[idx]
            for _ind, gd in enumerate(greedy):
                #print('1',greedy_sents)
                greedy_sents = []
                ext_sent = []
                for ids in gd:
                    if ids < len(raw_arts):
                        if ids == 0:
                            if ext_sent:
                                greedy_sents.append(ext_sent)
                            ext_sent = []
                        else:
                            ext_sent += raw_arts[ids]
                if gd[-1] != 0 and ext_sent:
                    greedy_sents.append(ext_sent)
                #print('2',greedy_sents)
                #greedy_sents = [raw_arts[ind] for ind in gd]

                baseline = 0
                for i, sent in enumerate(greedy_sents):
                    #greedy_sents = [[word for sent in greedy_sents for word in sent]]
                    #greedy_sents = abstractor(greedy_sents)
                    #print('3',greedy_sents)
                    #exit()
                    with torch.no_grad():
                        #print(sent)
                        greedy_sent = abstractor([sent])
                        #print(greedy_sent)
                        #print('section')
                        #exit()
                        greedy_sent = sent_tokenize(' '.join(greedy_sent[0]))
                        greedy_sent = [
                            s.strip().split(' ') for s in greedy_sent
                        ]
                    #print('1', list(concat(greedy_sent)))
                    #exit()
                    if reward_fn.__name__ != 'compute_rouge_l_summ':
                        if _ind != len(greedy) - 1:
                            if i < len(abss):
                                baseline += compute_rouge_with_marginal_increase(
                                    reward_fn,
                                    greedy_sent,
                                    abss[i],
                                    _ind,
                                    gamma=gamma)
                        else:
                            if i < len(abss):
                                baseline += reward_fn(
                                    list(concat(greedy_sent)),
                                    list(concat(abss[i])))
                        #print('1', baseline)
                    else:
                        if _ind != len(greedy) - 1:
                            if i < len(abss):
                                baseline += compute_rouge_with_marginal_increase(
                                    reward_fn,
                                    greedy_sent,
                                    abss[i],
                                    _ind,
                                    gamma=gamma)
                        else:
                            if i < len(abss):
                                baseline += reward_fn(greedy_sent, abss[i])
                        #print('2', baseline)
                #print(baseline)
                bs.append(baseline)
                #exit()
            #print(greedy, len(greedy), len(bs))
            #print(len(greedy), len(bs))
            #print(samples)
            #exit()
            #sample_sents = [raw_arts[ind] for ind in samples[0]]
            sample_sents = []
            ext_sent = []
            for ids in samples[0]:
                if ids < len(raw_arts):
                    if ids == 0:
                        if ext_sent:
                            sample_sents.append(ext_sent)
                        ext_sent = []
                    else:
                        ext_sent += raw_arts[ids]
            if gd[-1] != 0 and ext_sent:
                sample_sents.append(ext_sent)

            all_rewards = []
            for j, sent in enumerate(sample_sents):
                with torch.no_grad():
                    #sample_sents = [[word for sent in sample_sents for word in sent]]
                    sample_sent = abstractor([sent])
                    sample_sent = sent_tokenize(' '.join(sample_sent[0]))
                    sample_sent = [s.strip().split(' ') for s in sample_sent]

                #print('2', sample_sent)
                if reward_fn.__name__ != 'compute_rouge_l_summ':
                    #print(sample_sent, abss[j])
                    #exit()
                    #rewards = [reward_fn(list(concat(sample_sent[:i+1])), list(concat(abss[j]))) for i in range(len(sample_sent))]
                    rewards = []
                    for i in range(len(sample_sent)):
                        if j < len(abss):
                            #print('3', sample_sent[:+1])
                            #exit()
                            rewards.append(
                                reward_fn(list(concat(sample_sent[:i + 1])),
                                          list(concat(abss[j]))))
                        else:
                            rewards.append(0)
                    #print(rewards,len(rewards))
                    #exit()
                    for index in range(len(rewards)):
                        rwd = 0
                        for _index in range(len(rewards) - index):
                            if _index != 0:
                                rwd += (rewards[_index + index] -
                                        rewards[_index + index -
                                                1]) * math.pow(gamma, _index)
                            else:
                                rwd += rewards[_index + index]
                        all_rewards.append(rwd)
                    if j < len(abss):
                        all_rewards.append(
                            compute_rouge_n(list(concat(sample_sent[:j])),
                                            list(concat(abss[:j]))))
                    else:
                        all_rewards.append(
                            compute_rouge_n(list(concat(sample_sent[:j])),
                                            list(concat(abss))))
                else:
                    #rewards = [reward_fn(sample_sent[:i + 1], abss[j]) for i in
                    #           range(len(sample_sent))]
                    rewards = []
                    for i in range(len(sample_sent)):
                        if j < len(abss):
                            rewards.append(
                                reward_fn(list(concat(sample_sent[:i + 1])),
                                          list(concat(abss[j]))))
                        else:
                            rewards.append(0)
                    for index in range(len(rewards)):
                        rwd = 0
                        for _index in range(len(rewards) - index):
                            if _index != 0:
                                rwd += (rewards[_index + index] -
                                        rewards[_index + index -
                                                1]) * math.pow(gamma, _index)
                            else:
                                rwd += rewards[_index + index]
                        all_rewards.append(rwd)
                    all_rewards.append(
                        compute_rouge_n(list(concat(sample_sent)),
                                        list(concat(abss[j]))))
            # print('greedy:', greedy)
            # print('sample:', samples[0])
            # print('baseline:', bs)
            # print('rewars:', all_rewards)
            reward += bs[-1]
            advantage += (all_rewards[-1] - bs[-1])
            i += 1
            advs = [
                torch.tensor([_bs - rwd],
                             dtype=torch.float).to(all_log_probs[0][0].device)
                for _bs, rwd in zip(bs, all_rewards)
            ]
            for log_prob, adv in zip(all_log_probs[0], advs):
                all_loss.append(log_prob * adv)
        else:
            greedy_sents = [raw_arts[ind] for ind in greedy]
            greedy_sents = [word for sent in greedy_sents for word in sent]
            greedy_inputs.append(greedy_sents)
            sample_sents = [raw_arts[ind] for ind in samples[0]]
            sample_sents = [word for sent in sample_sents for word in sent]
            sample_inputs.append(sample_sents)
            sample_log_probs.append(all_log_probs[0])
    if not time_variant:
        with torch.no_grad():
            greedy_outs = abstractor(greedy_inputs)
            sample_outs = abstractor(sample_inputs)
        for greedy_sents, sample_sents, log_probs, abss in zip(
                greedy_outs, sample_outs, sample_log_probs, abs_batch):
            greedy_sents = sent_tokenize(' '.join(greedy_sents))
            greedy_sents = [sent.strip().split(' ') for sent in greedy_sents]
            if reward_fn.__name__ != 'compute_rouge_l_summ':
                bs = reward_fn(list(concat(greedy_sents)), list(concat(abss)))
            else:
                bs = reward_fn(greedy_sents, abss)
            sample_sents = sent_tokenize(' '.join(sample_sents))
            sample_sents = [sent.strip().split(' ') for sent in sample_sents]
            if reward_fn.__name__ != 'compute_rouge_l_summ':
                rwd = reward_fn(list(concat(sample_sents)), list(concat(abss)))
            else:
                rwd = reward_fn(sample_sents, abss)
            reward += bs
            advantage += (rwd - bs)
            i += 1
            adv = torch.tensor([bs - rwd],
                               dtype=torch.float).to(log_probs[0].device)
            for log_prob in log_probs:
                all_loss.append(log_prob * adv)
    reward = reward / i
    advantage = advantage / i

    # backprop and update
    loss = torch.cat(all_loss, dim=0).mean()
    loss.backward()
    grad_log = grad_fn()
    opt.step()
    log_dict = {}
    log_dict.update(grad_log)
    log_dict['reward'] = reward
    log_dict['advantage'] = advantage
    log_dict['mse'] = 0
    assert not math.isnan(log_dict['grad_norm'])
    return log_dict
예제 #26
0
def sc_train_step(agent,
                  abstractor,
                  loader,
                  opt,
                  grad_fn,
                  reward_fn=compute_rouge_l,
                  sample_time=1,
                  entity=False):
    gamma = 0.95
    opt.zero_grad()
    art_batch, abs_batch = next(loader)
    all_loss = []
    reward = 0
    advantage = 0
    i = 0
    greedy_inputs = []
    sample_inputs = []
    sample_log_probs = []
    for idx, raw_arts in enumerate(art_batch):
        greedy, samples, all_log_probs = agent(raw_arts,
                                               sample_time=sample_time)
        if agent.time_variant:
            bs = []
            abss = abs_batch[idx]
            for _ind, gd in enumerate(greedy):
                greedy_sents = [raw_arts[ind] for ind in gd]

                with torch.no_grad():
                    greedy_sents = [[
                        word for sent in greedy_sents for word in sent
                    ]]
                    greedy_sents = abstractor(greedy_sents)
                    greedy_sents = sent_tokenize(' '.join(greedy_sents[0]))
                    greedy_sents = [
                        sent.strip().split(' ') for sent in greedy_sents
                    ]
                if reward_fn.__name__ != 'compute_rouge_l_summ':
                    if _ind != len(greedy) - 1:
                        baseline = compute_rouge_with_marginal_increase(
                            reward_fn, greedy_sents, abss, _ind, gamma=gamma)
                    else:
                        baseline = reward_fn(list(concat(greedy_sents)),
                                             list(concat(abss)))
                else:
                    if _ind != len(greedy) - 1:
                        baseline = compute_rouge_with_marginal_increase(
                            reward_fn, greedy_sents, abss, _ind, gamma=gamma)
                    else:
                        baseline = reward_fn(greedy_sents, abss)
                bs.append(baseline)
            sample_sents = [raw_arts[ind] for ind in samples[0]]

            with torch.no_grad():
                sample_sents = [[
                    word for sent in sample_sents for word in sent
                ]]
                sample_sents = abstractor(sample_sents)
                sample_sents = sent_tokenize(' '.join(sample_sents[0]))
                sample_sents = [
                    sent.strip().split(' ') for sent in sample_sents
                ]

            if reward_fn.__name__ != 'compute_rouge_l_summ':
                rewards = [
                    reward_fn(list(concat(sample_sents[:i + 1])),
                              list(concat(abss)))
                    for i in range(len(sample_sents))
                ]
                all_rewards = []
                for index in range(len(rewards)):
                    rwd = 0
                    for _index in range(len(rewards) - index):
                        if _index != 0:
                            rwd += (rewards[_index + index] -
                                    rewards[_index + index - 1]) * math.pow(
                                        gamma, _index)
                        else:
                            rwd += rewards[_index + index]
                    all_rewards.append(rwd)
                all_rewards.append(
                    compute_rouge_n(list(concat(sample_sents)),
                                    list(concat(abss))))
            else:
                rewards = [
                    reward_fn(sample_sents[:i + 1], abss)
                    for i in range(len(sample_sents))
                ]
                all_rewards = []
                for index in range(len(rewards)):
                    rwd = 0
                    for _index in range(len(rewards) - index):
                        if _index != 0:
                            rwd += (rewards[_index + index] -
                                    rewards[_index + index - 1]) * math.pow(
                                        gamma, _index)
                        else:
                            rwd += rewards[_index + index]
                    all_rewards.append(rwd)
                all_rewards.append(
                    compute_rouge_n(list(concat(sample_sents)),
                                    list(concat(abss))))
            # print('greedy:', greedy)
            # print('sample:', samples[0])
            # print('baseline:', bs)
            # print('rewars:', all_rewards)
            reward += bs[-1]
            advantage += (all_rewards[-1] - bs[-1])
            i += 1
            advs = [
                torch.tensor([_bs - rwd],
                             dtype=torch.float).to(all_log_probs[0][0].device)
                for _bs, rwd in zip(bs, all_rewards)
            ]
            for log_prob, adv in zip(all_log_probs[0], advs):
                all_loss.append(log_prob * adv)
        else:
            if entity:
                raw_arts = raw_arts[0]
            greedy_sents = [raw_arts[ind] for ind in greedy]
            greedy_sents = [word for sent in greedy_sents for word in sent]
            greedy_inputs.append(greedy_sents)
            sample_sents = [raw_arts[ind] for ind in samples[0]]
            sample_sents = [word for sent in sample_sents for word in sent]
            sample_inputs.append(sample_sents)
            sample_log_probs.append(all_log_probs[0])
    if not agent.time_variant:
        with torch.no_grad():
            greedy_outs = abstractor(greedy_inputs)
            sample_outs = abstractor(sample_inputs)
        for greedy_sents, sample_sents, log_probs, abss in zip(
                greedy_outs, sample_outs, sample_log_probs, abs_batch):
            greedy_sents = sent_tokenize(' '.join(greedy_sents))
            greedy_sents = [sent.strip().split(' ') for sent in greedy_sents]
            if reward_fn.__name__ != 'compute_rouge_l_summ':
                bs = reward_fn(list(concat(greedy_sents)), list(concat(abss)))
            else:
                bs = reward_fn(greedy_sents, abss)
            sample_sents = sent_tokenize(' '.join(sample_sents))
            sample_sents = [sent.strip().split(' ') for sent in sample_sents]
            if reward_fn.__name__ != 'compute_rouge_l_summ':
                rwd = reward_fn(list(concat(sample_sents)), list(concat(abss)))
            else:
                rwd = reward_fn(sample_sents, abss)
            reward += bs
            advantage += (rwd - bs)
            i += 1
            adv = torch.tensor([bs - rwd],
                               dtype=torch.float).to(log_probs[0].device)
            for log_prob in log_probs:
                all_loss.append(log_prob * adv)
    reward = reward / i
    advantage = advantage / i

    # backprop and update
    loss = torch.cat(all_loss, dim=0).mean()
    loss.backward()
    grad_log = grad_fn()
    opt.step()
    log_dict = {}
    log_dict.update(grad_log)
    log_dict['reward'] = reward
    log_dict['advantage'] = advantage
    log_dict['mse'] = 0
    assert not math.isnan(log_dict['grad_norm'])
    return log_dict
예제 #27
0
def a2c_train_step(agent,
                   abstractor,
                   loader,
                   opt,
                   grad_fn,
                   gamma=0.99,
                   reward_fn=compute_rouge_l,
                   stop_reward_fn=compute_rouge_n(n=1),
                   stop_coeff=1.0):  #, wt_rge = wt_rge):
    opt.zero_grad()
    indices = []
    probs = []
    baselines = []
    ext_sents = []
    art_batch, abs_batch = next(loader)
    for raw_arts, raw_abs in zip(art_batch, abs_batch):
        # print(raw_abs)
        (inds, ms), bs, raw_arts = agent(raw_arts,
                                         raw_abs_sents=raw_abs,
                                         n_abs=None)  #, n_abs=3
        # print(inds, len(raw_arts), len(inds))
        baselines.append(bs)
        indices.append(inds)
        probs.append(ms)
        ext_sents += [
            raw_arts[idx.item()] for idx in inds if idx.item() < len(raw_arts)
        ]
    with torch.no_grad():
        if not ext_sents:
            summaries = []
        elif use_bart:
            summaries = get_bart_summaries(ext_sents, tokenizer, bart_model)
        else:
            summaries = abstractor(ext_sents)
    i = 0
    rewards = []
    avg_reward = 0
    # print(indices)
    # print(abs_batch)
    cands = []
    refs = []
    F1s = None
    if reward_fn == compute_bert_score or reward_fn == compute_bleurt_score:
        # print("bertscore -reward")
        for inds, abss in zip(indices, abs_batch):
            for j in range(min(len(inds) - 1, len(abss))):
                cands.append(" ".join(summaries[i + j]))
                refs.append(" ".join(abss[j]))
            i += len(inds) - 1
        # print(len(cands), len(refs)) #around 120 each
        F1s = reward_fn(cands, refs)
        # print(F1s)

    i = 0
    t = 0
    avg_len = 0
    for inds, abss in zip(indices, abs_batch):
        # print(abss)
        # abss_tot = [x for y in abss for x in y]
        # print([j for j in range(min(len(inds)-1, len(abss)))]) compute_rouge_lcompute_rouge_n(n=2) #PLEASE NOTE HERE
        x = len(inds) - 1
        if (reward_fn == compute_bert_score
                or reward_fn == compute_bleurt_score) and wt_rge >= 0.0001:
            rwd_lst = [(1 - wt_rge) * F1s[t + j] +
                       wt_rge * compute_rouge_n(summaries[i + j], abss[j], n=2)
                       for j in range(min(len(inds) - 1, len(abss)))]
            # print(rwd_lst)
            t += min(len(inds) - 1, len(abss))
        elif (reward_fn == compute_bert_score
              or reward_fn == compute_bleurt_score):
            rwd_lst = [
                F1s[t + j] for j in range(min(len(inds) - 1, len(abss)))
            ]
            # print(rwd_lst)
            t += min(len(inds) - 1, len(abss))
        elif wt_rge >= 0.0001:
            # print("hi", wt_rge)
            rwd_lst = [(1 - wt_rge) * reward_fn(summaries[i + j], abss[j]) +
                       wt_rge * compute_rouge_l(summaries[i + j], abss[j])
                       for j in range(min(len(inds) - 1, len(abss)))]
        elif reward_fn == presumm_reward or reward_fn == presumm_reward2 or reward_fn == presumm_reward3 or reward_fn == presumm_reward4:
            rwd_lst = reward_fn(summaries[i:i + len(inds) - 1], abss)
            x = 3
            # print("hi", x)
        elif reward_fn == "summ-rouge-l":
            tmp = list(concat(abss))
            rwd_lst = [
                compute_rouge_l(summaries[i + j], tmp, mode='r')
                for j in range(min(len(inds) - 1, 3))
            ]
        else:
            rwd_lst = [
                reward_fn(summaries[i + j], abss[j])
                for j in range(min(len(inds) - 1, len(abss)))
            ]  #abss_tot, mode=r
        rs = (
            rwd_lst + [0 for _ in range(max(0,
                                            len(inds) - 1 - len(rwd_lst)))
                       ]  #len(inds)-1-len(abss)
            + [
                stop_coeff * stop_reward_fn(list(concat(summaries[i:i + x])),
                                            list(concat(abss)))
            ])
        assert len(rs) == len(inds)
        avg_reward += rs[-1] / stop_coeff
        if reward_fn == presumm_reward or reward_fn == presumm_reward3 or reward_fn == presumm_reward4:
            rs[-1] = 0
        # if reward_fn == presumm_reward3:
        #     rs.pop()
        avg_len += (len(inds) - 1) / len(abs_batch)  # print(avg_reward)
        i += len(inds) - 1
        # compute discounted rewards
        # print(rs)
        R = 0
        disc_rs = []
        for r in rs[::-1]:
            R = r + gamma * R
            disc_rs.insert(0, R)
        rewards += disc_rs
    # print(avg_len)
    indices = list(concat(indices))
    probs = list(concat(probs))
    baselines = list(concat(baselines))
    # standardize rewards
    reward = torch.Tensor(rewards).to(baselines[0].device)
    # print(reward.mean())
    reward = (reward - reward.mean()) / (reward.std() +
                                         float(np.finfo(np.float32).eps))
    baseline = torch.cat(baselines).squeeze()
    avg_advantage = 0
    losses = []
    avg_list = []  ##ARJUN
    for action, p, r, b in zip(indices, probs, reward, baseline):
        advantage = (r - b).item()  ##ARJUN # (r-b)
        avg_advantage += advantage
        avg_list.append(advantage)  ##ARJUN

    avg_list = torch.Tensor(avg_list).to(baselines[0].device)  ##ARJUN #
    avg_list = (avg_list - avg_list.mean()) / (
        avg_list.std() + float(np.finfo(np.float32).eps))  ##ARJUN #

    for action, p, advantage in zip(indices, probs, avg_list):  ##ARJUN
        losses.append(-p.log_prob(action) *
                      (advantage / len(indices)))  # divide by T*B
    critic_loss = F.mse_loss(baseline, reward).unsqueeze(dim=0)
    # backprop and update
    autograd.backward([critic_loss] + losses,
                      [torch.ones(1).to(critic_loss.device)] *
                      (1 + len(losses)))
    grad_log = grad_fn()
    opt.step()
    log_dict = {}
    log_dict.update(grad_log)
    log_dict['reward'] = avg_reward / len(art_batch)
    log_dict['advantage'] = avg_advantage / len(indices)  #removed .item()
    log_dict['mse'] = critic_loss.item()
    log_dict['avg_len'] = avg_len
    assert not math.isnan(log_dict['grad_norm'])
    return log_dict
예제 #28
0
def a2c_train_step(agent,
                   abstractor,
                   loader,
                   opt,
                   grad_fn,
                   gamma=0.99,
                   reward_fn=compute_rouge_l,
                   stop_reward_fn=compute_rouge_n(n=1),
                   stop_coeff=1.0):
    opt.zero_grad()
    indices = []
    indicesL = []
    new_indices = []
    probs = []
    new_probs = []
    baselines = []
    ext_sents = []
    art_batch, abs_batch, ext_batch = next(loader)
    raw_adv = []
    r = 0
    a = 0
    f = 0
    zeros = []
    for raw_arts in art_batch:
        (inds, ms), bs = agent(raw_arts)
        baselines.append(bs)
        indices.append(inds)
        indicesL.append(len(inds))
        probs.append(ms)
        #ext_sents += [raw_arts[idx.item()]
        #              for idx in inds if idx.item() < len(raw_arts)]
        #print(ext_sents)
        #exit(0)
        inds_ = []
        ms_ = []
        ext_sent = []
        zero = []

        count = 0
        q = 0
        f = 0
        #print(inds)
        for idx in inds:
            #print('bx',bx)
            if idx.item() < len(raw_arts):
                if idx.item() == 0:
                    if ext_sent:
                        count += 1
                        ext_sents.append(ext_sent)
                        zero.append(f)
                        inds_.append(1)
                        raw_adv.append((r, a))
                        a += 1
                        q = 1
                        #print('1',a)
                    else:
                        if q == 1:
                            raw_adv.append((r, a - 1))
                        else:
                            raw_adv.append((r, a))
                        #print('a')
                    ext_sent = []
                    count = 0
                else:
                    ext_sent += raw_arts[idx.item()]
                    #print('b')
                    count += 1
                    raw_adv.append((r, a))
                    q = 0
            else:
                #print('c')
                if q == 1:
                    raw_adv.append((r, a - 1))
                else:
                    raw_adv.append((r, a))
            f += 1
        if inds[-1].item() != 0:
            if ext_sent:
                ext_sents.append(ext_sent)
                zero.append(f)
                inds_.append(1)
                a += 1
                #print('2',a)
            #ext_sent = []
        if not inds_:
            ext_sents.append(raw_arts[0])
            inds_.append(1)
            zero.append(f)
            a += 1
            #print('3',a)
        r += 1
        zeros.append(zero)
        new_indices.append(inds_)
        #print(len(inds),len(bs),len(inds_))
        #print(len(zero), len(inds_))
        #print(ms[0].probs)
        #exit()
    with torch.no_grad():
        summaries = abstractor(ext_sents)
    i = 0
    rewards = []
    avg_reward = 0
    set_reward, summ_reward = [], []
    for inds, abss in zip(new_indices, abs_batch):
        set_reward.append([
            reward_fn(summaries[i + j], abss[j])
            for j in range(min(len(inds) - 1, len(abss)))
        ])
        summ_reward.append(stop_coeff * stop_reward_fn(
            list(concat(summaries[i:i + len(inds) - 1])), list(concat(abss))))
        '''
        rs = ([reward_fn(summaries[i+j], abss[j])
              for j in range(min(len(inds)-1, len(abss)))]
              + [0 for _ in range(max(0, len(inds)-1-len(abss)))]
              + [stop_coeff*stop_reward_fn(
                  list(concat(summaries[i:i+len(inds)-1])),
                  list(concat(abss)))])
        if len(rs) != len(inds):				  
           print(len(rs),len(inds))
           print(indices)
           print(new_indices)
        assert len(rs) == len(inds)
        avg_reward += rs[-1]/stop_coeff
        i += len(inds)-1
        # compute discounted rewards
        R = 0
        disc_rs = []
        for r in rs[::-1]:
            R = r + gamma * R
            disc_rs.insert(0, R)
        rewards += disc_rs
        '''

    ext_rewards = []
    avg_ext_reward = 0
    for inds, exts, arts, setR, summR, zero in zip(indices, ext_batch,
                                                   art_batch, set_reward,
                                                   summ_reward, zeros):
        k = 0
        rs1 = []
        ext_zero = []
        for e in range(len(exts)):
            if exts[e] == 0:
                ext_zero.append(e)
        e0 = 0
        maxlen = 0
        for j in range(min(len(inds) - 1, len(exts))):
            #print(set(inds) & set(exts))
            maxlen += 1
            if j in zero and k < len(setR) and k < len(ext_zero) - 1:
                rs1.append(setR[k])
                k += 1
                e0 = ext_zero[k] + 1
            else:
                try:
                    if inds[j].item() < len(arts):
                        rs1.append(
                            compute_rouge_n(arts[inds[j].item()],
                                            arts[exts[e0]],
                                            n=1))
                    else:
                        rs1.append(0.0)
                    e0 += 1
                except:
                    maxlen -= 1
                    break
        rs2 = [0 for _ in range(len(inds) - 1 - maxlen)]
        stop = metrics(gt=exts, pred=inds, metrics_map=['MAP'])[0]
        rs3 = [stop_coeff * summR]
        rs = rs1 + rs2 + rs3
        assert len(rs) == len(inds)
        avg_ext_reward += rs[-1] / stop_coeff
        # compute discounted rewards
        R = 0
        disc_rs = []
        #print(rs)
        #print(rs[::-1])
        #exit()
        for r in rs[::-1]:
            R = r + gamma * R
            disc_rs.insert(0, R)
        ext_rewards += disc_rs
    #print(len(ext_rewards))
    #exit()
    indices = list(concat(indices))
    #print(len(indices),len(raw_adv))
    new_indices = list(concat(new_indices))
    #print(len(new_indices))
    probs = list(concat(probs))
    baselines = list(concat(baselines))
    #exit()
    # standardize rewards
    ext_reward = torch.Tensor(ext_rewards).to(baselines[0].device)
    ext_reward = (ext_reward - ext_reward.mean()) / (
        ext_reward.std() + float(np.finfo(np.float32).eps))
    baseline = torch.cat(baselines).squeeze()
    avg_advantage = 0
    losses = []
    for r, b, p, action in zip(ext_reward, baseline, probs, indices):
        advantage = r - b
        avg_advantage += advantage
        losses.append((-p.log_prob(action) *
                       (advantage / len(indices))))  # divide by T*B
    #exit()
    critic_loss = F.mse_loss(baseline, ext_reward)
    # backprop and update
    #print("[DEBUG]")
    #print(critic_loss)
    critic_loss = critic_loss.view([1])
    #print(critic_loss)
    autograd.backward(tensors=[critic_loss] + losses,
                      grad_tensors=[torch.ones(1).to(critic_loss.device)] *
                      (1 + len(losses)))
    #exit()
    grad_log = grad_fn()
    opt.step()
    log_dict = {}
    log_dict.update(grad_log)
    log_dict['reward'] = avg_ext_reward / len(art_batch)
    log_dict['advantage'] = avg_advantage.item() / len(indices)
    log_dict['mse'] = critic_loss.item()
    assert not math.isnan(log_dict['grad_norm'])
    return log_dict
예제 #29
0
    def train_step(self, sample_time=1):
        def argmax(arr, keys):
            return arr[max(range(len(arr)), key=lambda i: keys[i].item())]

        def sum_id2word(raw_article_sents, decs, attns, id2word):
            dec_sents = []
            for i, raw_words in enumerate(raw_article_sents):
                dec = []
                for id_, attn in zip(decs, attns):
                    if id_[i] == END:
                        break
                    elif id_[i] == UNK:
                        dec.append(argmax(raw_words, attn[i]))
                    else:
                        dec.append(id2word[id_[i].item()])
                dec_sents.append(dec)
            return dec_sents

        def pack_seq(seq_list):
            return torch.cat([_.unsqueeze(1) for _ in seq_list], 1)

        # forward pass of model
        self._net.train()
        #self._net.zero_grad()
        fw_args, bw_args = next(self._batches)
        raw_articles = bw_args[0]
        id2word = bw_args[1]
        raw_targets = bw_args[2]
        with torch.no_grad():
            greedies, greedy_attns = self._net.greedy(*fw_args)
        greedy_sents = sum_id2word(raw_articles, greedies, greedy_attns,
                                   id2word)
        bl_scores = []

        for baseline, target in zip(greedy_sents, raw_targets):
            bss = sent_tokenize(' '.join(baseline))
            tgs = sent_tokenize(' '.join(target))
            bss = [bs.split(' ') for bs in bss]
            tgs = [tg.split(' ') for tg in tgs]
            bss_bleu = list(concat(bss))
            tgs_bleu = list(concat(tgs))
            bss_bleu = ' '.join(bss_bleu)
            tgs_bleu = ' '.join(tgs_bleu)
            #bl_score = compute_rouge_l_summ(bss, tgs)
            if self._bleu:
                bleu_scores = bleu(bss_bleu, tgs_bleu)
                bleu_score = (bleu_scores[0] + bleu_scores[1] +
                              bleu_scores[2] + bleu_scores[3])
                bl_score = bleu_score
            elif self.f1:
                bl_score = compute_f1(bss_bleu, tgs_bleu)
            else:
                bl_score = (self._w8[2] * compute_rouge_l_summ(bss, tgs) + \
                       self._w8[0] * compute_rouge_n(list(concat(bss)), list(concat(tgs)), n=1) + \
                        self._w8[1] * compute_rouge_n(list(concat(bss)), list(concat(tgs)), n=2))
            bl_scores.append(bl_score)
        bl_scores = torch.tensor(bl_scores,
                                 dtype=torch.float32,
                                 device=greedy_attns[0].device)

        samples, sample_attns, seqLogProbs = self._net.sample(*fw_args)
        sample_sents = sum_id2word(raw_articles, samples, sample_attns,
                                   id2word)
        sp_seqs = pack_seq(samples)
        _masks = (sp_seqs > PAD).float()
        sp_seqLogProb = pack_seq(seqLogProbs)
        #loss_nll = - sp_seqLogProb.squeeze(2)
        loss_nll = -sp_seqLogProb.squeeze(2) * _masks.detach().type_as(
            sp_seqLogProb)
        sp_scores = []

        for sample, target in zip(sample_sents, raw_targets):
            sps = sent_tokenize(' '.join(sample))
            tgs = sent_tokenize(' '.join(target))
            sps = [sp.split(' ') for sp in sps]
            tgs = [tg.split(' ') for tg in tgs]
            #sp_score = compute_rouge_l_summ(sps, tgs)
            sps_bleu = list(concat(sps))
            tgs_bleu = list(concat(tgs))
            sps_bleu = ' '.join(sps_bleu)
            tgs_bleu = ' '.join(tgs_bleu)
            # bl_score = compute_rouge_l_summ(bss, tgs)
            if self._bleu:
                bleu_scores = bleu(sps_bleu, tgs_bleu)
                bleu_score = (bleu_scores[0] + bleu_scores[1] +
                              bleu_scores[2] + bleu_scores[3])
                sp_score = bleu_score
            elif self.f1:
                sp_score = compute_f1(sps_bleu, tgs_bleu)
            else:
                sp_score = (self._w8[2] * compute_rouge_l_summ(sps, tgs) + \
                        self._w8[0] * compute_rouge_n(list(concat(sps)), list(concat(tgs)), n=1) + \
                        self._w8[1]* compute_rouge_n(list(concat(sps)), list(concat(tgs)), n=2))
            sp_scores.append(sp_score)
        sp_scores = torch.tensor(sp_scores,
                                 dtype=torch.float32,
                                 device=greedy_attns[0].device)

        reward = sp_scores.view(-1, 1) - bl_scores.view(-1, 1)

        reward.requires_grad_(False)
        loss = reward.contiguous().detach() * loss_nll
        loss = loss.sum()
        full_length = _masks.data.float().sum()
        loss = loss / full_length

        loss.backward()

        log_dict = {}

        log_dict['reward'] = bl_scores.mean().item()

        if self._grad_fn is not None:
            log_dict.update(self._grad_fn())
        self._opt.step()
        self._net.zero_grad()
        #torch.cuda.empty_cache()

        return log_dict
예제 #30
0
    def train_step(self, sample_time=1):
        def argmax(arr, keys):
            return arr[max(range(len(arr)), key=lambda i: keys[i].item())]
        def sum_id2word(raw_article_sents, decs, attns, id2word):
            dec_sents = []
            for i, raw_words in enumerate(raw_article_sents):
                dec = []
                for id_, attn in zip(decs, attns):
                    if id_[i] == END:
                        break
                    elif id_[i] == UNK:
                        dec.append(argmax(raw_words, attn[i]))
                    else:
                        dec.append(id2word[id_[i].item()])
                dec_sents.append(dec)
            return dec_sents
        def pack_seq(seq_list):
            return torch.cat([_.unsqueeze(1) for _ in seq_list], 1)
        # forward pass of model
        self._net.train()
        #self._net.zero_grad()
        fw_args, bw_args = next(self._batches)
        raw_articles = bw_args[0]
        id2word = bw_args[1]
        raw_targets = bw_args[2]
        with torch.no_grad():
            greedies, greedy_attns = self._net.greedy(*fw_args)
        greedy_sents = sum_id2word(raw_articles, greedies, greedy_attns, id2word)
        bl_scores = []
        if self._coh_fn is not None:
            bl_coh_scores = []
            bl_coh_inputs = []
        if self._local_coh_fun is not None:
            bl_local_coh_scores = []
        for baseline, target in zip(greedy_sents, raw_targets):
            bss = sent_tokenize(' '.join(baseline))
            tgs = sent_tokenize(' '.join(target))
            if self._coh_fn is not None:
                bl_coh_inputs.append(bss)
                # if len(bss) > 1:
                #     input_args = (bss, ) + self._coh_fn
                #     coh_score = coherence_infer(*input_args) / 2
                # else:
                #     coh_score = 0
                # bl_coh_scores.append(coh_score)
            if self._local_coh_fun is not None:
                local_coh_score = self._local_coh_fun(bss)
                bl_local_coh_scores.append(local_coh_score)
            bss = [bs.split(' ') for bs in bss]
            tgs = [tg.split(' ') for tg in tgs]
            #bl_score = compute_rouge_l_summ(bss, tgs)
            bl_score = (self._weights[2] * compute_rouge_l_summ(bss, tgs) + \
                        self._weights[0] * compute_rouge_n(list(concat(bss)), list(concat(tgs)), n=1) + \
                        self._weights[1] * compute_rouge_n(list(concat(bss)), list(concat(tgs)), n=2))
            bl_scores.append(bl_score)
        bl_scores = torch.tensor(bl_scores, dtype=torch.float32, device=greedy_attns[0].device)
        if self._coh_fn is not None:
            input_args = (bl_coh_inputs,) + self._coh_fn
            bl_coh_scores = batch_global_infer(*input_args)
            bl_coh_scores = torch.tensor(bl_coh_scores, dtype=torch.float32, device=greedy_attns[0].device)
        if self._local_coh_fun is not None:
            bl_local_coh_scores = torch.tensor(bl_local_coh_scores, dtype=torch.float32, device=greedy_attns[0].device)

        # print('bl:', bl_scores)
        # print('bl_coh:', bl_coh_scores)

        for _ in range(sample_time):
            samples, sample_attns, seqLogProbs = self._net.sample(*fw_args)
            sample_sents = sum_id2word(raw_articles, samples, sample_attns, id2word)
            sp_seqs = pack_seq(samples)
            _masks = (sp_seqs > PAD).float()
            sp_seqLogProb = pack_seq(seqLogProbs)
            #loss_nll = - sp_seqLogProb.squeeze(2)
            loss_nll = - sp_seqLogProb.squeeze(2) * _masks.detach().type_as(sp_seqLogProb)
            sp_scores = []
            if self._coh_fn is not None:
                sp_coh_scores = []
                sp_coh_inputs = []
            if self._local_coh_fun is not None:
                sp_local_coh_scores = []
            for sample, target in zip(sample_sents, raw_targets):
                sps = sent_tokenize(' '.join(sample))
                tgs = sent_tokenize(' '.join(target))
                if self._coh_fn is not None:
                    sp_coh_inputs.append(sps)
                    # if len(sps) > 1:
                    #     input_args = (sps,) + self._coh_fn
                    #     coh_score = coherence_infer(*input_args) / 2
                    # else:
                    #     coh_score = 0
                    # sp_coh_scores.append(coh_score)
                if self._local_coh_fun is not None:
                    local_coh_score = self._local_coh_fun(sps)
                    sp_local_coh_scores.append(local_coh_score)
                sps = [sp.split(' ') for sp in sps]
                tgs = [tg.split(' ') for tg in tgs]
                #sp_score = compute_rouge_l_summ(sps, tgs)
                sp_score = (self._weights[2] * compute_rouge_l_summ(sps, tgs) + \
                            self._weights[0] * compute_rouge_n(list(concat(sps)), list(concat(tgs)), n=1) + \
                            self._weights[1] * compute_rouge_n(list(concat(sps)), list(concat(tgs)), n=2))
                sp_scores.append(sp_score)
            sp_scores = torch.tensor(sp_scores, dtype=torch.float32, device=greedy_attns[0].device)
            reward = sp_scores.view(-1, 1) - bl_scores.view(-1, 1)
            reward.requires_grad_(False)
            if self._coh_fn is not None:
                input_args = (sp_coh_inputs,) + self._coh_fn
                sp_coh_scores = batch_global_infer(*input_args)
                sp_coh_scores = torch.tensor(sp_coh_scores, dtype=torch.float32, device=greedy_attns[0].device)
                reward_coh = sp_coh_scores.view(-1, 1) - bl_coh_scores.view(-1, 1)
                reward_coh.requires_grad_(False)
                reward = reward + self._coh_cof * reward_coh
            if self._local_coh_fun is not None:
                sp_local_coh_scores = torch.tensor(sp_local_coh_scores, dtype=torch.float32, device=greedy_attns[0].device)
                reward_local = sp_local_coh_scores.view(-1, 1) - bl_local_coh_scores.view(-1, 1)
                reward_local.requires_grad_(False)
                reward = reward + self._local_co_coef * reward_local

            if _ == 0:
                loss = reward.contiguous().detach() * loss_nll
                loss = loss.sum()
                full_length = _masks.data.float().sum()
            else:
                loss += (reward.contiguous().detach() * loss_nll).sum()
                full_length += _masks.data.float().sum()

        # print('sp:', sp_scores)
        # print('sp_coh:', sp_coh_scores)
        loss = loss / full_length
        # backward and update ( and optional gradient monitoring )
        loss.backward()
        log_dict = {}
        if self._coh_fn is not None:
            log_dict['reward'] = bl_scores.mean().item() + self._coh_cof * bl_coh_scores.mean().item()
        else:
            log_dict['reward'] = bl_scores.mean().item()
        if self._grad_fn is not None:
            log_dict.update(self._grad_fn())
        self._opt.step()
        self._net.zero_grad()
        torch.cuda.empty_cache()

        return log_dict
예제 #31
0
def a2c_train_step(agent,
                   abstractor,
                   loader,
                   opt,
                   grad_fn,
                   gamma=0.99,
                   reward_fn=compute_rouge_l,
                   stop_reward_fn=compute_rouge_n(n=1),
                   stop_coeff=1.0):
    opt.zero_grad()
    indices = []
    probs = []
    baselines = []
    ext_sents = []
    summ_indices = []
    art_batch, topic_batch, abs_batch = next(loader)
    for raw_arts, topic in zip(art_batch, topic_batch):
        (inds, ms), bs = agent(raw_arts, topic)
        baselines.append(bs)
        indices.append(inds)
        probs.append(ms)
        ext_sents += [
            raw_arts[idx.item()] for idx in inds if idx.item() < len(raw_arts)
        ]
        summ_indices += [[
            idx.item() for idx in inds if idx.item() < len(raw_arts)
        ]]

    with torch.no_grad():
        summaries = abstractor(ext_sents)

    ## collect the generated headline
    summaries_collect = [None] * len(summ_indices)
    cnt = 0
    i = 0
    for summ_inds in summ_indices:
        try:
            if len(summ_inds) != 0:
                summaries_collect[cnt] = summaries[i]
                i = i + len(summ_inds)
            else:
                summaries_collect[cnt] = [" "]
        except:
            pdb.set_trace()
        cnt += 1
    add_cls_loss = False
    if add_cls_loss == True:
        ## collect the gnerated headline
        summaries_collect = [None] * len(summ_indices)
        cnt = 0
        i = 0
        for summ_inds in summ_indices:
            try:
                if len(summ_inds) != 0:
                    summaries_collect[cnt] = summaries[i]
                    i = i + len(summ_inds)
                else:
                    summaries_collect[cnt] = [" "]
            except:
                pdb.set_trace()
            cnt += 1
        with open(
                '/home/yunzhu/Headline/FASum/FASRL/model/classifier/TEXT.Field',
                'rb') as f:
            TEXT = dill.load(f)
        vocab_size = len(TEXT.vocab)
        word_embeddings = TEXT.vocab.vectors
    else:
        TEXT = None
        vocab_size = None
        word_embeddings = None
    #prediction = get_gen_score(summaries_collect, TEXT, vocab_size, word_embeddings)
    #cls_score = prediction[:,1].sum().item()

    i = 0
    rewards = []
    avg_reward = 0
    cnt = 0
    for inds, abss in zip(indices, abs_batch):
        cls_r = 0
        rs = ([
            reward_fn(summaries[i + j], abss[j]) + cls_r
            for j in range(min(len(inds) - 1, len(abss)))
        ] + [0 for _ in range(max(0,
                                  len(inds) - 1 - len(abss)))] +
              [
                  stop_coeff *
                  (stop_reward_fn(list(concat(summaries[i:i + len(inds) - 1])),
                                  list(concat(abss))) + cls_r)
              ])

        assert len(rs) == len(inds)
        avg_reward += rs[-1] / stop_coeff
        i += len(inds) - 1
        # compute discounted rewards
        R = 0
        disc_rs = []
        for r in rs[::-1]:
            R = r + gamma * R
            disc_rs.insert(0, R)
        rewards += disc_rs
        cnt += 1

    indices = list(concat(indices))
    probs = list(concat(probs))
    baselines = list(concat(baselines))
    # standardize rewards
    reward = torch.Tensor(rewards).to(baselines[0].device)
    reward = (reward - reward.mean()) / (reward.std() +
                                         float(np.finfo(np.float32).eps))
    baseline = torch.cat(baselines).squeeze()
    avg_advantage = 0
    losses = []
    for action, p, r, b in zip(indices, probs, reward, baseline):
        advantage = r - b
        avg_advantage += advantage
        losses.append(-p.log_prob(action) *
                      (advantage / len(indices)))  # divide by T*B
    critic_loss = F.mse_loss(baseline, reward)
    # backprop and update
    autograd.backward([critic_loss.unsqueeze(0)] + losses,
                      [torch.ones(1).to('cuda')] * (1 + len(losses))
                      #[torch.ones(1).to(critic_loss.device)]*(1+len(losses))
                      )

    grad_log = grad_fn()
    opt.step()
    log_dict = {}
    log_dict.update(grad_log)
    log_dict['reward'] = avg_reward / len(art_batch)
    log_dict['advantage'] = avg_advantage.item() / len(indices)
    log_dict['mse'] = critic_loss.item()
    assert not math.isnan(log_dict['grad_norm'])
    return log_dict
예제 #32
0
    def train_step(self, sample_time=1):
        torch.autograd.set_detect_anomaly(True)

        def argmax(arr, keys):
            return arr[max(range(len(arr)), key=lambda i: keys[i].item())]

        def sum_id2word(raw_article_sents, decs, attns, id2word):
            if self._bert:
                dec_sents = []
                for i, raw_words in enumerate(raw_article_sents):
                    dec = []
                    for id_, attn in zip(decs, attns):
                        if id_[i] == self._end:
                            break
                        elif id_[i] == self._unk:
                            dec.append(argmax(raw_words, attn[i]))
                        else:
                            dec.append(id2word[id_[i].item()])
                    dec_sents.append(dec)
            else:
                dec_sents = []
                for i, raw_words in enumerate(raw_article_sents):
                    dec = []
                    for id_, attn in zip(decs, attns):
                        if id_[i] == self._end:
                            break
                        elif id_[i] == self._unk:
                            dec.append(argmax(raw_words, attn[i]))
                        else:
                            dec.append(id2word[id_[i].item()])
                    dec_sents.append(dec)
            return dec_sents

        def pack_seq(seq_list):
            return torch.cat([_.unsqueeze(1) for _ in seq_list], 1)

        # forward pass of model
        self._net.train()
        #self._net.zero_grad()
        total_loss = None
        for i in range(self._accumulate_g_step):
            fw_args, bw_args = next(self._batches)
            raw_articles = bw_args[0]
            id2word = bw_args[1]
            raw_targets = bw_args[2]
            if self._reward_fn is not None:
                questions = bw_args[3]
            targets = bw_args[4]

            # encode
            # attention, init_dec_states, nodes = self._net.encode_general(*fw_args)
            #fw_args += (attention, init_dec_states, nodes)
            # _init_dec_states = ((init_dec_states[0][0].clone(), init_dec_states[0][1].clone()), init_dec_states[1].clone())
            with torch.no_grad():
                # g_fw_args = fw_args + (attention, _init_dec_states, nodes, False)
                # greedies, greedy_attns = self._net.rl_step(*g_fw_args)
                greedies, greedy_attns = self._net.greedy(*fw_args)
            greedy_sents = sum_id2word(raw_articles, greedies, greedy_attns,
                                       id2word)
            bl_scores = []
            if self._reward_fn is not None:
                bl_reward_scores = []
                bl_reward_inputs = []
            if self._local_coh_fun is not None:
                bl_local_coh_scores = []
            for baseline, target in zip(greedy_sents, raw_targets):
                if self._bert:
                    text = ''.join(baseline)
                    baseline = bytearray([
                        self._tokenizer.byte_decoder[c] for c in text
                    ]).decode('utf-8', errors=self._tokenizer.errors)
                    baseline = baseline.strip().lower().split(' ')
                    text = ''.join(target)
                    target = bytearray([
                        self._tokenizer.byte_decoder[c] for c in text
                    ]).decode('utf-8', errors=self._tokenizer.errors)
                    target = target.strip().lower().split(' ')

                bss = sent_tokenize(' '.join(baseline))
                tgs = sent_tokenize(' '.join(target))
                if self._reward_fn is not None:
                    bl_reward_inputs.append(bss)
                if self._local_coh_fun is not None:
                    local_coh_score = self._local_coh_fun(bss)
                    bl_local_coh_scores.append(local_coh_score)
                bss = [bs.split(' ') for bs in bss]
                tgs = [tg.split(' ') for tg in tgs]

                #bl_score = compute_rouge_l_summ(bss, tgs)
                bl_score = (self._w8[2] * compute_rouge_l_summ(bss, tgs) + \
                           self._w8[0] * compute_rouge_n(list(concat(bss)), list(concat(tgs)), n=1) + \
                            self._w8[1] * compute_rouge_n(list(concat(bss)), list(concat(tgs)), n=2))
                bl_scores.append(bl_score)
            bl_scores = torch.tensor(bl_scores,
                                     dtype=torch.float32,
                                     device=greedy_attns[0].device)

            # sample
            # s_fw_args = fw_args + (attention, init_dec_states, nodes, True)
            # samples, sample_attns, seqLogProbs = self._net.rl_step(*s_fw_args)
            fw_args += (self._ml_loss, )
            if self._ml_loss:
                samples, sample_attns, seqLogProbs, ml_logit = self._net.sample(
                    *fw_args)
            else:
                samples, sample_attns, seqLogProbs = self._net.sample(*fw_args)
            sample_sents = sum_id2word(raw_articles, samples, sample_attns,
                                       id2word)
            sp_seqs = pack_seq(samples)
            _masks = (sp_seqs > PAD).float()
            sp_seqLogProb = pack_seq(seqLogProbs)
            #loss_nll = - sp_seqLogProb.squeeze(2)
            loss_nll = -sp_seqLogProb.squeeze(2) * _masks.detach().type_as(
                sp_seqLogProb)
            sp_scores = []
            if self._reward_fn is not None:
                sp_reward_inputs = []
            for sample, target in zip(sample_sents, raw_targets):
                if self._bert:
                    text = ''.join(sample)
                    sample = bytearray([
                        self._tokenizer.byte_decoder[c] for c in text
                    ]).decode('utf-8', errors=self._tokenizer.errors)
                    sample = sample.strip().lower().split(' ')
                    text = ''.join(target)
                    target = bytearray([
                        self._tokenizer.byte_decoder[c] for c in text
                    ]).decode('utf-8', errors=self._tokenizer.errors)
                    target = target.strip().lower().split(' ')

                sps = sent_tokenize(' '.join(sample))
                tgs = sent_tokenize(' '.join(target))
                if self._reward_fn is not None:
                    sp_reward_inputs.append(sps)

                sps = [sp.split(' ') for sp in sps]
                tgs = [tg.split(' ') for tg in tgs]
                #sp_score = compute_rouge_l_summ(sps, tgs)
                sp_score = (self._w8[2] * compute_rouge_l_summ(sps, tgs) + \
                            self._w8[0] * compute_rouge_n(list(concat(sps)), list(concat(tgs)), n=1) + \
                            self._w8[1]* compute_rouge_n(list(concat(sps)), list(concat(tgs)), n=2))
                sp_scores.append(sp_score)
            sp_scores = torch.tensor(sp_scores,
                                     dtype=torch.float32,
                                     device=greedy_attns[0].device)
            if self._reward_fn is not None:
                sp_reward_scores, bl_reward_scores = self._reward_fn.score_two_seqs(
                    questions, sp_reward_inputs, bl_reward_inputs)
                sp_reward_scores = torch.tensor(sp_reward_scores,
                                                dtype=torch.float32,
                                                device=greedy_attns[0].device)
                bl_reward_scores = torch.tensor(bl_reward_scores,
                                                dtype=torch.float32,
                                                device=greedy_attns[0].device)

            reward = sp_scores.view(-1, 1) - bl_scores.view(-1, 1)
            if self._reward_fn is not None:
                reward += self._reward_w8 * (sp_reward_scores.view(-1, 1) -
                                             bl_reward_scores.view(-1, 1))
            reward.requires_grad_(False)

            loss = reward.contiguous().detach() * loss_nll
            loss = loss.sum()
            full_length = _masks.data.float().sum()
            loss = loss / full_length
            if self._ml_loss:
                ml_loss = self._ml_criterion(ml_logit, targets)
                loss += self._ml_loss_w8 * ml_loss.mean()
            # if total_loss is None:
            #     total_loss = loss
            # else:
            #     total_loss += loss

            loss = loss / self._accumulate_g_step
            loss.backward()

        log_dict = {}
        if self._reward_fn is not None:
            log_dict['reward'] = bl_scores.mean().item()
            log_dict['question_reward'] = bl_reward_scores.mean().item()
            log_dict['sample_question_reward'] = sp_reward_scores.mean().item()
            log_dict['sample_reward'] = sp_scores.mean().item()
        else:
            log_dict['reward'] = bl_scores.mean().item()

        if self._grad_fn is not None:
            log_dict.update(self._grad_fn())

        self._opt.step()
        self._net.zero_grad()
        #torch.cuda.empty_cache()

        return log_dict