def _report_rouge(self, gold_path, can_path): self.logger.info("Calculating Rouge") candidates = codecs.open(can_path, encoding="utf-8") references = codecs.open(gold_path, encoding="utf-8") if self.args.rouge_path is None: results_dict = test_rouge(candidates, references, 1) else: results_dict = test_rouge(candidates, references, 0, rouge_dir=os.path.join( os.getcwd(), self.args.rouge_path)) return results_dict
def test(self, test_iter, step, cal_lead=False, cal_oracle=False): """ Validate model. valid_iter: validate data iterator Returns: :obj:`nmt.Statistics`: validation loss statistics """ # Set model in validating mode. def _get_ngrams(n, text): ngram_set = set() text_length = len(text) max_index_ngram_start = text_length - n for i in range(max_index_ngram_start + 1): ngram_set.add(tuple(text[i:i + n])) return ngram_set def _block_tri(c, p): tri_c = _get_ngrams(3, c.split()) for s in p: tri_s = _get_ngrams(3, s.split()) if len(tri_c.intersection(tri_s)) > 0: return True return False if (not cal_lead and not cal_oracle): self.model.eval() stats = Statistics() can_path = '%s_step%d.candidate' % (self.args.result_path, step) gold_path = '%s_step%d.gold' % (self.args.result_path, step) with open(can_path, 'w') as save_pred: with open(gold_path, 'w') as save_gold: with torch.no_grad(): for batch in test_iter: src = batch.src labels = batch.src_sent_labels #segs = batch.segs clss = batch.clss #mask = batch.mask_src mask_cls = batch.mask_cls gold = [] pred = [] if (cal_lead): print('not implemented!') exit(1) #selected_ids = [list(range(batch.clss.size(1)))] * batch.batch_size elif (cal_oracle): print('not implemented!') exit(1) #selected_ids = [[j for j in range(batch.clss.size(1)) if labels[i][j] == 1] for i in #range(batch.batch_size)] else: sent_scores, mask = self.model(src, clss, mask_cls) loss = self.loss(sent_scores, labels.float()) loss = (loss * mask.float()).sum() batch_stats = Statistics( float(loss.cpu().data.numpy()), len(labels)) stats.update(batch_stats) sent_scores = sent_scores + mask.float() sent_scores = sent_scores.cpu().data.numpy() selected_ids = np.argsort(-sent_scores, 1) # selected_ids = np.sort(selected_ids,1) for i, idx in enumerate(selected_ids): _pred = [] if (len(batch.src_str[i]) == 0): continue for j in selected_ids[i][:len(batch.src_str[i])]: if (j >= len(batch.src_str[i])): continue candidate = batch.src_str[i][j].strip() if (self.args.block_trigram): if (not _block_tri(candidate, _pred)): _pred.append(candidate) else: _pred.append(candidate) if ((not cal_oracle) and (not self.args.recall_eval) and len(_pred) == 3): break _pred = '<q>'.join(_pred) if (self.args.recall_eval): _pred = ' '.join( _pred.split() [:len(batch.tgt_str[i].split())]) pred.append(_pred) gold.append(batch.tgt_str[i]) for i in range(len(gold)): save_gold.write(gold[i].strip() + '\n') for i in range(len(pred)): save_pred.write(pred[i].strip() + '\n') if (step != -1 and self.args.report_rouge): #raise NotImplementedError self.logger.info("Calculating Rouge") candidates = codecs.open(can_path, encoding="utf-8") references = codecs.open(gold_path, encoding="utf-8") rouges = test_rouge(candidates, references, 1) self.logger.info('Rouges at step %d \n%s' % (step, rouge_results_to_str(rouges))) if self.tensorboard_writer is not None: self.tensorboard_writer.add_scalar('test/rouge1-F', rouges['rouge_1_f_score'], step) self.tensorboard_writer.add_scalar('test/rouge2-F', rouges['rouge_2_f_score'], step) self.tensorboard_writer.add_scalar('test/rougeL-F', rouges['rouge_l_f_score'], step) self._report_step(0, step, valid_stats=stats) return stats
def _report_rouge(self, gold_path, can_path): self.logger.info("Calculating Rouge") candidates = codecs.open(can_path, encoding="utf-8") references = codecs.open(gold_path, encoding="utf-8") results_dict = test_rouge(candidates, references, 1) return results_dict