def compute_reward_old(pred_str,
                       pred_sent_list,
                       trg_str,
                       trg_sent_list,
                       reward_type,
                       regularization_factor=0.0,
                       regularization_type=0,
                       entropy=None):

    if regularization_type == 1:
        raise ValueError("Not implemented.")
    elif regularization_type == 2:
        regularization = entropy
    else:
        regularization = 0.0

    if reward_type == 0:
        tmp_reward = 0.3 * compute_rouge_n(
            pred_str, trg_str, n=1, mode='f') + 0.2 * compute_rouge_n(
                pred_str, trg_str, n=2, mode='f') + 0.5 * compute_rouge_l_summ(
                    pred_sent_list, trg_sent_list, mode='f')
    elif reward_type == 1:
        tmp_reward = compute_rouge_l_summ(pred_sent_list,
                                          trg_sent_list,
                                          mode='f')
    else:
        raise ValueError

    # Add the regularization term to the reward only if regularization type != 0
    if regularization_type == 0 or regularization_factor == 0:
        reward = tmp_reward
    else:
        reward = (1 - regularization_factor
                  ) * tmp_reward + regularization_factor * regularization
    return reward
def rouge_l_reward(pred_str_list, pred_sent_2d_list, trg_str_list,
                   trg_sent_2d_list, batch_size, device):
    #reward = np.zeros(batch_size)
    reward = []
    for idx, (pred_str, pred_sent_list, trg_str, trg_sent_list) in enumerate(
            zip(pred_str_list, pred_sent_2d_list, trg_str_list,
                trg_sent_2d_list)):
        #reward[idx] = compute_rouge_l_summ(pred_sent_list, trg_sent_list, mode='f')
        reward.append(
            compute_rouge_l_summ(pred_sent_list, trg_sent_list, mode='f'))
    return torch.FloatTensor(reward).to(device)  # tensor: [batch_size]
def rl_validate(net, val_batches, coherence_func=None, coh_coef = 0.01, local_coh_func=None, local_coh_coef=0.005):
    print('running validation ... ', end='')
    def argmax(arr, keys):
        return arr[max(range(len(arr)), key=lambda i: keys[i].item())]
    def sum_id2word(raw_article_sents, decs, attns):
        dec_sents = []
        for i, raw_words in enumerate(raw_article_sents):
            dec = []
            for id_, attn in zip(decs, attns):
                if id_[i] == END:
                    break
                elif id_[i] == UNK:
                    dec.append(argmax(raw_words, attn[i]))
                else:
                    dec.append(id2word[id_[i].item()])
            dec_sents.append(dec)
        return dec_sents
    net.eval()
    start = time()
    i = 0
    score = 0
    score_coh = 0
    score_local_coh = 0
    with torch.no_grad():
        for fw_args, bw_args in val_batches:
            raw_articles = bw_args[0]
            id2word = bw_args[1]
            raw_targets = bw_args[2]
            greedies, greedy_attns = net.greedy(*fw_args)
            greedy_sents = sum_id2word(raw_articles, greedies, greedy_attns)
            bl_scores = []
            if coherence_func is not None:
                bl_coh_scores = []
                bl_coh_inputs = []
            if local_coh_func is not None:
                bl_local_coh_scores = []
            for baseline, target in zip(greedy_sents, raw_targets):
                bss = sent_tokenize(' '.join(baseline))
                if coherence_func is not None:
                    bl_coh_inputs.append(bss)
                    # if len(bss) > 1:
                    #     input_args = (bss,) + coherence_func
                    #     coh_score = coherence_infer(*input_args) / 2
                    # else:
                    #     coh_score = 0
                    # bl_coh_scores.append(coh_score)
                if local_coh_func is not None:
                    local_coh_score = local_coh_func(bss)
                    bl_local_coh_scores.append(local_coh_score)
                bss = [bs.split(' ') for bs in bss]
                tgs = sent_tokenize(' '.join(target))
                tgs = [tg.split(' ') for tg in tgs]
                bl_score = compute_rouge_l_summ(bss, tgs)
                bl_scores.append(bl_score)
            # print('blscore:', bl_score)
            # print('baseline:', bss)
            # print('target:', tgs)

            bl_scores = torch.tensor(bl_scores, dtype=torch.float32, device=greedy_attns[0].device)
            if coherence_func is not None:
                input_args = (bl_coh_inputs,) + coherence_func
                bl_coh_scores = batch_global_infer(*input_args)
                bl_coh_scores = torch.tensor(bl_coh_scores, dtype=torch.float32, device=greedy_attns[0].device)
                score_coh += bl_coh_scores.mean().item() * 100 * coh_coef
            if local_coh_func is not None:
                bl_local_coh_scores = torch.tensor(bl_local_coh_scores, dtype=torch.float32, device=greedy_attns[0].device)
                score_local_coh += bl_local_coh_scores.mean().item() * local_coh_coef * 100
            reward = bl_scores.mean().item()
            i += 1
            score += reward * 100

    val_score = score / i
    if coherence_func is not None:
        val_coh_score = score_coh / i
    else:
        val_coh_score = 0
    val_local_coh_score = 0
    print(
        'validation finished in {}                                    '.format(
            timedelta(seconds=int(time()-start)))
    )
    print('validation reward: {:.4f} ... '.format(val_score))
    if coherence_func is not None:
        print('validation reward: {:.4f} ... '.format(val_coh_score))
    if local_coh_func is not None:
        val_local_coh_score = score_local_coh / i
        print('validation {} reward: {:.4f} ... '.format(local_coh_func.__name__, val_local_coh_score))
    print('n_data:', i)
    return {'score': val_score + val_coh_score + val_local_coh_score}
    def train_step(self, sample_time=1):
        def argmax(arr, keys):
            return arr[max(range(len(arr)), key=lambda i: keys[i].item())]
        def sum_id2word(raw_article_sents, decs, attns, id2word):
            dec_sents = []
            for i, raw_words in enumerate(raw_article_sents):
                dec = []
                for id_, attn in zip(decs, attns):
                    if id_[i] == END:
                        break
                    elif id_[i] == UNK:
                        dec.append(argmax(raw_words, attn[i]))
                    else:
                        dec.append(id2word[id_[i].item()])
                dec_sents.append(dec)
            return dec_sents
        def pack_seq(seq_list):
            return torch.cat([_.unsqueeze(1) for _ in seq_list], 1)
        # forward pass of model
        self._net.train()
        #self._net.zero_grad()
        fw_args, bw_args = next(self._batches)
        raw_articles = bw_args[0]
        id2word = bw_args[1]
        raw_targets = bw_args[2]
        with torch.no_grad():
            greedies, greedy_attns = self._net.greedy(*fw_args)
        greedy_sents = sum_id2word(raw_articles, greedies, greedy_attns, id2word)
        bl_scores = []
        if self._coh_fn is not None:
            bl_coh_scores = []
            bl_coh_inputs = []
        if self._local_coh_fun is not None:
            bl_local_coh_scores = []
        for baseline, target in zip(greedy_sents, raw_targets):
            bss = sent_tokenize(' '.join(baseline))
            tgs = sent_tokenize(' '.join(target))
            if self._coh_fn is not None:
                bl_coh_inputs.append(bss)
                # if len(bss) > 1:
                #     input_args = (bss, ) + self._coh_fn
                #     coh_score = coherence_infer(*input_args) / 2
                # else:
                #     coh_score = 0
                # bl_coh_scores.append(coh_score)
            if self._local_coh_fun is not None:
                local_coh_score = self._local_coh_fun(bss)
                bl_local_coh_scores.append(local_coh_score)
            bss = [bs.split(' ') for bs in bss]
            tgs = [tg.split(' ') for tg in tgs]
            #bl_score = compute_rouge_l_summ(bss, tgs)
            bl_score = (self._weights[2] * compute_rouge_l_summ(bss, tgs) + \
                        self._weights[0] * compute_rouge_n(list(concat(bss)), list(concat(tgs)), n=1) + \
                        self._weights[1] * compute_rouge_n(list(concat(bss)), list(concat(tgs)), n=2))
            bl_scores.append(bl_score)
        bl_scores = torch.tensor(bl_scores, dtype=torch.float32, device=greedy_attns[0].device)
        if self._coh_fn is not None:
            input_args = (bl_coh_inputs,) + self._coh_fn
            bl_coh_scores = batch_global_infer(*input_args)
            bl_coh_scores = torch.tensor(bl_coh_scores, dtype=torch.float32, device=greedy_attns[0].device)
        if self._local_coh_fun is not None:
            bl_local_coh_scores = torch.tensor(bl_local_coh_scores, dtype=torch.float32, device=greedy_attns[0].device)

        # print('bl:', bl_scores)
        # print('bl_coh:', bl_coh_scores)

        for _ in range(sample_time):
            samples, sample_attns, seqLogProbs = self._net.sample(*fw_args)
            sample_sents = sum_id2word(raw_articles, samples, sample_attns, id2word)
            sp_seqs = pack_seq(samples)
            _masks = (sp_seqs > PAD).float()
            sp_seqLogProb = pack_seq(seqLogProbs)
            #loss_nll = - sp_seqLogProb.squeeze(2)
            loss_nll = - sp_seqLogProb.squeeze(2) * _masks.detach().type_as(sp_seqLogProb)
            sp_scores = []
            if self._coh_fn is not None:
                sp_coh_scores = []
                sp_coh_inputs = []
            if self._local_coh_fun is not None:
                sp_local_coh_scores = []
            for sample, target in zip(sample_sents, raw_targets):
                sps = sent_tokenize(' '.join(sample))
                tgs = sent_tokenize(' '.join(target))
                if self._coh_fn is not None:
                    sp_coh_inputs.append(sps)
                    # if len(sps) > 1:
                    #     input_args = (sps,) + self._coh_fn
                    #     coh_score = coherence_infer(*input_args) / 2
                    # else:
                    #     coh_score = 0
                    # sp_coh_scores.append(coh_score)
                if self._local_coh_fun is not None:
                    local_coh_score = self._local_coh_fun(sps)
                    sp_local_coh_scores.append(local_coh_score)
                sps = [sp.split(' ') for sp in sps]
                tgs = [tg.split(' ') for tg in tgs]
                #sp_score = compute_rouge_l_summ(sps, tgs)
                sp_score = (self._weights[2] * compute_rouge_l_summ(sps, tgs) + \
                            self._weights[0] * compute_rouge_n(list(concat(sps)), list(concat(tgs)), n=1) + \
                            self._weights[1] * compute_rouge_n(list(concat(sps)), list(concat(tgs)), n=2))
                sp_scores.append(sp_score)
            sp_scores = torch.tensor(sp_scores, dtype=torch.float32, device=greedy_attns[0].device)
            reward = sp_scores.view(-1, 1) - bl_scores.view(-1, 1)
            reward.requires_grad_(False)
            if self._coh_fn is not None:
                input_args = (sp_coh_inputs,) + self._coh_fn
                sp_coh_scores = batch_global_infer(*input_args)
                sp_coh_scores = torch.tensor(sp_coh_scores, dtype=torch.float32, device=greedy_attns[0].device)
                reward_coh = sp_coh_scores.view(-1, 1) - bl_coh_scores.view(-1, 1)
                reward_coh.requires_grad_(False)
                reward = reward + self._coh_cof * reward_coh
            if self._local_coh_fun is not None:
                sp_local_coh_scores = torch.tensor(sp_local_coh_scores, dtype=torch.float32, device=greedy_attns[0].device)
                reward_local = sp_local_coh_scores.view(-1, 1) - bl_local_coh_scores.view(-1, 1)
                reward_local.requires_grad_(False)
                reward = reward + self._local_co_coef * reward_local

            if _ == 0:
                loss = reward.contiguous().detach() * loss_nll
                loss = loss.sum()
                full_length = _masks.data.float().sum()
            else:
                loss += (reward.contiguous().detach() * loss_nll).sum()
                full_length += _masks.data.float().sum()

        # print('sp:', sp_scores)
        # print('sp_coh:', sp_coh_scores)
        loss = loss / full_length
        # backward and update ( and optional gradient monitoring )
        loss.backward()
        log_dict = {}
        if self._coh_fn is not None:
            log_dict['reward'] = bl_scores.mean().item() + self._coh_cof * bl_coh_scores.mean().item()
        else:
            log_dict['reward'] = bl_scores.mean().item()
        if self._grad_fn is not None:
            log_dict.update(self._grad_fn())
        self._opt.step()
        self._net.zero_grad()
        torch.cuda.empty_cache()

        return log_dict
Beispiel #5
0
    def train_step(self, sample_time=1):
        torch.autograd.set_detect_anomaly(True)

        def argmax(arr, keys):
            return arr[max(range(len(arr)), key=lambda i: keys[i].item())]

        def sum_id2word(raw_article_sents, decs, attns, id2word):
            if self._bert:
                dec_sents = []
                for i, raw_words in enumerate(raw_article_sents):
                    dec = []
                    for id_, attn in zip(decs, attns):
                        if id_[i] == self._end:
                            break
                        elif id_[i] == self._unk:
                            dec.append(argmax(raw_words, attn[i]))
                        else:
                            dec.append(id2word[id_[i].item()])
                    dec_sents.append(dec)
            else:
                dec_sents = []
                for i, raw_words in enumerate(raw_article_sents):
                    dec = []
                    for id_, attn in zip(decs, attns):
                        if id_[i] == self._end:
                            break
                        elif id_[i] == self._unk:
                            dec.append(argmax(raw_words, attn[i]))
                        else:
                            dec.append(id2word[id_[i].item()])
                    dec_sents.append(dec)
            return dec_sents

        def pack_seq(seq_list):
            return torch.cat([_.unsqueeze(1) for _ in seq_list], 1)

        # forward pass of model
        self._net.train()
        #self._net.zero_grad()
        total_loss = None
        for i in range(self._accumulate_g_step):
            fw_args, bw_args = next(self._batches)
            raw_articles = bw_args[0]
            id2word = bw_args[1]
            raw_targets = bw_args[2]
            if self._reward_fn is not None:
                questions = bw_args[3]
            targets = bw_args[4]

            # encode
            # attention, init_dec_states, nodes = self._net.encode_general(*fw_args)
            #fw_args += (attention, init_dec_states, nodes)
            # _init_dec_states = ((init_dec_states[0][0].clone(), init_dec_states[0][1].clone()), init_dec_states[1].clone())
            with torch.no_grad():
                # g_fw_args = fw_args + (attention, _init_dec_states, nodes, False)
                # greedies, greedy_attns = self._net.rl_step(*g_fw_args)
                greedies, greedy_attns = self._net.greedy(*fw_args)
            greedy_sents = sum_id2word(raw_articles, greedies, greedy_attns,
                                       id2word)
            bl_scores = []
            if self._reward_fn is not None:
                bl_reward_scores = []
                bl_reward_inputs = []
            if self._local_coh_fun is not None:
                bl_local_coh_scores = []
            for baseline, target in zip(greedy_sents, raw_targets):
                if self._bert:
                    text = ''.join(baseline)
                    baseline = bytearray([
                        self._tokenizer.byte_decoder[c] for c in text
                    ]).decode('utf-8', errors=self._tokenizer.errors)
                    baseline = baseline.strip().lower().split(' ')
                    text = ''.join(target)
                    target = bytearray([
                        self._tokenizer.byte_decoder[c] for c in text
                    ]).decode('utf-8', errors=self._tokenizer.errors)
                    target = target.strip().lower().split(' ')

                bss = sent_tokenize(' '.join(baseline))
                tgs = sent_tokenize(' '.join(target))
                if self._reward_fn is not None:
                    bl_reward_inputs.append(bss)
                if self._local_coh_fun is not None:
                    local_coh_score = self._local_coh_fun(bss)
                    bl_local_coh_scores.append(local_coh_score)
                bss = [bs.split(' ') for bs in bss]
                tgs = [tg.split(' ') for tg in tgs]

                #bl_score = compute_rouge_l_summ(bss, tgs)
                bl_score = (self._w8[2] * compute_rouge_l_summ(bss, tgs) + \
                           self._w8[0] * compute_rouge_n(list(concat(bss)), list(concat(tgs)), n=1) + \
                            self._w8[1] * compute_rouge_n(list(concat(bss)), list(concat(tgs)), n=2))
                bl_scores.append(bl_score)
            bl_scores = torch.tensor(bl_scores,
                                     dtype=torch.float32,
                                     device=greedy_attns[0].device)

            # sample
            # s_fw_args = fw_args + (attention, init_dec_states, nodes, True)
            # samples, sample_attns, seqLogProbs = self._net.rl_step(*s_fw_args)
            fw_args += (self._ml_loss, )
            if self._ml_loss:
                samples, sample_attns, seqLogProbs, ml_logit = self._net.sample(
                    *fw_args)
            else:
                samples, sample_attns, seqLogProbs = self._net.sample(*fw_args)
            sample_sents = sum_id2word(raw_articles, samples, sample_attns,
                                       id2word)
            sp_seqs = pack_seq(samples)
            _masks = (sp_seqs > PAD).float()
            sp_seqLogProb = pack_seq(seqLogProbs)
            #loss_nll = - sp_seqLogProb.squeeze(2)
            loss_nll = -sp_seqLogProb.squeeze(2) * _masks.detach().type_as(
                sp_seqLogProb)
            sp_scores = []
            if self._reward_fn is not None:
                sp_reward_inputs = []
            for sample, target in zip(sample_sents, raw_targets):
                if self._bert:
                    text = ''.join(sample)
                    sample = bytearray([
                        self._tokenizer.byte_decoder[c] for c in text
                    ]).decode('utf-8', errors=self._tokenizer.errors)
                    sample = sample.strip().lower().split(' ')
                    text = ''.join(target)
                    target = bytearray([
                        self._tokenizer.byte_decoder[c] for c in text
                    ]).decode('utf-8', errors=self._tokenizer.errors)
                    target = target.strip().lower().split(' ')

                sps = sent_tokenize(' '.join(sample))
                tgs = sent_tokenize(' '.join(target))
                if self._reward_fn is not None:
                    sp_reward_inputs.append(sps)

                sps = [sp.split(' ') for sp in sps]
                tgs = [tg.split(' ') for tg in tgs]
                #sp_score = compute_rouge_l_summ(sps, tgs)
                sp_score = (self._w8[2] * compute_rouge_l_summ(sps, tgs) + \
                            self._w8[0] * compute_rouge_n(list(concat(sps)), list(concat(tgs)), n=1) + \
                            self._w8[1]* compute_rouge_n(list(concat(sps)), list(concat(tgs)), n=2))
                sp_scores.append(sp_score)
            sp_scores = torch.tensor(sp_scores,
                                     dtype=torch.float32,
                                     device=greedy_attns[0].device)
            if self._reward_fn is not None:
                sp_reward_scores, bl_reward_scores = self._reward_fn.score_two_seqs(
                    questions, sp_reward_inputs, bl_reward_inputs)
                sp_reward_scores = torch.tensor(sp_reward_scores,
                                                dtype=torch.float32,
                                                device=greedy_attns[0].device)
                bl_reward_scores = torch.tensor(bl_reward_scores,
                                                dtype=torch.float32,
                                                device=greedy_attns[0].device)

            reward = sp_scores.view(-1, 1) - bl_scores.view(-1, 1)
            if self._reward_fn is not None:
                reward += self._reward_w8 * (sp_reward_scores.view(-1, 1) -
                                             bl_reward_scores.view(-1, 1))
            reward.requires_grad_(False)

            loss = reward.contiguous().detach() * loss_nll
            loss = loss.sum()
            full_length = _masks.data.float().sum()
            loss = loss / full_length
            if self._ml_loss:
                ml_loss = self._ml_criterion(ml_logit, targets)
                loss += self._ml_loss_w8 * ml_loss.mean()
            # if total_loss is None:
            #     total_loss = loss
            # else:
            #     total_loss += loss

            loss = loss / self._accumulate_g_step
            loss.backward()

        log_dict = {}
        if self._reward_fn is not None:
            log_dict['reward'] = bl_scores.mean().item()
            log_dict['question_reward'] = bl_reward_scores.mean().item()
            log_dict['sample_question_reward'] = sp_reward_scores.mean().item()
            log_dict['sample_reward'] = sp_scores.mean().item()
        else:
            log_dict['reward'] = bl_scores.mean().item()

        if self._grad_fn is not None:
            log_dict.update(self._grad_fn())

        self._opt.step()
        self._net.zero_grad()
        #torch.cuda.empty_cache()

        return log_dict
    def coll(tokenizer, batch):
        blank = '[unused0]'
        questions, context, _ids, abstract = list(
            filter(bool, list(zip(*batch))))
        system = context
        # print('q:', questions)
        # print('c:', context)
        if len(abstract) > 0:
            rouges = {
                'RLr':
                compute_rouge_l_summ(
                    [_c.lower().split(' ') for _c in context[0]],
                    abstract[0],
                    mode='r'),
                'RLf':
                compute_rouge_l_summ(
                    [_c.lower().split(' ') for _c in context[0]],
                    abstract[0],
                    mode='f'),
                'R1':
                compute_rouge_n(' '.join(context[0]).split(' '),
                                list(concat(abstract[0])),
                                mode='f',
                                n=1),
                'R2':
                compute_rouge_n(' '.join(context[0]).split(' '),
                                list(concat(abstract[0])),
                                mode='f',
                                n=2)
            }
        if len(questions[0]) != 0:
            questions = questions[0]
            context = context[0]
            context = ' '.join(context)

            choicess = [[
                question['answer'], question['choice1'], question['choice2'],
                question['choice3']
            ] for question in questions]
            questions = [
                question['question'].replace('<\\blank>', blank)
                for question in questions
            ]
            questions = [[
                tokenizer.tokenize(qp.lower()) for qp in question.split(blank)
            ] for question in questions]
            new_questions = []
            for question in questions:
                new_q = ['[CLS]']
                for q in question:
                    new_q += q + [blank]
                new_q.pop()
                new_questions.append(new_q)
            questions = new_questions
            contexts = [['[SEP]'] + tokenizer.tokenize(context.lower())
                        for _ in range(len(questions))]
            choicess = [[[tokenizer.tokenize(c.lower()) for c in choice]
                         for choice in choices] for choices in choicess]
            choicess = [[['[SEP]'] + choice[0] + ['[SEP]'] +
                         choice[1] if len(choice) == 2 else ['[SEP]'] +
                         choice[0] for choice in choices]
                        for choices in choicess]
            _inputs = [[
                tokenizer.convert_tokens_to_ids(
                    (question + context + choice)[:MAX_LEN])
                for choice in choices
            ]
                       for question, context, choices in zip(
                           questions, contexts, choicess)]
            _inputs = pad_batch_tensorize_3d(_inputs, pad=0, cuda=False)
        else:
            _inputs = []

        return (_inputs, rouges, system, abstract)
Beispiel #7
0
def rl_validate(net,
                val_batches,
                reward_func=None,
                reward_coef=0.01,
                local_coh_func=None,
                local_coh_coef=0.005,
                bert=False):
    print('running validation ... ', end='')
    if bert:
        tokenizer = net._bert_model._tokenizer
        end = tokenizer.encoder[tokenizer._eos_token]
        unk = tokenizer.encoder[tokenizer._unk_token]

    def argmax(arr, keys):
        return arr[max(range(len(arr)), key=lambda i: keys[i].item())]

    def sum_id2word(raw_article_sents, decs, attns):
        if bert:
            dec_sents = []
            for i, raw_words in enumerate(raw_article_sents):
                dec = []
                for id_, attn in zip(decs, attns):
                    if id_[i] == end:
                        break
                    elif id_[i] == unk:
                        dec.append(argmax(raw_words, attn[i]))
                    else:
                        dec.append(id2word[id_[i].item()])
                dec_sents.append(dec)
        else:
            dec_sents = []
            for i, raw_words in enumerate(raw_article_sents):
                dec = []
                for id_, attn in zip(decs, attns):
                    if id_[i] == END:
                        break
                    elif id_[i] == UNK:
                        dec.append(argmax(raw_words, attn[i]))
                    else:
                        dec.append(id2word[id_[i].item()])
                dec_sents.append(dec)
        return dec_sents

    net.eval()
    start = time()
    i = 0
    score = 0
    score_reward = 0
    score_local_coh = 0
    score_r2 = 0
    score_r1 = 0
    bl_r2 = []
    bl_r1 = []
    with torch.no_grad():
        for fw_args, bw_args in val_batches:
            raw_articles = bw_args[0]
            id2word = bw_args[1]
            raw_targets = bw_args[2]
            if reward_func is not None:
                questions = bw_args[3]
            greedies, greedy_attns = net.greedy(*fw_args)
            greedy_sents = sum_id2word(raw_articles, greedies, greedy_attns)
            bl_scores = []
            if reward_func is not None:
                bl_coh_inputs = []
            if local_coh_func is not None:
                bl_local_coh_scores = []
            for baseline, target in zip(greedy_sents, raw_targets):
                if bert:
                    text = ''.join(baseline)
                    baseline = bytearray([
                        tokenizer.byte_decoder[c] for c in text
                    ]).decode('utf-8', errors=tokenizer.errors)
                    baseline = baseline.split(' ')
                    text = ''.join(target)
                    target = bytearray([
                        tokenizer.byte_decoder[c] for c in text
                    ]).decode('utf-8', errors=tokenizer.errors)
                    target = target.split(' ')

                bss = sent_tokenize(' '.join(baseline))
                if reward_func is not None:
                    bl_coh_inputs.append(bss)

                bss = [bs.split(' ') for bs in bss]
                tgs = sent_tokenize(' '.join(target))
                tgs = [tg.split(' ') for tg in tgs]
                bl_score = compute_rouge_l_summ(bss, tgs)
                bl_r1.append(
                    compute_rouge_n(list(concat(bss)), list(concat(tgs)), n=1))
                bl_r2.append(
                    compute_rouge_n(list(concat(bss)), list(concat(tgs)), n=2))
                bl_scores.append(bl_score)
            bl_scores = torch.tensor(bl_scores,
                                     dtype=torch.float32,
                                     device=greedy_attns[0].device)

            if reward_func is not None:
                bl_reward_scores = reward_func.score(questions, bl_coh_inputs)
                bl_reward_scores = torch.tensor(bl_reward_scores,
                                                dtype=torch.float32,
                                                device=greedy_attns[0].device)
                score_reward += bl_reward_scores.mean().item() * 100

            reward = bl_scores.mean().item()
            i += 1
            score += reward * 100
            score_r2 += torch.tensor(
                bl_r2, dtype=torch.float32,
                device=greedy_attns[0].device).mean().item() * 100
            score_r1 += torch.tensor(
                bl_r1, dtype=torch.float32,
                device=greedy_attns[0].device).mean().item() * 100

    val_score = score / i
    score_r2 = score_r2 / i
    score_r1 = score_r1 / i
    if reward_func is not None:
        val_reward_score = score_reward / i
    else:
        val_reward_score = 0
    val_local_coh_score = 0
    print(
        'validation finished in {}                                    '.format(
            timedelta(seconds=int(time() - start))))
    print('validation reward: {:.4f} ... '.format(val_score))
    print('validation r2: {:.4f} ... '.format(score_r2))
    print('validation r1: {:.4f} ... '.format(score_r1))
    if reward_func is not None:
        print('validation reward: {:.4f} ... '.format(val_reward_score))
    if local_coh_func is not None:
        val_local_coh_score = score_local_coh / i
        print('validation {} reward: {:.4f} ... '.format(
            local_coh_func.__name__, val_local_coh_score))
    print('n_data:', i)
    return {'score': val_score, 'score_reward:': val_reward_score}
Beispiel #8
0
    def train_step(self, sample_time=1):
        def argmax(arr, keys):
            return arr[max(range(len(arr)), key=lambda i: keys[i].item())]

        def sum_id2word(raw_article_sents, decs, attns, id2word):
            dec_sents = []
            for i, raw_words in enumerate(raw_article_sents):
                dec = []
                for id_, attn in zip(decs, attns):
                    if id_[i] == END:
                        break
                    elif id_[i] == UNK:
                        dec.append(argmax(raw_words, attn[i]))
                    else:
                        dec.append(id2word[id_[i].item()])
                dec_sents.append(dec)
            return dec_sents

        def pack_seq(seq_list):
            return torch.cat([_.unsqueeze(1) for _ in seq_list], 1)

        # forward pass of model
        self._net.train()
        #self._net.zero_grad()
        fw_args, bw_args = next(self._batches)
        raw_articles = bw_args[0]
        id2word = bw_args[1]
        raw_targets = bw_args[2]
        with torch.no_grad():
            greedies, greedy_attns = self._net.greedy(*fw_args)
        greedy_sents = sum_id2word(raw_articles, greedies, greedy_attns,
                                   id2word)
        bl_scores = []

        for baseline, target in zip(greedy_sents, raw_targets):
            bss = sent_tokenize(' '.join(baseline))
            tgs = sent_tokenize(' '.join(target))
            bss = [bs.split(' ') for bs in bss]
            tgs = [tg.split(' ') for tg in tgs]
            bss_bleu = list(concat(bss))
            tgs_bleu = list(concat(tgs))
            bss_bleu = ' '.join(bss_bleu)
            tgs_bleu = ' '.join(tgs_bleu)
            #bl_score = compute_rouge_l_summ(bss, tgs)
            if self._bleu:
                bleu_scores = bleu(bss_bleu, tgs_bleu)
                bleu_score = (bleu_scores[0] + bleu_scores[1] +
                              bleu_scores[2] + bleu_scores[3])
                bl_score = bleu_score
            elif self.f1:
                bl_score = compute_f1(bss_bleu, tgs_bleu)
            else:
                bl_score = (self._w8[2] * compute_rouge_l_summ(bss, tgs) + \
                       self._w8[0] * compute_rouge_n(list(concat(bss)), list(concat(tgs)), n=1) + \
                        self._w8[1] * compute_rouge_n(list(concat(bss)), list(concat(tgs)), n=2))
            bl_scores.append(bl_score)
        bl_scores = torch.tensor(bl_scores,
                                 dtype=torch.float32,
                                 device=greedy_attns[0].device)

        samples, sample_attns, seqLogProbs = self._net.sample(*fw_args)
        sample_sents = sum_id2word(raw_articles, samples, sample_attns,
                                   id2word)
        sp_seqs = pack_seq(samples)
        _masks = (sp_seqs > PAD).float()
        sp_seqLogProb = pack_seq(seqLogProbs)
        #loss_nll = - sp_seqLogProb.squeeze(2)
        loss_nll = -sp_seqLogProb.squeeze(2) * _masks.detach().type_as(
            sp_seqLogProb)
        sp_scores = []

        for sample, target in zip(sample_sents, raw_targets):
            sps = sent_tokenize(' '.join(sample))
            tgs = sent_tokenize(' '.join(target))
            sps = [sp.split(' ') for sp in sps]
            tgs = [tg.split(' ') for tg in tgs]
            #sp_score = compute_rouge_l_summ(sps, tgs)
            sps_bleu = list(concat(sps))
            tgs_bleu = list(concat(tgs))
            sps_bleu = ' '.join(sps_bleu)
            tgs_bleu = ' '.join(tgs_bleu)
            # bl_score = compute_rouge_l_summ(bss, tgs)
            if self._bleu:
                bleu_scores = bleu(sps_bleu, tgs_bleu)
                bleu_score = (bleu_scores[0] + bleu_scores[1] +
                              bleu_scores[2] + bleu_scores[3])
                sp_score = bleu_score
            elif self.f1:
                sp_score = compute_f1(sps_bleu, tgs_bleu)
            else:
                sp_score = (self._w8[2] * compute_rouge_l_summ(sps, tgs) + \
                        self._w8[0] * compute_rouge_n(list(concat(sps)), list(concat(tgs)), n=1) + \
                        self._w8[1]* compute_rouge_n(list(concat(sps)), list(concat(tgs)), n=2))
            sp_scores.append(sp_score)
        sp_scores = torch.tensor(sp_scores,
                                 dtype=torch.float32,
                                 device=greedy_attns[0].device)

        reward = sp_scores.view(-1, 1) - bl_scores.view(-1, 1)

        reward.requires_grad_(False)
        loss = reward.contiguous().detach() * loss_nll
        loss = loss.sum()
        full_length = _masks.data.float().sum()
        loss = loss / full_length

        loss.backward()

        log_dict = {}

        log_dict['reward'] = bl_scores.mean().item()

        if self._grad_fn is not None:
            log_dict.update(self._grad_fn())
        self._opt.step()
        self._net.zero_grad()
        #torch.cuda.empty_cache()

        return log_dict
Beispiel #9
0
 def multi_reward(x, y):
     return (compute_rouge_l_summ(x, y) +
             compute_rouge_n(list_cat(x), list_cat(y), n=2) +
             1 / 3 * compute_rouge_n(list_cat(x), list_cat(y), n=1))
Beispiel #10
0
def sc_validate(agent, extractor, loader):
    agent.eval()
    start = time()
    print('start running validation...')
    ave_reward = 0
    ave_reward_1 = 0
    ave_reward_2 = 0
    batch_num = 0
    for art_batch, abs_batch, ext_batch in loader:
        ext_sents = []
        sum_sents = []
        indices = []
        for raw_arts in art_batch:
            inds = extractor(raw_arts)
            inds = [i.item() for i in inds]
            indices.append(inds)
            ext_sents += [raw_arts[idx] for idx in inds if idx < len(raw_arts)]
            sum_sents += [
                raw_arts[idx - len(raw_arts)] for idx in inds
                if len(raw_arts) <= idx < len(raw_arts) * 2
            ]
        #--------batch decode--------
        if sum_sents:
            dec_greedy = agent(sum_sents)
            count_ext = 0
            count_sum = 0
            dec_greedy_mix = []
            for indice in indices:
                max_step = indice[-1] / 2
                for ind in indice:
                    if ind < max_step:
                        dec_greedy_mix.append(ext_sents[count_ext])
                        count_ext += 1
                    elif max_step <= ind < max_step * 2:
                        dec_greedy_mix.append(dec_greedy[count_sum])
                        count_sum += 1
        else:
            dec_greedy_mix = ext_sents
        i = 0
        for ex, abss in zip(indices, abs_batch):
            ex = ex[:-1]
            ave_reward += compute_rouge_l_summ(dec_greedy_mix[i:i + len(ex)],
                                               abss)
            ave_reward_1 += compute_rouge_n(list_cat(dec_greedy_mix[i:i +
                                                                    len(ex)]),
                                            list(concat(abss)),
                                            n=1)
            ave_reward_2 += compute_rouge_n(list_cat(dec_greedy_mix[i:i +
                                                                    len(ex)]),
                                            list(concat(abss)),
                                            n=2)
            i += len(ex)
            batch_num += 1
        assert i == len(dec_greedy_mix)
    ave_reward /= (batch_num / 100)
    ave_reward_1 /= (batch_num / 100)
    ave_reward_2 /= (batch_num / 100)
    print('finished in {}! avg reward: {:.2f} rouge-1: {:.2f} rouge-2: {:.2f}'.
          format(timedelta(seconds=int(time() - start)), ave_reward,
                 ave_reward_1, ave_reward_2))
    return {'reward': ave_reward}
Beispiel #11
0
def get_extract_label(art_sents, abs_sents):
    """ greedily match summary sentences to article sentences"""
    extracted = []
    scores = []
    if ''.join(art_sents[0]) != '<E>':
        art_sents.insert(0, '<E>')
    indices = list(range(len(art_sents)))
    new_abs_sents = []
    for j in range(len(abs_sents)):
        rouges = list(
            map(compute_rouge_l(reference=abs_sents[j], mode='r'),
                art_sents[1:]))
        rouges.insert(0, 0)
        ext = max(indices, key=lambda i: rouges[i])
        max_scores = rouges[ext]
        max_exts = collections.Counter(rouges)[max_scores]
        if max_exts != 1:
            max_inds = []
            rouge_f = []
            for idx, score in enumerate(rouges):
                if idx in indices:
                    if score == max_scores:
                        max_inds.append(idx)
                        rouge_f.append(
                            compute_rouge_l_summ(art_sents[idx],
                                                 abs_sents[j],
                                                 mode='f'))
            maxrouge = max(list(range(len(max_inds))),
                           key=lambda i: rouge_f[i])
            ext = max_inds[maxrouge]
        if ext == 0:
            ext = 1
        new_art_sents = []
        new_art_sents.append('<E>')
        for i in range(1, len(art_sents)):
            #print(art_sents[i])
            if i < ext:
                new_art_sents.append(art_sents[i] + art_sents[ext])
            elif i > ext:
                new_art_sents.append(art_sents[ext] + art_sents[i])
            else:
                new_art_sents.append(art_sents[ext])
        new_rouges = list(
            map(compute_rouge_l_summ(refs=abs_sents[j], mode='fr'),
                new_art_sents[1:]))
        new_rouges_f = [fr[0] for fr in new_rouges]
        new_rouges_r = [fr[1] for fr in new_rouges]
        new_rouges_f.insert(0, 0)
        new_rouges_r.insert(0, 0)
        new_ext = max(indices, key=lambda i: new_rouges_f[i])
        if new_ext == 0:
            new_ext = 1
        #if ext == new_ext or rouges[ext] >= new_rouges[new_ext]:
        if ext == new_ext or rouges[ext] >= new_rouges_r[new_ext]:
            extracted.append(ext)
            extracted.append(0)
            scores.append(new_rouges_f[ext])
        elif ext < new_ext:
            extracted.append(ext)
            extracted.append(new_ext)
            extracted.append(0)
            scores.append(new_rouges_f[new_ext])
        else:
            extracted.append(new_ext)
            extracted.append(ext)
            extracted.append(0)
            scores.append(new_rouges_f[new_ext])

        #reduce duplication: ab->A bc->B, abc->AB
        new_abs_sents.append(abs_sents[j])
        index = findindex(extracted, 0)
        #dic = collections.Counter(extracted)
        while (len(index) >= 2):
            #print('in')
            if len(index) == 2:
                overlap = list(
                    set(extracted[:index[-2]])
                    & set(extracted[index[-2] + 1:index[-1]]))
                l = len(overlap)
                if l > 0:
                    new = list(
                        set(extracted[:index[-2]]).union(
                            set(extracted[index[-2] + 1:index[-1]])))
                    new.sort()
                    del extracted[:index[-1] + 1]
                    extracted = extracted + new
                    extracted.append(0)
                    new_sent = new_abs_sents[-2] + new_abs_sents[-1]
                    del new_abs_sents[-2:]
                    new_abs_sents.append(new_sent)
                    index = findindex(extracted, 0)
                else:
                    break
            else:
                overlap = list(
                    set(extracted[index[-3] + 1:index[-2]])
                    & set(extracted[index[-2] + 1:index[-1]]))
                l = len(overlap)
                if l > 0:
                    new = list(
                        set(extracted[index[-3] + 1:index[-2]]).union(
                            set(extracted[index[-2] + 1:index[-1]])))
                    new.sort()
                    del extracted[index[-3] + 1:index[-1] + 1]
                    extracted = extracted + new
                    extracted.append(0)
                    new_sent = new_abs_sents[-2] + new_abs_sents[-1]
                    del new_abs_sents[-2:]
                    new_abs_sents.append(new_sent)
                    index = findindex(extracted, 0)
                else:
                    break
        if len(index) >= 2:
            if len(index) == 2:
                for idx in extracted[:index[-2]]:
                    try:
                        indices.remove(idx)
                    except:
                        continue
            if len(index) > 2:
                for idx in extracted[index[-3]:index[-2]]:
                    try:
                        indices.remove(idx)
                    except:
                        continue
        if not indices:
            break
    length = len(new_abs_sents)
    for i in range(length):
        new_abs_sents.insert(i + (i + 1), ['<E>'])
    return extracted, scores, new_abs_sents, art_sents
Beispiel #12
0
def a2c_train_step(agent,
                   abstractor,
                   loader,
                   opt,
                   grad_fn,
                   gamma=0.99,
                   reward_fn=compute_rouge_l,
                   stop_reward_fn=compute_rouge_n(n=1),
                   stop_coeff=1.0,
                   step=0,
                   reward_plan=1):
    opt.zero_grad()
    indices = []
    probs = []
    baselines = []
    ext_sents = []
    sum_sents = []
    art_batch, abs_batch = next(loader)
    for raw_arts in art_batch:
        (inds, ms), bs = agent(raw_arts, step=step)
        baselines.append(bs)
        indices.append(inds)
        probs.append(ms)
        ext_sents += [
            raw_arts[idx.item()] for idx in inds if idx.item() < len(raw_arts)
        ]
        sum_sents += [
            raw_arts[idx.item() - len(raw_arts)] for idx in inds
            if len(raw_arts) <= idx.item() < len(raw_arts) * 2
        ]
    if sum_sents:
        with torch.no_grad():
            sum_sents = abstractor(sum_sents)
        # mix the exts and sums
        count_ext = 0
        count_sum = 0
        summaries = []
        for indice in indices:
            max_step = indice[-1].item() / 2
            for ind in indice[:-1]:
                if ind.item() < max_step:
                    summaries.append(ext_sents[count_ext])
                    count_ext += 1
                elif max_step <= ind.item() < max_step * 2:
                    summaries.append(sum_sents[count_sum])
                    count_sum += 1
    else:
        summaries = ext_sents
    assert len(summaries) == len(ext_sents) + len(sum_sents)
    copy_rate = len(ext_sents) / len(summaries)

    i = 0
    rewards = []
    avg_reward = 0
    for inds, abss in zip(indices, abs_batch):
        # plain
        if reward_plan == 0:
            reward_save = [
                compute_rouge_l_summ(summaries[i:i + j + 1], abss)
                for j in range(min(len(inds) - 1, len(abss)))
            ]
            rs = ([(reward_save[j] -
                    reward_save[j - 1]) if j > 0 else reward_save[j]
                   for j in range(min(len(inds) - 1, len(abss)))] +
                  [0 for _ in range(max(0,
                                        len(inds) - 1 - len(abss)))] +
                  [
                      stop_coeff * stop_reward_fn(
                          list(concat(summaries[i:i + len(inds) - 1])),
                          list(concat(abss)))
                  ])
        # margin
        elif reward_plan == 1:
            reward_save = [
                compute_rouge_l_summ(summaries[i:i + j + 1], abss)
                for j in range(len(inds) - 1)
            ]
            rs = ([(reward_save[j] -
                    reward_save[j - 1]) if j > 0 else reward_save[j]
                   for j in range(len(inds) - 1)] + [
                       stop_coeff * stop_reward_fn(
                           list(concat(summaries[i:i + len(inds) - 1])),
                           list(concat(abss)))
                   ])
        assert len(rs) == len(inds)
        avg_reward += rs[-1] / stop_coeff
        i += len(inds) - 1
        # compute discounted rewards
        R = 0
        disc_rs = []
        for r in rs[::-1]:
            R = r + gamma * R
            disc_rs.insert(0, R)
        rewards += disc_rs
    indices = list(concat(indices))
    probs = list(concat(probs))
    baselines = list(concat(baselines))

    # standardize rewards
    reward = torch.Tensor(rewards).to(baselines[0].device)
    reward = (reward - reward.mean()) / (reward.std() +
                                         float(np.finfo(np.float32).eps))
    baseline = torch.cat(baselines).squeeze()
    avg_advantage = 0
    losses = []
    for action, p, r, b in zip(indices, probs, reward, baseline):
        advantage = r - b
        avg_advantage += advantage
        losses.append(
            -p.log_prob(action) * (advantage / len(indices))
        )  # divide by T*B  + 1e-2*torch.exp(p.log_prob(action))*p.log_prob(action)
    critic_loss = F.mse_loss(baseline, reward)
    # backprop and update
    autograd.backward([critic_loss] + losses,
                      [torch.ones(1).to(critic_loss.device)] *
                      (1 + len(losses)))
    grad_log = grad_fn()
    opt.step()
    log_dict = {}
    log_dict.update(grad_log)
    log_dict['reward'] = avg_reward / len(art_batch)
    log_dict['advantage'] = avg_advantage.item() / len(indices)
    log_dict['mse'] = critic_loss.item()
    log_dict['copy_rate'] = copy_rate
    assert not math.isnan(log_dict['grad_norm'])
    return log_dict