Exemplo n.º 1
0
class RLTrainer(object):
    def __init__(self, model, discriminator, loss, optimizer, translator,
                 logger, opt, training_data, validation_data, src_vocab,
                 tgt_vocab):
        self.model = model
        self.discriminator = discriminator
        self.loss = loss
        self.rl_loss = NLLLoss(opt, do_reduce=False)
        self.optimizer = optimizer
        self.translator = translator
        self.logger = logger
        self.opt = opt
        self.tgt_vocab = tgt_vocab

        self.training_data = training_data
        self.validation_data = validation_data

        self.separate = opt.answer == 'sep'
        self.answer = opt.answer == 'enc'
        self.sep_id = src_vocab.lookup(
            Constants.SEP_WORD) if self.separate else Constants.SEP

        self.is_attn_mask = True if opt.defined_slf_attn_mask else False

        self.cntBatch, self.best_metric, self.best_ppl, self.best_bleu = 0, 0, math.exp(
            100), 0

    def cal_performance(self, loss_input):

        loss = self.loss.cal_loss(loss_input)

        gold, pred = loss_input['gold'], loss_input['pred']

        pred = pred.contiguous().view(-1, pred.size(2))
        pred = pred.max(1)[1]

        gold = gold.contiguous().view(-1)
        non_pad_mask = gold.ne(Constants.PAD)

        n_correct = pred.eq(gold)
        n_correct = n_correct.masked_select(non_pad_mask).sum().item()

        return loss, n_correct, pred

    def cal_rl_loss(self, pred, decoded_text, flu_rl_inputs, rel_rl_inputs,
                    ans_rl_inputs, flu_discriminator, rel_discriminator,
                    ans_discriminator):
        def _get_n_best(logits, n_best_size):
            index_and_score = sorted(enumerate(logits),
                                     key=lambda x: x[1],
                                     reverse=True)

            best_indexes = []
            for i in range(len(index_and_score)):
                if i >= n_best_size:
                    break
                best_indexes.append(index_and_score[i])

            return best_indexes

        def _get_score(start_logits, end_logits, indexes):
            start_logits = _get_n_best(start_logits[indexes[0]:indexes[1]], 5)
            end_logits = _get_n_best(end_logits[indexes[0]:indexes[1]], 5)
            b_scores = [[0, 0]]

            for start in start_logits:
                for end in end_logits:
                    if start[0] <= end[0] and end[0] - start[
                            0] < 64:  # TODO: magic number
                        score = torch.tensor([start[1], end[1]])
                        score = math.pow(score[0].item() * score[1].item(),
                                         0.5)
                        b_scores.append([score, end[0] - start[0] + 1])
            b_scores.sort(key=lambda x: x[0], reverse=True)
            return b_scores[0][0]

        batch_size, seq_length, vocab_size = pred.size()
        gold = decoded_text[:, 1:].contiguous()

        ##=== fluency ===##
        flu_reward = 0
        if 'fluency' in self.opt.rl:
            with torch.no_grad():
                output_pred_dicts = flu_discriminator(
                    flu_rl_inputs[0], attention_mask=flu_rl_inputs[1])

                reward_fct = NLLLoss(self.opt, do_reduce=False)

                lm_loss = reward_fct.cal_simple_nll(output_pred_dicts[0],
                                                    flu_rl_inputs[2])
                lm_loss = lm_loss.view(pred.size(0), -1).mean(-1)
                scores = torch.exp(lm_loss).to(pred.device)

            flu_reward = scores.data.sum().item()
            flu_scores_scale = self.opt.flu_alpha - scores.data

        ##=== relevance ===##
        rel_reward = 0
        if 'relevance' in self.opt.rl:
            with torch.no_grad():
                output = rel_discriminator(rel_rl_inputs[0],
                                           token_type_ids=rel_rl_inputs[1])
                # get the output logits for [CLS]
                logits = output[0].contiguous().to(pred.device)
            scores = torch.softmax(logits, dim=1).transpose(0,
                                                            1)[1].contiguous()

            rel_reward = scores.data.sum().item()
            rel_scores_scale = torch.log(self.opt.rel_alpha /
                                         (1 - scores.data + 1e-16))

        ##=== answerability ===##
        ans_reward = 0
        if 'answerability' in self.opt.rl:
            with torch.no_grad():
                batch_start_logits, batch_end_logits = ans_discriminator(
                    ans_rl_inputs[0], ans_rl_inputs[1], ans_rl_inputs[2])
            scores, rand_scores = [], []
            for b in range(batch_start_logits.size(0)):
                start_logits = torch.softmax(batch_start_logits[b], dim=-1)
                end_logits = torch.softmax(batch_end_logits[b], dim=-1)
                score = _get_score(start_logits.detach().cpu().tolist(),
                                   end_logits.detach().cpu().tolist(),
                                   ans_rl_inputs[3][b])
                scores.append(score)
            scores = torch.tensor(scores, device=pred.device)

            ans_reward = scores.data.sum().item()
            ans_scores_scale = torch.log(self.opt.ans_alpha /
                                         (1 - scores.data + 1e-16))

        ##=== combination ===##
        scores = 0
        if 'fluency' in self.opt.rl:
            scores += flu_scores_scale * self.opt.flu_gamma
        if 'relevance' in self.opt.rl:
            scores += rel_scores_scale * self.opt.rel_gamma
        if 'answerability' in self.opt.rl:
            scores += ans_scores_scale * self.opt.ans_gamma

        n_correct = scores.gt(0).float().sum().item()
        weights = [(batch_size - n_correct) / batch_size,
                   n_correct / batch_size]
        weights = [1 / 3, 2 / 3] if weights[0] > 1 / 3 else weights
        scores_scale_rgt = scores.gt(0).float() * scores * weights[1]
        scores_scale_wrg = scores.lt(0).float() * scores * weights[0]
        scores_scale = scores_scale_rgt + scores_scale_wrg

        log_prb = self.rl_loss.cal_simple_nll(pred.contiguous(),
                                              gold.contiguous()).view(
                                                  batch_size, -1).mean(-1)
        loss = torch.sum(scores_scale * log_prb)

        return loss, [
            n_correct, batch_size, [flu_reward, rel_reward, ans_reward]
        ]

    def save_model(self, better, bleu):
        model_state_dict = self.model.module.state_dict() if len(
            self.opt.gpus) > 1 else self.model.state_dict()
        model_state_dict = collections.OrderedDict([
            (x, y.cpu()) for x, y in model_state_dict.items()
        ])
        checkpoint = {
            'model state dict': model_state_dict,
            'options': self.opt
        }

        if self.opt.save_mode == 'all':
            model_name = self.opt.save_model + '_ppl_{ppl:2.5f}.chkpt'.format(
                ppl=self.best_ppl)
            torch.save(checkpoint, model_name)
        elif self.opt.save_mode == 'best':
            model_name = self.opt.save_model + '.chkpt'
            if better:
                torch.save(checkpoint, model_name)
                print('    - [Info] The checkpoint file has been updated.')

        if bleu != 'unk' and bleu > self.best_bleu:
            self.best_bleu = bleu
            model_name = self.opt.save_model + '_' + str(round(
                bleu * 100, 5)) + '_bleu4.chkpt'
            torch.save(checkpoint, model_name)

    def eval_step(self, device, epoch):
        ''' Epoch operation in evaluation phase '''
        self.model.eval()

        with torch.no_grad():
            max_length = 0
            total_nll_loss, n_word_total, n_word_correct = 0, 0, 0
            total_rl_loss, n_rl_correct = 0, [0, 0]
            flu_reward, rel_reward, ans_reward = 0, 0, 0
            valid_length = len(self.validation_data)
            eval_index_list = range(valid_length)

            for idx in tqdm(eval_index_list,
                            mininterval=2,
                            desc='  - (Validation) ',
                            leave=False):
                batch = self.validation_data[idx]
                inputs, max_length, gold, copy = preprocess_batch(
                    batch,
                    separate=self.separate,
                    enc_rnn=self.opt.enc_rnn != '',
                    dec_rnn=self.opt.dec_rnn != '',
                    feature=self.opt.feature,
                    dec_feature=self.opt.dec_feature,
                    answer=self.answer,
                    ans_feature=self.opt.ans_feature,
                    sep_id=self.sep_id,
                    copy=self.opt.copy,
                    attn_mask=self.is_attn_mask,
                    device=device)
                copy_gold, copy_switch = copy[0], copy[1]

                ### forward ###
                rst = self.model(inputs, max_length=max_length)

                loss_input = {}
                loss_input['pred'], loss_input['gold'] = rst['pred'], gold
                if self.opt.copy:
                    loss_input['copy_pred'], loss_input['copy_gate'] = rst[
                        'copy_pred'], rst['copy_gate']
                    loss_input['copy_gold'], loss_input[
                        'copy_switch'] = copy_gold, copy_switch
                if self.opt.coverage:
                    loss_input['coverage_pred'] = rst['coverage_pred']
                nll_loss, n_correct, _ = self.cal_performance(loss_input)

                rst = self.model(inputs,
                                 max_length=max_length,
                                 rl_type=self.opt.rl)

                flu_rl_inputs, flu_discriminator = None, None
                if 'fluency' in self.opt.rl:
                    flu_rl_inputs = preprocess_rl_batch(
                        rst['decoded_text'], rst['rand_decoded_text'],
                        'fluency', self.tgt_vocab,
                        self.opt.rl_device['fluency'])
                    flu_discriminator = self.discriminator['fluency']
                    flu_discriminator.eval()

                rel_rl_inputs, rel_discriminator = None, None
                if 'relevance' in self.opt.rl:
                    inputing = (inputs['encoder']['src_seq'],
                                rst['decoded_text'],
                                inputs['encoder']['lengths'],
                                rst['rand_decoded_text'])
                    rel_rl_inputs = preprocess_rl_batch(
                        inputing, None, 'relevance', self.tgt_vocab,
                        self.opt.rl_device['relevance'])
                    rel_discriminator = self.discriminator['relevance']
                    rel_discriminator.eval()

                ans_rl_inputs, ans_discriminator = None, None
                if 'answerability' in self.opt.rl:
                    inputing = (inputs['encoder']['src_seq'],
                                rst['decoded_text'],
                                inputs['encoder']['lengths'],
                                rst['rand_decoded_text'])
                    ans_rl_inputs, rand_rl_inputs = preprocess_rl_batch(
                        inputing, None, 'answerability', self.tgt_vocab,
                        self.opt.rl_device['answerability'])
                    ans_discriminator = self.discriminator['answerability']
                    ans_discriminator.eval()

                if self.opt.rl:
                    rl_loss, rl_n_correct = self.cal_rl_loss(
                        rst['pred'], rst['decoded_text'], flu_rl_inputs,
                        rel_rl_inputs, ans_rl_inputs, flu_discriminator,
                        rel_discriminator, ans_discriminator)

                non_pad_mask = gold.ne(Constants.PAD)
                n_word = non_pad_mask.sum().item()

                total_nll_loss += nll_loss.item()
                n_word_total += n_word
                n_word_correct += n_correct

                total_rl_loss += rl_loss.item()
                n_rl_correct[0] += rl_n_correct[0]
                n_rl_correct[1] += rl_n_correct[1]

                flu_reward += rl_n_correct[2][0]
                rel_reward += rl_n_correct[2][1]
                ans_reward += rl_n_correct[2][2]

            loss_per_word = total_nll_loss / n_word_total
            nll_accuracy = n_word_correct / n_word_total
            loss_per_sample = total_rl_loss / n_rl_correct[1]
            bleu = 'unk'
            perplexity = math.exp(min(loss_per_word, 16))
            rl_accuracy = n_rl_correct[0] / n_rl_correct[1]

            flu_reward /= n_rl_correct[1]
            rel_reward /= n_rl_correct[1]
            ans_reward /= n_rl_correct[1]

            if (perplexity <= self.opt.translate_ppl
                    or perplexity > self.best_ppl):
                if self.cntBatch % self.opt.translate_steps == 0:
                    bleu = self.translator.eval_all(self.model,
                                                    self.validation_data)

        return [loss_per_word,
                loss_per_sample], nll_accuracy, rl_accuracy, bleu, [
                    flu_reward, rel_reward, ans_reward
                ]

    def train_epoch(self, device, epoch):
        ''' Epoch operation in training phase'''
        if self.opt.extra_shuffle and epoch > self.opt.curriculum:
            self.logger.info('Shuffling...')
            self.training_data.shuffle()

        self.model.train()

        total_rl_loss, total_nll_loss, n_sample_total, n_sample_correct = 0, 0, 0, 0
        n_word_total, n_word_correct, report_n_word_total, report_n_word_correct = 0, 0, 0, 0
        report_total_rl_loss, report_total_nll_loss, report_n_sample_total, report_n_sample_correct = 0, 0, 0, 0

        batch_order = torch.randperm(len(self.training_data))

        for idx in tqdm(range(len(self.training_data)),
                        mininterval=2,
                        desc='  - (Training)   ',
                        leave=False):

            batch_idx = batch_order[idx] if epoch > self.opt.curriculum else idx
            batch = self.training_data[batch_idx]

            ##### ==================== prepare data ==================== #####
            inputs, max_length, gold, copy = preprocess_batch(
                batch,
                separate=self.separate,
                enc_rnn=self.opt.enc_rnn != '',
                dec_rnn=self.opt.dec_rnn != '',
                feature=self.opt.feature,
                dec_feature=self.opt.dec_feature,
                answer=self.answer,
                ans_feature=self.opt.ans_feature,
                sep_id=self.sep_id,
                copy=self.opt.copy,
                attn_mask=self.is_attn_mask,
                device=device)
            copy_gold, copy_switch = copy[0], copy[1]

            ##### ==================== forward ==================== #####
            self.model.zero_grad()
            self.optimizer.zero_grad()

            rst = self.model(inputs,
                             max_length=max_length,
                             rl_type=self.opt.rl)

            ##### ==================== backward ==================== #####
            ##=== rl loss ===##
            flu_rl_inputs, flu_discriminator = None, None
            if 'fluency' in self.opt.rl:
                flu_rl_inputs = preprocess_rl_batch(
                    rst['decoded_text'], rst['rand_decoded_text'], 'fluency',
                    self.tgt_vocab, self.opt.rl_device['fluency'])
                flu_discriminator = self.discriminator['fluency']
                flu_discriminator.eval()

            rel_rl_inputs, rel_discriminator = None, None
            if 'relevance' in self.opt.rl:
                inputing = (inputs['encoder']['src_seq'], rst['decoded_text'],
                            inputs['encoder']['lengths'],
                            rst['rand_decoded_text'])
                rel_rl_inputs = preprocess_rl_batch(
                    inputing, None, 'relevance', self.tgt_vocab,
                    self.opt.rl_device['relevance'])
                rel_discriminator = self.discriminator['relevance']
                rel_discriminator.eval()

            ans_rl_inputs, ans_discriminator = None, None
            if 'answerability' in self.opt.rl:
                inputing = (inputs['encoder']['src_seq'], rst['decoded_text'],
                            inputs['encoder']['lengths'],
                            rst['rand_decoded_text'])
                ans_rl_inputs, rand_rl_inputs = preprocess_rl_batch(
                    inputing, None, 'answerability', self.tgt_vocab,
                    self.opt.rl_device['answerability'])
                ans_discriminator = self.discriminator['answerability']
                ans_discriminator.eval()

            if self.opt.rl:
                rl_loss, n_correct = self.cal_rl_loss(
                    rst['pred'], rst['decoded_text'], flu_rl_inputs,
                    rel_rl_inputs, ans_rl_inputs, flu_discriminator,
                    rel_discriminator, ans_discriminator)

            if len(self.opt.gpus) > 1:
                rl_loss = rl_loss.mean()  # mean() to average on multi-gpu.
            loss = rl_loss

            ###=== NLL loss ===##
            rst = self.model(inputs, max_length=max_length)

            loss_input = {}
            loss_input['pred'], loss_input['gold'] = rst['pred'], gold
            if self.opt.copy:
                loss_input['copy_pred'], loss_input['copy_gate'] = rst[
                    'copy_pred'], rst['copy_gate']
                loss_input['copy_gold'], loss_input[
                    'copy_switch'] = copy_gold, copy_switch
            if self.opt.coverage:
                loss_input['coverage_pred'] = rst['coverage_pred']

            nll_loss, word_correct, _ = self.cal_performance(loss_input)
            if len(self.opt.gpus) > 1:
                nll_loss = nll_loss.mean()  # mean() to average on multi-gpu.

            self.cntBatch += 1
            if self.cntBatch % 4 == 0 or loss.item() < -10:
                loss = loss + nll_loss

            ##=== backward ===##
            if math.isnan(loss):
                print('loss catch NaN')
                import ipdb
                ipdb.set_trace()

            self.optimizer.backward(loss)
            self.optimizer.step()

            if math.isnan(self.model.generator.weight.data.contiguous().view(
                    -1).sum().item()):
                print('parameter catch NaN')
                import ipdb
                ipdb.set_trace()

            ##### ==================== note for epoch report & step report ==================== #####
            n_word = gold.ne(Constants.PAD).float().sum().item()
            total_nll_loss += nll_loss.item()
            n_word_total += n_word
            n_word_correct += word_correct

            total_rl_loss += rl_loss.item()
            n_sample_total += n_correct[1]
            n_sample_correct += n_correct[0]

            report_total_nll_loss += nll_loss.item()
            report_n_word_total += n_word
            report_n_word_correct += word_correct

            report_total_rl_loss += rl_loss.item()
            report_n_sample_total += n_correct[1]
            report_n_sample_correct += n_correct[0]

            ##### ==================== evaluation ==================== #####
            if self.cntBatch % self.opt.valid_steps == 0:
                ### ========== evaluation on dev ========== ###
                valid_loss, valid_nll_accu, valid_rl_accu, valid_bleu, rewards = self.eval_step(
                    device, epoch)
                valid_ppl = math.exp(min(valid_loss[0], 16))

                report_avg_nll_loss = report_total_nll_loss / report_n_word_total
                report_avg_rl_loss = report_total_rl_loss / report_n_sample_total
                report_avg_ppl = math.exp(min(report_avg_nll_loss, 16))
                report_avg_nll_accu = report_n_word_correct / report_n_word_total
                report_avg_rl_accu = report_n_sample_correct / report_n_sample_total

                better = False
                # metric = (valid_rl_accu if self.opt.rl in ['relevance', 'answerability', ''] else 1 / (valid_loss[1] + 1e-16)) / valid_ppl
                metric = valid_rl_accu / valid_ppl
                if metric >= self.best_metric:
                    self.best_metric = metric
                    better = True

                report_total_nll_loss, report_total_rl_loss = 0, 0
                report_n_word_total, report_n_sample_total = 0, 0
                report_n_word_correct, report_n_sample_correct = 0, 0

                ### ========== update learning rate ========== ###
                self.optimizer.update_learning_rate(better)

                record_log(self.opt.logfile_train,
                           step=self.cntBatch,
                           rl_loss=report_avg_rl_loss,
                           rl_accu=report_avg_rl_accu,
                           loss=report_avg_nll_loss,
                           accu=report_avg_nll_accu,
                           ppl=math.exp(min(report_avg_nll_loss, 16)),
                           bad_cnt=self.optimizer._bad_cnt,
                           lr=self.optimizer._learning_rate)
                dev_record_log(self.opt.logfile_dev,
                               step=self.cntBatch,
                               loss=valid_loss[0],
                               accu=valid_nll_accu,
                               ppl=valid_ppl,
                               bleu=valid_bleu,
                               rl_loss=valid_loss[1],
                               rl_accu=valid_rl_accu,
                               bad_cnt=self.optimizer._bad_cnt,
                               lr=self.optimizer._learning_rate,
                               flu=rewards[0],
                               rel=rewards[1],
                               ans=rewards[2])

                if self.opt.save_model:
                    self.save_model(better, valid_bleu)

                self.model.train()

        loss_per_word = total_nll_loss / n_word_total
        nll_accuracy = n_word_correct / n_word_total * 100
        rl_accuracy = n_sample_correct / n_sample_total * 100

        return math.exp(min(loss_per_word, 16)), nll_accuracy, rl_accuracy

    def train(self, device):
        ''' Start training '''
        self.logger.info(self.model)

        for epoch_i in range(self.opt.epoch):
            self.logger.info('')
            self.logger.info(' *  [ Epoch {0} ]:   '.format(epoch_i))
            start = time.time()
            ppl, nll_accu, rl_accu = self.train_epoch(device, epoch_i + 1)

            self.logger.info(
                ' *  - (Training)   ppl: {ppl: 8.5f}, accuracy: nll - {nll:3.3f} %; rl - {rl:3.3f} %'
                .format(ppl=ppl, nll=nll_accu, rl=rl_accu))
            print('                ' + str(time.time() - start) +
                  ' seconds for epoch ' + str(epoch_i))
Exemplo n.º 2
0
    def cal_rl_loss(self, pred, decoded_text, flu_rl_inputs, rel_rl_inputs,
                    ans_rl_inputs, flu_discriminator, rel_discriminator,
                    ans_discriminator):
        def _get_n_best(logits, n_best_size):
            index_and_score = sorted(enumerate(logits),
                                     key=lambda x: x[1],
                                     reverse=True)

            best_indexes = []
            for i in range(len(index_and_score)):
                if i >= n_best_size:
                    break
                best_indexes.append(index_and_score[i])

            return best_indexes

        def _get_score(start_logits, end_logits, indexes):
            start_logits = _get_n_best(start_logits[indexes[0]:indexes[1]], 5)
            end_logits = _get_n_best(end_logits[indexes[0]:indexes[1]], 5)
            b_scores = [[0, 0]]

            for start in start_logits:
                for end in end_logits:
                    if start[0] <= end[0] and end[0] - start[
                            0] < 64:  # TODO: magic number
                        score = torch.tensor([start[1], end[1]])
                        score = math.pow(score[0].item() * score[1].item(),
                                         0.5)
                        b_scores.append([score, end[0] - start[0] + 1])
            b_scores.sort(key=lambda x: x[0], reverse=True)
            return b_scores[0][0]

        batch_size, seq_length, vocab_size = pred.size()
        gold = decoded_text[:, 1:].contiguous()

        ##=== fluency ===##
        flu_reward = 0
        if 'fluency' in self.opt.rl:
            with torch.no_grad():
                output_pred_dicts = flu_discriminator(
                    flu_rl_inputs[0], attention_mask=flu_rl_inputs[1])

                reward_fct = NLLLoss(self.opt, do_reduce=False)

                lm_loss = reward_fct.cal_simple_nll(output_pred_dicts[0],
                                                    flu_rl_inputs[2])
                lm_loss = lm_loss.view(pred.size(0), -1).mean(-1)
                scores = torch.exp(lm_loss).to(pred.device)

            flu_reward = scores.data.sum().item()
            flu_scores_scale = self.opt.flu_alpha - scores.data

        ##=== relevance ===##
        rel_reward = 0
        if 'relevance' in self.opt.rl:
            with torch.no_grad():
                output = rel_discriminator(rel_rl_inputs[0],
                                           token_type_ids=rel_rl_inputs[1])
                # get the output logits for [CLS]
                logits = output[0].contiguous().to(pred.device)
            scores = torch.softmax(logits, dim=1).transpose(0,
                                                            1)[1].contiguous()

            rel_reward = scores.data.sum().item()
            rel_scores_scale = torch.log(self.opt.rel_alpha /
                                         (1 - scores.data + 1e-16))

        ##=== answerability ===##
        ans_reward = 0
        if 'answerability' in self.opt.rl:
            with torch.no_grad():
                batch_start_logits, batch_end_logits = ans_discriminator(
                    ans_rl_inputs[0], ans_rl_inputs[1], ans_rl_inputs[2])
            scores, rand_scores = [], []
            for b in range(batch_start_logits.size(0)):
                start_logits = torch.softmax(batch_start_logits[b], dim=-1)
                end_logits = torch.softmax(batch_end_logits[b], dim=-1)
                score = _get_score(start_logits.detach().cpu().tolist(),
                                   end_logits.detach().cpu().tolist(),
                                   ans_rl_inputs[3][b])
                scores.append(score)
            scores = torch.tensor(scores, device=pred.device)

            ans_reward = scores.data.sum().item()
            ans_scores_scale = torch.log(self.opt.ans_alpha /
                                         (1 - scores.data + 1e-16))

        ##=== combination ===##
        scores = 0
        if 'fluency' in self.opt.rl:
            scores += flu_scores_scale * self.opt.flu_gamma
        if 'relevance' in self.opt.rl:
            scores += rel_scores_scale * self.opt.rel_gamma
        if 'answerability' in self.opt.rl:
            scores += ans_scores_scale * self.opt.ans_gamma

        n_correct = scores.gt(0).float().sum().item()
        weights = [(batch_size - n_correct) / batch_size,
                   n_correct / batch_size]
        weights = [1 / 3, 2 / 3] if weights[0] > 1 / 3 else weights
        scores_scale_rgt = scores.gt(0).float() * scores * weights[1]
        scores_scale_wrg = scores.lt(0).float() * scores * weights[0]
        scores_scale = scores_scale_rgt + scores_scale_wrg

        log_prb = self.rl_loss.cal_simple_nll(pred.contiguous(),
                                              gold.contiguous()).view(
                                                  batch_size, -1).mean(-1)
        loss = torch.sum(scores_scale * log_prb)

        return loss, [
            n_correct, batch_size, [flu_reward, rel_reward, ans_reward]
        ]