def test_alphaBert(DS_model, dloader, threshold=0.5, is_clean_up=True, ep=0, train=False, mean_max='mean', rouge=False, parallel=parallel): if not train: DS_model.to(device) if parallel: DS_model = torch.nn.DataParallel(DS_model) DS_model.eval() out_pred_res = [] mIOU = [] ROC_IOU_ = [] ROC_threshold = torch.linspace(0, 1, 100).to(device) all_pred_trg = {'pred': [], 'trg': []} rouge_set = [] with torch.no_grad(): for batch_idx, sample in enumerate(dloader): src = sample['src_token'] trg = sample['trg'] att_mask = sample['mask_padding'] origin_len = sample['origin_seq_length'] src = src.float().to(device) trg = trg.float().to(device) att_mask = att_mask.float().to(device) origin_len = origin_len.to(device) pred_prop = DS_model(x=src, x_lengths=origin_len) if is_clean_up: pred_prop = clean_up(src, pred_prop, mean_max=mean_max, tokenize_alphabets=tokenize_alphabets) for i, src_ in enumerate(src): all_pred_trg['pred'].append(pred_prop[i][:origin_len[i]].cpu()) all_pred_trg['trg'].append(trg[i][:origin_len[i]].cpu()) if rouge: src_split, src_isword = split2words( src_, rouge=rouge, tokenize_alphabets=tokenize_alphabets) referecne = [] hypothesis = [] for j in range(len(src_split)): if src_isword[j] > 0: if trg[i][src_split[j]][0].cpu() > threshold: referecne.append( tokenize_alphabets.convert_idx2str( src_[src_split[j]])) if pred_prop[i][src_split[j]][0].cpu() > threshold: hypothesis.append( tokenize_alphabets.convert_idx2str( src_[src_split[j]])) rouge_set.append((hypothesis, referecne)) # mIOU += IOU_ACC(pred_prop,trg,origin_len, threshold) # ROC_IOU_.append(ROC(pred_prop,trg,origin_len, ROC_threshold)) pred_selected = pred_prop > threshold trg_selected = trg > threshold for i, src_ in enumerate(src): a_ = tokenize_alphabets.convert_idx2str(src_[:origin_len[i]]) s_ = tokenize_alphabets.convert_idx2str(src_[pred_selected[i]]) t_ = tokenize_alphabets.convert_idx2str(src_[trg_selected[i]]) # print(a_,pred_prop[0],s_,t_) out_pred_res.append((a_, s_, t_, pred_prop[0])) print(batch_idx, len(dloader)) out_pd_res = pd.DataFrame(out_pred_res) out_pd_res.to_csv('test_pred.csv', sep=',') make_statistics(all_pred_trg, ep=ep) DS_model.train() if rouge: rouge_res = rouge12l(rouge_set) rouge_res_pd = pd.DataFrame(rouge_res) rouge_res_pd.to_csv('./iou_pic/lstm/rouge_res.csv', index=False) rouge_res_np = np.array(rouge_res_pd) pd.DataFrame(rouge_res_np.mean(axis=0)).to_csv( './iou_pic/lstm/rouge_res_mean.csv', index=False)
def test_alphaBert_head(DS_model, dloader, threshold=0.5, is_clean_up=True, ep=0, train=False, mean_max='mean', rouge=False): if not train: DS_model.to(device) DS_model = torch.nn.DataParallel(DS_model) DS_model.eval() out_pred_res = [] all_pred_trg = {'pred': [], 'trg': []} rouge_set = [] leading_token_idx = tokenize_alphabets.alphabet2idx['|'] padding_token_idx = tokenize_alphabets.alphabet2idx[' '] with torch.no_grad(): for batch_idx, sample in enumerate(dloader): src = sample['src_token'] trg = sample['trg'] att_mask = sample['mask_padding'] origin_len = sample['origin_seq_length'] src = src.float().to(device) trg = trg.float().to(device) att_mask = att_mask.float().to(device) origin_len = origin_len.to(device) bs = src.shape pred_prop_bin, = DS_model(input_ids=src, attention_mask=att_mask, out='finehead') # max_pred_prop, pred_prop = pred_prop_bin.view(*bs,-1).max(dim=2) cls_idx = src == leading_token_idx pred_prop_bin_softmax = nn.Softmax(dim=-1)(pred_prop_bin.view( *bs, -1)) pred_prop = pred_prop_bin_softmax[:, :, 1] pred_selected = pred_prop > threshold trg_selected = trg > threshold for i, src_ in enumerate(src): all_pred_trg['pred'].append(pred_prop[i][cls_idx[i]].cpu()) all_pred_trg['trg'].append(trg[i][cls_idx[i]].cpu()) if rouge: # src_split, src_isword = split2words(src_,rouge=rouge) referecne = [] hypothesis = [] isselect_pred = False isselect_trg = False for j, wp in enumerate(src_): if wp == leading_token_idx: if pred_prop[i][j] > threshold: if isselect_pred: hypothesis.append(padding_token_idx) isselect_pred = True else: if isselect_pred: hypothesis.append(padding_token_idx) isselect_pred = False if trg[i][j] > 0: if isselect_trg: referecne.append(padding_token_idx) isselect_trg = True else: if isselect_trg: referecne.append(padding_token_idx) isselect_trg = False else: if isselect_pred: if wp != leading_token_idx: hypothesis.append(wp.item()) if isselect_trg: if wp != leading_token_idx: referecne.append(wp.item()) hypothesis = tokenize_alphabets.convert_idx2str(hypothesis) referecne = tokenize_alphabets.convert_idx2str(referecne) hypothesis_list = hypothesis.split() referecne_list = referecne.split() rouge_set.append((hypothesis_list, referecne_list)) a_ = tokenize_alphabets.convert_idx2str( src_[:origin_len[i]]) s_ = hypothesis t_ = referecne else: a_ = tokenize_alphabets.convert_idx2str( src_[:origin_len[i]]) s_ = tokenize_alphabets.convert_idx2str( src_[pred_selected[i]]) t_ = tokenize_alphabets.convert_idx2str( src_[trg_selected[i]]) out_pred_res.append((a_, s_, t_)) print(batch_idx, len(dloader)) out_pd_res = pd.DataFrame(out_pred_res) out_pd_res.to_csv('./iou_pic/test_pred.csv', sep=',') if not train: make_statistics(all_pred_trg, ep=ep) if rouge: rouge_res = rouge12l(rouge_set) rouge_res_pd = pd.DataFrame(rouge_res) rouge_res_pd.to_csv('./iou_pic/rouge_res.csv', index=False) rouge_res_np = np.array(rouge_res_pd) pd.DataFrame(rouge_res_np.mean(axis=0)).to_csv( './iou_pic/rouge_res_mean.csv', index=False) DS_model.train()
def test_BERT(DS_model, dloader, threshold=0.5, is_clean_up=True, ep=0, train=False, mean_max='mean', rouge=False, parallel=parallel): if not train: DS_model.to(device) if parallel: DS_model = torch.nn.DataParallel(DS_model) DS_model.eval() out_pred_res = [] mIOU = [] ROC_IOU_ = [] ROC_threshold = torch.linspace(0, 1, 100).to(device) all_pred_trg = {'pred': [], 'trg': []} rouge_set = [] with torch.no_grad(): for batch_idx, sample in enumerate(dloader): src = sample['src_token'] trg = sample['trg'] att_mask = sample['mask_padding'] origin_len = sample['origin_seq_length'] bs = src.shape src = src.float().to(device) trg = trg.float().to(device) att_mask = att_mask.float().to(device) origin_len = origin_len.to(device) pred_prop_ = DS_model(input_ids=src.long(), attention_mask=att_mask) cls_idx = src == 101 max_pred_prop, pred_prop2 = pred_prop_.view(*bs, -1).max(dim=2) pred_prop_bin_softmax = nn.Softmax(dim=-1)(pred_prop_.view( *bs, -1)) pred_prop = pred_prop_bin_softmax[:, :, 1] for i, src_ in enumerate(src): all_pred_trg['pred'].append(pred_prop[i][cls_idx[i]].cpu()) all_pred_trg['trg'].append(trg[i][cls_idx[i]].cpu()) if rouge: # src_split, src_isword = split2words(src_,rouge=rouge,tokenize_alphabets=tokenize_alphabets) referecne = [] hypothesis = [] isselect_pred = False isselect_trg = False for j, wp in enumerate(src_): if wp == 101: if pred_prop[i][j] > 0.5: isselect_pred = True else: isselect_pred = False if trg[i][j] > 0: isselect_trg = True else: isselect_trg = False else: if isselect_pred: if wp > 0: hypothesis.append(wp.item()) if isselect_trg: if wp > 0: referecne.append(wp.item()) hypothesis = bert_tokenizer.convert_ids_to_tokens( hypothesis) referecne = bert_tokenizer.convert_ids_to_tokens(referecne) hypothesis = bert_tokenizer.convert_tokens_to_string( hypothesis) referecne = bert_tokenizer.convert_tokens_to_string( referecne) hypothesis_list = hypothesis.split() referecne_list = referecne.split() rouge_set.append((hypothesis_list, referecne_list)) a_ = bert_tokenizer.convert_ids_to_tokens( src_[:origin_len[i]].detach().cpu().numpy()) a_ = bert_tokenizer.convert_tokens_to_string(a_) out_pred_res.append((a_, hypothesis, referecne)) print(batch_idx, len(dloader)) out_pd_res = pd.DataFrame(out_pred_res) out_pd_res.to_csv('./iou_pic/bert/test_pred.csv', sep=',') make_statistics(all_pred_trg, ep=ep) DS_model.train() if rouge: rouge_res = rouge12l(rouge_set) rouge_res_pd = pd.DataFrame(rouge_res) rouge_res_pd.to_csv('./iou_pic/bert/rouge_res.csv', index=False) rouge_res_np = np.array(rouge_res_pd) pd.DataFrame(rouge_res_np.mean(axis=0)).to_csv( './iou_pic/bert/rouge_res_mean.csv', index=False)
def test_alphaBert(DS_model, dloader, threshold=0.5, is_clean_up=True, ep=0, train=False, mean_max='mean', rouge=False): if not train: DS_model.to(device) DS_model = torch.nn.DataParallel(DS_model) DS_model.eval() out_pred_res = [] all_pred_trg = {'pred': [], 'trg': []} rouge_set = [] with torch.no_grad(): for batch_idx, sample in enumerate(dloader): src = sample['src_token'] trg = sample['trg'] att_mask = sample['mask_padding'] origin_len = sample['origin_seq_length'] src = src.float().to(device) trg = trg.float().to(device) att_mask = att_mask.float().to(device) origin_len = origin_len.to(device) bs = src.shape prediction_scores, (pooled_output, head_outputs) = DS_model( input_ids=src, attention_mask=att_mask) pred_prop_bin = pooled_output[0].view(*bs, -1) if is_clean_up: pred_prop = clean_up_v204_ft( src, pred_prop_bin, tokenize_alphabets=tokenize_alphabets, mean_max=mean_max) else: pred_prop_value, pred_prop = pred_prop_bin.max(dim=2) pred_prop = pred_prop.float() pred_selected = pred_prop > threshold trg_selected = trg > threshold for i, src_ in enumerate(src): if rouge: src_split, src_isword = split2words( src_, tokenize_alphabets=tokenize_alphabets, rouge=rouge) referecne = [] hypothesis = [] trg_ = [] pred_ = [] for j in range(len(src_split)): if src_isword[j] > 0: if trg[i][src_split[j]][0].cpu() > threshold: referecne.append( tokenize_alphabets.convert_idx2str( src_[src_split[j]])) if pred_prop[i][src_split[j]][0].cpu() > threshold: hypothesis.append( tokenize_alphabets.convert_idx2str( src_[src_split[j]])) trg_.append(trg[i][src_split[j]][0].cpu()) pred_.append(pred_prop[i][src_split[j]][0].cpu()) rouge_set.append((hypothesis, referecne)) all_pred_trg['trg'].append(torch.tensor(trg_)) all_pred_trg['pred'].append(torch.tensor(pred_)) a_ = tokenize_alphabets.convert_idx2str( src_[:origin_len[i]]) s_ = ''.join(i + ' ' for i in hypothesis) t_ = ''.join(i + ' ' for i in referecne) else: all_pred_trg['pred'].append( pred_prop[i][:origin_len[i]].cpu()) all_pred_trg['trg'].append(trg[i][:origin_len[i]].cpu()) a_ = tokenize_alphabets.convert_idx2str( src_[:origin_len[i]]) s_ = tokenize_alphabets.convert_idx2str( src_[pred_selected[i]]) t_ = tokenize_alphabets.convert_idx2str( src_[trg_selected[i]]) out_pred_res.append((a_, s_, t_, pred_prop[i])) print(batch_idx, len(dloader)) out_pd_res = pd.DataFrame(out_pred_res) out_pd_res.to_csv('./iou_pic/test_pred.csv', sep=',') make_statistics(all_pred_trg, ep=ep) DS_model.train() if rouge: rouge_res = rouge12l(rouge_set) rouge_res_pd = pd.DataFrame(rouge_res) rouge_res_pd.to_csv('./iou_pic/rouge_res.csv', index=False) rouge_res_np = np.array(rouge_res_pd) pd.DataFrame(rouge_res_np.mean(axis=0)).to_csv( './iou_pic/rouge_res_mean.csv', index=False)