コード例 #1
0
ファイル: main.py プロジェクト: byecc/pytorch_ner
def train(data):
    print("Training model...")
    model = SeqModel(data)
    if data.use_cuda:
        model.cuda()
    parameters = filter(lambda p: p.requires_grad, model.parameters())
    if data.optimizer == "SGD":
        optimizer = optim.SGD(parameters,
                              lr=data.lr,
                              momentum=data.momentum,
                              weight_decay=data.weight_decay)
    elif data.optimizer == "Adagrad":
        optimizer = optim.Adagrad(parameters,
                                  lr=data.lr,
                                  weight_decay=data.weight_decay)
    elif data.optimizer == "Adam":
        optimizer = optim.Adam(parameters,
                               lr=data.lr,
                               weight_decay=data.weight_decay)
    else:
        print("Optimizer Error: {} optimizer is not support.".format(
            data.optimizer))
    for idx in range(data.iter):
        epoch_start = temp_start = time.time()
        train_num = len(data.train_text)
        if train_num % data.batch_size == 0:
            batch_block = train_num // data.batch_size
        else:
            batch_block = train_num // data.batch_size + 1
        correct = total = total_loss = 0
        random.shuffle(data.train_idx)
        for block in range(batch_block):
            left = block * data.batch_size
            right = left + data.batch_size
            if right > train_num:
                right = train_num
            instance = data.train_idx[left:right]
            batch_word, batch_word_len, word_recover, batch_char, batch_char_len, char_recover, batch_label, mask = generate_batch(
                instance, data.use_cuda)
            loss, seq = model.forward(batch_word, batch_word_len, batch_char,
                                      batch_char_len, char_recover,
                                      batch_label, mask)
            right_token, total_token = predict_check(seq, batch_label, mask)
            correct += right_token
            total += total_token
            total_loss += loss.data[0]
            loss.backward()
            optimizer.step()
            model.zero_grad()
        epoch_end = time.time()
        print("Epoch:{}. Time:{}. Loss:{}. acc:{}".format(
            idx, epoch_end - epoch_start, total_loss, correct / total))
        evaluate(data, model, "dev", idx)
        evaluate(data, model, "test", idx)
        print("Finish.")
コード例 #2
0
ファイル: main.py プロジェクト: byecc/pytorch_ner
def train(data):
    print("Training model...")
    data.show_data_summary()
    model = SeqModel(data)
    if data.use_cuda:
        model.cuda()
    parameters = filter(lambda p: p.requires_grad, model.parameters())
    if data.optimizer == "SGD":
        optimizer = optim.SGD(parameters,
                              lr=data.lr,
                              momentum=data.momentum,
                              weight_decay=data.weight_decay)
    elif data.optimizer == "Adagrad":
        optimizer = optim.Adagrad(parameters,
                                  lr=data.lr,
                                  weight_decay=data.weight_decay)
    elif data.optimizer == "Adam":
        optimizer = optim.Adam(parameters,
                               lr=data.lr,
                               weight_decay=data.weight_decay)
    else:
        print("Optimizer Error: {} optimizer is not support.".format(
            data.optimizer))
    # if data.out_dict:
    #     external_dict(data.out_dict, data.train_file)
    #     external_dict(data.out_dict, data.dev_file)
    #     external = external_dict(data.out_dict, data.test_file)
    #     external_dict(data.out_dict,data.oov_file)
    #     #print(len(external))
    #     #with open('../data/ali_7k/external_dict','w',encoding='utf-8') as fout:
    #     #    for e in external:
    #     #        fout.write(e+'\n')
    for idx in range(data.iter):
        # data.mask_entity(data.iter, idx)
        epoch_start = temp_start = time.time()
        # if data.optimizer == "SGD":
        #    optimizer = lr_decay(optimizer,idx,data.lr_decay,data.lr)
        train_num = len(data.train_text)
        if train_num % data.batch_size == 0:
            batch_block = train_num // data.batch_size
        else:
            batch_block = train_num // data.batch_size + 1
        correct = total = total_loss = 0
        random.shuffle(data.train_idx)
        model.train()
        model.zero_grad()
        for block in range(batch_block):
            left = block * data.batch_size
            right = left + data.batch_size
            if right > train_num:
                right = train_num
            instance = data.train_idx[left:right]
            batch_word, batch_feat, batch_word_len, word_recover, batch_char, batch_char_len, char_recover, batch_label, mask, batch_bert \
                = generate_batch(instance, data.use_cuda, data.use_bert)
            batch_dict = None
            if data.out_dict:
                batch_dict = generate_dict_feature(instance, data,
                                                   data.use_cuda)
            loss, seq = model.neg_log_likehood(batch_word, batch_feat,
                                               batch_word_len, batch_char,
                                               batch_char_len, char_recover,
                                               batch_label, mask, batch_dict,
                                               batch_bert)
            right_token, total_token = predict_check(seq, batch_label, mask)
            correct += right_token
            total += total_token
            total_loss += loss.item()
            loss.backward()
            optimizer.step()
            model.zero_grad()
        epoch_end = time.time()
        print("Epoch:{}. Time:{}. Loss:{}. acc:{}".format(
            idx, epoch_end - epoch_start, total_loss, correct / total))
        # torch.save(model,data.model_save_dir+'/model'+str(idx)+'.pkl')
        evaluate(data, model, "dev", idx)
        evaluate(data, model, "test", idx)
        if data.oov_file is not None:
            evaluate(data, model, "oov", idx)
        print("Finish.")