Exemple #1
0
    def validate(self, valid_iter, step=0):
        """ Validate model.
            valid_iter: validate data iterator
        Returns:
            :obj:`nmt.Statistics`: validation loss statistics
        """
        # Set model in validating mode.
        self.model.eval()
        stats = Statistics()

        with torch.no_grad():
            for batch in valid_iter:
                src = batch.src
                labels = batch.src_sent_labels
                segs = batch.segs
                clss = batch.clss
                mask = batch.mask_src
                mask_cls = batch.mask_cls

                sent_scores, mask = self.model(src, segs, clss, mask, mask_cls)

                loss = self.loss(sent_scores, labels.float())
                loss = (loss * mask.float()).sum()
                batch_stats = Statistics(float(loss.cpu().data.numpy()),
                                         len(labels))
                stats.update(batch_stats)
            self._report_step(0, step, valid_stats=stats)
            return stats
    def validate(self, valid_iter, step=0):
        """ Validate model.
            valid_iter: validate data iterator
        Returns:
            :obj:`nmt.Statistics`: validation loss statistics
        """
        # Set model in validating mode.
        self.model.eval()
        stats = Statistics()

        with torch.no_grad():
            for batch in valid_iter:
                src = batch.src
                labels = batch.src_sent_labels.float()
                segs = batch.segs
                clss = batch.clss
                mask = batch.mask_src
                mask_cls = batch.mask_cls

                sent_scores, mask = self.model(src, segs, clss, mask, mask_cls)
                loss = self.loss(sent_scores, labels)
                loss = (loss * mask.float()).sum() / mask.float().sum()

                # baogs: report accuracy
                abs_scores, abs_ids = torch.topk(sent_scores, 3, dim=1)
                abs_mask = (abs_scores > 0).float()
                n_sents = abs_mask.sum().item()
                n_correct = torch.sum(
                    torch.gather(labels, 1, abs_ids) * abs_mask).item()
                batch_stats = Statistics(loss.item() * batch.batch_size,
                                         batch.batch_size, n_sents, n_correct)
                stats.update(batch_stats)
            self._report_step(0, step, valid_stats=stats)
            return stats
Exemple #3
0
    def test(self, test_iter, step, cal_lead=False, cal_oracle=False):
        """ Validate model.
            valid_iter: validate data iterator
        Returns:
            :obj:`nmt.Statistics`: validation loss statistics
        """

        # Set model in validating mode.
        def _get_ngrams(n, text):
            ngram_set = set()
            text_length = len(text)
            max_index_ngram_start = text_length - n
            for i in range(max_index_ngram_start + 1):
                ngram_set.add(tuple(text[i:i + n]))
            return ngram_set

        def _block_tri(c, p):
            tri_c = _get_ngrams(3, c.split())
            for s in p:
                tri_s = _get_ngrams(3, s.split())
                if len(tri_c.intersection(tri_s)) > 0:
                    return True
            return False

        if (not cal_lead and not cal_oracle):
            self.model.eval()
        stats = Statistics()

        can_path = '%s_step%d.candidate' % (self.args.result_path, step)
        gold_path = '%s_step%d.gold' % (self.args.result_path, step)
        all_preds = []
        with open(can_path, 'w') as save_pred:
            with open(gold_path, 'w') as save_gold:
                with torch.no_grad():
                    for batch in test_iter:
                        batch_size = batch.p_pair.size(0)
                        batch_stats, p_scores, n_scores = self._main(
                            batch, batch_size, is_train=False)
                        stats.update(batch_stats)

                        scores = []
                        preds = []
                        for i, idx in enumerate(p_scores):
                            p = p_scores[i].cpu().data.numpy()
                            n = n_scores[i].cpu().data.numpy()
                            scores.append(str(p) + '\t' + str(n))
                            preds.append(int(p > n))
                            all_preds.append(int(p > n))

                        for i in range(len(scores)):
                            save_gold.write(scores[i] + '\n')
                        for i in range(len(preds)):
                            save_pred.write(str(preds[i]) + '\n')
        print("**********************")
        print(sum(all_preds), len(all_preds))
        print("ACC: ", sum(all_preds) / float(len(all_preds)))
        print("**********************")
        self._report_step(0, step, valid_stats=stats)

        return stats
Exemple #4
0
    def validate(self, valid_iter, step=0):
        """ Validate model.
            valid_iter: validate data iterator
        Returns:
            :obj:`nmt.Statistics`: validation loss statistics
        """
        # Set model in validating mode.
        self.model.eval()
        stats = Statistics()

        with torch.no_grad():
            for batch in valid_iter:
                batch_size = batch.p_pair.size(0)
                batch_stats, _, _ = self._main(batch,
                                               batch_size,
                                               is_train=False)
                stats.update(batch_stats)
            self._report_step(0, step, valid_stats=stats)
            return stats
Exemple #5
0
    def test(self, test_iter, step, cal_lead=False, cal_oracle=False):
        """ Validate model.
            valid_iter: validate data iterator
        Returns:
            :obj:`nmt.Statistics`: validation loss statistics
        """

        # Set model in validating mode.
        def _get_ngrams(n, text):
            ngram_set = set()
            text_length = len(text)
            max_index_ngram_start = text_length - n
            for i in range(max_index_ngram_start + 1):
                ngram_set.add(tuple(text[i:i + n]))
            return ngram_set

        def _block_tri(c, p):
            tri_c = _get_ngrams(3, c.split())
            for s in p:
                tri_s = _get_ngrams(3, s.split())
                if len(tri_c.intersection(tri_s)) > 0:
                    return True
            return False

        if (not cal_lead and not cal_oracle):
            self.model.eval()
        stats = Statistics()

        can_path = '%s_step%d.candidate' % (self.args.result_path, step)
        gold_path = '%s_step%d.gold' % (self.args.result_path, step)
        with open(can_path, 'w') as save_pred:
            with open(gold_path, 'w') as save_gold:
                with torch.no_grad():
                    ct = 0
                    for batch in test_iter:
                        src = batch.src
                        labels = batch.src_sent_labels
                        segs = batch.segs
                        clss = batch.clss
                        mask = batch.mask_src
                        mask_cls = batch.mask_cls

                        gold = []
                        pred = []

                        if (cal_lead):
                            selected_ids = [list(range(batch.clss.size(1)))
                                            ] * batch.batch_size
                        elif (cal_oracle):
                            selected_ids = [[
                                j for j in range(batch.clss.size(1))
                                if labels[i][j] == 1
                            ] for i in range(batch.batch_size)]
                        else:
                            sent_scores, mask = self.model(
                                src, segs, clss, mask, mask_cls)

                            loss = self.loss(sent_scores, labels.float())
                            loss = (loss * mask.float()).sum()
                            batch_stats = Statistics(
                                float(loss.cpu().data.numpy()), len(labels))
                            stats.update(batch_stats)

                            sent_scores = sent_scores + mask.float()
                            sent_scores = sent_scores.cpu().data.numpy()
                            selected_ids = np.argsort(-sent_scores, 1)
                        # selected_ids = np.sort(selected_ids,1)
                        for i, idx in enumerate(selected_ids):
                            _pred = []
                            if (len(batch.src_str[i]) == 0):
                                continue
                            for j in selected_ids[i][:len(batch.src_str[i])]:
                                if (j >= len(batch.src_str[i])):
                                    continue
                                candidate = batch.src_str[i][j].strip()
                                if (self.args.block_trigram):
                                    if (not _block_tri(candidate, _pred)):
                                        _pred.append(candidate)
                                else:
                                    _pred.append(candidate)

                                if ((not cal_oracle)
                                        and (not self.args.recall_eval)
                                        and len(_pred) == 3):
                                    break

                            _pred = '<q>'.join(_pred)
                            if (self.args.recall_eval):
                                _pred = ' '.join(
                                    _pred.split()
                                    [:len(batch.tgt_str[i].split())])

                            pred.append(_pred)
                            gold.append(batch.tgt_str[i])

                        for i in range(len(gold)):
                            save_gold.write(str(ct) + '\n')
                            save_gold.write(gold[i].strip() + '\n')
                            save_pred.write(str(ct) + '\n')
                            save_pred.write(pred[i].strip() + '\n')
                            ct += 1
        if (step != -1 and self.args.report_rouge):
            rouges = test_rouge(self.args.temp_dir, can_path, gold_path)
            logger.info('Rouges at step %d \n%s' %
                        (step, rouge_results_to_str(rouges)))
        self._report_step(0, step, valid_stats=stats)

        return stats
Exemple #6
0
    def validate(self, valid_iter, step=0):
        """ Validate model.
            valid_iter: validate data iterator
        Returns:
            :obj:`nmt.Statistics`: validation loss statistics
        """
        # Set model in validating mode.
        self.model.eval()
        if self.args.acc_reporter == 1:
            stats = acc_reporter.Statistics()
        else:
            stats = Statistics()

        with torch.no_grad():
            for batch in valid_iter:
                src = batch.src
                labels = batch.src_sent_labels
                segs = batch.segs
                clss = batch.clss
                mask = batch.mask_src
                mask_cls = batch.mask_cls

                if self.args.ext_sum_dec:
                    sent_scores, mask = self.model(src, segs, clss, mask, mask_cls, labels)  # B, tgt_len custom_num
                    tgt_len = 3
                    _, labels_id = torch.topk(labels, k=tgt_len)  # B, tgt_len
                    labels_id, _ = torch.sort(labels_id)
                    # nsent 100 weight_up 20
                    weight = torch.linspace(start=1, end=self.args.weight_up, steps=self.args.max_src_nsents).type_as(sent_scores)
                    # self.max_class = max(self.max_class,torch.max(labels_id+1).item())
                    # weight = weight[:self.max_class]
                    weight = weight[:sent_scores.size(-1)]
                    # weight = torch.ones(self.args.max_src_nsents)
                    loss = F.nll_loss(
                        F.log_softmax(
                            sent_scores.view(-1, sent_scores.size(-1)),
                            dim=-1,
                            dtype=torch.float32,
                        ),
                        labels_id.view(-1),  # bsz sent
                        weight=weight,
                        reduction='sum',
                        ignore_index=-1,
                    )
                    prediction = torch.argmax(sent_scores, dim=-1)
                    if (self.optim._step + 1) % self.args.print_every == 0:
                        logger.info(
                            'train prediction: %s |label %s ' % (str(prediction), str(labels_id)))
                    accuracy = torch.div(torch.sum(torch.equal(prediction, labels_id).float()), tgt_len)
                else:
                    sent_scores, mask = self.model(src, segs, clss, mask, mask_cls)  # B, custom_N
                    loss = self.loss(sent_scores, labels.float())
                    loss = (loss * mask.float()).sum()
                    tgt_len = 3
                    _, labels_id = torch.topk(labels, k=tgt_len)  # B, tgt_len
                    labels_id, _ = torch.sort(labels_id)
                    _, prediction = torch.topk(sent_scores, k=tgt_len)
                    prediction,_ = torch.sort(labels_id)
                    if (self.optim._step + 1) % self.args.print_every == 0:
                        logger.info(
                            'train prediction: %s |label %s ' % (str(prediction), str(labels_id)))
                    accuracy = torch.div(torch.sum(torch.equal(prediction, labels_id).float()), tgt_len)
                if self.args.acc_reporter == 1:
                    batch_stats = Statistics(float(loss.cpu().data.numpy()),accuracy, len(labels))
                else:
                    batch_stats = Statistics(float(loss.cpu().data.numpy()), len(labels))
                stats.update(batch_stats)
            self._report_step(0, step, valid_stats=stats)
            return stats
Exemple #7
0
    def validate_rouge_baseline(self, valid_iter_fct, step=0, valid_gl_stats=None, write_scores_to_pickle=False):
        """ Validate model.
            valid_iter: validate data iterator
        Returns:
            :obj:`nmt.Statistics`: validation loss statistics
        """

        preds = {}
        preds_with_idx = {}
        golds = {}
        can_path = '%s_step%d.source' % (self.args.result_path, step)
        gold_path = '%s_step%d.target' % (self.args.result_path, step)

        if step == self.best_val_step:
            can_path = '%s_step%d.source' % (self.args.result_path_test, step)
            gold_path = '%s_step%d.target' % (self.args.result_path_test, step)

        save_pred = open(can_path, 'w')
        save_gold = open(gold_path, 'w')
        sent_scores_whole = {}
        sent_sects_whole_pred = {}
        sent_sects_whole_true = {}
        sent_labels_true = {}
        sent_numbers_whole = {}
        paper_srcs = {}
        paper_tgts = {}
        sent_sect_wise_rg_whole = {}
        sent_sections_txt_whole = {}
        # Set model in validating mode.
        self.model.eval()
        stats = Statistics()
        best_model_saved = False
        best_recall_model_saved = False

        valid_iter = valid_iter_fct()

        with torch.no_grad():
            for batch in tqdm(valid_iter):
                src = batch.src
                labels = batch.src_sent_labels
                sent_labels = batch.sent_labels

                if self.rg_predictor:
                    sent_true_rg = batch.src_sent_labels
                else:
                    sent_labels = batch.sent_labels
                segs = batch.segs
                clss = batch.clss
                mask = batch.mask_src
                mask_cls = batch.mask_cls
                p_id = batch.paper_id
                segment_src = batch.src_str
                paper_tgt = batch.tgt_str
                sent_sect_wise_rg = batch.sent_sect_wise_rg
                sent_sections_txt = batch.sent_sections_txt
                sent_numbers = batch.sent_numbers

                sent_sect_labels = batch.sent_sect_labels
                if self.is_joint:
                    if not self.rg_predictor:
                        sent_scores, sent_sect_scores, mask, loss, loss_sent, loss_sect = self.model(src, segs, clss,
                                                                                                     mask, mask_cls,
                                                                                                     sent_labels,
                                                                                                     sent_sect_labels)
                    else:
                        sent_scores, sent_sect_scores, mask, loss, loss_sent, loss_sect = self.model(src, segs, clss,
                                                                                                     mask, mask_cls,
                                                                                                     sent_true_rg,
                                                                                                     sent_sect_labels)
                    acc, _ = self._get_mertrics(sent_sect_scores, sent_sect_labels, mask=mask,
                                                task='sent_sect')

                    batch_stats = Statistics(loss=float(loss.cpu().data.numpy().sum()),
                                             loss_sect=float(loss_sect.cpu().data.numpy().sum()),
                                             loss_sent=float(loss_sent.cpu().data.numpy().sum()),
                                             n_docs=len(labels),
                                             n_acc=batch.batch_size,
                                             RMSE=self._get_mertrics(sent_scores, labels, mask=mask, task='sent'),
                                             accuracy=acc)

                else:
                    if not self.rg_predictor:
                        sent_scores, mask, loss, _, _ = self.model(src, segs, clss, mask, mask_cls, sent_labels,
                                                                   sent_sect_labels=None, is_inference=True)
                    else:
                        sent_scores, mask, loss, _, _ = self.model(src, segs, clss, mask, mask_cls, sent_true_rg,
                                                                   sent_sect_labels=None, is_inference=True)

                    # sent_scores = (section_rg.unsqueeze(1).expand_as(sent_scores).to(device='cuda')*100) * sent_scores

                    batch_stats = Statistics(loss=float(loss.cpu().data.numpy().sum()),
                                             RMSE=self._get_mertrics(sent_scores, labels, mask=mask, task='sent'),
                                             n_acc=batch.batch_size,
                                             n_docs=len(labels))

                stats.update(batch_stats)

                sent_scores = sent_scores + mask.float()
                sent_scores = sent_scores.cpu().data.numpy()

                for idx, p_id in enumerate(p_id):
                    p_id = p_id.split('___')[0]

                    if p_id not in sent_scores_whole.keys():
                        masked_scores = sent_scores[idx] * mask[idx].cpu().data.numpy()
                        masked_scores = masked_scores[np.nonzero(masked_scores)]

                        masked_sent_labels_true = (sent_labels[idx] + 1) * mask[idx].long()

                        masked_sent_labels_true = masked_sent_labels_true[np.nonzero(masked_sent_labels_true)].flatten()
                        masked_sent_labels_true = (masked_sent_labels_true - 1)

                        sent_scores_whole[p_id] = masked_scores
                        sent_labels_true[p_id] = masked_sent_labels_true.cpu()

                        masked_sents_sections_true = (sent_sect_labels[idx] + 1) * mask[idx].long()

                        masked_sents_sections_true = masked_sents_sections_true[
                            np.nonzero(masked_sents_sections_true)].flatten()
                        masked_sents_sections_true = (masked_sents_sections_true - 1)
                        sent_sects_whole_true[p_id] = masked_sents_sections_true.cpu()

                        if self.is_joint:
                            masked_scores_sects = sent_sect_scores[idx] * mask[idx].view(-1, 1).expand_as(
                                sent_sect_scores[idx]).float()
                            masked_scores_sects = masked_scores_sects[torch.abs(masked_scores_sects).sum(dim=1) != 0]
                            sent_sects_whole_pred[p_id] = torch.max(self.softmax(masked_scores_sects), 1)[1].cpu()

                        paper_srcs[p_id] = segment_src[idx]
                        if sent_numbers[0] is not None:
                            sent_numbers_whole[p_id] = sent_numbers[idx]
                            # sent_tokens_count_whole[p_id] = sent_tokens_count[idx]
                        paper_tgts[p_id] = paper_tgt[idx]
                        sent_sect_wise_rg_whole[p_id] = sent_sect_wise_rg[idx]
                        sent_sections_txt_whole[p_id] = sent_sections_txt[idx]


                    else:
                        masked_scores = sent_scores[idx] * mask[idx].cpu().data.numpy()
                        masked_scores = masked_scores[np.nonzero(masked_scores)]

                        masked_sent_labels_true = (sent_labels[idx] + 1) * mask[idx].long()
                        masked_sent_labels_true = masked_sent_labels_true[np.nonzero(masked_sent_labels_true)].flatten()
                        masked_sent_labels_true = (masked_sent_labels_true - 1)

                        sent_scores_whole[p_id] = np.concatenate((sent_scores_whole[p_id], masked_scores), 0)
                        sent_labels_true[p_id] = np.concatenate((sent_labels_true[p_id], masked_sent_labels_true.cpu()),
                                                                0)

                        masked_sents_sections_true = (sent_sect_labels[idx] + 1) * mask[idx].long()
                        masked_sents_sections_true = masked_sents_sections_true[
                            np.nonzero(masked_sents_sections_true)].flatten()
                        masked_sents_sections_true = (masked_sents_sections_true - 1)
                        sent_sects_whole_true[p_id] = np.concatenate(
                            (sent_sects_whole_true[p_id], masked_sents_sections_true.cpu()), 0)

                        if self.is_joint:
                            masked_scores_sects = sent_sect_scores[idx] * mask[idx].view(-1, 1).expand_as(
                                sent_sect_scores[idx]).float()
                            masked_scores_sects = masked_scores_sects[
                                torch.abs(masked_scores_sects).sum(dim=1) != 0]
                            sent_sects_whole_pred[p_id] = np.concatenate(
                                (sent_sects_whole_pred[p_id], torch.max(self.softmax(masked_scores_sects), 1)[1].cpu()),
                                0)

                        paper_srcs[p_id] = np.concatenate((paper_srcs[p_id], segment_src[idx]), 0)
                        if sent_numbers[0] is not None:
                            sent_numbers_whole[p_id] = np.concatenate((sent_numbers_whole[p_id], sent_numbers[idx]), 0)
                            # sent_tokens_count_whole[p_id] = np.concatenate(
                            #     (sent_tokens_count_whole[p_id], sent_tokens_count[idx]), 0)

                        sent_sect_wise_rg_whole[p_id] = np.concatenate(
                            (sent_sect_wise_rg_whole[p_id], sent_sect_wise_rg[idx]), 0)
                        sent_sections_txt_whole[p_id] = np.concatenate(
                            (sent_sections_txt_whole[p_id], sent_sections_txt[idx]), 0)


        PRED_LEN = self.args.val_pred_len
        acum_f_sent_labels = 0
        acum_p_sent_labels = 0
        acum_r_sent_labels = 0
        acc_total = 0
        for p_idx, (p_id, sent_scores) in enumerate(sent_scores_whole.items()):
            # sent_true_labels = pickle.load(open("sent_labels_files/pubmedL/val.labels.p", "rb"))
            # section_textual = np.array(section_textual)
            paper_sent_true_labels = np.array(sent_labels_true[p_id])
            if self.is_joint:
                sent_sects_true = np.array(sent_sects_whole_true[p_id])
                sent_sects_pred = np.array(sent_sects_whole_pred[p_id])

            sent_scores = np.array(sent_scores)
            p_src = np.array(paper_srcs[p_id])

            # selected_ids_unsorted = np.argsort(-sent_scores, 0)
            keep_ids = [idx for idx, s in enumerate(p_src) if
                        len(s.replace('.', '').replace(',', '').replace('(', '').replace(')', '').
                            replace('-', '').replace(':', '').replace(';', '').replace('*', '').split()) > 5 and
                        len(s.replace('.', '').replace(',', '').replace('(', '').replace(')', '').
                            replace('-', '').replace(':', '').replace(';', '').replace('*', '').split()) < 100
                        ]

            keep_ids = sorted(keep_ids)

            # top_sent_indexes = top_sent_indexes[top_sent_indexes]
            p_src = p_src[keep_ids]
            sent_scores = sent_scores[keep_ids]
            paper_sent_true_labels = paper_sent_true_labels[keep_ids]

            sent_scores = np.asarray([s - 1.00 for s in sent_scores])

            selected_ids_unsorted = np.argsort(-sent_scores, 0)

            _pred = []
            for j in selected_ids_unsorted:
                if (j >= len(p_src)):
                    continue
                candidate = p_src[j].strip()
                if True:
                    # if (not _block_tri(candidate, _pred)):
                    _pred.append((candidate, j))

                if (len(_pred) == PRED_LEN):
                    break
            _pred = sorted(_pred, key=lambda x: x[1])
            _pred_final_str = '<q>'.join([x[0] for x in _pred])

            preds[p_id] = _pred_final_str
            golds[p_id] = paper_tgts[p_id]
            preds_with_idx[p_id] = _pred
            if p_idx > 10:
                f, p, r = _get_precision_(paper_sent_true_labels, [p[1] for p in _pred])
                if self.is_joint:
                    acc_whole = _get_accuracy_sections(sent_sects_true, sent_sects_pred, [p[1] for p in _pred])
                    acc_total += acc_whole

            else:
                f, p, r = _get_precision_(paper_sent_true_labels, [p[1] for p in _pred], print_few=True, p_id=p_id)
                if self.is_joint:
                    acc_whole = _get_accuracy_sections(sent_sects_true, sent_sects_pred, [p[1] for p in _pred],
                                                       print_few=True, p_id=p_id)
                    acc_total += acc_whole

            acum_f_sent_labels += f
            acum_p_sent_labels += p
            acum_r_sent_labels += r

        for id, pred in preds.items():
            save_pred.write(pred.strip().replace('<q>', ' ') + '\n')
            save_gold.write(golds[id].replace('<q>', ' ').strip() + '\n')

        print(f'Gold: {gold_path}')
        print(f'Prediction: {can_path}')

        r1, r2, rl = self._report_rouge(preds.values(), golds.values())
        stats.set_rl(r1, r2, rl)
        logger.info("F-score: %4.4f, Prec: %4.4f, Recall: %4.4f" % (
        acum_f_sent_labels / len(sent_scores_whole), acum_p_sent_labels / len(sent_scores_whole),
        acum_r_sent_labels / len(sent_scores_whole)))
        if self.is_joint:
            logger.info("Section Accuracy: %4.4f" % (acc_total / len(sent_scores_whole)))


        stats.set_ir_metrics(acum_f_sent_labels / len(sent_scores_whole),
                             acum_p_sent_labels / len(sent_scores_whole),
                             acum_r_sent_labels / len(sent_scores_whole))
        self.valid_rgls.append((r2 + rl) / 2)
        self._report_step(0, step,
                          self.model.uncertainty_loss._sigmas_sq[0] if self.is_joint else 0,
                          self.model.uncertainty_loss._sigmas_sq[1] if self.is_joint else 0,
                          valid_stats=stats)

        if len(self.valid_rgls) > 0:
            if self.min_rl < self.valid_rgls[-1]:
                self.min_rl = self.valid_rgls[-1]
                best_model_saved = True

        return stats, best_model_saved, best_recall_model_saved
Exemple #8
0
    def validate(self, valid_iter, step=0):
        """ Validate model.
            valid_iter: validate data iterator
        Returns:
            :obj:`nmt.Statistics`: validation loss statistics
        """
        # Set model in validating mode.
        self.model.eval()
        if self.args.acc_reporter:
            stats = acc_reporter.Statistics()
        else:
            stats = Statistics()

        with torch.no_grad():
            for batch in valid_iter:
                # src = batch.src
                # labels = batch.src_sent_labels
                # segs = batch.segs
                # clss = batch.clss
                # mask = batch.mask_src
                # mask_cls = batch.mask_cls

                # sent_scores, mask = self.model(src, segs, clss, mask, mask_cls)
                if self.args.jigsaw == 'jigsaw_lab':  # jigsaw_lab 3.31 23:38 发现之前忘了改validate, 早上起来再跑一次看看
                    logits = self.model(batch.src_s, batch.segs_s, batch.clss_s, batch.mask_src_s, batch.mask_cls_s)
                    # bsz, sent, max-sent_num
                    # mask = batch.mask_cls_s[:, :, None].float()
                    # loss = self.loss(sent_scores, batch.poss_s.float())
                    loss = F.nll_loss(
                        F.log_softmax(
                            logits.view(-1, logits.size(-1)),
                            dim=-1,
                            dtype=torch.float32,
                        ),
                        batch.poss_s.view(-1),  # bsz sent
                        reduction='sum',
                        ignore_index=-1,
                    )
                    prediction = torch.argmax(logits, dim=-1)
                    if (self.optim._step + 1) % self.args.print_every == 0:
                        logger.info(
                            'train prediction: %s |label %s ' % (str(prediction), str(batch.poss_s)))
                    accuracy = torch.div(torch.sum(torch.equal(prediction, batch.poss_s) * batch.mask_cls_s),
                                         torch.sum(batch.mask_cls_s)) * len(logits)
                elif self.args.jigsaw == 'jigsaw_dec':  # jigsaw decoder
                    poss_s = batch.poss_s
                    mask_poss = torch.eq(poss_s, -1)
                    poss_s.masked_fill_(mask_poss, 1e4)
                    # poss_s[i] [5,1,4,0,2,3,-1,-1]->[5,1,4,0,2,3,1e4,1e4]
                    dec_labels = torch.argsort(poss_s, dim=1)  # dec_labels[i] [3,1,xxx,6,7]
                    logits = self.model(batch.src_s, batch.segs_s, batch.clss_s, batch.mask_src_s, batch.mask_cls_s,
                                        dec_labels)
                    final_dec_labels = dec_labels.masked_fill(mask_poss, -1)  # final_dec_labels[i] [3,1,xxx,-1,-1]
                    loss = F.nll_loss(
                        F.log_softmax(
                            logits.view(-1, logits.size(-1)),
                            dim=-1,
                            dtype=torch.float32,
                        ),
                        final_dec_labels.view(-1),  # bsz sent
                        reduction='sum',
                        ignore_index=-1,
                    )

                    # loss = (loss * batch.mask_cls_s.float()).sum()
                    prediction = torch.argmax(logits, dim=-1)
                    if (self.optim._step + 1) % self.args.print_every == 0:
                        logger.info(
                            'train prediction: %s |label %s ' % (str(prediction), str(batch.poss_s)))
                    accuracy = torch.div(torch.sum(torch.equal(prediction, batch.final_dec_labels) * batch.mask_cls_s),
                                         torch.sum(batch.mask_cls_s)) * len(logits)


                # loss = self.loss(sent_scores, labels.float())
                # loss = (loss * mask.float()).sum()
                if self.args.acc_reporter:
                    batch_stats = acc_reporter.Statistics(float(loss.cpu().data.numpy()), accuracy, len(batch.poss_s))
                else:
                    batch_stats = Statistics(float(loss.cpu().data.numpy()), len(batch.poss_s))
                stats.update(batch_stats)

            self._report_step(0, step, valid_stats=stats)
            return stats
    def test1(self, test_iter, step, cal_lead=False, cal_oracle=False):
        """ Validate model.
            valid_iter: validate data iterator
        Returns:
            :obj:`nmt.Statistics`: validation loss statistics
        """

        # Set model in validating mode.
        def _get_ngrams(n, text):
            ngram_set = set()
            text_length = len(text)
            max_index_ngram_start = text_length - n
            for i in range(max_index_ngram_start + 1):
                ngram_set.add(tuple(text[i:i + n]))
            return ngram_set

        def _block_tri(c, p):
            tri_c = _get_ngrams(3, c.split())
            for s in p:
                tri_s = _get_ngrams(3, s.split())
                if len(tri_c.intersection(tri_s)) > 0:
                    return True
            return False

        if (not cal_lead and not cal_oracle):
            self.model.eval()
        stats = Statistics()

        output = ''
        with torch.no_grad():
            for batch in test_iter:
                src = batch.src
                segs = batch.segs
                clss = batch.clss
                mask = batch.mask_src
                mask_cls = batch.mask_cls
                gold = []
                pred = []
                if (cal_lead):
                    selected_ids = [list(range(batch.clss.size(1)))
                                    ] * batch.batch_size
                elif (cal_oracle):
                    labels = batch.src_sent_labels
                    selected_ids = [[
                        j for j in range(batch.clss.size(1))
                        if labels[i][j] == 1
                    ] for i in range(batch.batch_size)]
                else:
                    # logger.info("src:%s, segs:%s, clss:%s, mask:%s, mask_cls:%s" % (
                    #    src, segs, clss, mask, mask_cls))
                    sent_scores, mask = self.model(src, segs, clss, mask,
                                                   mask_cls)

                    sent_scores = sent_scores + mask.float()
                    sent_scores = sent_scores.cpu().data.numpy()
                    selected_ids = np.argsort(-sent_scores, 1)

                    if (hasattr(batch, 'src_sent_labels')):
                        labels = batch.src_sent_labels
                        loss = self.loss(sent_scores, labels.float())
                        loss = (loss * mask.float()).sum()
                        batch_stats = Statistics(
                            float(loss.cpu().data.numpy()), len(labels))
                        stats.update(batch_stats)

                for i, idx in enumerate(selected_ids):
                    _pred = []
                    if (len(batch.src_str[i]) == 0):
                        continue
                    for j in selected_ids[i][:len(batch.src_str[i])]:
                        if (j >= len(batch.src_str[i])):
                            continue
                        candidate = batch.src_str[i][j].strip()
                        if (self.args.block_trigram):
                            if (not _block_tri(candidate, _pred)):
                                _pred.append(candidate)
                        else:
                            _pred.append(candidate)

                        if ((not cal_oracle) and (not self.args.recall_eval)
                                and len(_pred) == 3):
                            break

                    _pred = ' '.join(_pred)
                    if (self.args.recall_eval):
                        _pred = ' '.join(
                            _pred.split()[:len(batch.tgt_str[i].split())])

                    pred.append(_pred)
                    gold.append(batch.tgt_str[i])

        return ' '.join(pred)