Example #1
0
def test(
    C,
    logger,
    dataset,
    models,
    loss_func,
    generator,
    mode="valid",
    epoch_id=0,
    run_name="0",
    need_generated=False,
):

    device, batch_size, batch_numb, models = before_test(
        C, logger, dataset, models)

    pbar = tqdm(range(batch_numb), ncols=70)
    avg_loss = 0
    generated = ""
    for batch_id in pbar:

        data = dataset[batch_id * batch_size:(batch_id + 1) * batch_size]
        sents, ents, anss, data_ent = get_data_from_batch(data,
                                                          device=tc.device(
                                                              C.device))

        with tc.no_grad():
            model, preds, loss, partial_generated = get_output(
                C, logger, models, device, loss_func, generator, sents, ents,
                anss, data_ent)
        generated += partial_generated
        avg_loss += float(loss) / len(models)

        pbar.set_description_str("(Test )Epoch {0}".format(epoch_id))
        pbar.set_postfix_str("loss = %.4f (avg = %.4f)" %
                             (float(loss), avg_loss / (batch_id + 1)))

    micro_f1, macro_f1 = get_evaluate(C, logger, mode, generated, generator,
                                      dataset)

    #print (result)
    logger.log(
        "-----Epoch {} tested. Micro F1 = {:.2f}% , Macro F1 = {:.2f}% , loss = {:.4f}"
        .format(epoch_id, micro_f1, macro_f1, avg_loss / batch_numb))
    logger.log("\n")

    fitlog.add_metric(micro_f1,
                      step=epoch_id,
                      name="({0})micro f1".format(run_name))
    fitlog.add_metric(macro_f1,
                      step=epoch_id,
                      name="({0})macro f1".format(run_name))

    if need_generated:
        return micro_f1, macro_f1, avg_loss, generated

    return micro_f1, macro_f1, avg_loss
Example #2
0
def compare(C, logger, dataset, models_1, models_2, generator):
    #----- determine some arguments and prepare model -----

    bert_type = "bert-base-uncased"
    tokenizer = BertTokenizer.from_pretrained(bert_type)

    golden = write_keyfile(dataset, generator)

    models_1 = models_1.eval()
    models_2 = models_2.eval()

    batch_size = 8
    batch_numb = (len(dataset) // batch_size) + int(
        (len(dataset) % batch_size) != 0)

    #----- gene -----

    readable_info = ""
    json_info = []

    all_generated_1 = ""
    all_generated_2 = ""
    for batch_id in tqdm(range(batch_numb), ncols=70, desc="Generating..."):
        #----- get data -----

        data = dataset[batch_id * batch_size:(batch_id + 1) * batch_size]
        sents, ents, anss, data_ent = get_data_from_batch(data,
                                                          device=tc.device(
                                                              C.device))

        with tc.no_grad():

            preds_1 = models_1(sents, ents, output_preds=True)
            preds_2 = models_2(sents, ents, output_preds=True)

            #----- get generated output -----

            ans_rels = [[(u, v) for u, v, t in bat]
                        for bat in anss] if C.gene_in_data else None
            generated_1 = generator(preds_1,
                                    data_ent,
                                    ans_rels=ans_rels,
                                    split_generate=True)
            generated_2 = generator(preds_2,
                                    data_ent,
                                    ans_rels=ans_rels,
                                    split_generate=True)

            all_generated_1 += "".join(generated_1)
            all_generated_2 += "".join(generated_2)

        for text_id in range(len(data)):

            #----- form data structure -----
            # text
            tmp_sents = sents[text_id]
            while tmp_sents[-1] == 0:  # remove padding
                tmp_sents = tmp_sents[:-1]
            text = tokenizer.decode(tmp_sents[1:-1])

            # entitys
            tmp_ents = ents[text_id]
            for i in range(len(tmp_ents)):
                tmp_ents[i].append(
                    tokenizer.decode(tmp_sents[tmp_ents[i][0]:tmp_ents[i][1]]))
                tmp_ents[i][0] = len(
                    tokenizer.decode(tmp_sents[1:tmp_ents[i][0]])
                ) + 1  #前方的字符数(+1 is for space)
                tmp_ents[i][1] = len(
                    tokenizer.decode(tmp_sents[1:tmp_ents[i][1]]))  #前方的字符数
                tmp_ents[i] = [i] + tmp_ents[i]

            # golden answer
            tmp_anss = anss[text_id]
            for i in range(len(tmp_anss)):
                tmp_anss[i][2] = relations[tmp_anss[i][2]]
            golden_ans = tmp_anss

            # model 1 output
            got_ans_1 = []
            for x in list(
                    filter(lambda x: x,
                           generated_1[text_id].strip().split("\n"))):
                if x == "":
                    continue
                reg = "(.*)\\(.*\\.(\\d*)\\,.*\\.(\\d*)(.*)\\)"
                rel_type, u, v, rev = re.findall(reg, x)[0]
                assert (not rev) or (rev == ",REVERSE")
                if rev: u, v = v, u
                got_ans_1.append([int(u) - 1, int(v) - 1, rel_type])

            got_ans_2 = []
            for x in list(
                    filter(lambda x: x,
                           generated_2[text_id].strip().split("\n"))):
                if x == "":
                    continue
                reg = "(.*)\\(.*\\.(\\d*)\\,.*\\.(\\d*)(.*)\\)"
                rel_type, u, v, rev = re.findall(reg, x)[0]
                assert (not rev) or (rev == ",REVERSE")
                if rev: u, v = v, u
                got_ans_2.append([int(u) - 1, int(v) - 1, rel_type])

            tmp_ents_s = beautiful_str(["id", "l", "r", "content"], tmp_ents)

            if (not C.gene_in_data) or (not C.gene_no_rel):
                golden_ans_s = beautiful_str(
                    ["ent 0 id", "ent 1 id", "relation type"], golden_ans)
                got_ans_1_s = beautiful_str(
                    ["ent 0 id", "ent 1 id", "relation type"], got_ans_1)
                got_ans_2_s = beautiful_str(
                    ["ent 0 id", "ent 1 id", "relation type"], got_ans_2)

                readable_info += "text-%d:\n%s\n\nentitys:%s\n\ngolden relations:%s\n\nmodel output-1:%s\n\noutput-1:%s\n\n\n" % (
                    batch_id * batch_size + text_id + 1, text, tmp_ents_s,
                    golden_ans_s, got_ans_1_s, got_ans_2_s)

                json_info.append({
                    "text-id": batch_id * batch_size + text_id + 1,
                    "text": text,
                    "entitys": intize(tmp_ents, [0, 1, 2]),
                    "golden_ans": intize(golden_ans, [0, 1]),
                    "got_ans_1": intize(got_ans_1, [0, 1]),
                    "got_ans_2": intize(got_ans_2, [0, 1]),
                })
            else:  #ensure there are exactly the same entity pairs in gold and generated

                try:
                    assert [x[:2]
                            for x in golden_ans] == [x[:2] for x in got_ans_1]
                    assert [x[:2]
                            for x in golden_ans] == [x[:2] for x in got_ans_2]
                except AssertionError:
                    pdb.set_trace()

                all_ans = []
                for _ins_i in range(len(golden_ans)):
                    all_ans.append([
                        golden_ans[_ins_i][0],
                        golden_ans[_ins_i][1],
                        golden_ans[_ins_i][2],
                        got_ans_1[_ins_i][2],
                        got_ans_2[_ins_i][2],
                    ])

                all_ans_s = beautiful_str(
                    ["ent 0 id", "ent 1 id", "golden", "model 1", "model 2"],
                    all_ans)
                readable_info += "text-%d:\n%s\n\nentitys:%s\n\noutputs:%s\n\n\n" % (
                    text_id + 1,
                    text,
                    tmp_ents_s,
                    all_ans_s,
                )

                json_info.append({
                    "text-id": batch_id * batch_size + text_id + 1,
                    "text": text,
                    "entitys": intize(tmp_ents, [0, 1, 2]),
                    "relations": intize(all_ans, [0, 1]),
                })

    os.makedirs(os.path.dirname(C.gene_file), exist_ok=True)
    with open(C.gene_file + ".txt", "w", encoding="utf-8") as fil:
        fil.write(readable_info)
    with open(C.gene_file + ".json", "w", encoding="utf-8") as fil:
        json.dump(json_info, fil)
    print("score (model 1): %.4f %.4f" %
          get_f1(golden,
                 all_generated_1,
                 is_file_content=True,
                 no_rel_name=generator.get_no_rel_name()))
    print("score (model 2): %.4f %.4f" %
          get_f1(golden,
                 all_generated_2,
                 is_file_content=True,
                 no_rel_name=generator.get_no_rel_name()))
def generate_output(C, logger, dataset, models, generator):
    #----- determine some arguments and prepare model -----

    bert_type = "bert-base-uncased"
    tokenizer = BertTokenizer.from_pretrained(bert_type)

    if models is not None:
        if isinstance(models, tc.nn.Module):
            models = [models]
        for i in range(len(models)):
            models[i] = models[i].eval()
    batch_size = 8
    batch_numb = (len(dataset) // batch_size) + int(
        (len(dataset) % batch_size) != 0)

    device = tc.device(C.device)

    readable_info = ""
    model_output = []
    dataset_info = []
    all_generated = ""

    #----- gene -----

    #dataset = dataset[:5]
    pbar = tqdm(range(batch_numb), ncols=70)

    generated = ""
    for batch_id in pbar:
        #----- get data -----

        data = dataset[batch_id * batch_size:(batch_id + 1) * batch_size]
        sents, ents, anss, data_ent = get_data_from_batch(data,
                                                          device=tc.device(
                                                              C.device))

        if models is not None:
            with tc.no_grad():

                preds = [0 for _ in range(len(models))]
                for i, model in enumerate(models):

                    old_device = next(model.parameters()).device
                    model = model.to(device)
                    preds[i] = model(sents, ents)
                    model = model.to(old_device)  #如果他本来在cpu上,生成完之后还是把他放回cpu

                #----- get generated output -----

                ans_rels = [[(u, v) for u, v, t in bat]
                            for bat in anss] if C.gene_in_data else None
                generated, pred = generator(preds,
                                            data_ent,
                                            ans_rels=ans_rels,
                                            give_me_pred=True,
                                            split_generate=True)
                all_generated += "".join(generated)

        for text_id in range(len(data)):

            #----- form data structure -----
            # text
            tmp_sents = sents[text_id]
            while tmp_sents[-1] == 0:  # remove padding
                tmp_sents = tmp_sents[:-1]
            text = tokenizer.decode(tmp_sents[1:-1])

            # entitys
            tmp_ents = ents[text_id]
            for i in range(len(tmp_ents)):
                tmp_ents[i].append(
                    tokenizer.decode(tmp_sents[tmp_ents[i][0]:tmp_ents[i][1]]))
                tmp_ents[i][0] = len(
                    tokenizer.decode(tmp_sents[1:tmp_ents[i][0]])
                ) + 1  #前方的字符数(+1 is for space)
                tmp_ents[i][1] = len(
                    tokenizer.decode(tmp_sents[1:tmp_ents[i][1]]))  #前方的字符数
                tmp_ents[i] = [i] + tmp_ents[i]

            # golden answer
            tmp_anss = anss[text_id]
            for i in range(len(tmp_anss)):
                tmp_anss[i][2] = relations[tmp_anss[i][2]]
            golden_ans = tmp_anss

            # model output
            if models is not None:
                got_ans = []
                for x in list(
                        filter(lambda x: x,
                               generated[text_id].strip().split("\n"))):
                    if x == "":
                        continue
                    reg = "(.*)\\(.*\\.(\\d*)\\,.*\\.(\\d*)(.*)\\)"
                    rel_type, u, v, rev = re.findall(reg, x)[0]
                    assert (not rev) or (rev == ",REVERSE")
                    if rev: u, v = v, u
                    got_ans.append([int(u) - 1, int(v) - 1, rel_type])

            if models is not None:
                tmp_pred = pred[text_id]

                for u, v, _ in got_ans:
                    model_output.append({
                        "doc_id":
                        text_id + 1,
                        "ent0_id":
                        u,
                        "ent1_id":
                        v,
                        "list_of_prob": [float(x) for x in tmp_pred[u][v]],
                    })

            dataset_info.append({
                "doc_id":
                text_id + 1,
                "text":
                text,
                "entity_set": [[int(idx), int(l),
                                int(r), cont] for idx, l, r, cont in tmp_ents],
                "list_of_relations":
                [[x[0], x[1], relations.index(x[2])] for x in golden_ans],
            })

            tmp_ents = beautiful_str(["id", "l", "r", "content"], tmp_ents)
            golden_ans = beautiful_str(["ent0 id", "ent1 id", "relation type"],
                                       golden_ans)
            if models is not None:
                got_ans = beautiful_str(
                    ["ent0 id", "ent1 id", "relation type"], got_ans)
            else:
                got_ans = "None"

            readable_info += "text-%d:\n%s\n\nentitys:%s\n\ngolden relations:%s\n\nmodel(edge-aware) output:%s\n\n\n" % (
                text_id + 1, text, tmp_ents, golden_ans, got_ans)

        pbar.set_description_str("(Generate)")

    os.makedirs(os.path.dirname(C.gene_file), exist_ok=True)
    with open(C.gene_file + ".txt", "w", encoding="utf-8") as fil:
        fil.write(readable_info)
    with open(C.gene_file + ".generate.txt", "w", encoding="utf-8") as fil:
        fil.write(all_generated)
    with open(C.gene_file + ".model.json", "w", encoding="utf-8") as fil:
        json.dump(model_output, fil)
    with open(C.gene_file + ".dataset.json", "w", encoding="utf-8") as fil:
        json.dump(dataset_info, fil)
def train(C,
          logger,
          train_data,
          valid_data,
          loss_func,
          generator,
          n_rel_typs,
          run_name="0",
          test_data=None):
    (batch_numb, device), (model, optimizer,
                           scheduler) = before_train(C, logger, train_data,
                                                     n_rel_typs)
    #----- iterate each epoch -----

    best_epoch = -1
    best_metric = -1
    for epoch_id in range(C.epoch_numb):

        pbar = tqdm(range(batch_numb), ncols=70)
        avg_loss = 0
        for batch_id in pbar:
            #----- get data -----
            data = train_data[batch_id * C.batch_size:(batch_id + 1) *
                              C.batch_size]
            sents, ents, anss, data_ent = get_data_from_batch(data,
                                                              device=device)

            loss, pred = update_batch(C, logger, model, optimizer, scheduler,
                                      loss_func, sents, ents, anss, data_ent)

            avg_loss += float(loss)
            fitlog.add_loss(value=float(loss),
                            step=epoch_id * batch_numb + batch_id,
                            name="({0})train loss".format(run_name))

            pbar.set_description_str("(Train)Epoch %d" % (epoch_id))
            pbar.set_postfix_str("loss = %.4f (avg = %.4f)" %
                                 (float(loss), avg_loss / (batch_id + 1)))
        logger.log("Epoch %d ended. avg_loss = %.4f" %
                   (epoch_id, avg_loss / batch_numb))

        micro_f1, macro_f1, test_loss = test(
            C,
            logger,
            valid_data,
            model,
            loss_func,
            generator,
            "valid",
            epoch_id,
            run_name,
        )

        if C.valid_metric in ["macro*micro", "micro*macro"]:
            metric = macro_f1 * micro_f1
        elif C.valid_metric == "macro":
            metric = macro_f1
        elif C.valid_metric == "micro":
            metric = micro_f1
        else:
            assert False

        if best_metric < metric:
            best_epoch = epoch_id
            best_metric = metric
            with open(C.tmp_file_name + ".model" + "." + str(run_name),
                      "wb") as fil:
                pickle.dump(model, fil)

        #	fitlog.add_best_metric(best_macro_f1 , name = "({0})macro f1".format(ensemble_id))

        model = model.train()

    if not C.no_valid:  #reload best model
        with open(C.tmp_file_name + ".model" + "." + str(run_name),
                  "rb") as fil:
            model = pickle.load(fil)  #load best valid model
        logger.log("reloaded best model at epoch %d" % best_epoch)

    if test_data is not None:
        final_micro_f1, final_macro_f1, final_test_loss = test(
            C,
            logger,
            test_data,
            model,
            loss_func,
            generator,
            "test",
            epoch_id,
            run_name,
        )

    return model, best_metric
Example #5
0
def gene_golden(C , logger , dataset , generator ):
	#----- determine some arguments and prepare model -----

	bert_type = "bert-base-uncased"
	tokenizer = BertTokenizer.from_pretrained(bert_type)

	batch_size = 8
	batch_numb = (len(dataset) // batch_size) + int((len(dataset) % batch_size) != 0)

	#----- gene -----

	readable_info = ""
	json_info = []

	for batch_id in tqdm(range(batch_numb) , ncols = 70 , desc = "Generating..."):
		#----- get data -----

		data = dataset[batch_id * batch_size:(batch_id+1) * batch_size]
		sents , ents , anss , data_ent = get_data_from_batch(data, device=tc.device(C.device))


		for text_id in range(len(data)):

			#----- form data structure -----
			# text
			tmp_sents = sents[text_id]
			while tmp_sents[-1] == 0: # remove padding
				tmp_sents = tmp_sents[:-1]
			text = tokenizer.decode(tmp_sents[1:-1])

			# entitys
			tmp_ents = ents[text_id]
			for i in range(len(tmp_ents)):
				tmp_ents[i].append(tokenizer.decode(tmp_sents[tmp_ents[i][0] : tmp_ents[i][1]]))
				tmp_ents[i][0] = len(tokenizer.decode(tmp_sents[1:tmp_ents[i][0]]))+1 #前方的字符数(+1 is for space)
				tmp_ents[i][1] = len(tokenizer.decode(tmp_sents[1:tmp_ents[i][1]]))   #前方的字符数
				tmp_ents[i] = [i] + tmp_ents[i]
			
			# golden answer
			tmp_anss = anss[text_id]
			for i in range(len(tmp_anss)):
				tmp_anss[i][2] = relations[tmp_anss[i][2]]
			golden_ans = tmp_anss


			tmp_ents_s 	= beautiful_str(["id" , "l" , "r" , "content"] , tmp_ents)
			golden_ans_s = beautiful_str(["ent 0 id" , "ent 1 id" , "relation"] , golden_ans)

			readable_info += "text-%d:\n%s\n\nentitys:%s\n\noutputs:%s\n\n\n" % (
				batch_id*batch_size+text_id+1 , text , tmp_ents_s , golden_ans_s , 
			)

			json_info.append({
				"text-id" 	: batch_id*batch_size+text_id+1 , 
				"text" 	  	: text , 
				"entitys" 	: intize(tmp_ents , [0,1,2]) , 
				"relations" : intize(golden_ans , [0,1]) , 
			})



	os.makedirs(os.path.dirname(C.gene_file) , exist_ok = True)
	with open(C.gene_file + ".txt" , "w" , encoding = "utf-8") as fil:
		fil.write(readable_info)
	with open(C.gene_file + ".json" , "w" , encoding = "utf-8") as fil:
		json.dump(json_info , fil)