def train(infile): if resume_train==False: encode_model = AlbertModel.from_pretrained(pretrained).to(device) model=BertBidaf(tokenizer,encode_model,device) else: model=torch.load("dureder_model") model=model.to(device) data_gen=DataGenerator(infile,tokenizer,device) optimizer = optim.Adam(model.parameters(),lr=learning_rate) criterion = nn.CrossEntropyLoss() model.train() for e in range(epoch): batch_cnt=0 for batch in data_gen.batchIter(train_batch_size): batch_cnt+=1 batch_pair_ids=batch["batch_pair_ids"] batch_token_type_ids=batch["batch_token_type_ids"] batch_attention_mask=batch["batch_attention_mask"] batch_start=batch["batch_start"] batch_end=batch["batch_end"] batch_context_len=batch["batch_context_len"] batch_question_len=batch["batch_question_len"] p1,p2=model(batch_pair_ids,batch_token_type_ids,batch_attention_mask,batch_context_len,batch_question_len) optimizer.zero_grad() batch_loss = criterion(p1, batch_start) + criterion(p2, batch_end) print(e,batch_cnt*train_batch_size,batch_loss.item()) # print(get_batch_predict_answer(batch_pair_ids, p1, p2)) # print(batch_ans) batch_loss.backward() optimizer.step() if(batch_cnt % 20 == 0): torch.save(model,'dureder_model') torch.save(model,'dureder_model')
def evaluation(infile, tokenizer, model, device, epoch): slot_turn_acc, joint_acc, slot_F1_pred, slot_F1_count = 0, 0, 0, 0 len_test_data = 0 model.eval() data_gen = DataGenerator(infile, tokenizer, device) wall_times = [] for batch in data_gen.batchIter(eval_batch_size): batch_content_ids = batch["batch_content_ids"] batch_token_type_ids = batch["batch_token_type_ids"] batch_attention_mask = batch["batch_attention_mask"] batch_gold_state = batch["batch_gold_state"] start = time.perf_counter() with torch.no_grad(): domain_score, slot_pointer_prob, slot_gate_prob, slot_pointer, slot_gate, start_prob, end_prob = model( batch_content_ids, batch_token_type_ids, batch_attention_mask) state_list = get_state(batch_content_ids, slot_pointer, slot_gate, start_prob, end_prob) end = time.perf_counter() wall_times.append(end - start) for pred_state, gold_state in zip(state_list, batch_gold_state): if set(pred_state) == set(gold_state): joint_acc += 1 len_test_data += 1 # Compute prediction slot accuracy temp_acc = compute_acc(set(gold_state), set(pred_state), SLOT) slot_turn_acc += temp_acc # Compute prediction F1 score temp_f1, temp_r, temp_p, count = compute_prf( gold_state, pred_state) slot_F1_pred += temp_f1 slot_F1_count += count joint_acc_score = joint_acc / len_test_data turn_acc_score = slot_turn_acc / len_test_data slot_F1_score = slot_F1_pred / slot_F1_count latency = np.mean(wall_times) * 1000 print("------------------------------") print("Epoch %d joint accuracy : " % epoch, joint_acc_score) print("Epoch %d slot turn accuracy : " % epoch, turn_acc_score) print("Epoch %d slot turn F1: " % epoch, slot_F1_score) print("Latency Per Prediction : %f ms" % latency) print("-----------------------------\n") scores = { 'epoch': epoch, 'joint_acc': joint_acc_score, 'slot_acc': turn_acc_score, 'slot_f1': slot_F1_score } return scores
def train(infile, test_file): encode_model = BertModel.from_pretrained(pretrained) model = BertDST(tokenizer, encode_model, device) model = model.to(device) data_gen = DataGenerator(infile, tokenizer, device) optimizer = optim.Adam(model.parameters(), lr=learning_rate) crossEntropyLoss = nn.CrossEntropyLoss() mseLoss = torch.nn.MSELoss(reduce=True, size_average=True) bceLoss = nn.BCELoss() model.train() for e in range(epoch): batch_cnt = 0 for batch in data_gen.batchIter(train_batch_size): batch_cnt += 1 batch_content_ids = batch["batch_content_ids"] batch_token_type_ids = batch["batch_token_type_ids"] batch_attention_mask = batch["batch_attention_mask"] batch_domain_id = batch["batch_domain_id"] # batch_domain_id [batch,domain] batch_slot_pointer_id = batch["batch_slot_pointer_id"] # batch_slot_pointer_id [batch,slot] batch_slot_gate_id = batch["batch_slot_gate_id"] # batch_slot_gate_id [batch,max_slot] batch_value_start = batch["batch_value_start"] # batch_value_start [batch,max_predict] batch_value_end = batch["batch_value_end"] # batch_value_start [batch,max_predict] batch_gold_state = batch["batch_gold_state"] domain_score, slot_pointer_prob, slot_gate_prob, slot_pointer, slot_gate, start_prob, end_prob = model( batch_content_ids, batch_token_type_ids, batch_attention_mask, batch_slot_pointer_id, batch_slot_gate_id) optimizer.zero_grad() batch_loss = crossEntropyLoss(domain_score, batch_domain_id) #batch_loss+=mseLoss(slot_pointer_prob,batch_slot_pointer_id.float()) batch_loss += bceLoss(slot_pointer_prob, batch_slot_pointer_id.float()) if (batch_slot_gate_id.size(1) > 0): slot_gate_loss = masked_cross_entropy(slot_gate_prob, batch_slot_gate_id, slotgate2id["pad"]) batch_loss += slot_gate_loss if (start_prob.size(1) > 0): start_loss = masked_cross_entropy(start_prob, batch_value_start, 0) end_loss = masked_cross_entropy(end_prob, batch_value_end, 0) batch_loss += start_loss batch_loss += end_loss a = get_state(batch_content_ids, batch_slot_pointer_id, batch_slot_gate_id, start_prob, end_prob) b = batch_gold_state a = [sorted(x) for x in a] b = [sorted(x) for x in b] print(a) print(b) print(e, batch_cnt * train_batch_size, batch_loss.item()) print("=========================================") batch_loss.backward() optimizer.step() evaluation(test_file, tokenizer, model, device, e) torch.save(model, 'model_save')