Beispiel #1
0
import model
import config
import torch
import tools
import json
import _pickle
from transformers import AutoTokenizer, AutoModel

torch.backends.cudnn.benchmark = True
torch.cuda.empty_cache()

if __name__ == "__main__":

    c = config.best_config

    coref_model = model.CorefModel(c).eval().to(c["device"])
    tokenizer = AutoTokenizer.from_pretrained(c["transformer_model_name"])

    transformer_model = AutoModel.from_pretrained(
        c["checkpoint_path"] + ".transformer.max").eval().to(c["device"])
    checkpoint = torch.load(c["checkpoint_path"] + ".max")
    coref_model.load_state_dict(checkpoint["model"])

    # data format: [[[sentence1, sentence2, ...], [cluster1, cluster2, ...]], ...]
    # cluster: [[span_start_loc/gap_end_loc, span_len/gap_len], ...]

    # 自己的数据
    # val_data = [
    #     {
    #         "sentences": [["打", "雷", "了", "怎", "么", "发", "短", "信", "安", "慰", "女", "朋", "友", "?", "打", "雷", "时", "还", "给", "她", "发", "?"]],
    #         "clusters": [[[10, 12], [19, 19]]],
def train():

    c = config.best_config
    tokenizer = AutoTokenizer.from_pretrained(c["transformer_model_name"])
    transformer_model = AutoModel.from_pretrained(
        c["transformer_model_name"]).to(c["device"])
    transformer_optim = AdamW(transformer_model.parameters(),
                              lr=c["transformer_lr"])
    transformer_model.train()

    print("preparing data...")

    # =========使用自己的数据=========
    train_data = [{
        "sentences": [[
            "打", "雷", "了", "怎", "么", "发", "短", "信", "安", "慰", "女", "朋", "友",
            "?", "打", "雷", "时", "还", "给", "她", "发", "?"
        ]],
        "clusters": [[[10, 12], [19, 19]]],
        "speaker_ids": [[
            "a", "a", "a", "a", "a", "a", "a", "a", "a", "a", "a", "a", "a",
            "a", "b", "b", "b", "b", "b", "b", "b", "b"
        ]],
        "sentence_map":
        [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1]],
        "subtoken_map": [[
            1, 1, 2, 3, 3, 4, 5, 5, 6, 6, 7, 8, 8, 9, 10, 10, 11, 12, 13, 14,
            15, 16
        ]],
        "genre":
        "dummy_genre",
        "doc_key":
        "dummy_data"
    }]

    val_data = copy.deepcopy(train_data)

    # =========end:使用自己的数据=========

    # =========使用已生成的数据文件=========
    train_data = list()
    with open(c["train_file_path"], "r", encoding="utf-8") as fd:
        for line in fd:
            item = json.loads(line.strip())
            train_data.append(item)

    val_data = list()
    with open(c["val_file_path"], "r", encoding="utf-8") as fd:
        for line in fd:
            item = json.loads(line.strip())
            val_data.append(item)

    # =========end:使用已生成的数据文件=========

    tokenized_train_data = list()
    for data_i in train_data:
        if len(data_i["sentences"]) == 0:
            print("Warning: `sentences` in %s is empty." % data_i["doc_key"])
        else:
            tokenized_train_data.append(
                (tools.tokenize_example(data_i, tokenizer, c)))

    tokenized_val_data = list()
    for data_i in val_data:
        if len(data_i["sentences"]) == 0:
            print("Warning: `sentences` in %s is empty." % data_i["doc_key"])
        else:
            tokenized_val_data.append(
                (tools.tokenize_example(data_i, tokenizer, c)))

    coref_model = model.CorefModel(c).to(c["device"])
    optimizer = torch.optim.Adam(coref_model.parameters(),
                                 lr=c["lr"],
                                 weight_decay=c["weight_decay"])
    accumulated_loss = 0.0
    max_f1 = 0.0

    print("start training...")

    init_time = time.time()
    steps = 0
    while True:
        for idx in RandomSampler(SequentialSampler(tokenized_train_data)):
            steps += 1
            sentences_ids, sentences_masks, sentences_valid_masks, clusters, speaker_ids, sentence_map, subtoken_map, genre = tokenized_train_data[
                idx]
            if len(sentences_ids) > c["max_training_sentences"]:
                sentences_ids, sentences_masks, sentences_valid_masks, clusters, speaker_ids, sentence_map, subtoken_map = tools.truncate_example(
                    sentences_ids, sentences_masks, sentences_valid_masks,
                    clusters, speaker_ids, sentence_map, subtoken_map,
                    c["max_training_sentences"])

            top_antecedents_score, top_antecedents_index, top_m_spans_masks, top_m_spans_start, top_m_spans_end = coref_model(
                sentences_ids, sentences_masks, sentences_valid_masks,
                speaker_ids, sentence_map, subtoken_map, genre,
                transformer_model)

            num_spans = len(top_m_spans_start)
            index_clusters = dict()
            for i, cluster in enumerate(clusters):
                for loc in cluster:
                    index_clusters[tuple(loc)] = i
            top_m_spans_cluster_idx = list()  # size: m
            for i in range(num_spans):
                cluster_idx = find_clusters(
                    (top_m_spans_start[i].item(), top_m_spans_end[i].item()),
                    index_clusters)
                top_m_spans_cluster_idx.append(cluster_idx)
            top_m_spans_cluster_idx = torch.LongTensor(
                top_m_spans_cluster_idx).to(device=torch.device(c["device"]))
            top_m_spans_in_gold_clusters = (top_m_spans_cluster_idx != -1)
            # gold labels, size: m * k
            top_antecedents_label = (
                top_m_spans_cluster_idx[top_antecedents_index].t() ==
                top_m_spans_cluster_idx).t()
            top_antecedents_label[~top_m_spans_in_gold_clusters] = False
            top_antecedents_label[~torch.gather(
                top_m_spans_masks.to(device=torch.device(c["device"])
                                     ), 1, top_antecedents_index)] = False
            # gold labels with dummy label, size: m * (k+1)
            top_antecedents_label = torch.cat(
                (top_antecedents_label,
                 ~torch.sum(top_antecedents_label, dim=1).bool().view(-1, 1)),
                dim=1)

            gold_scores = top_antecedents_score + torch.log(
                top_antecedents_label.float())  # size: m * (k+1)
            marginalized_gold_scores = torch.logsumexp(gold_scores,
                                                       1)  # size: m
            log_norm = torch.logsumexp(top_antecedents_score, 1)  # size: m

            loss = torch.sum(log_norm - marginalized_gold_scores)

            optimizer.zero_grad()
            transformer_optim.zero_grad()
            loss.backward()
            optimizer.step()
            transformer_optim.step()

            accumulated_loss += loss.item()

            if steps % c["report_frequency"] == 0:
                total_time = time.time() - init_time
                print("[%d] loss=%.2f, steps/s=%.4f" %
                      (steps, accumulated_loss / c["report_frequency"],
                       steps / total_time))
                accumulated_loss = 0.0

            if steps % c["eval_frequency"] == 0:
                # 每 c["eval_frequency"] 轮保存一次
                coref_model.eval()
                transformer_model.eval()
                torch.save(
                    {
                        "model": coref_model.state_dict(),
                        "optimizer": optimizer.state_dict(),
                        "steps": steps
                    }, c["checkpoint_path"] + "." + str(steps))
                transformer_model.save_pretrained(c["checkpoint_path"] +
                                                  ".transformer." + str(steps))
                try:
                    p, r, f = coref_model.evaluate(tokenized_val_data,
                                                   transformer_model)
                    if f >= max_f1:
                        max_f1 = f
                        torch.save(
                            {
                                "model": coref_model.state_dict(),
                                "optimizer": optimizer.state_dict(),
                                "steps": steps
                            }, c["checkpoint_path"] + ".max")
                        transformer_model.save_pretrained(
                            c["checkpoint_path"] + ".transformer.max")
                    print(
                        "evaluation result:\np:%.4f,r:%.4f,f:%.4f(max f:%.4f)"
                        % (p, r, f, max_f1))
                except Exception as e:
                    print("Error: evaluation error:", e)
                coref_model.train()
                transformer_model.train()
                torch.cuda.empty_cache()