def _report_rouge(self, gold_path, can_path):
     self.logger.info("Calculating Rouge")
     candidates = codecs.open(can_path, encoding="utf-8")
     references = codecs.open(gold_path, encoding="utf-8")
     if self.args.rouge_path is None:
         results_dict = test_rouge(candidates, references, 1)
     else:
         results_dict = test_rouge(candidates,
                                   references,
                                   0,
                                   rouge_dir=os.path.join(
                                       os.getcwd(), self.args.rouge_path))
     return results_dict
    def test(self, test_iter, step, cal_lead=False, cal_oracle=False):
        """ Validate model.
            valid_iter: validate data iterator
        Returns:
            :obj:`nmt.Statistics`: validation loss statistics
        """

        # Set model in validating mode.
        def _get_ngrams(n, text):
            ngram_set = set()
            text_length = len(text)
            max_index_ngram_start = text_length - n
            for i in range(max_index_ngram_start + 1):
                ngram_set.add(tuple(text[i:i + n]))
            return ngram_set

        def _block_tri(c, p):
            tri_c = _get_ngrams(3, c.split())
            for s in p:
                tri_s = _get_ngrams(3, s.split())
                if len(tri_c.intersection(tri_s)) > 0:
                    return True
            return False

        if (not cal_lead and not cal_oracle):
            self.model.eval()
        stats = Statistics()

        can_path = '%s_step%d.candidate' % (self.args.result_path, step)
        gold_path = '%s_step%d.gold' % (self.args.result_path, step)
        with open(can_path, 'w') as save_pred:
            with open(gold_path, 'w') as save_gold:
                with torch.no_grad():
                    for batch in test_iter:
                        src = batch.src
                        labels = batch.src_sent_labels
                        #segs = batch.segs
                        clss = batch.clss
                        #mask = batch.mask_src
                        mask_cls = batch.mask_cls

                        gold = []
                        pred = []

                        if (cal_lead):
                            print('not implemented!')
                            exit(1)
                            #selected_ids = [list(range(batch.clss.size(1)))] * batch.batch_size
                        elif (cal_oracle):
                            print('not implemented!')
                            exit(1)
                            #selected_ids = [[j for j in range(batch.clss.size(1)) if labels[i][j] == 1] for i in
                            #range(batch.batch_size)]
                        else:
                            sent_scores, mask = self.model(src, clss, mask_cls)

                            loss = self.loss(sent_scores, labels.float())
                            loss = (loss * mask.float()).sum()
                            batch_stats = Statistics(
                                float(loss.cpu().data.numpy()), len(labels))
                            stats.update(batch_stats)

                            sent_scores = sent_scores + mask.float()
                            sent_scores = sent_scores.cpu().data.numpy()
                            selected_ids = np.argsort(-sent_scores, 1)
                        # selected_ids = np.sort(selected_ids,1)
                        for i, idx in enumerate(selected_ids):
                            _pred = []
                            if (len(batch.src_str[i]) == 0):
                                continue
                            for j in selected_ids[i][:len(batch.src_str[i])]:
                                if (j >= len(batch.src_str[i])):
                                    continue
                                candidate = batch.src_str[i][j].strip()
                                if (self.args.block_trigram):
                                    if (not _block_tri(candidate, _pred)):
                                        _pred.append(candidate)
                                else:
                                    _pred.append(candidate)

                                if ((not cal_oracle)
                                        and (not self.args.recall_eval)
                                        and len(_pred) == 3):
                                    break

                            _pred = '<q>'.join(_pred)
                            if (self.args.recall_eval):
                                _pred = ' '.join(
                                    _pred.split()
                                    [:len(batch.tgt_str[i].split())])

                            pred.append(_pred)
                            gold.append(batch.tgt_str[i])

                        for i in range(len(gold)):
                            save_gold.write(gold[i].strip() + '\n')
                        for i in range(len(pred)):
                            save_pred.write(pred[i].strip() + '\n')
        if (step != -1 and self.args.report_rouge):
            #raise NotImplementedError
            self.logger.info("Calculating Rouge")
            candidates = codecs.open(can_path, encoding="utf-8")
            references = codecs.open(gold_path, encoding="utf-8")
            rouges = test_rouge(candidates, references, 1)
            self.logger.info('Rouges at step %d \n%s' %
                             (step, rouge_results_to_str(rouges)))
            if self.tensorboard_writer is not None:
                self.tensorboard_writer.add_scalar('test/rouge1-F',
                                                   rouges['rouge_1_f_score'],
                                                   step)
                self.tensorboard_writer.add_scalar('test/rouge2-F',
                                                   rouges['rouge_2_f_score'],
                                                   step)
                self.tensorboard_writer.add_scalar('test/rougeL-F',
                                                   rouges['rouge_l_f_score'],
                                                   step)

        self._report_step(0, step, valid_stats=stats)

        return stats
 def _report_rouge(self, gold_path, can_path):
     self.logger.info("Calculating Rouge")
     candidates = codecs.open(can_path, encoding="utf-8")
     references = codecs.open(gold_path, encoding="utf-8")
     results_dict = test_rouge(candidates, references, 1)
     return results_dict