def validate(self, valid_iter, step=0): """ Validate model. valid_iter: validate data iterator Returns: :obj:`nmt.Statistics`: validation loss statistics """ # Set model in validating mode. self.model.eval() stats = Statistics() with torch.no_grad(): for batch in valid_iter: src = batch.src labels = batch.src_sent_labels segs = batch.segs clss = batch.clss mask = batch.mask_src mask_cls = batch.mask_cls sent_scores, mask = self.model(src, segs, clss, mask, mask_cls) loss = self.loss(sent_scores, labels.float()) loss = (loss * mask.float()).sum() batch_stats = Statistics(float(loss.cpu().data.numpy()), len(labels)) stats.update(batch_stats) self._report_step(0, step, valid_stats=stats) return stats
def validate(self, valid_iter, step=0): """ Validate model. valid_iter: validate data iterator Returns: :obj:`nmt.Statistics`: validation loss statistics """ # Set model in validating mode. self.model.eval() stats = Statistics() with torch.no_grad(): for batch in valid_iter: src = batch.src labels = batch.src_sent_labels.float() segs = batch.segs clss = batch.clss mask = batch.mask_src mask_cls = batch.mask_cls sent_scores, mask = self.model(src, segs, clss, mask, mask_cls) loss = self.loss(sent_scores, labels) loss = (loss * mask.float()).sum() / mask.float().sum() # baogs: report accuracy abs_scores, abs_ids = torch.topk(sent_scores, 3, dim=1) abs_mask = (abs_scores > 0).float() n_sents = abs_mask.sum().item() n_correct = torch.sum( torch.gather(labels, 1, abs_ids) * abs_mask).item() batch_stats = Statistics(loss.item() * batch.batch_size, batch.batch_size, n_sents, n_correct) stats.update(batch_stats) self._report_step(0, step, valid_stats=stats) return stats
def test(self, test_iter, step, cal_lead=False, cal_oracle=False): """ Validate model. valid_iter: validate data iterator Returns: :obj:`nmt.Statistics`: validation loss statistics """ # Set model in validating mode. def _get_ngrams(n, text): ngram_set = set() text_length = len(text) max_index_ngram_start = text_length - n for i in range(max_index_ngram_start + 1): ngram_set.add(tuple(text[i:i + n])) return ngram_set def _block_tri(c, p): tri_c = _get_ngrams(3, c.split()) for s in p: tri_s = _get_ngrams(3, s.split()) if len(tri_c.intersection(tri_s)) > 0: return True return False if (not cal_lead and not cal_oracle): self.model.eval() stats = Statistics() can_path = '%s_step%d.candidate' % (self.args.result_path, step) gold_path = '%s_step%d.gold' % (self.args.result_path, step) all_preds = [] with open(can_path, 'w') as save_pred: with open(gold_path, 'w') as save_gold: with torch.no_grad(): for batch in test_iter: batch_size = batch.p_pair.size(0) batch_stats, p_scores, n_scores = self._main( batch, batch_size, is_train=False) stats.update(batch_stats) scores = [] preds = [] for i, idx in enumerate(p_scores): p = p_scores[i].cpu().data.numpy() n = n_scores[i].cpu().data.numpy() scores.append(str(p) + '\t' + str(n)) preds.append(int(p > n)) all_preds.append(int(p > n)) for i in range(len(scores)): save_gold.write(scores[i] + '\n') for i in range(len(preds)): save_pred.write(str(preds[i]) + '\n') print("**********************") print(sum(all_preds), len(all_preds)) print("ACC: ", sum(all_preds) / float(len(all_preds))) print("**********************") self._report_step(0, step, valid_stats=stats) return stats
def validate(self, valid_iter, step=0): """ Validate model. valid_iter: validate data iterator Returns: :obj:`nmt.Statistics`: validation loss statistics """ # Set model in validating mode. self.model.eval() stats = Statistics() with torch.no_grad(): for batch in valid_iter: batch_size = batch.p_pair.size(0) batch_stats, _, _ = self._main(batch, batch_size, is_train=False) stats.update(batch_stats) self._report_step(0, step, valid_stats=stats) return stats
def test(self, test_iter, step, cal_lead=False, cal_oracle=False): """ Validate model. valid_iter: validate data iterator Returns: :obj:`nmt.Statistics`: validation loss statistics """ # Set model in validating mode. def _get_ngrams(n, text): ngram_set = set() text_length = len(text) max_index_ngram_start = text_length - n for i in range(max_index_ngram_start + 1): ngram_set.add(tuple(text[i:i + n])) return ngram_set def _block_tri(c, p): tri_c = _get_ngrams(3, c.split()) for s in p: tri_s = _get_ngrams(3, s.split()) if len(tri_c.intersection(tri_s)) > 0: return True return False if (not cal_lead and not cal_oracle): self.model.eval() stats = Statistics() can_path = '%s_step%d.candidate' % (self.args.result_path, step) gold_path = '%s_step%d.gold' % (self.args.result_path, step) with open(can_path, 'w') as save_pred: with open(gold_path, 'w') as save_gold: with torch.no_grad(): ct = 0 for batch in test_iter: src = batch.src labels = batch.src_sent_labels segs = batch.segs clss = batch.clss mask = batch.mask_src mask_cls = batch.mask_cls gold = [] pred = [] if (cal_lead): selected_ids = [list(range(batch.clss.size(1))) ] * batch.batch_size elif (cal_oracle): selected_ids = [[ j for j in range(batch.clss.size(1)) if labels[i][j] == 1 ] for i in range(batch.batch_size)] else: sent_scores, mask = self.model( src, segs, clss, mask, mask_cls) loss = self.loss(sent_scores, labels.float()) loss = (loss * mask.float()).sum() batch_stats = Statistics( float(loss.cpu().data.numpy()), len(labels)) stats.update(batch_stats) sent_scores = sent_scores + mask.float() sent_scores = sent_scores.cpu().data.numpy() selected_ids = np.argsort(-sent_scores, 1) # selected_ids = np.sort(selected_ids,1) for i, idx in enumerate(selected_ids): _pred = [] if (len(batch.src_str[i]) == 0): continue for j in selected_ids[i][:len(batch.src_str[i])]: if (j >= len(batch.src_str[i])): continue candidate = batch.src_str[i][j].strip() if (self.args.block_trigram): if (not _block_tri(candidate, _pred)): _pred.append(candidate) else: _pred.append(candidate) if ((not cal_oracle) and (not self.args.recall_eval) and len(_pred) == 3): break _pred = '<q>'.join(_pred) if (self.args.recall_eval): _pred = ' '.join( _pred.split() [:len(batch.tgt_str[i].split())]) pred.append(_pred) gold.append(batch.tgt_str[i]) for i in range(len(gold)): save_gold.write(str(ct) + '\n') save_gold.write(gold[i].strip() + '\n') save_pred.write(str(ct) + '\n') save_pred.write(pred[i].strip() + '\n') ct += 1 if (step != -1 and self.args.report_rouge): rouges = test_rouge(self.args.temp_dir, can_path, gold_path) logger.info('Rouges at step %d \n%s' % (step, rouge_results_to_str(rouges))) self._report_step(0, step, valid_stats=stats) return stats
def validate(self, valid_iter, step=0): """ Validate model. valid_iter: validate data iterator Returns: :obj:`nmt.Statistics`: validation loss statistics """ # Set model in validating mode. self.model.eval() if self.args.acc_reporter == 1: stats = acc_reporter.Statistics() else: stats = Statistics() with torch.no_grad(): for batch in valid_iter: src = batch.src labels = batch.src_sent_labels segs = batch.segs clss = batch.clss mask = batch.mask_src mask_cls = batch.mask_cls if self.args.ext_sum_dec: sent_scores, mask = self.model(src, segs, clss, mask, mask_cls, labels) # B, tgt_len custom_num tgt_len = 3 _, labels_id = torch.topk(labels, k=tgt_len) # B, tgt_len labels_id, _ = torch.sort(labels_id) # nsent 100 weight_up 20 weight = torch.linspace(start=1, end=self.args.weight_up, steps=self.args.max_src_nsents).type_as(sent_scores) # self.max_class = max(self.max_class,torch.max(labels_id+1).item()) # weight = weight[:self.max_class] weight = weight[:sent_scores.size(-1)] # weight = torch.ones(self.args.max_src_nsents) loss = F.nll_loss( F.log_softmax( sent_scores.view(-1, sent_scores.size(-1)), dim=-1, dtype=torch.float32, ), labels_id.view(-1), # bsz sent weight=weight, reduction='sum', ignore_index=-1, ) prediction = torch.argmax(sent_scores, dim=-1) if (self.optim._step + 1) % self.args.print_every == 0: logger.info( 'train prediction: %s |label %s ' % (str(prediction), str(labels_id))) accuracy = torch.div(torch.sum(torch.equal(prediction, labels_id).float()), tgt_len) else: sent_scores, mask = self.model(src, segs, clss, mask, mask_cls) # B, custom_N loss = self.loss(sent_scores, labels.float()) loss = (loss * mask.float()).sum() tgt_len = 3 _, labels_id = torch.topk(labels, k=tgt_len) # B, tgt_len labels_id, _ = torch.sort(labels_id) _, prediction = torch.topk(sent_scores, k=tgt_len) prediction,_ = torch.sort(labels_id) if (self.optim._step + 1) % self.args.print_every == 0: logger.info( 'train prediction: %s |label %s ' % (str(prediction), str(labels_id))) accuracy = torch.div(torch.sum(torch.equal(prediction, labels_id).float()), tgt_len) if self.args.acc_reporter == 1: batch_stats = Statistics(float(loss.cpu().data.numpy()),accuracy, len(labels)) else: batch_stats = Statistics(float(loss.cpu().data.numpy()), len(labels)) stats.update(batch_stats) self._report_step(0, step, valid_stats=stats) return stats
def validate_rouge_baseline(self, valid_iter_fct, step=0, valid_gl_stats=None, write_scores_to_pickle=False): """ Validate model. valid_iter: validate data iterator Returns: :obj:`nmt.Statistics`: validation loss statistics """ preds = {} preds_with_idx = {} golds = {} can_path = '%s_step%d.source' % (self.args.result_path, step) gold_path = '%s_step%d.target' % (self.args.result_path, step) if step == self.best_val_step: can_path = '%s_step%d.source' % (self.args.result_path_test, step) gold_path = '%s_step%d.target' % (self.args.result_path_test, step) save_pred = open(can_path, 'w') save_gold = open(gold_path, 'w') sent_scores_whole = {} sent_sects_whole_pred = {} sent_sects_whole_true = {} sent_labels_true = {} sent_numbers_whole = {} paper_srcs = {} paper_tgts = {} sent_sect_wise_rg_whole = {} sent_sections_txt_whole = {} # Set model in validating mode. self.model.eval() stats = Statistics() best_model_saved = False best_recall_model_saved = False valid_iter = valid_iter_fct() with torch.no_grad(): for batch in tqdm(valid_iter): src = batch.src labels = batch.src_sent_labels sent_labels = batch.sent_labels if self.rg_predictor: sent_true_rg = batch.src_sent_labels else: sent_labels = batch.sent_labels segs = batch.segs clss = batch.clss mask = batch.mask_src mask_cls = batch.mask_cls p_id = batch.paper_id segment_src = batch.src_str paper_tgt = batch.tgt_str sent_sect_wise_rg = batch.sent_sect_wise_rg sent_sections_txt = batch.sent_sections_txt sent_numbers = batch.sent_numbers sent_sect_labels = batch.sent_sect_labels if self.is_joint: if not self.rg_predictor: sent_scores, sent_sect_scores, mask, loss, loss_sent, loss_sect = self.model(src, segs, clss, mask, mask_cls, sent_labels, sent_sect_labels) else: sent_scores, sent_sect_scores, mask, loss, loss_sent, loss_sect = self.model(src, segs, clss, mask, mask_cls, sent_true_rg, sent_sect_labels) acc, _ = self._get_mertrics(sent_sect_scores, sent_sect_labels, mask=mask, task='sent_sect') batch_stats = Statistics(loss=float(loss.cpu().data.numpy().sum()), loss_sect=float(loss_sect.cpu().data.numpy().sum()), loss_sent=float(loss_sent.cpu().data.numpy().sum()), n_docs=len(labels), n_acc=batch.batch_size, RMSE=self._get_mertrics(sent_scores, labels, mask=mask, task='sent'), accuracy=acc) else: if not self.rg_predictor: sent_scores, mask, loss, _, _ = self.model(src, segs, clss, mask, mask_cls, sent_labels, sent_sect_labels=None, is_inference=True) else: sent_scores, mask, loss, _, _ = self.model(src, segs, clss, mask, mask_cls, sent_true_rg, sent_sect_labels=None, is_inference=True) # sent_scores = (section_rg.unsqueeze(1).expand_as(sent_scores).to(device='cuda')*100) * sent_scores batch_stats = Statistics(loss=float(loss.cpu().data.numpy().sum()), RMSE=self._get_mertrics(sent_scores, labels, mask=mask, task='sent'), n_acc=batch.batch_size, n_docs=len(labels)) stats.update(batch_stats) sent_scores = sent_scores + mask.float() sent_scores = sent_scores.cpu().data.numpy() for idx, p_id in enumerate(p_id): p_id = p_id.split('___')[0] if p_id not in sent_scores_whole.keys(): masked_scores = sent_scores[idx] * mask[idx].cpu().data.numpy() masked_scores = masked_scores[np.nonzero(masked_scores)] masked_sent_labels_true = (sent_labels[idx] + 1) * mask[idx].long() masked_sent_labels_true = masked_sent_labels_true[np.nonzero(masked_sent_labels_true)].flatten() masked_sent_labels_true = (masked_sent_labels_true - 1) sent_scores_whole[p_id] = masked_scores sent_labels_true[p_id] = masked_sent_labels_true.cpu() masked_sents_sections_true = (sent_sect_labels[idx] + 1) * mask[idx].long() masked_sents_sections_true = masked_sents_sections_true[ np.nonzero(masked_sents_sections_true)].flatten() masked_sents_sections_true = (masked_sents_sections_true - 1) sent_sects_whole_true[p_id] = masked_sents_sections_true.cpu() if self.is_joint: masked_scores_sects = sent_sect_scores[idx] * mask[idx].view(-1, 1).expand_as( sent_sect_scores[idx]).float() masked_scores_sects = masked_scores_sects[torch.abs(masked_scores_sects).sum(dim=1) != 0] sent_sects_whole_pred[p_id] = torch.max(self.softmax(masked_scores_sects), 1)[1].cpu() paper_srcs[p_id] = segment_src[idx] if sent_numbers[0] is not None: sent_numbers_whole[p_id] = sent_numbers[idx] # sent_tokens_count_whole[p_id] = sent_tokens_count[idx] paper_tgts[p_id] = paper_tgt[idx] sent_sect_wise_rg_whole[p_id] = sent_sect_wise_rg[idx] sent_sections_txt_whole[p_id] = sent_sections_txt[idx] else: masked_scores = sent_scores[idx] * mask[idx].cpu().data.numpy() masked_scores = masked_scores[np.nonzero(masked_scores)] masked_sent_labels_true = (sent_labels[idx] + 1) * mask[idx].long() masked_sent_labels_true = masked_sent_labels_true[np.nonzero(masked_sent_labels_true)].flatten() masked_sent_labels_true = (masked_sent_labels_true - 1) sent_scores_whole[p_id] = np.concatenate((sent_scores_whole[p_id], masked_scores), 0) sent_labels_true[p_id] = np.concatenate((sent_labels_true[p_id], masked_sent_labels_true.cpu()), 0) masked_sents_sections_true = (sent_sect_labels[idx] + 1) * mask[idx].long() masked_sents_sections_true = masked_sents_sections_true[ np.nonzero(masked_sents_sections_true)].flatten() masked_sents_sections_true = (masked_sents_sections_true - 1) sent_sects_whole_true[p_id] = np.concatenate( (sent_sects_whole_true[p_id], masked_sents_sections_true.cpu()), 0) if self.is_joint: masked_scores_sects = sent_sect_scores[idx] * mask[idx].view(-1, 1).expand_as( sent_sect_scores[idx]).float() masked_scores_sects = masked_scores_sects[ torch.abs(masked_scores_sects).sum(dim=1) != 0] sent_sects_whole_pred[p_id] = np.concatenate( (sent_sects_whole_pred[p_id], torch.max(self.softmax(masked_scores_sects), 1)[1].cpu()), 0) paper_srcs[p_id] = np.concatenate((paper_srcs[p_id], segment_src[idx]), 0) if sent_numbers[0] is not None: sent_numbers_whole[p_id] = np.concatenate((sent_numbers_whole[p_id], sent_numbers[idx]), 0) # sent_tokens_count_whole[p_id] = np.concatenate( # (sent_tokens_count_whole[p_id], sent_tokens_count[idx]), 0) sent_sect_wise_rg_whole[p_id] = np.concatenate( (sent_sect_wise_rg_whole[p_id], sent_sect_wise_rg[idx]), 0) sent_sections_txt_whole[p_id] = np.concatenate( (sent_sections_txt_whole[p_id], sent_sections_txt[idx]), 0) PRED_LEN = self.args.val_pred_len acum_f_sent_labels = 0 acum_p_sent_labels = 0 acum_r_sent_labels = 0 acc_total = 0 for p_idx, (p_id, sent_scores) in enumerate(sent_scores_whole.items()): # sent_true_labels = pickle.load(open("sent_labels_files/pubmedL/val.labels.p", "rb")) # section_textual = np.array(section_textual) paper_sent_true_labels = np.array(sent_labels_true[p_id]) if self.is_joint: sent_sects_true = np.array(sent_sects_whole_true[p_id]) sent_sects_pred = np.array(sent_sects_whole_pred[p_id]) sent_scores = np.array(sent_scores) p_src = np.array(paper_srcs[p_id]) # selected_ids_unsorted = np.argsort(-sent_scores, 0) keep_ids = [idx for idx, s in enumerate(p_src) if len(s.replace('.', '').replace(',', '').replace('(', '').replace(')', ''). replace('-', '').replace(':', '').replace(';', '').replace('*', '').split()) > 5 and len(s.replace('.', '').replace(',', '').replace('(', '').replace(')', ''). replace('-', '').replace(':', '').replace(';', '').replace('*', '').split()) < 100 ] keep_ids = sorted(keep_ids) # top_sent_indexes = top_sent_indexes[top_sent_indexes] p_src = p_src[keep_ids] sent_scores = sent_scores[keep_ids] paper_sent_true_labels = paper_sent_true_labels[keep_ids] sent_scores = np.asarray([s - 1.00 for s in sent_scores]) selected_ids_unsorted = np.argsort(-sent_scores, 0) _pred = [] for j in selected_ids_unsorted: if (j >= len(p_src)): continue candidate = p_src[j].strip() if True: # if (not _block_tri(candidate, _pred)): _pred.append((candidate, j)) if (len(_pred) == PRED_LEN): break _pred = sorted(_pred, key=lambda x: x[1]) _pred_final_str = '<q>'.join([x[0] for x in _pred]) preds[p_id] = _pred_final_str golds[p_id] = paper_tgts[p_id] preds_with_idx[p_id] = _pred if p_idx > 10: f, p, r = _get_precision_(paper_sent_true_labels, [p[1] for p in _pred]) if self.is_joint: acc_whole = _get_accuracy_sections(sent_sects_true, sent_sects_pred, [p[1] for p in _pred]) acc_total += acc_whole else: f, p, r = _get_precision_(paper_sent_true_labels, [p[1] for p in _pred], print_few=True, p_id=p_id) if self.is_joint: acc_whole = _get_accuracy_sections(sent_sects_true, sent_sects_pred, [p[1] for p in _pred], print_few=True, p_id=p_id) acc_total += acc_whole acum_f_sent_labels += f acum_p_sent_labels += p acum_r_sent_labels += r for id, pred in preds.items(): save_pred.write(pred.strip().replace('<q>', ' ') + '\n') save_gold.write(golds[id].replace('<q>', ' ').strip() + '\n') print(f'Gold: {gold_path}') print(f'Prediction: {can_path}') r1, r2, rl = self._report_rouge(preds.values(), golds.values()) stats.set_rl(r1, r2, rl) logger.info("F-score: %4.4f, Prec: %4.4f, Recall: %4.4f" % ( acum_f_sent_labels / len(sent_scores_whole), acum_p_sent_labels / len(sent_scores_whole), acum_r_sent_labels / len(sent_scores_whole))) if self.is_joint: logger.info("Section Accuracy: %4.4f" % (acc_total / len(sent_scores_whole))) stats.set_ir_metrics(acum_f_sent_labels / len(sent_scores_whole), acum_p_sent_labels / len(sent_scores_whole), acum_r_sent_labels / len(sent_scores_whole)) self.valid_rgls.append((r2 + rl) / 2) self._report_step(0, step, self.model.uncertainty_loss._sigmas_sq[0] if self.is_joint else 0, self.model.uncertainty_loss._sigmas_sq[1] if self.is_joint else 0, valid_stats=stats) if len(self.valid_rgls) > 0: if self.min_rl < self.valid_rgls[-1]: self.min_rl = self.valid_rgls[-1] best_model_saved = True return stats, best_model_saved, best_recall_model_saved
def validate(self, valid_iter, step=0): """ Validate model. valid_iter: validate data iterator Returns: :obj:`nmt.Statistics`: validation loss statistics """ # Set model in validating mode. self.model.eval() if self.args.acc_reporter: stats = acc_reporter.Statistics() else: stats = Statistics() with torch.no_grad(): for batch in valid_iter: # src = batch.src # labels = batch.src_sent_labels # segs = batch.segs # clss = batch.clss # mask = batch.mask_src # mask_cls = batch.mask_cls # sent_scores, mask = self.model(src, segs, clss, mask, mask_cls) if self.args.jigsaw == 'jigsaw_lab': # jigsaw_lab 3.31 23:38 发现之前忘了改validate, 早上起来再跑一次看看 logits = self.model(batch.src_s, batch.segs_s, batch.clss_s, batch.mask_src_s, batch.mask_cls_s) # bsz, sent, max-sent_num # mask = batch.mask_cls_s[:, :, None].float() # loss = self.loss(sent_scores, batch.poss_s.float()) loss = F.nll_loss( F.log_softmax( logits.view(-1, logits.size(-1)), dim=-1, dtype=torch.float32, ), batch.poss_s.view(-1), # bsz sent reduction='sum', ignore_index=-1, ) prediction = torch.argmax(logits, dim=-1) if (self.optim._step + 1) % self.args.print_every == 0: logger.info( 'train prediction: %s |label %s ' % (str(prediction), str(batch.poss_s))) accuracy = torch.div(torch.sum(torch.equal(prediction, batch.poss_s) * batch.mask_cls_s), torch.sum(batch.mask_cls_s)) * len(logits) elif self.args.jigsaw == 'jigsaw_dec': # jigsaw decoder poss_s = batch.poss_s mask_poss = torch.eq(poss_s, -1) poss_s.masked_fill_(mask_poss, 1e4) # poss_s[i] [5,1,4,0,2,3,-1,-1]->[5,1,4,0,2,3,1e4,1e4] dec_labels = torch.argsort(poss_s, dim=1) # dec_labels[i] [3,1,xxx,6,7] logits = self.model(batch.src_s, batch.segs_s, batch.clss_s, batch.mask_src_s, batch.mask_cls_s, dec_labels) final_dec_labels = dec_labels.masked_fill(mask_poss, -1) # final_dec_labels[i] [3,1,xxx,-1,-1] loss = F.nll_loss( F.log_softmax( logits.view(-1, logits.size(-1)), dim=-1, dtype=torch.float32, ), final_dec_labels.view(-1), # bsz sent reduction='sum', ignore_index=-1, ) # loss = (loss * batch.mask_cls_s.float()).sum() prediction = torch.argmax(logits, dim=-1) if (self.optim._step + 1) % self.args.print_every == 0: logger.info( 'train prediction: %s |label %s ' % (str(prediction), str(batch.poss_s))) accuracy = torch.div(torch.sum(torch.equal(prediction, batch.final_dec_labels) * batch.mask_cls_s), torch.sum(batch.mask_cls_s)) * len(logits) # loss = self.loss(sent_scores, labels.float()) # loss = (loss * mask.float()).sum() if self.args.acc_reporter: batch_stats = acc_reporter.Statistics(float(loss.cpu().data.numpy()), accuracy, len(batch.poss_s)) else: batch_stats = Statistics(float(loss.cpu().data.numpy()), len(batch.poss_s)) stats.update(batch_stats) self._report_step(0, step, valid_stats=stats) return stats
def test1(self, test_iter, step, cal_lead=False, cal_oracle=False): """ Validate model. valid_iter: validate data iterator Returns: :obj:`nmt.Statistics`: validation loss statistics """ # Set model in validating mode. def _get_ngrams(n, text): ngram_set = set() text_length = len(text) max_index_ngram_start = text_length - n for i in range(max_index_ngram_start + 1): ngram_set.add(tuple(text[i:i + n])) return ngram_set def _block_tri(c, p): tri_c = _get_ngrams(3, c.split()) for s in p: tri_s = _get_ngrams(3, s.split()) if len(tri_c.intersection(tri_s)) > 0: return True return False if (not cal_lead and not cal_oracle): self.model.eval() stats = Statistics() output = '' with torch.no_grad(): for batch in test_iter: src = batch.src segs = batch.segs clss = batch.clss mask = batch.mask_src mask_cls = batch.mask_cls gold = [] pred = [] if (cal_lead): selected_ids = [list(range(batch.clss.size(1))) ] * batch.batch_size elif (cal_oracle): labels = batch.src_sent_labels selected_ids = [[ j for j in range(batch.clss.size(1)) if labels[i][j] == 1 ] for i in range(batch.batch_size)] else: # logger.info("src:%s, segs:%s, clss:%s, mask:%s, mask_cls:%s" % ( # src, segs, clss, mask, mask_cls)) sent_scores, mask = self.model(src, segs, clss, mask, mask_cls) sent_scores = sent_scores + mask.float() sent_scores = sent_scores.cpu().data.numpy() selected_ids = np.argsort(-sent_scores, 1) if (hasattr(batch, 'src_sent_labels')): labels = batch.src_sent_labels loss = self.loss(sent_scores, labels.float()) loss = (loss * mask.float()).sum() batch_stats = Statistics( float(loss.cpu().data.numpy()), len(labels)) stats.update(batch_stats) for i, idx in enumerate(selected_ids): _pred = [] if (len(batch.src_str[i]) == 0): continue for j in selected_ids[i][:len(batch.src_str[i])]: if (j >= len(batch.src_str[i])): continue candidate = batch.src_str[i][j].strip() if (self.args.block_trigram): if (not _block_tri(candidate, _pred)): _pred.append(candidate) else: _pred.append(candidate) if ((not cal_oracle) and (not self.args.recall_eval) and len(_pred) == 3): break _pred = ' '.join(_pred) if (self.args.recall_eval): _pred = ' '.join( _pred.split()[:len(batch.tgt_str[i].split())]) pred.append(_pred) gold.append(batch.tgt_str[i]) return ' '.join(pred)