def evaluation(self): if self.hyper.is_bert == "nyt_bert_tokenizer" or \ self.hyper.is_bert == "nyt11_bert_tokenizer" or \ self.hyper.is_bert == "nyt10_bert_tokenizer": dev_set = Selection_bert_Nyt_Dataset(self.hyper, self.hyper.dev) loader = Selection_bert_loader(dev_set, batch_size=100, pin_memory=True) elif self.hyper.is_bert == "bert_bilstem_crf": dev_set = Selection_Nyt_Dataset(self.hyper, self.hyper.dev) loader = Selection_loader(dev_set, batch_size=100, pin_memory=True) else: dev_set = Selection_Dataset(self.hyper, self.hyper.dev) loader = Selection_loader(dev_set, batch_size=100, pin_memory=True) self.metrics.reset() self.model.eval() pbar = tqdm(enumerate(BackgroundGenerator(loader)), total=len(loader)) with torch.no_grad(): for batch_ndx, sample in pbar: output = self.model(sample, is_train=False) #self.metrics(output['selection_triplets'], output['spo_gold']) self.metrics(output, output) result = self.metrics.get_metric() print(', '.join([ "%s: %.4f" % (name, value) for name, value in result.items() if not name.startswith("_") ]) + " ||")
def train(self): train_set = Selection_Dataset(self.hyper, self.hyper.train) loader = Selection_loader(train_set, batch_size=self.hyper.train_batch, pin_memory=True) for epoch in range(self.hyper.epoch_num): self.model.train() pbar = tqdm(enumerate(BackgroundGenerator(loader)), total=len(loader)) for batch_idx, sample in pbar: self.optimizer.zero_grad() output = self.model(sample, is_train=True) loss = output['loss'] loss.backward() self.optimizer.step() pbar.set_description(output['description']( epoch, self.hyper.epoch_num)) # break self.save_model(epoch) if epoch % self.hyper.print_epoch == 0 and epoch > 3: self.evaluation()
def evaluation(self): dev_set = Selection_Dataset(self.hyper, self.hyper.dev) loader = Selection_loader(dev_set, batch_size=self.hyper.eval_batch, pin_memory=True) self.triplet_metrics.reset() self.model.eval() pbar = tqdm(enumerate(BackgroundGenerator(loader)), total=len(loader)) with torch.no_grad(): for batch_ndx, sample in pbar: output = self.model(sample, is_train=False) self.triplet_metrics(output['selection_triplets'], output['spo_gold']) self.ner_metrics(output['gold_tags'], output['decoded_tag']) triplet_result = self.triplet_metrics.get_metric() ner_result = self.ner_metrics.get_metric() print('Triplets-> ' + ', '.join([ "%s: %.4f" % (name[0], value) for name, value in triplet_result.items() if not name.startswith("_") ]) + ' ||' + 'NER->' + ', '.join([ "%s: %.4f" % (name[0], value) for name, value in ner_result.items() if not name.startswith("_") ]))
def postcheck(self): dev_set = Selection_Dataset(self.hyper, self.hyper.xgb_train_root, is_xgb=True) loader = Selection_loader(dev_set, batch_size=self.hyper.eval_batch, pin_memory=True) self.model.eval() pbar = tqdm(enumerate(BackgroundGenerator(loader)), total=len(loader)) self.save_rc.reset('pos_neg_score.json') flag = False with torch.no_grad(): for batch_ndx, sample in pbar: output = self.model(sample, is_train=False) if output.get('pos_list'): pos_list = output['pos_list'] flag = True neg_list = output['neg_list'] if flag: self.save_rc.save(neg_list, pos_list) else: self.save_rc.save(neg_list)
def evaluation(self): dev_set = Selection_Dataset(self.hyper, self.hyper.dev) loader = Selection_loader(dev_set, batch_size=self.hyper.eval_batch, pin_memory=True) self.selection_metrics.reset() self.model.eval() pbar = tqdm(enumerate(BackgroundGenerator(loader)), total=len(loader)) with torch.no_grad(): for batch_ndx, sample in pbar: pred = self.model(sample, is_train=False) pred = torch.sigmoid(pred) > 0.5 labels = sample.selection_id self.selection_metrics( np.array(pred.cpu().numpy(), dtype=int).tolist(), np.array(labels.cpu().numpy(), dtype=int).tolist()) triplet_result = self.selection_metrics.get_metric() print('Triplets-> ' + ', '.join([ "%s: %.4f" % (name[0], value) for name, value in triplet_result.items() if not name.startswith("_") ]))
def train(self): if self.hyper.is_bert == "bert_bilstem_crf": train_set = Selection_Nyt_Dataset(self.hyper, self.hyper.train) loader = Selection_loader(train_set, batch_size=100, pin_memory=True) elif self.hyper.is_bert == "nyt_bert_tokenizer" or \ self.hyper.is_bert == "nyt11_bert_tokenizer" or \ self.hyper.is_bert == "nyt10_bert_tokenizer": train_set = Selection_bert_Nyt_Dataset(self.hyper, self.hyper.train) loader = Selection_bert_loader(train_set, batch_size=100, pin_memory=True) else: train_set = Selection_Dataset(self.hyper, self.hyper.train) loader = Selection_loader(train_set, batch_size=100, pin_memory=True) for epoch in range(self.hyper.epoch_num): self.model.train() pbar = tqdm(enumerate(BackgroundGenerator(loader)), total=len(loader)) for batch_idx, sample in pbar: self.optimizer.zero_grad() output = self.model(sample, is_train=True) loss = output['loss'] loss.backward() self.optimizer.step() pbar.set_description(output['description']( epoch, self.hyper.epoch_num)) self.save_model(epoch) if epoch > 3: self.evaluation() #if epoch >= 6: # self.evaluation() '''
def evaluation(self): dev_set = Selection_Dataset(self.hyper, self.hyper.dev) loader = Selection_loader(dev_set, batch_size=self.hyper.eval_batch, pin_memory=True) self.triplet_metrics.reset() self.model.eval() self.model._init_spo_search(SPO_searcher(self.hyper)) pbar = tqdm(enumerate(BackgroundGenerator(loader)), total=len(loader)) self.save_rc.reset('dev.json') with torch.no_grad(): for batch_ndx, sample in pbar: output = self.model(sample, is_train=False) """ attn = self.model.attn.cpu() if batch_ndx == 41: print(sample.text[0]) attn_multi_plot(attn,sample.text[0]) """ # distant search self.triplet_metrics(output['selection_triplets'], output['spo_gold']) self.ner_metrics(output['gold_tags'], output['decoded_tag']) #self.save_err.save(batch_ndx,output['selection_triplets'], output['spo_gold']) self.save_rc.save(output['selection_triplets']) triplet_result = self.triplet_metrics.get_metric() ner_result = self.ner_metrics.get_metric() print('Triplets-> ' + ', '.join([ "%s: %.4f" % (name[0], value) for name, value in triplet_result.items() if not name.startswith("_") ]) + ' ||' + 'NER->' + ', '.join([ "%s: %.4f" % (name[0], value) for name, value in ner_result.items() if not name.startswith("_") ]))
def train_pso(self, epoch=None): train_set = Selection_Dataset(self.hyper, self.hyper.train) loader = Selection_loader(train_set, batch_size=self.hyper.train_batch, pin_memory=True) if not epoch: epoch = 0 num_train_steps = int( len(loader) / self.hyper.train_batch * (self.hyper.epoch_num - epoch + 1)) num_warmup_steps = int(num_train_steps * self.hyper.warmup_prop) self.scheduler = self._scheduler(self.optimizer, num_warmup_steps, num_train_steps) while epoch <= self.hyper.epoch_num: pbar = tqdm(enumerate(BackgroundGenerator(loader)), total=len(loader)) for batch_idx, sample in pbar: self.model.train() output = self.model(sample, is_train=True) loss = output['loss'] loss.backward() torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0) self.optimizer.step() self.scheduler.step() self.model.zero_grad() pbar.set_description(output['description']( epoch, self.hyper.epoch_num)) self.save_model(epoch) epoch += 1
def train(self, epoch=None): train_set = Selection_Dataset(self.hyper, self.hyper.train) #sampler = torch.utils.data.sampler.WeightedRandomSampler(train_set.weight,len(train_set.weight)) loader = Selection_loader(train_set, batch_size=self.hyper.train_batch, pin_memory=True) #loader = Selection_loader(train_set, batch_size=self.hyper.train_batch,sampler = sampler, pin_memory=True) if not epoch: epoch = 0 while epoch <= self.hyper.epoch_num: #for epoch in range(self.hyper.epoch_num): self.model.train() pbar = tqdm(enumerate(BackgroundGenerator(loader)), total=len(loader)) for batch_idx, sample in pbar: self.optimizer.zero_grad() output = self.model(sample, is_train=True) loss = output['loss'] loss.backward() self.optimizer.step() pbar.set_description(output['description']( epoch, self.hyper.epoch_num)) self.save_model(epoch) """ if epoch % self.hyper.print_epoch == 0 and epoch > 3: if self.model_name == 'selection': self.evaluation() elif self.model_name == 'pso_1': self.evaluation_pso() """ epoch += 1
def evaluation(self): dev_set = Selection_Dataset(self.hyper, self.hyper.dev) loader = Selection_loader(dev_set, batch_size=self.hyper.eval_batch, pin_memory=True) self.triplet_metrics.reset() self.model.eval() pbar = tqdm(enumerate(BackgroundGenerator(loader)), total=len(loader)) with torch.no_grad(): with open('./output/sub00.csv', 'w') as file: id = 0 for batch_ndx, sample in pbar: tokens = sample.tokens_id.to(self.device) selection_gold = sample.selection_id.to(self.device) bio_gold = sample.bio_id.to(self.device) text_list = sample.text spo_gold = sample.spo_gold bio_text = sample.bio output = self.model(sample, is_train=False) self.triplet_metrics(output['selection_triplets'], output['spo_gold']) self.ner_metrics(output['gold_tags'], output['decoded_tag']) for i in range(len(output['decoded_tag'])): file.write(str(8001 + id) + ',') if len(output['selection_triplets'][i]) != 0: file.write(output['selection_triplets'][i][0] ['predicate'] + ',') file.write( output['selection_triplets'][i][0]['subject'] + ',') file.write( output['selection_triplets'][i][0]['object'] + '\n') else: if output['decoded_tag'][i].count('B') < 2: file.write('Other' + ',' + 'Other' + ',' + 'Other') else: BIO = output['decoded_tag'][i] tt = ''.join(reversed(BIO)) index1 = BIO.index('B') index2 = len(tt) - tt.index('B') - 1 file.write('Other' + ',' + text_list[i][index2] + ',' + text_list[i][index1]) file.write('\n') id += 1 # file.write('sentence {} BIO:\n'.format(i)) # for j in range(len(text_list[i])): # file.write(text_list[i][j]+' ') # file.write('\n') # file.writelines(bio_text[i]) # file.write('\n') # # file.writelines(output['decoded_tag'][i]) # file.write('\n') # file.writelines(output['gold_tags'][i]) # file.write('\n') # file.write('sentence {} relation:\n'.format(i)) # file.write('\n') # if len(output['selection_triplets']) == 0: # file.write('empty') # else: # file.writelines(str(output['selection_triplets'][i])) # file.write('\n') # file.writelines(str(output['spo_gold'][i])) # file.write('\n') triplet_result = self.triplet_metrics.get_metric() ner_result = self.ner_metrics.get_metric() # print('triplet_result=', triplet_result) # print('ner_result=', ner_result) print('Triplets-> ' + ', '.join([ "%s: %.4f" % (name[0], value) for name, value in triplet_result.items() if not name.startswith("_") ]) + ' ||' + 'NER->' + ', '.join([ "%s: %.4f" % (name[0], value) for name, value in ner_result.items() if not name.startswith("_") ]))
def evaluation_pso(self): dev_set = Selection_Dataset(self.hyper, self.hyper.dev) loader = Selection_loader(dev_set, batch_size=self.hyper.eval_batch, pin_memory=True) self.p_metrics.reset() self.triplet_metrics.reset() self.model.eval() if self.model_name == 'pso_2': self.model_p.eval() all_labels, all_logits = [], [] self.model._init_spo_search(SPO_searcher(self.hyper)) pbar = tqdm(enumerate(BackgroundGenerator(loader)), total=len(loader)) if self.model_name == 'pso_1': with torch.no_grad(): for batch_ndx, sample in pbar: output = self.model(sample, is_train=False) """ if batch_ndx == 44: attn = self.model.attn.cpu() print(attn) attn_plot(attn,sample.text[0]) """ self.p_metrics(output['p_decode'], output['p_golden']) all_labels.extend(output['p_golden']) all_logits.extend(output['p_logits']) """ with open('labels.txt','w',encoding = 'utf-8') as f: for l in all_labels: f.write(' '.join(l)) f.write('\n') with open('scores.txt','w',encoding ='utf-8') as t: for l in all_logits: t.write(' '.join([str(i) for i in l])) t.write('\n') """ p_result = self.p_metrics.get_metric() print('P-> ' + ', '.join([ "%s: %.4f" % (name[0], value) for name, value in p_result.items() if not name.startswith("_") ])) #roc_auc_class(all_labels,all_logits) elif self.model_name == 'pso_2': self.save_rc.reset('dev.json') with torch.no_grad(): for batch_ndx, sample in pbar: output_p = self.model_p(sample, is_train=False) self.p_metrics(output_p['p_decode'], output_p['p_golden']) all_labels.extend(output_p['p_decode']) all_logits.extend(output_p['p_golden']) dev_set.schema_transformer(output_p, sample) output = self.model(sample, is_train=False) """ if batch_ndx == 63: print(sample.text[0]) attn = self.model.attn[2].cpu() attn_pso_plot_sub(attn,sample.text[0]) for id in range(9,12): attn = self.model.attn[id-9].cpu() attn_pso_plot(attn,sample.text[0],id) #attn_pso_plot_stack(attn,sample.text[0],id) attn = self.model.attn[-1].cpu() attn_pso_plot_sub(attn,sample.text[0]) """ self.triplet_metrics(output['selection_triplets'], output['spo_gold']) self.save_err.save(batch_ndx, output['selection_triplets'], output['spo_gold']) self.save_rc.save(output['selection_triplets']) p_result = self.p_metrics.get_metric() triplet_result = self.triplet_metrics.get_metric() print('Triplets-> ' + ', '.join([ "%s: %.4f" % (name[0], value) for name, value in triplet_result.items() if not name.startswith("_") ])) print('P-> ' + ', '.join([ "%s: %.4f" % (name[0], value) for name, value in p_result.items() if not name.startswith("_") ]))