def eval(model, iterator, fname, write): model.eval() words_all, triggers_all, triggers_hat_all = [], [], [] with torch.no_grad(): for i, batch in enumerate(iterator): tokens_2d, triggers_2d, entities_3d, postags_2d, adj, seqlen_1d, words, triggers = batch trigger_logits, trigger_hat_2d = model.predict_triggers( tokens_2d=tokens_2d, entities_3d=entities_3d, postags_2d=postags_2d, seqlen_1d=seqlen_1d, adjm=adj) words_all.extend(words) triggers_all.extend(triggers) triggers_hat_all.extend(trigger_hat_2d.cpu().numpy().tolist()) triggers_true, triggers_pred = [], [] with open('temp', 'w') as fout: for i, (words, triggers, triggers_hat) in enumerate( zip(words_all, triggers_all, triggers_hat_all)): triggers_hat = triggers_hat[:len(words)] triggers_hat = [idx2trigger[hat] for hat in triggers_hat] # [(ith sentence, t_start, t_end, t_type_str)] triggers_true.extend([(i, *item) for item in find_triggers(triggers)]) triggers_pred.extend([(i, *item) for item in find_triggers(triggers_hat)]) for w, t, t_h in zip(words, triggers, triggers_hat): fout.write('{}\t{}\t{}\n'.format(w, t, t_h)) fout.write("\n") print('[trigger classification]') trigger_p, trigger_r, trigger_f1 = calc_metric(triggers_true, triggers_pred) print('P={:.3f}\tR={:.3f}\tF1={:.3f}'.format(trigger_p, trigger_r, trigger_f1)) print('[trigger identification]') triggers_true = [(item[0], item[1], item[2]) for item in triggers_true] triggers_pred = [(item[0], item[1], item[2]) for item in triggers_pred] trigger_p_, trigger_r_, trigger_f1_ = calc_metric(triggers_true, triggers_pred) print('P={:.3f}\tR={:.3f}\tF1={:.3f}'.format(trigger_p_, trigger_r_, trigger_f1_)) metric = '[trigger classification]\tP={:.3f}\tR={:.3f}\tF1={:.3f}\n'.format( trigger_p, trigger_r, trigger_f1) metric += '[trigger identification]\tP={:.3f}\tR={:.3f}\tF1={:.3f}\n'.format( trigger_p_, trigger_r_, trigger_f1_) final = fname if write: with open(final, 'w') as fout: result = open("temp", "r").read() fout.write("{}\n".format(result)) fout.write(metric) os.remove("temp") return metric
def predict_triggers(self, tokens_x_2d, entities_x_3d, postags_x_2d, head_indexes_2d, triggers_y_2d, arguments_2d): tokens_x_2d = torch.LongTensor(tokens_x_2d).to(self.device) # postags_x_2d = torch.LongTensor(postags_x_2d).to(self.device) triggers_y_2d = torch.LongTensor(triggers_y_2d).to(self.device) head_indexes_2d = torch.LongTensor(head_indexes_2d).to(self.device) # postags_x_2d = self.postag_embed(postags_x_2d) # entity_x_2d = self.entity_embed(entities_x_3d) if self.training: self.bert.train() encoded_layers, _ = self.bert(tokens_x_2d) enc = encoded_layers[-1] else: self.bert.eval() with torch.no_grad(): encoded_layers, _ = self.bert(tokens_x_2d) enc = encoded_layers[-1] # x = torch.cat([enc, entity_x_2d, postags_x_2d], 2) # x = self.fc1(enc) # x: [batch_size, seq_len, hidden_size] x = enc # logits = self.fc2(x + enc) batch_size = tokens_x_2d.shape[0] for i in range(batch_size): x[i] = torch.index_select(x[i], 0, head_indexes_2d[i]) trigger_logits = self.fc_trigger(x) trigger_hat_2d = trigger_logits.argmax(-1) argument_hidden, argument_keys = [], [] for i in range(batch_size): candidates = arguments_2d[i]['candidates'] golden_entity_tensors = {} for j in range(len(candidates)): e_start, e_end, e_type_str = candidates[j] golden_entity_tensors[candidates[j]] = x[i, e_start:e_end, ].mean( dim=0) predicted_triggers = find_triggers([ idx2trigger[trigger] for trigger in trigger_hat_2d[i].tolist() ]) for predicted_trigger in predicted_triggers: t_start, t_end, t_type_str = predicted_trigger event_tensor = x[i, t_start:t_end, ].mean(dim=0) for j in range(len(candidates)): e_start, e_end, e_type_str = candidates[j] entity_tensor = golden_entity_tensors[candidates[j]] argument_hidden.append( torch.cat([event_tensor, entity_tensor])) argument_keys.append((i, t_start, t_end, t_type_str, e_start, e_end, e_type_str)) return trigger_logits, triggers_y_2d, trigger_hat_2d, argument_hidden, argument_keys
def train(model, iterator, optimizer, hp): model.train() words_all, triggers_all, triggers_hat_all, arguments_all, arguments_hat_all = [], [], [], [], [] triggers_true, triggers_pred, arguments_true, arguments_pred = [], [], [], [] # 角色 # argument_keys:(正确)预测触发词 - 正确实体 # arguments_2d:正确触发词 - 正确角色 # 正确实体 # arguments_y_2d:输入CRF的标签数据[dim0, seq_len] # argument_hat_1d: CRF计算结果 # argument_hat_2d:根据argument_keys和argument_hat_1d写成字典格式 # # 触发词 # trigger_hat_2d:CRF预测触发词 # triggers_y_2d:正确触发词 for i, batch in enumerate(iterator): tokens_x_2d, entities_x_3d, postags_x_2d, triggers_y_2d, arguments_2d, seqlens_1d, head_indexes_2d, words_2d, triggers_2d, adjm = batch optimizer.zero_grad() ## crf_loss, 触发词标签, 预测触发词, 实体-事件拼接张量, (7维元组) trigger_loss, triggers_y_2d, trigger_hat_2d, argument_hidden, argument_keys = model.module.predict_triggers(tokens_x_2d=tokens_x_2d, entities_x_3d=entities_x_3d, postags_x_2d=postags_x_2d, head_indexes_2d=head_indexes_2d, triggers_y_2d=triggers_y_2d, arguments_2d=arguments_2d, adjm=adjm) if len(argument_keys) > 0: argument_loss, arguments_y_2d, argument_hat_1d, argument_hat_2d = model.module.predict_arguments(argument_hidden, argument_keys, arguments_2d, adjm) # argument_loss = criterion(argument_logits, arguments_y_1d) loss = trigger_loss + hp.LOSS_alpha* argument_loss # if i == 0: # print("=====sanity check for triggers======") # print('triggers_y_2d[0]:', triggers_y_2d[0]) # print("trigger_hat_2d[0]:", trigger_hat_2d[0]) # print("=======================") # print("=====sanity check for arguments======") # print('arguments_y_2d[0]:', arguments_y_2d[0]) # print('argument_hat_1d[0]:', argument_hat_1d[0]) # print("arguments_2d[0]:", arguments_2d) # print("argument_hat_2d[0]:", argument_hat_2d) # print("=======================") else: loss = trigger_loss nn.utils.clip_grad_norm_(model.parameters(), 3.0) loss.backward() optimizer.step() # if i == 0: # print("=====sanity check======") # print("tokens_x_2d[0]:", tokenizer.convert_ids_to_tokens(tokens_x_2d[0])[:seqlens_1d[0]]) # print("entities_x_3d[0]:", entities_x_3d[0][:seqlens_1d[0]]) # print("postags_x_2d[0]:", postags_x_2d[0][:seqlens_1d[0]]) # print("head_indexes_2d[0]:", head_indexes_2d[0][:seqlens_1d[0]]) # print("triggers_2d[0]:", triggers_2d[0]) # print("triggers_y_2d[0]:", triggers_y_2d.cpu().numpy().tolist()[0][:seqlens_1d[0]]) # print('trigger_hat_2d[0]:', trigger_hat_2d.cpu().numpy().tolist()[0][:seqlens_1d[0]]) # print("seqlens_1d[0]:", seqlens_1d[0]) # print("arguments_2d[0]:", arguments_2d[0]) # print("=======================") #### 训练精度评估 #### words_all.extend(words_2d) triggers_all.extend(triggers_2d) triggers_hat_all.extend(trigger_hat_2d.cpu().numpy().tolist()) arguments_all.extend(arguments_2d) if len(argument_keys) > 0: arguments_hat_all.extend(argument_hat_2d) else: batch_size = len(arguments_2d) argument_hat_2d = [{'events': {}} for _ in range(batch_size)] arguments_hat_all.extend(argument_hat_2d) for ii, (words, triggers, triggers_hat, arguments, arguments_hat) in enumerate( zip(words_all, triggers_all, triggers_hat_all, arguments_all, arguments_hat_all)): triggers_hat = triggers_hat[:len(words)] triggers_hat = [idx2trigger[hat] for hat in triggers_hat] # [(ith sentence, t_start, t_end, t_type_str)] triggers_true.extend([(ii, *item) for item in find_triggers(triggers)]) triggers_pred.extend([(ii, *item) for item in find_triggers(triggers_hat)]) # [(ith sentence, t_start, t_end, t_type_str, a_start, a_end, a_type_idx)] for trigger in arguments['events']: t_start, t_end, t_type_str = trigger for argument in arguments['events'][trigger]: a_start, a_end, a_type_idx = argument arguments_true.append((ii, t_start, t_end, t_type_str, a_start, a_end, a_type_idx)) for trigger in arguments_hat['events']: t_start, t_end, t_type_str = trigger for argument in arguments_hat['events'][trigger]: a_start, a_end, a_type_idx = argument arguments_pred.append((ii, t_start, t_end, t_type_str, a_start, a_end, a_type_idx)) if i % 100 == 0: # monitoring trigger_p, trigger_r, trigger_f1 = calc_metric(triggers_true, triggers_pred) argument_p, argument_r, argument_f1 = calc_metric(arguments_true, arguments_pred) ## 100step 清零 words_all, triggers_all, triggers_hat_all, arguments_all, arguments_hat_all = [], [], [], [], [] triggers_true, triggers_pred, arguments_true, arguments_pred = [], [], [], [] ######################### if len(argument_keys) > 0: print("【识别到事件】step: {}, loss: {:.3f}, trigger_loss:{:.3f}, argument_loss:{:.3f}".format(i, loss.item(), trigger_loss.item(), argument_loss.item()), '【trigger】 P={:.3f} R={:.3f} F1={:.3f}'.format(trigger_p, trigger_r, trigger_f1), '【argument】 P={:.3f} R={:.3f} F1={:.3f}'.format(argument_p, argument_r, argument_f1) ) else: print("【未识别到事件】step: {}, loss: {:.3f} ".format(i, loss.item()), '【trigger】 P={:.3f} R={:.3f} F1={:.3f}'.format(trigger_p, trigger_r, trigger_f1) )
def predict_triggers(self, tokens_x_2d, entities_x_3d, postags_x_2d, head_indexes_2d, triggers_y_2d, arguments_2d, adjm): # def get_Ngram_emb(self,emb,N): # # batch_size, SEN_LEN, hidden_size = emb.size() # hidden_size = hidden_size*2 # x = torch.zeros([batch_size, SEN_LEN, hidden_size],dtype = emb.dtype) # # for i in range(batch_size): # # for j in range(SEN_LEN): # # x[i,j]=emb[i,max(j-N,0):min(j+N,SEQ_LEN-1)].mean(dim=0) # # for j in range(SEN_LEN): # cnnfeature=self.NgramCNN.forward(emb[:, max(j - N, 0):min(j + N, SEQ_LEN - 1),:])# [batch_size,hidden_size] # Nmax, _ = emb[:,max(j-N,0):min(j+N,SEQ_LEN-1),:].max(dim=1)# [batch_size,hidden_size] # x[:,j,:] = torch.cat([cnnfeature,Nmax],dim=-1) # # return x.to(self.device) ## 字符ID [batch_size, seq_length] tokens_x_2d = torch.LongTensor(tokens_x_2d).to(self.device) ## 触发词标签ID [batch_size, seq_length] triggers_y_2d = torch.LongTensor(triggers_y_2d).to(self.device) ## [batch_size, seq_length] xlen = [max(x) for x in head_indexes_2d] head_indexes_2d = torch.LongTensor(head_indexes_2d).to(self.device) if self.training: self.PreModel.train() x_1, _ = self.PreModel(tokens_x_2d) else: self.PreModel.eval() with torch.no_grad(): x_1, _ = self.PreModel(tokens_x_2d) batch_size = tokens_x_2d.shape[0] SEQ_LEN = x_1.size()[1] # [CLS]字符 # sen_emb = torch.unsqueeze(x_1[:,0,:],dim=1).repeat(1, SEQ_LEN, 1) # [batch,1,hidden_size] # 复数形式拆解 x = torch.zeros(x_1.size(), dtype=x_1.dtype).to(self.device) for i in range(batch_size): ## 切片, 会改变位置 同时去除了[CLS] x[i] = torch.index_select(x_1[i], 0, head_indexes_2d[i]) mask = numpy.zeros(shape=[batch_size, SEQ_LEN], dtype=numpy.uint8) for i in range(len(xlen)): mask[i, :xlen[i]] = 1 mask = torch.ByteTensor(mask).to(self.device) self.mask = mask ## [batch_size, SEQ_LEN, hidden_size*2] # n_gram_emb = get_Ngram_emb(self,x,5) ## emb = torch.cat([x,sen_emb,n_gram_emb],dim=-1) #hidden_size*3 #emb = torch.cat([x, n_gram_emb], dim=-1) # hidden_size*3 emb = x # [batch_size, seq_len, hidden_size] trigger_logits1 = self.tri_fc1(emb) trigger_logits1 = nn.functional.leaky_relu_( trigger_logits1) # x: [batch_size, seq_len, trigger_size + 2 ] ## tri_CRF1 ## trigger_loss = self.tri_CRF1.neg_log_likelihood_loss( feats=trigger_logits1, mask=mask, tags=triggers_y_2d) _, trigger_hat_2d = self.tri_CRF1.forward(feats=trigger_logits1, mask=mask) self.emb = emb self.tri_result = trigger_hat_2d argument_keys = {} # 记录预测出的正确的触发词,对应的正确角色 sen_mask_id = [] for i in range(batch_size): ## 列表 元素格式:[触发词开始位置,触发词结束位置,事件类型(34个) predicted_triggers = find_triggers([ self.idx2trigger[trigger] for trigger in trigger_hat_2d[i].tolist() ]) for predicted_trigger in predicted_triggers: ## 预测-触发词开始位置,预测-触发词结束位置,预测-事件类型(文本) t_start, t_end, t_type_str = predicted_trigger ## 当预测的触发词 是正确的 if (t_start, t_end, t_type_str) in arguments_2d[i]['events']: for (a_start, a_end, a_type_idx) in arguments_2d[i]['events'][( t_start, t_end, t_type_str)]: if (i, t_start, t_end, t_type_str) in argument_keys: argument_keys[(i, t_start, t_end, t_type_str)].append( (a_start, a_end, a_type_idx)) else: argument_keys[(i, t_start, t_end, t_type_str)] = [] argument_keys[(i, t_start, t_end, t_type_str)].append( (a_start, a_end, a_type_idx)) # else: #当预测触发词是错误的时候 # argument_keys[(i, t_start, t_end, t_type_str)] = [] return trigger_loss, triggers_y_2d, trigger_hat_2d, sen_mask_id, argument_keys
def eval(model, iterator, fname): model.eval() words_all, triggers_all, triggers_hat_all, arguments_all, arguments_hat_all = [], [], [], [], [] with torch.no_grad(): for i, batch in enumerate(iterator): tokens_x_2d, entities_x_3d, postags_x_2d, triggers_y_2d, arguments_2d, seqlens_1d, head_indexes_2d, words_2d, triggers_2d, adjm = batch trigger_loss, triggers_y_2d, trigger_hat_2d, argument_hidden, argument_keys = model.module.predict_triggers( tokens_x_2d=tokens_x_2d, entities_x_3d=entities_x_3d, postags_x_2d=postags_x_2d, head_indexes_2d=head_indexes_2d, triggers_y_2d=triggers_y_2d, arguments_2d=arguments_2d, adjm=adjm) words_all.extend(words_2d) triggers_all.extend(triggers_2d) triggers_hat_all.extend(trigger_hat_2d.cpu().numpy().tolist()) arguments_all.extend(arguments_2d) if len(argument_keys) > 0: argument_loss, arguments_y_2d, argument_hat_1d, argument_hat_2d = model.module.predict_arguments( argument_hidden, argument_keys, arguments_2d, adjm) arguments_hat_all.extend(argument_hat_2d) # if i == 0: # print("=====sanity check for triggers======") # print('triggers_y_2d[0]:', triggers_y_2d[0]) # print("trigger_hat_2d[0]:", trigger_hat_2d[0]) # print("=======================") # print("=====sanity check for arguments======") # print('arguments_y_2d[0]:', arguments_y_2d[0]) # print('argument_hat_1d[0]:', argument_hat_1d[0]) # print("arguments_2d[0]:", arguments_2d) # print("argument_hat_2d[0]:", argument_hat_2d) # print("=======================") else: batch_size = len(arguments_2d) argument_hat_2d = [{'events': {}} for _ in range(batch_size)] arguments_hat_all.extend(argument_hat_2d) triggers_true, triggers_pred, arguments_true, arguments_pred = [], [], [], [] with open('temp', 'w', encoding="utf-8") as fout: for i, (words, triggers, triggers_hat, arguments, arguments_hat) in enumerate( zip(words_all, triggers_all, triggers_hat_all, arguments_all, arguments_hat_all)): triggers_hat = triggers_hat[:len(words)] triggers_hat = [idx2trigger[hat] for hat in triggers_hat] # [(ith sentence, t_start, t_end, t_type_str)] triggers_true.extend([(i, *item) for item in find_triggers(triggers)]) triggers_pred.extend([(i, *item) for item in find_triggers(triggers_hat)]) # [(ith sentence, t_start, t_end, t_type_str, a_start, a_end, a_type_idx)] for trigger in arguments['events']: t_start, t_end, t_type_str = trigger for argument in arguments['events'][trigger]: a_start, a_end, a_type_idx = argument arguments_true.append((i, t_start, t_end, t_type_str, a_start, a_end, a_type_idx)) for trigger in arguments_hat['events']: t_start, t_end, t_type_str = trigger for argument in arguments_hat['events'][trigger]: a_start, a_end, a_type_idx = argument arguments_pred.append((i, t_start, t_end, t_type_str, a_start, a_end, a_type_idx)) for w, t, t_h in zip(words[1:-1], triggers, triggers_hat): fout.write('{}\t{}\t{}\n'.format(w, t, t_h)) arg_write = arguments['events'] for arg_key in arg_write: arg = arg_write[ arg_key] # list,eg: [(0, 5, 25), (8, 19, 17), (20, 21, 29)] for ii, tup in enumerate(arg): arg[ii] = (tup[0], tup[1], idx2argument[tup[2]] ) # 将id 转为 str arg_write[arg_key] = arg arghat_write = arguments_hat['events'] for arg_key in arghat_write: arg = arghat_write[ arg_key] # list,eg: [(0, 5, 25), (8, 19, 17), (20, 21, 29)] for ii, tup in enumerate(arg): arg[ii] = (tup[0], tup[1], idx2argument[tup[2]] ) # 将id 转为 str arghat_write[arg_key] = arg fout.write('#真实值#\t{}\n'.format(arg_write)) fout.write('#预测值#\t{}\n'.format(arghat_write)) fout.write("\n") # print(classification_report([idx2trigger[idx] for idx in y_true], [idx2trigger[idx] for idx in y_pred])) print('[trigger classification]') trigger_p, trigger_r, trigger_f1 = calc_metric(triggers_true, triggers_pred) print('P={:.3f}\tR={:.3f}\tF1={:.3f}'.format(trigger_p, trigger_r, trigger_f1)) print('[argument classification]') argument_p, argument_r, argument_f1 = calc_metric(arguments_true, arguments_pred) print('P={:.3f}\tR={:.3f}\tF1={:.3f}'.format(argument_p, argument_r, argument_f1)) print('[trigger identification]') triggers_true = [(item[0], item[1], item[2]) for item in triggers_true] triggers_pred = [(item[0], item[1], item[2]) for item in triggers_pred] trigger_p_, trigger_r_, trigger_f1_ = calc_metric(triggers_true, triggers_pred) print('P={:.3f}\tR={:.3f}\tF1={:.3f}'.format(trigger_p_, trigger_r_, trigger_f1_)) print('[argument identification]') arguments_true = [(item[0], item[1], item[2], item[3], item[4], item[5]) for item in arguments_true] arguments_pred = [(item[0], item[1], item[2], item[3], item[4], item[5]) for item in arguments_pred] argument_p_, argument_r_, argument_f1_ = calc_metric( arguments_true, arguments_pred) print('P={:.3f}\tR={:.3f}\tF1={:.3f}'.format(argument_p_, argument_r_, argument_f1_)) metric = '[trigger classification]\tP={:.3f}\tR={:.3f}\tF1={:.3f}\n'.format( trigger_p, trigger_r, trigger_f1) metric += '[argument classification]\tP={:.3f}\tR={:.3f}\tF1={:.3f}\n'.format( argument_p, argument_r, argument_f1) metric += '[trigger identification]\tP={:.3f}\tR={:.3f}\tF1={:.3f}\n'.format( trigger_p_, trigger_r_, trigger_f1_) metric += '[argument identification]\tP={:.3f}\tR={:.3f}\tF1={:.3f}\n'.format( argument_p_, argument_r_, argument_f1_) final = fname + ".trigger-F%.2f argument-F%.2f" % (trigger_f1, argument_f1) with open(final, 'w', encoding="utf-8") as fout: result = open("temp", "r", encoding="utf-8").read() fout.write("{}\n".format(result)) fout.write(metric) os.remove("temp") return metric, trigger_f1, argument_f1
def eval(model, iterator, fname): model.eval() words_all, triggers_all, triggers_hat_all, arguments_all, arguments_hat_all = [], [], [], [], [] with torch.no_grad(): # for i, batch in enumerate(iterator): for i, (test, labels) in enumerate(iterator): trigger_logits, trigger_entities_hat_2d, triggers_y_2d, argument_hidden_logits, arguments_y_1d, argument_hidden_hat_1d, argument_hat_2d, argument_keys = model( test, labels) words_all.extend(test[3]) triggers_all.extend(test[4]) triggers_hat_all.extend( trigger_entities_hat_2d.cpu().numpy().tolist()) arguments_2d = test[-1] arguments_all.extend(arguments_2d) if len(argument_keys) > 0: arguments_hat_all.extend(argument_hat_2d) else: batch_size = len(arguments_2d) argument_hat_2d = [{'events': {}} for _ in range(batch_size)] arguments_hat_all.extend(argument_hat_2d) triggers_true, triggers_pred, arguments_true, arguments_pred = [], [], [], [] with open('temp', 'w', encoding='utf-8') as fout: for i, (words, triggers, triggers_hat, arguments, arguments_hat) in enumerate( zip(words_all, triggers_all, triggers_hat_all, arguments_all, arguments_hat_all)): triggers_hat = triggers_hat[:len(words)] triggers_hat = [idx2trigger_entities[hat] for hat in triggers_hat] # [(ith sentence, t_start, t_end, t_type_str)] triggers_true_, entities_true = find_triggers( triggers[:len(words)]) triggers_pred_, entities_pred = find_triggers(triggers_hat) triggers_true.extend([(i, *item) for item in triggers_true_]) triggers_pred.extend([(i, *item) for item in triggers_pred_]) # [(ith sentence, t_start, t_end, t_type_str, a_start, a_end, a_type_idx)] for trigger in arguments['events']: t_start, t_end, t_type_str = trigger for argument in arguments['events'][trigger]: a_start, a_end, a_type_idx = argument arguments_true.append( (t_type_str, a_start, a_end, a_type_idx)) for trigger in arguments_hat['events']: t_start, t_end, t_type_str = trigger if t_start >= len(words) or t_end >= len(words): continue for argument in arguments_hat['events'][trigger]: a_start, a_end, a_type_idx = argument if a_start >= len(words) or a_end >= len(words): continue arguments_pred.append( (t_type_str, a_start, a_end, a_type_idx)) for w, t, t_h in zip(words, triggers, triggers_hat): fout.write('{}\t{}\t{}\n'.format(w, t, t_h)) fout.write('#arguments#{}\n'.format(arguments['events'])) fout.write('#arguments_hat#{}\n'.format(arguments_hat['events'])) fout.write("\n") # print(classification_report([idx2trigger[idx] for idx in y_true], [idx2trigger[idx] for idx in y_pred])) print('[trigger classification]') trigger_p, trigger_r, trigger_f1 = calc_metric(triggers_true, triggers_pred) print('P={:.3f}\tR={:.3f}\tF1={:.3f}'.format(trigger_p, trigger_r, trigger_f1)) print('[argument classification]') argument_p, argument_r, argument_f1 = calc_metric(arguments_true, arguments_pred) print('P={:.3f}\tR={:.3f}\tF1={:.3f}'.format(argument_p, argument_r, argument_f1)) print('[trigger identification]') triggers_true = [(item[0], item[1], item[2]) for item in triggers_true] triggers_pred = [(item[0], item[1], item[2]) for item in triggers_pred] trigger_p_, trigger_r_, trigger_f1_ = calc_metric(triggers_true, triggers_pred) print('P={:.3f}\tR={:.3f}\tF1={:.3f}'.format(trigger_p_, trigger_r_, trigger_f1_)) print('[argument identification]') arguments_true = [(item[0], item[1], item[2]) for item in arguments_true] arguments_pred = [(item[0], item[1], item[2]) for item in arguments_pred] argument_p_, argument_r_, argument_f1_ = calc_metric( arguments_true, arguments_pred) print('P={:.3f}\tR={:.3f}\tF1={:.3f}'.format(argument_p_, argument_r_, argument_f1_)) metric = '[trigger classification]\tP={:.3f}\tR={:.3f}\tF1={:.3f}\n'.format( trigger_p, trigger_r, trigger_f1) metric += '[argument classification]\tP={:.3f}\tR={:.3f}\tF1={:.3f}\n'.format( argument_p, argument_r, argument_f1) metric += '[trigger identification]\tP={:.3f}\tR={:.3f}\tF1={:.3f}\n'.format( trigger_p_, trigger_r_, trigger_f1_) metric += '[argument identification]\tP={:.3f}\tR={:.3f}\tF1={:.3f}\n'.format( argument_p_, argument_r_, argument_f1_) final = fname + ".P%.2f_R%.2f_F%.2f" % (trigger_p, trigger_r, trigger_f1) with open(final, 'w', encoding='utf-8') as fout: result = open("temp", "r", encoding='utf-8').read() fout.write("{}\n".format(result)) fout.write(metric) os.remove("temp") return metric, trigger_f1, argument_f1
def eval(model, iterator, fname): model.eval() words_all, triggers_all, triggers_hat_all, arguments_all, arguments_hat_all = [], [], [], [], [] with torch.no_grad(): for i, batch in enumerate(iterator): # tokens_x_2d, entities_x_3d, postags_x_2d, triggers_y_2d, arguments_2d, seqlens_1d, head_indexes_2d, words_2d, triggers_2d, \ pre_sent_tokens_x, next_sent_tokens_x, pre_sent_len, next_sent_len, maxlen = batch # maxlen = max(seqlens_1d) # pre_sent_len_max = max(pre_sent_len) # next_sent_len_max = max(next_sent_len) pre_sent_flags = [] next_sent_flags = [] pre_sent_len_mat = [] next_sent_len_mat = [] for i in pre_sent_len: tmp = [[1] * 768] * i + [[0] * 768] * (maxlen - i) pre_sent_flags.append(tmp) pre_sent_len_mat.append([i] * 768) for i in next_sent_len: tmp = [[1] * 768] * i + [[0] * 768] * (maxlen - i) next_sent_flags.append(tmp) next_sent_len_mat.append([i] * 768) # trigger_logits, triggers_y_2d, trigger_hat_2d, argument_hidden, argument_keys = model.module.predict_triggers(tokens_x_2d=tokens_x_2d, entities_x_3d=entities_x_3d, trigger_logits, triggers_y_2d, trigger_hat_2d = model.predict_triggers( tokens_x_2d=tokens_x_2d, entities_x_3d=entities_x_3d, postags_x_2d=postags_x_2d, head_indexes_2d=head_indexes_2d, triggers_y_2d=triggers_y_2d, arguments_2d=arguments_2d, pre_sent_tokens_x=pre_sent_tokens_x, next_sent_tokens_x=next_sent_tokens_x, pre_sent_flags=pre_sent_flags, next_sent_flags=next_sent_flags, pre_sent_len_mat=pre_sent_len_mat, next_sent_len_mat=next_sent_len_mat) words_all.extend(words_2d) triggers_all.extend(triggers_2d) triggers_hat_all.extend(trigger_hat_2d.cpu().numpy().tolist()) arguments_all.extend(arguments_2d) triggers_true, triggers_pred = [], [] with open('temp', 'w', encoding='utf-8') as fout: for i, (words, triggers, triggers_hat) in enumerate( zip(words_all, triggers_all, triggers_hat_all)): triggers_hat = triggers_hat[:len(words)] triggers_hat = [idx2trigger[hat] for hat in triggers_hat] # [(ith sentence, t_start, t_end, t_type_str)] triggers_true.extend([(i, *item) for item in find_triggers(triggers)]) triggers_pred.extend([(i, *item) for item in find_triggers(triggers_hat)]) for w, t, t_h in zip(words[1:-1], triggers, triggers_hat): fout.write('{}\t{}\t{}\n'.format(w, t, t_h)) fout.write("\n") # print(classification_report([idx2trigger[idx] for idx in y_true], [idx2trigger[idx] for idx in y_pred])) print('[trigger classification]') trigger_p, trigger_r, trigger_f1 = calc_metric(triggers_true, triggers_pred) print('P={:.3f}\tR={:.3f}\tF1={:.3f}'.format(trigger_p, trigger_r, trigger_f1)) print('[trigger identification]') triggers_true = [(item[0], item[1], item[2]) for item in triggers_true] triggers_pred = [(item[0], item[1], item[2]) for item in triggers_pred] trigger_p_, trigger_r_, trigger_f1_ = calc_metric(triggers_true, triggers_pred) print('P={:.3f}\tR={:.3f}\tF1={:.3f}'.format(trigger_p_, trigger_r_, trigger_f1_)) metric = '[trigger classification]\tP={:.3f}\tR={:.3f}\tF1={:.3f}\n'.format( trigger_p, trigger_r, trigger_f1) metric += '[trigger identification]\tP={:.3f}\tR={:.3f}\tF1={:.3f}\n'.format( trigger_p_, trigger_r_, trigger_f1_) final = fname + ".P%.2f_R%.2f_F%.2f" % (trigger_p, trigger_r, trigger_f1) metric_2 = { "trigger classification": [trigger_p, trigger_r, trigger_f1], "trigger identification": [trigger_p_, trigger_r_, trigger_f1_] } with open(final, 'w') as fout: result = open("temp", "r").read() fout.write("{}\n".format(result)) fout.write(metric) os.remove("temp") return metric_2
def eval(model, iterator, fname): model.eval() words_all, triggers_all, triggers_hat_all, arguments_all, arguments_hat_all = [], [], [], [], [] with torch.no_grad(): for i, batch in enumerate(iterator): tokens_x_2d, entities_x_3d, postags_x_2d, triggers_y_2d, arguments_2d, seqlens_1d, head_indexes_2d, words_2d, triggers_2d = batch trigger_logits, triggers_y_2d, trigger_hat_2d, argument_hidden, argument_keys = model.module.predict_triggers( tokens_x_2d=tokens_x_2d, entities_x_3d=entities_x_3d, postags_x_2d=postags_x_2d, head_indexes_2d=head_indexes_2d, triggers_y_2d=triggers_y_2d, arguments_2d=arguments_2d) words_all.extend(words_2d) triggers_all.extend(triggers_2d) triggers_hat_all.extend(trigger_hat_2d.cpu().numpy().tolist()) arguments_all.extend(arguments_2d) if len(argument_keys) > 0: argument_logits, arguments_y_1d, argument_hat_1d, argument_hat_2d = model.module.predict_arguments( argument_hidden, argument_keys, arguments_2d) arguments_hat_all.extend(argument_hat_2d) else: batch_size = len(arguments_2d) argument_hat_2d = [{'events': {}} for _ in range(batch_size)] arguments_hat_all.extend(argument_hat_2d) triggers_true, triggers_pred, arguments_true, arguments_pred = [], [], [], [] with open('temp', 'w') as fout: for i, (words, triggers, triggers_hat, arguments, arguments_hat) in enumerate( zip(words_all, triggers_all, triggers_hat_all, arguments_all, arguments_hat_all)): triggers_hat = triggers_hat[:len(words)] triggers_hat = [idx2trigger[hat] for hat in triggers_hat] # [(ith sentence, t_start, t_end, t_type_str)] triggers_true.extend([(i, *item) for item in find_triggers(triggers)]) triggers_pred.extend([(i, *item) for item in find_triggers(triggers_hat)]) # [(ith sentence, t_start, t_end, t_type_str, a_start, a_end, a_type_idx)] for trigger in arguments['events']: t_start, t_end, t_type_str = trigger for argument in arguments['events'][trigger]: a_start, a_end, a_type_idx = argument arguments_true.append((i, t_start, t_end, t_type_str, a_start, a_end, a_type_idx)) for trigger in arguments_hat['events']: t_start, t_end, t_type_str = trigger for argument in arguments_hat['events'][trigger]: a_start, a_end, a_type_idx = argument arguments_pred.append((i, t_start, t_end, t_type_str, a_start, a_end, a_type_idx)) for w, t, t_h in zip(words[1:-1], triggers, triggers_hat): fout.write('{}\t{}\t{}\n'.format(w, t, t_h)) fout.write('#arguments#{}\n'.format(arguments['events'])) fout.write('#arguments_hat#{}\n'.format(arguments_hat['events'])) fout.write("\n") # print(classification_report([idx2trigger[idx] for idx in y_true], [idx2trigger[idx] for idx in y_pred])) print('[trigger classification]') trigger_p, trigger_r, trigger_f1 = calc_metric(triggers_true, triggers_pred) print('P={:.3f}\tR={:.3f}\tF1={:.3f}'.format(trigger_p, trigger_r, trigger_f1)) print('[argument classification]') argument_p, argument_r, argument_f1 = calc_metric(arguments_true, arguments_pred) print('P={:.3f}\tR={:.3f}\tF1={:.3f}'.format(argument_p, argument_r, argument_f1)) print('[trigger identification]') triggers_true = [(item[0], item[1], item[2]) for item in triggers_true] triggers_pred = [(item[0], item[1], item[2]) for item in triggers_pred] trigger_p_, trigger_r_, trigger_f1_ = calc_metric(triggers_true, triggers_pred) print('P={:.3f}\tR={:.3f}\tF1={:.3f}'.format(trigger_p_, trigger_r_, trigger_f1_)) print('[argument identification]') arguments_true = [(item[0], item[1], item[2], item[3], item[4], item[5]) for item in arguments_true] arguments_pred = [(item[0], item[1], item[2], item[3], item[4], item[5]) for item in arguments_pred] argument_p_, argument_r_, argument_f1_ = calc_metric( arguments_true, arguments_pred) print('P={:.3f}\tR={:.3f}\tF1={:.3f}'.format(argument_p_, argument_r_, argument_f1_)) metric = '[trigger classification]\tP={:.3f}\tR={:.3f}\tF1={:.3f}\n'.format( trigger_p, trigger_r, trigger_f1) metric += '[argument classification]\tP={:.3f}\tR={:.3f}\tF1={:.3f}\n'.format( argument_p, argument_r, argument_f1) metric += '[trigger identification]\tP={:.3f}\tR={:.3f}\tF1={:.3f}\n'.format( trigger_p_, trigger_r_, trigger_f1_) metric += '[argument identification]\tP={:.3f}\tR={:.3f}\tF1={:.3f}\n'.format( argument_p_, argument_r_, argument_f1_) final = fname + ".P%.2f_R%.2f_F%.2f" % (trigger_p, trigger_r, trigger_f1) with open(final, 'w') as fout: result = open("temp", "r").read() fout.write("{}\n".format(result)) fout.write(metric) os.remove("temp") return metric
def eval_module(model, iterator, fname, module, idx2argument): model.eval() words_all, triggers_all, triggers_hat_all, arguments_all, arguments_hat_all = [], [], [], [], [] with torch.no_grad(): for i, batch in enumerate(iterator): tokens_x_2d, entities_x_3d, postags_x_2d, triggers_y_2d, arguments_2d, seqlens_1d, head_indexes_2d, words_2d, triggers_2d = batch trigger_logits, triggers_y_2d, trigger_hat_2d, argument_hidden, argument_keys, trigger_info, auxiliary_feature = model.module.predict_triggers( tokens_x_2d=tokens_x_2d, entities_x_3d=entities_x_3d, postags_x_2d=postags_x_2d, head_indexes_2d=head_indexes_2d, triggers_y_2d=triggers_y_2d, arguments_2d=arguments_2d) words_all.extend(words_2d) triggers_all.extend(triggers_2d) triggers_hat_all.extend(trigger_hat_2d.cpu().numpy().tolist()) arguments_all.extend(arguments_2d) if len(argument_keys) > 0: argument_logits, arguments_y_1d, argument_hat_1d, argument_hat_2d = model.module.module_predict_arguments( argument_hidden, argument_keys, arguments_2d, module) module_decisions_logit, module_decisions_y, argument_hat_2d = model.module.meta_classifier( argument_keys, arguments_2d, trigger_info, argument_logits, argument_hat_1d, auxiliary_feature, module) arguments_hat_all.extend(argument_hat_2d) else: batch_size = len(arguments_2d) argument_hat_2d = [{'events': {}} for _ in range(batch_size)] arguments_hat_all.extend(argument_hat_2d) triggers_true, triggers_pred, arguments_true, arguments_pred = [], [], [], [] with open('temp', 'w') as fout: for i, (words, triggers, triggers_hat, arguments, arguments_hat) in enumerate( zip(words_all, triggers_all, triggers_hat_all, arguments_all, arguments_hat_all)): triggers_hat = triggers_hat[:len(words)] triggers_hat = [idx2trigger[hat] for hat in triggers_hat] # [(ith sentence, t_start, t_end, t_type_str)] triggers_true.extend([(i, *item) for item in find_triggers(triggers)]) triggers_pred.extend([(i, *item) for item in find_triggers(triggers_hat)]) # [(ith sentence, t_start, t_end, t_type_str, a_start, a_end, a_type_idx)] for trigger in arguments['events']: t_start, t_end, t_type_str = trigger for argument in arguments['events'][trigger]: a_start, a_end, a_type_idx = argument # strict metric #arguments_true.append((i, t_start, t_end, t_type_str, a_start, a_end, a_type_idx)) # relaxed metric if idx2argument[a_type_idx] == module: arguments_true.append( (i, t_type_str, a_start, a_end, 2)) #else: # arguments_true.append((i, t_type_str, a_start, a_end, 1)) #print(arguments_hat) for trigger in arguments_hat['events']: t_start, t_end, t_type_str = trigger for argument in arguments_hat['events'][trigger]: a_start, a_end, a_type_idx = argument # stric metric # arguments_pred.append((i, t_start, t_end, t_type_str, a_start, a_end, a_type_idx)) # relaxed metric #if idx2argument[a_type_idx] == module: arguments_pred.append( (i, t_type_str, a_start, a_end, a_type_idx )) # 2 is the specific argument idx in module network # else: # print(idx2argument[a_type_idx]) # arguments_pred.append((i, t_type_str, a_start, a_end, 1)) # if len(arguments_pred) == 0: # print('---batch {} -----'.format(i)) # print(arguments_hat) for w, t, t_h in zip(words[1:-1], triggers, triggers_hat): fout.write('{}\t{}\t{}\n'.format(w, t, t_h)) fout.write('#arguments#{}\n'.format(arguments['events'])) fout.write('#arguments_hat#{}\n'.format(arguments_hat['events'])) fout.write("\n") # print(classification_report([idx2trigger[idx] for idx in y_true], [idx2trigger[idx] for idx in y_pred])) print('[trigger classification]') trigger_p, trigger_r, trigger_f1 = calc_metric(triggers_true, triggers_pred) print('P={:.3f}\tR={:.3f}\tF1={:.3f}'.format(trigger_p, trigger_r, trigger_f1)) print('[argument classification]') argument_p, argument_r, argument_f1, num_proposed, num_correct, num_gold = calc_metric( arguments_true, arguments_pred, num_flag=True) print('P={:.3f}\tR={:.3f}\tF1={:.3f}'.format(argument_p, argument_r, argument_f1)) #print('[trigger identification]') # triggers_true = [(item[0], item[1], item[2]) for item in triggers_true] # triggers_pred = [(item[0], item[1], item[2]) for item in triggers_pred] # trigger_p_, trigger_r_, trigger_f1_ = calc_metric(triggers_true, triggers_pred) #print('P={:.3f}\tR={:.3f}\tF1={:.3f}'.format(trigger_p_, trigger_r_, trigger_f1_)) #print('[argument identification]') # strcit metric #arguments_true = [(item[0], item[1], item[2], item[3], item[4], item[5]) for item in arguments_true] #arguments_pred = [(item[0], item[1], item[2], item[3], item[4], item[5]) for item in arguments_pred] # relax metric # arguments_true = [(item[0], item[1], item[2], item[3]) for item in arguments_true] # arguments_pred = [(item[0], item[1], item[2], item[3]) for item in arguments_pred] # argument_p_, argument_r_, argument_f1_ = calc_metric(arguments_true, arguments_pred) #print('P={:.3f}\tR={:.3f}\tF1={:.3f}'.format(argument_p_, argument_r_, argument_f1_)) metric = '[trigger classification]\tP={:.3f}\tR={:.3f}\tF1={:.3f}\n'.format( trigger_p, trigger_r, trigger_f1) # metric += '[argument classification]\tP={:.3f}\tR={:.3f}\tF1={:.3f}\n'.format(argument_p, argument_r, argument_f1) # metric += '[trigger identification]\tP={:.3f}\tR={:.3f}\tF1={:.3f}\n'.format(trigger_p_, trigger_r_, trigger_f1_) # metric += '[argument identification]\tP={:.3f}\tR={:.3f}\tF1={:.3f}\n'.format(argument_p_, argument_r_, argument_f1_) # final = fname + ".P%.2f_R%.2f_F%.2f" % (trigger_p, trigger_r, trigger_f1) # with open(final, 'w') as fout: # result = open("temp", "r").read() # fout.write("{}\n".format(result)) # fout.write(metric) # os.remove("temp") return metric, argument_f1, num_proposed, num_correct, num_gold #,arguments_true, arguments_pred