def train(data): print("Training model...") model = SeqModel(data) if data.use_cuda: model.cuda() parameters = filter(lambda p: p.requires_grad, model.parameters()) if data.optimizer == "SGD": optimizer = optim.SGD(parameters, lr=data.lr, momentum=data.momentum, weight_decay=data.weight_decay) elif data.optimizer == "Adagrad": optimizer = optim.Adagrad(parameters, lr=data.lr, weight_decay=data.weight_decay) elif data.optimizer == "Adam": optimizer = optim.Adam(parameters, lr=data.lr, weight_decay=data.weight_decay) else: print("Optimizer Error: {} optimizer is not support.".format( data.optimizer)) for idx in range(data.iter): epoch_start = temp_start = time.time() train_num = len(data.train_text) if train_num % data.batch_size == 0: batch_block = train_num // data.batch_size else: batch_block = train_num // data.batch_size + 1 correct = total = total_loss = 0 random.shuffle(data.train_idx) for block in range(batch_block): left = block * data.batch_size right = left + data.batch_size if right > train_num: right = train_num instance = data.train_idx[left:right] batch_word, batch_word_len, word_recover, batch_char, batch_char_len, char_recover, batch_label, mask = generate_batch( instance, data.use_cuda) loss, seq = model.forward(batch_word, batch_word_len, batch_char, batch_char_len, char_recover, batch_label, mask) right_token, total_token = predict_check(seq, batch_label, mask) correct += right_token total += total_token total_loss += loss.data[0] loss.backward() optimizer.step() model.zero_grad() epoch_end = time.time() print("Epoch:{}. Time:{}. Loss:{}. acc:{}".format( idx, epoch_end - epoch_start, total_loss, correct / total)) evaluate(data, model, "dev", idx) evaluate(data, model, "test", idx) print("Finish.")
def train(data): print("Training model...") data.show_data_summary() model = SeqModel(data) if data.use_cuda: model.cuda() parameters = filter(lambda p: p.requires_grad, model.parameters()) if data.optimizer == "SGD": optimizer = optim.SGD(parameters, lr=data.lr, momentum=data.momentum, weight_decay=data.weight_decay) elif data.optimizer == "Adagrad": optimizer = optim.Adagrad(parameters, lr=data.lr, weight_decay=data.weight_decay) elif data.optimizer == "Adam": optimizer = optim.Adam(parameters, lr=data.lr, weight_decay=data.weight_decay) else: print("Optimizer Error: {} optimizer is not support.".format( data.optimizer)) # if data.out_dict: # external_dict(data.out_dict, data.train_file) # external_dict(data.out_dict, data.dev_file) # external = external_dict(data.out_dict, data.test_file) # external_dict(data.out_dict,data.oov_file) # #print(len(external)) # #with open('../data/ali_7k/external_dict','w',encoding='utf-8') as fout: # # for e in external: # # fout.write(e+'\n') for idx in range(data.iter): # data.mask_entity(data.iter, idx) epoch_start = temp_start = time.time() # if data.optimizer == "SGD": # optimizer = lr_decay(optimizer,idx,data.lr_decay,data.lr) train_num = len(data.train_text) if train_num % data.batch_size == 0: batch_block = train_num // data.batch_size else: batch_block = train_num // data.batch_size + 1 correct = total = total_loss = 0 random.shuffle(data.train_idx) model.train() model.zero_grad() for block in range(batch_block): left = block * data.batch_size right = left + data.batch_size if right > train_num: right = train_num instance = data.train_idx[left:right] batch_word, batch_feat, batch_word_len, word_recover, batch_char, batch_char_len, char_recover, batch_label, mask, batch_bert \ = generate_batch(instance, data.use_cuda, data.use_bert) batch_dict = None if data.out_dict: batch_dict = generate_dict_feature(instance, data, data.use_cuda) loss, seq = model.neg_log_likehood(batch_word, batch_feat, batch_word_len, batch_char, batch_char_len, char_recover, batch_label, mask, batch_dict, batch_bert) right_token, total_token = predict_check(seq, batch_label, mask) correct += right_token total += total_token total_loss += loss.item() loss.backward() optimizer.step() model.zero_grad() epoch_end = time.time() print("Epoch:{}. Time:{}. Loss:{}. acc:{}".format( idx, epoch_end - epoch_start, total_loss, correct / total)) # torch.save(model,data.model_save_dir+'/model'+str(idx)+'.pkl') evaluate(data, model, "dev", idx) evaluate(data, model, "test", idx) if data.oov_file is not None: evaluate(data, model, "oov", idx) print("Finish.")