Ejemplo n.º 1
0
    def validate(self, valid_iter, step=0):
        """ Validate model.
            valid_iter: validate data iterator
        Returns:
            :obj:`nmt.Statistics`: validation loss statistics
        """
        # Set model in validating mode.
        self.model.eval()
        stats = Statistics()

        with torch.no_grad():
            for batch in valid_iter:

                src = batch.src
                labels = batch.labels
                segs = batch.segs
                clss = batch.clss
                mask = batch.mask
                mask_cls = batch.mask_cls

                sent_scores, mask = self.model(src, segs, clss, mask, mask_cls)


                loss = self.loss(sent_scores, labels.float())
                loss = (loss * mask.float()).sum()
                batch_stats = Statistics(float(loss.cpu().data.numpy()), len(labels))
                stats.update(batch_stats)
            self._report_step(0, step, valid_stats=stats)
            return stats
Ejemplo n.º 2
0
    def validate(self, valid_iter):
        # Set model in validating mode.
        self.model.eval()
        stats = Statistics()

        with torch.no_grad():
            for batch in valid_iter:
                src = batch.src
                src_lengths = batch.src_length
                labels = batch.labels

                if(self.args.structured):
                    roots, mask = self.model(src, labels, src_lengths)
                    r = torch.clamp(roots[-1], 1e-5, 1 - 1e-5)
                    loss = self.loss(r, labels)

                else:
                    sent_scores, mask = self.model(src, labels, src_lengths)
                    loss = self.loss(sent_scores, labels)
                loss = (loss * mask.float()).sum()
                batch_stats = Statistics(float(loss.cpu().data.numpy()), len(labels))

                stats.update(batch_stats)
            return stats
Ejemplo n.º 3
0
    def test(self, test_iter, step, cal_lead=False, cal_oracle=False):
        """ Validate model.
            valid_iter: validate data iterator
        Returns:
            :obj:`nmt.Statistics`: validation loss statistics
        """
        # Set model in validating mode.
        def _get_ngrams(n, text):
            ngram_set = set()
            text_length = len(text)
            max_index_ngram_start = text_length - n
            for i in range(max_index_ngram_start + 1):
                ngram_set.add(tuple(text[i:i + n]))
            return ngram_set

        def _block_tri(c, p):
            tri_c = _get_ngrams(3, c.split())
            for s in p:
                tri_s = _get_ngrams(3, s.split())
                if len(tri_c.intersection(tri_s))>0:
                    return True
            return False

        if (not cal_lead and not cal_oracle):
            self.model.eval()
        stats = Statistics()

        can_path = '%s_step%d.candidate'%(self.args.result_path,step)
        gold_path = '%s_step%d.gold' % (self.args.result_path, step)
        with open(can_path, 'w') as save_pred:
            with open(gold_path, 'w') as save_gold:
                with torch.no_grad():
                    for batch in test_iter:
                        src = batch.src
                        labels = batch.labels
                        segs = batch.segs
                        clss = batch.clss
                        mask = batch.mask
                        mask_cls = batch.mask_cls


                        gold = []
                        pred = []

                        if (cal_lead):
                            selected_ids = [list(range(batch.clss.size(1)))] * batch.batch_size
                        elif (cal_oracle):
                            selected_ids = [[j for j in range(batch.clss.size(1)) if labels[i][j] == 1] for i in
                                            range(batch.batch_size)]
                        else:
                            sent_scores, mask = self.model(src, segs, clss, mask, mask_cls)

                            loss = self.loss(sent_scores, labels.float())
                            loss = (loss * mask.float()).sum()
                            batch_stats = Statistics(float(loss.cpu().data.numpy()), len(labels))
                            stats.update(batch_stats)

                            sent_scores = sent_scores + mask.float()
                            sent_scores = sent_scores.cpu().data.numpy()
                            selected_ids = np.argsort(-sent_scores, 1)
                        # selected_ids = np.sort(selected_ids,1)
                        for i, idx in enumerate(selected_ids):
                            _pred = []

                            if(len(batch.src_str[i])==0):
                                continue

                            for j in selected_ids[i][:len(batch.src_str[i])]:
                                if(j>=len( batch.src_str[i])):
                                    continue
                                candidate = batch.src_str[i][j].strip()
                                if(self.args.block_trigram):
                                    if(not _block_tri(candidate,_pred)):
                                        _pred.append(candidate)
                                else:
                                    _pred.append(candidate)

                                if ((not cal_oracle) and (not self.args.recall_eval) and len(_pred) == 3):
                                    break

                            _pred = '<q>'.join(_pred)
                            _pred=_pred+"   original txt:   "+" ".join(batch.src_str[i])
                            if(self.args.recall_eval):
                                _pred = ' '.join(_pred.split()[:len(batch.tgt_str[i].split())])

                            pred.append(_pred)
                            gold.append(batch.tgt_str[i])

                        for i in range(len(gold)):
                            save_gold.write(gold[i].strip()+'\n')
                        for i in range(len(pred)):
                            save_pred.write(pred[i].strip()+'\n')
        if(step!=-1 and self.args.report_rouge):
            rouges = test_rouge(self.args.temp_dir, can_path, gold_path)
            logger.info('Rouges at step %d \n%s' % (step, rouge_results_to_str(rouges)))
        self._report_step(0, step, valid_stats=stats)

        return stats
Ejemplo n.º 4
0
    def predict(self,
                test_iter,
                step,
                result_file='predicted_titles.csv',
                cal_lead=False,
                cal_oracle=False):
        """ Predict model.
            test_iter: predict data iterator
        Returns:
            :obj:`nmt.Statistics`: predict loss statistics
        """

        # Set model in validating mode.
        def _get_ngrams(n, text):
            ngram_set = set()
            text_length = len(text)
            max_index_ngram_start = text_length - n
            for i in range(max_index_ngram_start + 1):
                ngram_set.add(tuple(text[i:i + n]))
            return ngram_set

        def _block_tri(c, p):
            tri_c = _get_ngrams(3, c.split())
            for s in p:
                tri_s = _get_ngrams(3, s.split())
                if len(tri_c.intersection(tri_s)) > 0:
                    return True
            return False

        if (not cal_lead and not cal_oracle):
            self.model.eval()
        stats = Statistics()

        can_path = '%s_step%d.results' % (self.args.result_path, step)
        pred = []
        source = []

        with open(can_path, 'w') as save_pred:
            with torch.no_grad():
                for batch in test_iter:
                    src = batch.src
                    labels = batch.labels
                    segs = batch.segs
                    clss = batch.clss
                    mask = batch.mask
                    mask_cls = batch.mask_cls

                    if (cal_lead):
                        selected_ids = [list(range(batch.clss.size(1)))
                                        ] * batch.batch_size
                    elif (cal_oracle):
                        selected_ids = [[
                            j for j in range(batch.clss.size(1))
                            if labels[i][j] == 1
                        ] for i in range(batch.batch_size)]
                    else:
                        sent_scores, mask = self.model(src, segs, clss, mask,
                                                       mask_cls)

                        loss = self.loss(sent_scores, labels.float())
                        loss = (loss * mask.float()).sum()
                        batch_stats = Statistics(
                            float(loss.cpu().data.numpy()), len(labels))
                        stats.update(batch_stats)

                        sent_scores = sent_scores + mask.float()
                        sent_scores = sent_scores.cpu().data.numpy()
                        selected_ids = np.argsort(-sent_scores, 1)
                    # selected_ids = np.sort(selected_ids,1)
                    for i, idx in enumerate(selected_ids):
                        _pred = []
                        if (len(batch.src_str[i]) == 0):
                            continue
                        for j in selected_ids[i][:len(batch.src_str[i])]:
                            if (j >= len(batch.src_str[i])):
                                continue
                            candidate = batch.src_str[i][j].strip()
                            if (self.args.block_trigram):
                                if (not _block_tri(candidate, _pred)):
                                    _pred.append(candidate)
                            else:
                                _pred.append(candidate)

                            if ((not cal_oracle)
                                    and (not self.args.recall_eval)
                                    and len(_pred) == 3):
                                break

                        _pred = '<q>'.join(_pred)
                        if (self.args.recall_eval):
                            _pred = ' '.join(
                                _pred.split()[:len(batch.tgt_str[i].split())])

                        pred.append(_pred)
                        source.append(''.join(batch.src_str[i]))

        submission_df = pd.DataFrame({'abstract': source, 'title': pred})
        submission_df.to_csv(result_file, index=False)

        # for i in range(len(pred)):
        #     save_pred.write( pred[i].strip().replace(chr(240), "").replace('"', '').replace("'", "").rstrip() + ' ' + chr(240) + ' ' + source[i].strip().replace(chr(240), "").replace('"', '').replace("'", "").rstrip() + ' ' + chr(240) + '\n' )
        return stats
Ejemplo n.º 5
0
    def summary(self, test_iter, step, cal_lead=False, cal_oracle=False):
        """ Validate model.
            valid_iter: validate data iterator
        Returns:
            :obj:`nmt.Statistics`: validation loss statistics
        """
        # Set model in validating mode.
        def _get_ngrams(n, text):
            ngram_set = set()
            text_length = len(text)
            max_index_ngram_start = text_length - n
            for i in range(max_index_ngram_start + 1):
                ngram_set.add(tuple(text[i:i + n]))
            return ngram_set

        def _block_tri(c, p):
            tri_c = _get_ngrams(3, c.split())
            for s in p:
                tri_s = _get_ngrams(3, s.split())
                if len(tri_c.intersection(tri_s))>0:
                    return True
            return False

        if (not cal_lead and not cal_oracle):
            self.model.eval()
        stats = Statistics()

        
        with torch.no_grad():
            for batch in test_iter:
                src = batch.src
                labels = batch.labels
                segs = batch.segs
                clss = batch.clss
                mask = batch.mask
                mask_cls = batch.mask_cls


                gold = []
                pred = []

                if (cal_lead):
                    selected_ids = [list(range(batch.clss.size(1)))] * batch.batch_size
                elif (cal_oracle):
                    selected_ids = [[j for j in range(batch.clss.size(1)) if labels[i][j] == 1] for i in
                                    range(batch.batch_size)]
                else:
                    sent_scores, mask = self.model(src, segs, clss, mask, mask_cls)

                    loss = self.loss(sent_scores, labels.float())
                    loss = (loss * mask.float()).sum()
                    batch_stats = Statistics(float(loss.cpu().data.numpy()), len(labels))
                    stats.update(batch_stats)

                    sent_scores = sent_scores + mask.float()
                    sent_scores = sent_scores.cpu().data.numpy()
                    selected_ids = np.argsort(-sent_scores, 1)
                # selected_ids = np.sort(selected_ids,1)
                for i, idx in enumerate(selected_ids):
                    _pred = []
                    if(len(batch.src_str[i])==0):
                        continue
                    for j in selected_ids[i][:len(batch.src_str[i])]:
                        if(j>=len(batch.src_str[i])):
                            continue
                        candidate = batch.src_str[i][j].strip()
                        if(self.args.block_trigram):
                            if(not _block_tri(candidate,_pred)):
                                _pred.append(candidate)
                        else:
                            _pred.append(candidate)

                        if ((not cal_oracle) and (not self.args.recall_eval) and len(_pred) == 3):
                            break

                    _pred = '<q>'.join(_pred)
                    if(self.args.recall_eval):
                        _pred = ' '.join(_pred.split()[:len(batch.tgt_str[i].split())])

                    pred.append(_pred)
                    gold.append(batch.tgt_str[i])

                print(' '.join(pred), ' '.join(gold))

        return pred
Ejemplo n.º 6
0
    def test(self, test_iter, step, cal_lead=False, cal_oracle=False):
        """ Validate model.
            valid_iter: validate data iterator
        Returns:
            :obj:`nmt.Statistics`: validation loss statistics
        """
        # Set model in validating mode.
        def _get_ngrams(n, text):
            ngram_set = set()
            text_length = len(text)
            max_index_ngram_start = text_length - n
            for i in range(max_index_ngram_start + 1):
                ngram_set.add(tuple(text[i:i + n]))
            return ngram_set

        def _block_tri(c, p):
            tri_c = _get_ngrams(3, c.split())
            for s in p:
                tri_s = _get_ngrams(3, s.split())
                if len(tri_c.intersection(tri_s))>0:
                    return True
            return False

        if (not cal_lead and not cal_oracle):
            self.model.eval()
        stats = Statistics()

        base_dir = os.path.dirname(self.args.result_path)
        if (not os.path.exists(base_dir)):
            os.makedirs(base_dir)

        can_path = '%s_step%d_initial.candidate'%(self.args.result_path,step)
        gold_path = '%s_step%d_initial.gold' % (self.args.result_path, step)

        all_pred_ids, all_gold_ids, all_doc_ids = [], [], []
        all_gold_texts, all_pred_texts = [], []

        with torch.no_grad():
            for batch in test_iter:
                src = batch.src
                labels = batch.labels
                segs = batch.segs
                clss = batch.clss
                mask = batch.mask
                mask_cls = batch.mask_cls
                doc_ids = batch.doc_id
                group_idxs = batch.groups

                oracle_ids = [set([j for j in seq if j > -1]) for seq in batch.label_seq.tolist()]

                if (cal_lead):
                    selected_ids = [list(range(batch.clss.size(1)))] * batch.batch_size
                elif (cal_oracle):
                    selected_ids = [[j for j in range(batch.clss.size(1)) if labels[i][j] == 1] for i in
                                    range(batch.batch_size)]
                else:
                    sent_scores, mask = self.model(src, mask, segs, clss, mask_cls, group_idxs, candi_sent_masks=mask_cls, is_test=True)
                    #selected sentences in candi_masks can be set to 0
                    loss = -self.logsoftmax(sent_scores) * labels.float() #batch_size, max_sent_count
                    loss = (loss*mask.float()).sum()

                    batch_stats = Statistics(float(loss.cpu().data.numpy()), len(labels))
                    stats.update(batch_stats)

                    sent_scores[mask==False] = float('-inf')
                    # give a cap 1 to sentscores, so no need to add 1000
                    sent_scores = sent_scores.cpu().data.numpy()
                    selected_ids = np.argsort(-sent_scores, 1)
                for i, idx in enumerate(selected_ids):
                    _pred = []
                    _pred_ids = []
                    if(len(batch.src_str[i])==0):
                        continue
                    for j in selected_ids[i][:len(batch.src_str[i])]:
                        if(j>=len( batch.src_str[i])):
                            continue
                        candidate = batch.src_str[i][j].strip()
                        if(self.args.block_trigram):
                            if(not _block_tri(candidate,_pred)):
                                _pred.append(candidate)
                                _pred_ids.append(j)
                        else:
                            _pred.append(candidate)
                            _pred_ids.append(j)

                        if ((not cal_oracle) and (not self.args.recall_eval) and len(_pred) == 3):
                            break

                    _pred = '<q>'.join(_pred)
                    if(self.args.recall_eval):
                        _pred = ' '.join(_pred.split()[:len(batch.tgt_str[i].split())])

                    all_pred_texts.append(_pred)
                    all_pred_ids.append(_pred_ids)
                    all_gold_texts.append(batch.tgt_str[i])
                    all_gold_ids.append(oracle_ids[i])
                    all_doc_ids.append(doc_ids[i])
        macro_precision, micro_precision = self._output_predicted_summaries(
                all_doc_ids, all_pred_ids, all_gold_ids,
                all_pred_texts, all_gold_texts, can_path, gold_path)
        rouge1_arr, rouge2_arr = du.cal_rouge_score(all_pred_texts, all_gold_texts)
        rouge_1, rouge_2 = du.aggregate_rouge(rouge1_arr, rouge2_arr)
        logger.info('[PERF]At step %d: rouge1:%.2f rouge2:%.2f' % (
            step, rouge_1 * 100, rouge_2 * 100))

        if(step!=-1 and self.args.report_precision):
            macro_arr = ["P@%s:%.2f%%" % (i+1, macro_precision[i] * 100) for i in range(3)]
            micro_arr = ["P@%s:%.2f%%" % (i+1, micro_precision[i] * 100) for i in range(3)]
            logger.info('[PERF]MacroPrecision at step %d: %s' % (step, '\t'.join(macro_arr)))
            logger.info('[PERF]MicroPrecision at step %d: %s' % (step, '\t'.join(micro_arr)))

        if(step!=-1 and self.args.report_rouge):
            rouge_str, detail_rouge = test_rouge(self.args.temp_dir, can_path, gold_path, all_doc_ids, show_all=True)
            logger.info('[PERF]Rouges at step %d: %s \n' % (step, rouge_str))
            result_path = '%s_step%d_initial.rouge' % (self.args.result_path, step)
            if detail_rouge is not None:
                du.output_rouge_file(result_path, rouge1_arr, rouge2_arr, detail_rouge, all_doc_ids)
        self._report_step(0, step, valid_stats=stats)

        return stats
Ejemplo n.º 7
0
    def validate(self, valid_iter, step=0, valid_by_rouge=False):
        """ Validate model.
            valid_iter: validate data iterator
        Returns:
            :obj:`nmt.Statistics`: validation loss statistics
        """
        # Set model in validating mode.
        self.model.eval()
        stats = Statistics()

        with torch.no_grad():
            for batch in valid_iter:
                src, labels, segs = batch.src, batch.labels, batch.segs
                clss, mask, mask_cls = batch.clss, batch.mask, batch.mask_cls
                #group_idxs, pair_masks = batch.groups, batch.pair_masks
                group_idxs = batch.groups
                soft_labels = batch.soft_labels
                candi_masks = batch.candi_masks
                if valid_by_rouge:
                    src_str, tgt_str = batch.src_str, batch.tgt_str
                    # add negative rouge score as loss to be used as a criterion
                    sel_sent_idxs, sel_sent_masks = self.model.infer_sentences(batch, 3)
                    sel_sent_idxs = sel_sent_idxs.tolist()
                    total_rouge = 0.
                    for i in range(len(sel_sent_idxs)):
                        rouge = du.cal_rouge_doc(src_str[i], tgt_str[i], sel_sent_idxs[i], sel_sent_masks[i])
                        total_rouge += rouge
                    loss = -total_rouge
                else:
                    if self.args.model_name == 'seq':
                        sent_scores, _ = self.model(src, mask, segs, clss, mask_cls, group_idxs,
                                pair_masks, sel_sent_idxs=sel_sent_idxs, sel_sent_masks=sel_sent_masks,
                                candi_sent_masks=candi_masks)
                        #batch, seq_len, sent_count
                        pred = sent_scores.contiguous().view(-1, sent_scores.size(2))
                        gold = batch.label_seq.contiguous().view(-1)
                        if self.args.use_rouge_label:
                            soft_labels = soft_labels.contiguous().view(-1, soft_labels.size(2))
                            #batch*seq_len, sent_count
                            log_prb = F.log_softmax(pred, dim=1)
                            non_pad_mask = gold.ne(-1) # padding value
                            sent_mask = mask_cls.unsqueeze(1).expand(-1,sent_scores.size(1),-1)
                            sent_mask = sent_mask.contiguous().view(-1, sent_scores.size(2))
                            loss = -((soft_labels * log_prb) * sent_mask.float()).sum(dim=1)
                            loss = loss.masked_select(non_pad_mask).sum()  # average later
                        else:
                            loss = F.cross_entropy(pred, gold, ignore_index=-1, reduction='sum')
                    else:
                        sel_sent_idxs, sel_sent_masks = batch.sel_sent_idxs, batch.sel_sent_masks
                        sent_scores, _ = self.model(src, mask, segs, clss, mask_cls, group_idxs, \
                                sel_sent_idxs=sel_sent_idxs,
                                sel_sent_masks=batch.sel_sent_masks, candi_sent_masks=candi_masks, is_test=True,
                                sel_sent_hit_map=batch.hit_map)
                        if self.args.use_rouge_label:
                            labels = soft_labels
                        if self.args.loss == "bce":
                            loss = self.bce_logits_loss(sent_scores, labels.float()) #pointwise
                        elif self.args.loss == "wsoftmax":
                            loss = -self.logsoftmax(sent_scores) * labels.float()
                            #weighted average
                        else:
                            sum_labels = labels.sum(dim=-1).unsqueeze(-1).expand_as(labels)
                            labels = torch.where(sum_labels==0, labels, labels/sum_labels)
                            loss = -self.logsoftmax(sent_scores) * labels.float()

                        #batch_size, max_sent_count
                        loss = (loss*candi_masks.float()).sum()
                        loss = float(loss.cpu().data.numpy())
                batch_stats = Statistics(loss, len(labels))
                stats.update(batch_stats)
            self._report_step(0, step, valid_stats=stats)
            return stats