def train_one(config: Config, train_insts: List[Instance], dev_insts: List[Instance], model_name: str, test_insts: List[Instance] = None, config_name: str = None, result_filename: str = None) -> NNCRF: train_batches = batching_list_instances(config, train_insts) dev_batches = batching_list_instances(config, dev_insts) if test_insts: test_batches = simple_batching(config, test_insts) else: test_batches = None model = NNCRF(config) model.train() optimizer = get_optimizer(config, model) epoch = config.num_epochs best_dev_f1 = -1 saved_test_metrics = None for i in range(1, epoch + 1): epoch_loss = 0 start_time = time.time() model.zero_grad() if config.optimizer.lower() == "sgd": optimizer = lr_decay(config, optimizer, i) for index in np.random.permutation(len(train_batches)): model.train() loss = model(*train_batches[index]) epoch_loss += loss.item() loss.backward() optimizer.step() model.zero_grad() end_time = time.time() print("Epoch %d: %.5f, Time is %.2fs" % (i, epoch_loss, end_time - start_time), flush=True) model.eval() # metric is [precision, recall, f_score] dev_metrics = evaluate_model(config, model, "dev", dev_insts) if test_insts is not None: test_metrics = evaluate_model(config, model, "test", test_insts) if dev_metrics[2] > best_dev_f1: print("saving the best model...") best_dev_f1 = dev_metrics[2] if test_insts is not None: saved_test_metrics = test_metrics torch.save(model.state_dict(), model_name) # # Save the corresponding config as well. if config_name: f = open(config_name, 'wb') pickle.dump(config, f) f.close() if result_filename: write_results(result_filename, test_insts) model.zero_grad() if test_insts is not None: print(f"The best dev F1: {best_dev_f1}") print(f"The corresponding test: {saved_test_metrics}") return model
def batching_list_instances(config: Config, insts: List[Instance]): train_num = len(insts) batch_size = config.batch_size total_batch = train_num // batch_size + 1 if train_num % batch_size != 0 else train_num // batch_size batched_data = [] for batch_id in range(total_batch): one_batch_insts = insts[batch_id * batch_size:(batch_id + 1) * batch_size] batched_data.append(simple_batching(config, one_batch_insts)) return batched_data
def create_batch_data(self, insts: List[Instance]): return simple_batching(self.conf, insts)
def evaluate_model(config: Config, model: NNCRF, batch_insts_ids, name: str, insts: List[Instance]): ## evaluation tp, fp, tn, fn = 0, 0, 0, 0 # metrics, metrics_e2e = np.asarray([0, 0, 0], dtype=int), np.asarray([0, 0, 0], dtype=int) metrics, metrics_e2e = np.asarray([0, 0, 0], dtype=int), np.zeros( (1, 3), dtype=int) pair_metrics = np.asarray([0, 0, 0], dtype=int) batch_idx = 0 batch_size = config.batch_size # print('insts',len(insts)) for batch in batch_insts_ids: # print('batch_idx * batch_size:(batch_idx + 1) * batch_size', batch_idx* batch_size,(batch_idx + 1) * batch_size ) one_batch_insts = insts[batch_idx * batch_size:(batch_idx + 1) * batch_size] processed_batched_data = simple_batching(config, batch) # print(len(one_batch_insts)) batch_max_scores, batch_max_ids, pair_ids = model.decode( processed_batched_data) metrics += evaluate_batch_insts(one_batch_insts, batch_max_ids, processed_batched_data[-6], processed_batched_data[2], config.idx2labels) # print(processed_batched_data[-1]) metrics_e2e += evaluate_batch_insts_e2e( one_batch_insts, batch_max_ids, processed_batched_data[-6], processed_batched_data[2], config.idx2labels, processed_batched_data[-8], pair_ids, processed_batched_data[-1]) word_seq_lens = processed_batched_data[2].tolist() for batch_id in range(batch_max_ids.size()[0]): # print('batch_max_ids[batch_id]: ',batch_max_ids[batch_id].size(),batch_max_ids[batch_id]) length = word_seq_lens[batch_id] # prediction = batch_max_ids[batch_id][:length] # prediction = torch.flip(prediction,dims = [0]) gold = processed_batched_data[-6][batch_id][:length] # gold = torch.flip(gold, dims=[0]) # s_id = (prediction == 2).nonzero() # b_id = (prediction == 3).nonzero() # e_id = (prediction == 4).nonzero() # i_id = (prediction == 5).nonzero() # pred_id = torch.cat([s_id, b_id, e_id, i_id]).squeeze(1) # pred_id,_ = pred_id.sort(0, descending=False) # pred_id = pred_id[pred_id < processed_batched_data[-1][batch_id]] s_id = (gold == 2).nonzero() b_id = (gold == 3).nonzero() e_id = (gold == 4).nonzero() i_id = (gold == 5).nonzero() gold_id = torch.cat([s_id, b_id, e_id, i_id]).squeeze(1) gold_id, _ = gold_id.sort(0, descending=False) gold_id = gold_id[gold_id < processed_batched_data[-1][batch_id]] # argu_id = torch.LongTensor(list(set(gold_id.tolist()).intersection(set(pred_id.tolist())))) argu_id = torch.LongTensor(list(set(gold_id.tolist()))) # print('gold_id', gold_id, 'pred_id', pred_id, 'argu_id', argu_id) # print(pair_ids[batch_id].size(), batch[-3][batch_id].size()) one_batch_insts[batch_id].gold2 = processed_batched_data[-3][ batch_id].tolist() one_batch_insts[batch_id].pred2 = pair_ids[batch_id].squeeze( 2).tolist() # print(one_batch_insts[batch_id].gold2) # print(torch.sum(one_batch_insts[batch_id].pred2, dim=1)) # pred2 = one_batch_insts[batch_id].pred2[argu_id] pred2 = pair_ids[batch_id].squeeze(2) # gold2 = one_batch_insts[batch_id].gold2[argu_id] gold2 = processed_batched_data[-3][batch_id] # print('argu_id: ',argu_id.size(),argu_id) # print('one_batch_insts[batch_id].pred2: ',one_batch_insts[batch_id].pred2.size(),one_batch_insts[batch_id].pred2) gold_pairs = gold2.flatten() pred_pairs = pred2.flatten() # print(gold_pairs,pred_pairs) sum_table = gold_pairs + pred_pairs # print(sum_table.size(),sum_table[:100]) sum_table_sliced = sum_table[sum_table >= 0] # print(sum_table_sliced.size(),sum_table_sliced) tp_tmp = len(sum_table_sliced[sum_table_sliced == 2]) tn_tmp = len(sum_table_sliced[sum_table_sliced == 0]) tp += tp_tmp tn += tn_tmp ones = len(gold_pairs[gold_pairs == 1]) zeros = len(gold_pairs[gold_pairs == 0]) fp += (zeros - tn_tmp) fn += (ones - tp_tmp) # print(tp,tp_tmp,tn,tn_tmp,ones,zeros,fp,fn) batch_idx += 1 print('tp, fp, fn, tn: ', tp, fp, fn, tn) precision_2 = 1.0 * tp / (tp + fp) * 100 if tp + fp != 0 else 0 recall_2 = 1.0 * tp / (tp + fn) * 100 if tp + fn != 0 else 0 f1_2 = 2.0 * precision_2 * recall_2 / ( precision_2 + recall_2) if precision_2 + recall_2 != 0 else 0 acc = 1.0 * (tp + tn) / (fp + fn + tp + tn) * 100 if fp + fn + tp + tn != 0 else 0 p, total_predict, total_entity = metrics[0], metrics[1], metrics[2] precision = p * 1.0 / total_predict * 100 if total_predict != 0 else 0 recall = p * 1.0 / total_entity * 100 if total_entity != 0 else 0 fscore = 2.0 * precision * recall / ( precision + recall) if precision != 0 or recall != 0 else 0 p_e2e, total_predict_e2e, total_entity_e2e = metrics_e2e[:, 0], metrics_e2e[:, 1], metrics_e2e[:, 2] # precision_e2e = p_e2e * 1.0 / total_predict_e2e * 100 if total_predict_e2e != 0 else 0 # recall_e2e = p_e2e * 1.0 / total_entity_e2e * 100 if total_entity_e2e != 0 else 0 # fscore_e2e = 2.0 * precision_e2e * recall_e2e / (precision_e2e + recall_e2e) if precision_e2e != 0 or recall_e2e != 0 else 0 total_predict_e2e[total_predict_e2e == 0] = sys.maxsize total_entity_e2e[total_entity_e2e == 0] = sys.maxsize precision_e2e = p_e2e * 1.0 / total_predict_e2e * 100 recall_e2e = p_e2e * 1.0 / total_entity_e2e * 100 sum_e2e = precision_e2e + recall_e2e sum_e2e[sum_e2e == 0] = sys.maxsize fscore_e2e = 2.0 * precision_e2e * recall_e2e / sum_e2e print("Task1: ", p, total_predict, total_entity) # print("Overall: ", p_e2e, total_predict_e2e, total_entity_e2e) print("Task1: [%s set] Precision: %.2f, Recall: %.2f, F1: %.2f" % (name, precision, recall, fscore), flush=True) print( "Task2: [%s set] Precision: %.2f, Recall: %.2f, F1: %.2f, acc: %.2f" % (name, precision_2, recall_2, f1_2, acc), flush=True) percs = [0.9] #percs = [0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9] for i in range(len(percs)): print("Overall ", percs[i], ": [%s set] Precision: %.2f, Recall: %.2f, F1: %.2f" % (name, precision_e2e[i], recall_e2e[i], fscore_e2e[i]), flush=True) return [ precision, recall, fscore, precision_2, recall_2, f1_2, acc, precision_e2e, recall_e2e, fscore_e2e ]
def train_model(config: Config, epoch: int, train_insts: List[Instance], dev_insts: List[Instance], test_insts: List[Instance]): model = NNCRF(config) num_param = 0 for idx in list(model.parameters()): try: num_param += idx.size()[0] * idx.size()[1] except: num_param += idx.size()[0] print(num_param) optimizer = get_optimizer(config, model) train_num = len(train_insts) print("number of instances: %d" % (train_num)) print(colored("[Shuffled] Shuffle the training instance ids", "red")) random.shuffle(train_insts) batched_data = batching_list_instances(config, train_insts) dev_batches = batching_list_instances(config, dev_insts) test_batches = batching_list_instances(config, test_insts) best_dev = [-1, 0, -1] best_test = [-1, 0, -1] model_folder = config.model_folder res_folder = "results" if os.path.exists("model_files/" + model_folder): raise FileExistsError( f"The folder model_files/{model_folder} exists. Please either delete it or create a new one " f"to avoid override.") model_path = f"model_files/{model_folder}/lstm_crf.m" config_path = f"model_files/{model_folder}/config.conf" res_path = f"{res_folder}/{model_folder}.results" print("[Info] The model will be saved to: %s.tar.gz" % (model_folder)) os.makedirs(f"model_files/{model_folder}", exist_ok=True) ## create model files. not raise error if exist os.makedirs(res_folder, exist_ok=True) no_incre_dev = 0 for i in tqdm(range(1, epoch + 1), desc="Epoch"): epoch_loss = 0 start_time = time.time() model.zero_grad() if config.optimizer.lower() == "sgd": optimizer = lr_decay(config, optimizer, i) for index in np.random.permutation(len(batched_data)): processed_batched_data = simple_batching(config, batched_data[index]) model.train() loss = model(*processed_batched_data) epoch_loss += loss.item() loss.backward() optimizer.step() model.zero_grad() end_time = time.time() print("Epoch %d: %.5f, Time is %.2fs" % (i, epoch_loss, end_time - start_time), flush=True) model.eval() dev_metrics = evaluate_model(config, model, dev_batches, "dev", dev_insts) test_metrics = evaluate_model(config, model, test_batches, "test", test_insts) # print(test_insts.prediction) # if dev_metrics[2] > best_dev[0] or (dev_metrics[2] == best_dev[0] and dev_metrics[-1] > best_dev[-1]): # task 1 & task 2 if np.max(dev_metrics[-1]) > best_dev[-1]: # task 2 # if dev_metrics[2] > best_dev[0]: # task 1 print("saving the best model...") no_incre_dev = 0 best_dev[0] = dev_metrics[2] best_dev[-1] = np.max(dev_metrics[-1]) best_dev[1] = i best_test[0] = test_metrics[2] best_test[-1] = np.max(test_metrics[-1]) best_test[1] = i torch.save(model.state_dict(), model_path) # Save the corresponding config as well. f = open(config_path, 'wb') pickle.dump(config, f) f.close() write_results(res_path, test_insts) else: no_incre_dev += 1 model.zero_grad() if no_incre_dev >= config.max_no_incre: print( "early stop because there are %d epochs not increasing f1 on dev" % no_incre_dev) break print("Archiving the best Model...") with tarfile.open(f"model_files/{model_folder}/{model_folder}.tar.gz", "w:gz") as tar: tar.add(f"model_files/{model_folder}", arcname=os.path.basename(model_folder)) print("Finished archiving the models") print("The best dev: %.2f" % (best_dev[0])) print("The corresponding test: %.2f" % (best_test[0])) print("Final testing.") model.load_state_dict(torch.load(model_path)) model.eval() evaluate_model(config, model, test_batches, "test", test_insts) write_results(res_path, test_insts)