def test( C, logger, dataset, models, loss_func, generator, mode="valid", epoch_id=0, run_name="0", need_generated=False, ): device, batch_size, batch_numb, models = before_test( C, logger, dataset, models) pbar = tqdm(range(batch_numb), ncols=70) avg_loss = 0 generated = "" for batch_id in pbar: data = dataset[batch_id * batch_size:(batch_id + 1) * batch_size] sents, ents, anss, data_ent = get_data_from_batch(data, device=tc.device( C.device)) with tc.no_grad(): model, preds, loss, partial_generated = get_output( C, logger, models, device, loss_func, generator, sents, ents, anss, data_ent) generated += partial_generated avg_loss += float(loss) / len(models) pbar.set_description_str("(Test )Epoch {0}".format(epoch_id)) pbar.set_postfix_str("loss = %.4f (avg = %.4f)" % (float(loss), avg_loss / (batch_id + 1))) micro_f1, macro_f1 = get_evaluate(C, logger, mode, generated, generator, dataset) #print (result) logger.log( "-----Epoch {} tested. Micro F1 = {:.2f}% , Macro F1 = {:.2f}% , loss = {:.4f}" .format(epoch_id, micro_f1, macro_f1, avg_loss / batch_numb)) logger.log("\n") fitlog.add_metric(micro_f1, step=epoch_id, name="({0})micro f1".format(run_name)) fitlog.add_metric(macro_f1, step=epoch_id, name="({0})macro f1".format(run_name)) if need_generated: return micro_f1, macro_f1, avg_loss, generated return micro_f1, macro_f1, avg_loss
def compare(C, logger, dataset, models_1, models_2, generator): #----- determine some arguments and prepare model ----- bert_type = "bert-base-uncased" tokenizer = BertTokenizer.from_pretrained(bert_type) golden = write_keyfile(dataset, generator) models_1 = models_1.eval() models_2 = models_2.eval() batch_size = 8 batch_numb = (len(dataset) // batch_size) + int( (len(dataset) % batch_size) != 0) #----- gene ----- readable_info = "" json_info = [] all_generated_1 = "" all_generated_2 = "" for batch_id in tqdm(range(batch_numb), ncols=70, desc="Generating..."): #----- get data ----- data = dataset[batch_id * batch_size:(batch_id + 1) * batch_size] sents, ents, anss, data_ent = get_data_from_batch(data, device=tc.device( C.device)) with tc.no_grad(): preds_1 = models_1(sents, ents, output_preds=True) preds_2 = models_2(sents, ents, output_preds=True) #----- get generated output ----- ans_rels = [[(u, v) for u, v, t in bat] for bat in anss] if C.gene_in_data else None generated_1 = generator(preds_1, data_ent, ans_rels=ans_rels, split_generate=True) generated_2 = generator(preds_2, data_ent, ans_rels=ans_rels, split_generate=True) all_generated_1 += "".join(generated_1) all_generated_2 += "".join(generated_2) for text_id in range(len(data)): #----- form data structure ----- # text tmp_sents = sents[text_id] while tmp_sents[-1] == 0: # remove padding tmp_sents = tmp_sents[:-1] text = tokenizer.decode(tmp_sents[1:-1]) # entitys tmp_ents = ents[text_id] for i in range(len(tmp_ents)): tmp_ents[i].append( tokenizer.decode(tmp_sents[tmp_ents[i][0]:tmp_ents[i][1]])) tmp_ents[i][0] = len( tokenizer.decode(tmp_sents[1:tmp_ents[i][0]]) ) + 1 #前方的字符数(+1 is for space) tmp_ents[i][1] = len( tokenizer.decode(tmp_sents[1:tmp_ents[i][1]])) #前方的字符数 tmp_ents[i] = [i] + tmp_ents[i] # golden answer tmp_anss = anss[text_id] for i in range(len(tmp_anss)): tmp_anss[i][2] = relations[tmp_anss[i][2]] golden_ans = tmp_anss # model 1 output got_ans_1 = [] for x in list( filter(lambda x: x, generated_1[text_id].strip().split("\n"))): if x == "": continue reg = "(.*)\\(.*\\.(\\d*)\\,.*\\.(\\d*)(.*)\\)" rel_type, u, v, rev = re.findall(reg, x)[0] assert (not rev) or (rev == ",REVERSE") if rev: u, v = v, u got_ans_1.append([int(u) - 1, int(v) - 1, rel_type]) got_ans_2 = [] for x in list( filter(lambda x: x, generated_2[text_id].strip().split("\n"))): if x == "": continue reg = "(.*)\\(.*\\.(\\d*)\\,.*\\.(\\d*)(.*)\\)" rel_type, u, v, rev = re.findall(reg, x)[0] assert (not rev) or (rev == ",REVERSE") if rev: u, v = v, u got_ans_2.append([int(u) - 1, int(v) - 1, rel_type]) tmp_ents_s = beautiful_str(["id", "l", "r", "content"], tmp_ents) if (not C.gene_in_data) or (not C.gene_no_rel): golden_ans_s = beautiful_str( ["ent 0 id", "ent 1 id", "relation type"], golden_ans) got_ans_1_s = beautiful_str( ["ent 0 id", "ent 1 id", "relation type"], got_ans_1) got_ans_2_s = beautiful_str( ["ent 0 id", "ent 1 id", "relation type"], got_ans_2) readable_info += "text-%d:\n%s\n\nentitys:%s\n\ngolden relations:%s\n\nmodel output-1:%s\n\noutput-1:%s\n\n\n" % ( batch_id * batch_size + text_id + 1, text, tmp_ents_s, golden_ans_s, got_ans_1_s, got_ans_2_s) json_info.append({ "text-id": batch_id * batch_size + text_id + 1, "text": text, "entitys": intize(tmp_ents, [0, 1, 2]), "golden_ans": intize(golden_ans, [0, 1]), "got_ans_1": intize(got_ans_1, [0, 1]), "got_ans_2": intize(got_ans_2, [0, 1]), }) else: #ensure there are exactly the same entity pairs in gold and generated try: assert [x[:2] for x in golden_ans] == [x[:2] for x in got_ans_1] assert [x[:2] for x in golden_ans] == [x[:2] for x in got_ans_2] except AssertionError: pdb.set_trace() all_ans = [] for _ins_i in range(len(golden_ans)): all_ans.append([ golden_ans[_ins_i][0], golden_ans[_ins_i][1], golden_ans[_ins_i][2], got_ans_1[_ins_i][2], got_ans_2[_ins_i][2], ]) all_ans_s = beautiful_str( ["ent 0 id", "ent 1 id", "golden", "model 1", "model 2"], all_ans) readable_info += "text-%d:\n%s\n\nentitys:%s\n\noutputs:%s\n\n\n" % ( text_id + 1, text, tmp_ents_s, all_ans_s, ) json_info.append({ "text-id": batch_id * batch_size + text_id + 1, "text": text, "entitys": intize(tmp_ents, [0, 1, 2]), "relations": intize(all_ans, [0, 1]), }) os.makedirs(os.path.dirname(C.gene_file), exist_ok=True) with open(C.gene_file + ".txt", "w", encoding="utf-8") as fil: fil.write(readable_info) with open(C.gene_file + ".json", "w", encoding="utf-8") as fil: json.dump(json_info, fil) print("score (model 1): %.4f %.4f" % get_f1(golden, all_generated_1, is_file_content=True, no_rel_name=generator.get_no_rel_name())) print("score (model 2): %.4f %.4f" % get_f1(golden, all_generated_2, is_file_content=True, no_rel_name=generator.get_no_rel_name()))
def generate_output(C, logger, dataset, models, generator): #----- determine some arguments and prepare model ----- bert_type = "bert-base-uncased" tokenizer = BertTokenizer.from_pretrained(bert_type) if models is not None: if isinstance(models, tc.nn.Module): models = [models] for i in range(len(models)): models[i] = models[i].eval() batch_size = 8 batch_numb = (len(dataset) // batch_size) + int( (len(dataset) % batch_size) != 0) device = tc.device(C.device) readable_info = "" model_output = [] dataset_info = [] all_generated = "" #----- gene ----- #dataset = dataset[:5] pbar = tqdm(range(batch_numb), ncols=70) generated = "" for batch_id in pbar: #----- get data ----- data = dataset[batch_id * batch_size:(batch_id + 1) * batch_size] sents, ents, anss, data_ent = get_data_from_batch(data, device=tc.device( C.device)) if models is not None: with tc.no_grad(): preds = [0 for _ in range(len(models))] for i, model in enumerate(models): old_device = next(model.parameters()).device model = model.to(device) preds[i] = model(sents, ents) model = model.to(old_device) #如果他本来在cpu上,生成完之后还是把他放回cpu #----- get generated output ----- ans_rels = [[(u, v) for u, v, t in bat] for bat in anss] if C.gene_in_data else None generated, pred = generator(preds, data_ent, ans_rels=ans_rels, give_me_pred=True, split_generate=True) all_generated += "".join(generated) for text_id in range(len(data)): #----- form data structure ----- # text tmp_sents = sents[text_id] while tmp_sents[-1] == 0: # remove padding tmp_sents = tmp_sents[:-1] text = tokenizer.decode(tmp_sents[1:-1]) # entitys tmp_ents = ents[text_id] for i in range(len(tmp_ents)): tmp_ents[i].append( tokenizer.decode(tmp_sents[tmp_ents[i][0]:tmp_ents[i][1]])) tmp_ents[i][0] = len( tokenizer.decode(tmp_sents[1:tmp_ents[i][0]]) ) + 1 #前方的字符数(+1 is for space) tmp_ents[i][1] = len( tokenizer.decode(tmp_sents[1:tmp_ents[i][1]])) #前方的字符数 tmp_ents[i] = [i] + tmp_ents[i] # golden answer tmp_anss = anss[text_id] for i in range(len(tmp_anss)): tmp_anss[i][2] = relations[tmp_anss[i][2]] golden_ans = tmp_anss # model output if models is not None: got_ans = [] for x in list( filter(lambda x: x, generated[text_id].strip().split("\n"))): if x == "": continue reg = "(.*)\\(.*\\.(\\d*)\\,.*\\.(\\d*)(.*)\\)" rel_type, u, v, rev = re.findall(reg, x)[0] assert (not rev) or (rev == ",REVERSE") if rev: u, v = v, u got_ans.append([int(u) - 1, int(v) - 1, rel_type]) if models is not None: tmp_pred = pred[text_id] for u, v, _ in got_ans: model_output.append({ "doc_id": text_id + 1, "ent0_id": u, "ent1_id": v, "list_of_prob": [float(x) for x in tmp_pred[u][v]], }) dataset_info.append({ "doc_id": text_id + 1, "text": text, "entity_set": [[int(idx), int(l), int(r), cont] for idx, l, r, cont in tmp_ents], "list_of_relations": [[x[0], x[1], relations.index(x[2])] for x in golden_ans], }) tmp_ents = beautiful_str(["id", "l", "r", "content"], tmp_ents) golden_ans = beautiful_str(["ent0 id", "ent1 id", "relation type"], golden_ans) if models is not None: got_ans = beautiful_str( ["ent0 id", "ent1 id", "relation type"], got_ans) else: got_ans = "None" readable_info += "text-%d:\n%s\n\nentitys:%s\n\ngolden relations:%s\n\nmodel(edge-aware) output:%s\n\n\n" % ( text_id + 1, text, tmp_ents, golden_ans, got_ans) pbar.set_description_str("(Generate)") os.makedirs(os.path.dirname(C.gene_file), exist_ok=True) with open(C.gene_file + ".txt", "w", encoding="utf-8") as fil: fil.write(readable_info) with open(C.gene_file + ".generate.txt", "w", encoding="utf-8") as fil: fil.write(all_generated) with open(C.gene_file + ".model.json", "w", encoding="utf-8") as fil: json.dump(model_output, fil) with open(C.gene_file + ".dataset.json", "w", encoding="utf-8") as fil: json.dump(dataset_info, fil)
def train(C, logger, train_data, valid_data, loss_func, generator, n_rel_typs, run_name="0", test_data=None): (batch_numb, device), (model, optimizer, scheduler) = before_train(C, logger, train_data, n_rel_typs) #----- iterate each epoch ----- best_epoch = -1 best_metric = -1 for epoch_id in range(C.epoch_numb): pbar = tqdm(range(batch_numb), ncols=70) avg_loss = 0 for batch_id in pbar: #----- get data ----- data = train_data[batch_id * C.batch_size:(batch_id + 1) * C.batch_size] sents, ents, anss, data_ent = get_data_from_batch(data, device=device) loss, pred = update_batch(C, logger, model, optimizer, scheduler, loss_func, sents, ents, anss, data_ent) avg_loss += float(loss) fitlog.add_loss(value=float(loss), step=epoch_id * batch_numb + batch_id, name="({0})train loss".format(run_name)) pbar.set_description_str("(Train)Epoch %d" % (epoch_id)) pbar.set_postfix_str("loss = %.4f (avg = %.4f)" % (float(loss), avg_loss / (batch_id + 1))) logger.log("Epoch %d ended. avg_loss = %.4f" % (epoch_id, avg_loss / batch_numb)) micro_f1, macro_f1, test_loss = test( C, logger, valid_data, model, loss_func, generator, "valid", epoch_id, run_name, ) if C.valid_metric in ["macro*micro", "micro*macro"]: metric = macro_f1 * micro_f1 elif C.valid_metric == "macro": metric = macro_f1 elif C.valid_metric == "micro": metric = micro_f1 else: assert False if best_metric < metric: best_epoch = epoch_id best_metric = metric with open(C.tmp_file_name + ".model" + "." + str(run_name), "wb") as fil: pickle.dump(model, fil) # fitlog.add_best_metric(best_macro_f1 , name = "({0})macro f1".format(ensemble_id)) model = model.train() if not C.no_valid: #reload best model with open(C.tmp_file_name + ".model" + "." + str(run_name), "rb") as fil: model = pickle.load(fil) #load best valid model logger.log("reloaded best model at epoch %d" % best_epoch) if test_data is not None: final_micro_f1, final_macro_f1, final_test_loss = test( C, logger, test_data, model, loss_func, generator, "test", epoch_id, run_name, ) return model, best_metric
def gene_golden(C , logger , dataset , generator ): #----- determine some arguments and prepare model ----- bert_type = "bert-base-uncased" tokenizer = BertTokenizer.from_pretrained(bert_type) batch_size = 8 batch_numb = (len(dataset) // batch_size) + int((len(dataset) % batch_size) != 0) #----- gene ----- readable_info = "" json_info = [] for batch_id in tqdm(range(batch_numb) , ncols = 70 , desc = "Generating..."): #----- get data ----- data = dataset[batch_id * batch_size:(batch_id+1) * batch_size] sents , ents , anss , data_ent = get_data_from_batch(data, device=tc.device(C.device)) for text_id in range(len(data)): #----- form data structure ----- # text tmp_sents = sents[text_id] while tmp_sents[-1] == 0: # remove padding tmp_sents = tmp_sents[:-1] text = tokenizer.decode(tmp_sents[1:-1]) # entitys tmp_ents = ents[text_id] for i in range(len(tmp_ents)): tmp_ents[i].append(tokenizer.decode(tmp_sents[tmp_ents[i][0] : tmp_ents[i][1]])) tmp_ents[i][0] = len(tokenizer.decode(tmp_sents[1:tmp_ents[i][0]]))+1 #前方的字符数(+1 is for space) tmp_ents[i][1] = len(tokenizer.decode(tmp_sents[1:tmp_ents[i][1]])) #前方的字符数 tmp_ents[i] = [i] + tmp_ents[i] # golden answer tmp_anss = anss[text_id] for i in range(len(tmp_anss)): tmp_anss[i][2] = relations[tmp_anss[i][2]] golden_ans = tmp_anss tmp_ents_s = beautiful_str(["id" , "l" , "r" , "content"] , tmp_ents) golden_ans_s = beautiful_str(["ent 0 id" , "ent 1 id" , "relation"] , golden_ans) readable_info += "text-%d:\n%s\n\nentitys:%s\n\noutputs:%s\n\n\n" % ( batch_id*batch_size+text_id+1 , text , tmp_ents_s , golden_ans_s , ) json_info.append({ "text-id" : batch_id*batch_size+text_id+1 , "text" : text , "entitys" : intize(tmp_ents , [0,1,2]) , "relations" : intize(golden_ans , [0,1]) , }) os.makedirs(os.path.dirname(C.gene_file) , exist_ok = True) with open(C.gene_file + ".txt" , "w" , encoding = "utf-8") as fil: fil.write(readable_info) with open(C.gene_file + ".json" , "w" , encoding = "utf-8") as fil: json.dump(json_info , fil)