def train_one(config: Config,
              train_insts: List[Instance],
              dev_insts: List[Instance],
              model_name: str,
              test_insts: List[Instance] = None,
              config_name: str = None,
              result_filename: str = None) -> NNCRF:
    train_batches = batching_list_instances(config, train_insts)
    dev_batches = batching_list_instances(config, dev_insts)
    if test_insts:
        test_batches = simple_batching(config, test_insts)
    else:
        test_batches = None
    model = NNCRF(config)
    model.train()
    optimizer = get_optimizer(config, model)
    epoch = config.num_epochs
    best_dev_f1 = -1
    saved_test_metrics = None
    for i in range(1, epoch + 1):
        epoch_loss = 0
        start_time = time.time()
        model.zero_grad()
        if config.optimizer.lower() == "sgd":
            optimizer = lr_decay(config, optimizer, i)
        for index in np.random.permutation(len(train_batches)):
            model.train()
            loss = model(*train_batches[index])
            epoch_loss += loss.item()
            loss.backward()
            optimizer.step()
            model.zero_grad()
        end_time = time.time()
        print("Epoch %d: %.5f, Time is %.2fs" %
              (i, epoch_loss, end_time - start_time),
              flush=True)

        model.eval()
        # metric is [precision, recall, f_score]
        dev_metrics = evaluate_model(config, model, "dev", dev_insts)
        if test_insts is not None:
            test_metrics = evaluate_model(config, model, "test", test_insts)
        if dev_metrics[2] > best_dev_f1:
            print("saving the best model...")
            best_dev_f1 = dev_metrics[2]
            if test_insts is not None:
                saved_test_metrics = test_metrics
            torch.save(model.state_dict(), model_name)
            # # Save the corresponding config as well.
            if config_name:
                f = open(config_name, 'wb')
                pickle.dump(config, f)
                f.close()
            if result_filename:
                write_results(result_filename, test_insts)
        model.zero_grad()
    if test_insts is not None:
        print(f"The best dev F1: {best_dev_f1}")
        print(f"The corresponding test: {saved_test_metrics}")
    return model
def batching_list_instances(config: Config, insts: List[Instance]):
    train_num = len(insts)
    batch_size = config.batch_size
    total_batch = train_num // batch_size + 1 if train_num % batch_size != 0 else train_num // batch_size
    batched_data = []
    for batch_id in range(total_batch):
        one_batch_insts = insts[batch_id * batch_size:(batch_id + 1) *
                                batch_size]
        batched_data.append(simple_batching(config, one_batch_insts))

    return batched_data
Beispiel #3
0
 def create_batch_data(self, insts: List[Instance]):
     return simple_batching(self.conf, insts)
Beispiel #4
0
def evaluate_model(config: Config, model: NNCRF, batch_insts_ids, name: str,
                   insts: List[Instance]):
    ## evaluation
    tp, fp, tn, fn = 0, 0, 0, 0
    # metrics, metrics_e2e = np.asarray([0, 0, 0], dtype=int), np.asarray([0, 0, 0], dtype=int)
    metrics, metrics_e2e = np.asarray([0, 0, 0], dtype=int), np.zeros(
        (1, 3), dtype=int)
    pair_metrics = np.asarray([0, 0, 0], dtype=int)
    batch_idx = 0
    batch_size = config.batch_size
    # print('insts',len(insts))
    for batch in batch_insts_ids:
        # print('batch_idx * batch_size:(batch_idx + 1) * batch_size', batch_idx* batch_size,(batch_idx + 1) * batch_size )
        one_batch_insts = insts[batch_idx * batch_size:(batch_idx + 1) *
                                batch_size]

        processed_batched_data = simple_batching(config, batch)
        # print(len(one_batch_insts))
        batch_max_scores, batch_max_ids, pair_ids = model.decode(
            processed_batched_data)

        metrics += evaluate_batch_insts(one_batch_insts, batch_max_ids,
                                        processed_batched_data[-6],
                                        processed_batched_data[2],
                                        config.idx2labels)
        # print(processed_batched_data[-1])
        metrics_e2e += evaluate_batch_insts_e2e(
            one_batch_insts, batch_max_ids, processed_batched_data[-6],
            processed_batched_data[2], config.idx2labels,
            processed_batched_data[-8], pair_ids, processed_batched_data[-1])

        word_seq_lens = processed_batched_data[2].tolist()
        for batch_id in range(batch_max_ids.size()[0]):
            # print('batch_max_ids[batch_id]:  ',batch_max_ids[batch_id].size(),batch_max_ids[batch_id])
            length = word_seq_lens[batch_id]
            # prediction = batch_max_ids[batch_id][:length]
            # prediction = torch.flip(prediction,dims = [0])

            gold = processed_batched_data[-6][batch_id][:length]
            # gold = torch.flip(gold, dims=[0])

            # s_id = (prediction == 2).nonzero()
            # b_id = (prediction == 3).nonzero()
            # e_id = (prediction == 4).nonzero()
            # i_id = (prediction == 5).nonzero()
            # pred_id = torch.cat([s_id, b_id, e_id, i_id]).squeeze(1)
            # pred_id,_ = pred_id.sort(0, descending=False)
            # pred_id = pred_id[pred_id < processed_batched_data[-1][batch_id]]

            s_id = (gold == 2).nonzero()
            b_id = (gold == 3).nonzero()
            e_id = (gold == 4).nonzero()
            i_id = (gold == 5).nonzero()
            gold_id = torch.cat([s_id, b_id, e_id, i_id]).squeeze(1)
            gold_id, _ = gold_id.sort(0, descending=False)
            gold_id = gold_id[gold_id < processed_batched_data[-1][batch_id]]

            # argu_id = torch.LongTensor(list(set(gold_id.tolist()).intersection(set(pred_id.tolist()))))
            argu_id = torch.LongTensor(list(set(gold_id.tolist())))
            # print('gold_id', gold_id, 'pred_id', pred_id, 'argu_id', argu_id)

            # print(pair_ids[batch_id].size(), batch[-3][batch_id].size())
            one_batch_insts[batch_id].gold2 = processed_batched_data[-3][
                batch_id].tolist()
            one_batch_insts[batch_id].pred2 = pair_ids[batch_id].squeeze(
                2).tolist()

            # print(one_batch_insts[batch_id].gold2)
            # print(torch.sum(one_batch_insts[batch_id].pred2, dim=1))

            # pred2 = one_batch_insts[batch_id].pred2[argu_id]
            pred2 = pair_ids[batch_id].squeeze(2)
            # gold2 = one_batch_insts[batch_id].gold2[argu_id]
            gold2 = processed_batched_data[-3][batch_id]

            # print('argu_id:  ',argu_id.size(),argu_id)
            # print('one_batch_insts[batch_id].pred2:  ',one_batch_insts[batch_id].pred2.size(),one_batch_insts[batch_id].pred2)

            gold_pairs = gold2.flatten()
            pred_pairs = pred2.flatten()

            # print(gold_pairs,pred_pairs)
            sum_table = gold_pairs + pred_pairs
            # print(sum_table.size(),sum_table[:100])
            sum_table_sliced = sum_table[sum_table >= 0]
            # print(sum_table_sliced.size(),sum_table_sliced)
            tp_tmp = len(sum_table_sliced[sum_table_sliced == 2])
            tn_tmp = len(sum_table_sliced[sum_table_sliced == 0])
            tp += tp_tmp
            tn += tn_tmp
            ones = len(gold_pairs[gold_pairs == 1])
            zeros = len(gold_pairs[gold_pairs == 0])
            fp += (zeros - tn_tmp)
            fn += (ones - tp_tmp)
            # print(tp,tp_tmp,tn,tn_tmp,ones,zeros,fp,fn)

        batch_idx += 1
    print('tp, fp, fn, tn: ', tp, fp, fn, tn)
    precision_2 = 1.0 * tp / (tp + fp) * 100 if tp + fp != 0 else 0
    recall_2 = 1.0 * tp / (tp + fn) * 100 if tp + fn != 0 else 0
    f1_2 = 2.0 * precision_2 * recall_2 / (
        precision_2 + recall_2) if precision_2 + recall_2 != 0 else 0
    acc = 1.0 * (tp + tn) / (fp + fn + tp +
                             tn) * 100 if fp + fn + tp + tn != 0 else 0
    p, total_predict, total_entity = metrics[0], metrics[1], metrics[2]
    precision = p * 1.0 / total_predict * 100 if total_predict != 0 else 0
    recall = p * 1.0 / total_entity * 100 if total_entity != 0 else 0
    fscore = 2.0 * precision * recall / (
        precision + recall) if precision != 0 or recall != 0 else 0

    p_e2e, total_predict_e2e, total_entity_e2e = metrics_e2e[:,
                                                             0], metrics_e2e[:,
                                                                             1], metrics_e2e[:,
                                                                                             2]
    # precision_e2e = p_e2e * 1.0 / total_predict_e2e * 100 if total_predict_e2e != 0 else 0
    # recall_e2e = p_e2e * 1.0 / total_entity_e2e * 100 if total_entity_e2e != 0 else 0
    # fscore_e2e = 2.0 * precision_e2e * recall_e2e / (precision_e2e + recall_e2e) if precision_e2e != 0 or recall_e2e != 0 else 0
    total_predict_e2e[total_predict_e2e == 0] = sys.maxsize
    total_entity_e2e[total_entity_e2e == 0] = sys.maxsize

    precision_e2e = p_e2e * 1.0 / total_predict_e2e * 100
    recall_e2e = p_e2e * 1.0 / total_entity_e2e * 100

    sum_e2e = precision_e2e + recall_e2e
    sum_e2e[sum_e2e == 0] = sys.maxsize
    fscore_e2e = 2.0 * precision_e2e * recall_e2e / sum_e2e

    print("Task1: ", p, total_predict, total_entity)
    # print("Overall: ", p_e2e, total_predict_e2e, total_entity_e2e)

    print("Task1: [%s set] Precision: %.2f, Recall: %.2f, F1: %.2f" %
          (name, precision, recall, fscore),
          flush=True)
    print(
        "Task2: [%s set] Precision: %.2f, Recall: %.2f, F1: %.2f, acc: %.2f" %
        (name, precision_2, recall_2, f1_2, acc),
        flush=True)
    percs = [0.9]
    #percs = [0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9]
    for i in range(len(percs)):
        print("Overall ",
              percs[i],
              ": [%s set] Precision: %.2f, Recall: %.2f, F1: %.2f" %
              (name, precision_e2e[i], recall_e2e[i], fscore_e2e[i]),
              flush=True)
    return [
        precision, recall, fscore, precision_2, recall_2, f1_2, acc,
        precision_e2e, recall_e2e, fscore_e2e
    ]
Beispiel #5
0
def train_model(config: Config, epoch: int, train_insts: List[Instance],
                dev_insts: List[Instance], test_insts: List[Instance]):
    model = NNCRF(config)
    num_param = 0
    for idx in list(model.parameters()):
        try:
            num_param += idx.size()[0] * idx.size()[1]
        except:
            num_param += idx.size()[0]
    print(num_param)
    optimizer = get_optimizer(config, model)
    train_num = len(train_insts)
    print("number of instances: %d" % (train_num))
    print(colored("[Shuffled] Shuffle the training instance ids", "red"))
    random.shuffle(train_insts)

    batched_data = batching_list_instances(config, train_insts)
    dev_batches = batching_list_instances(config, dev_insts)
    test_batches = batching_list_instances(config, test_insts)

    best_dev = [-1, 0, -1]
    best_test = [-1, 0, -1]

    model_folder = config.model_folder
    res_folder = "results"
    if os.path.exists("model_files/" + model_folder):
        raise FileExistsError(
            f"The folder model_files/{model_folder} exists. Please either delete it or create a new one "
            f"to avoid override.")
    model_path = f"model_files/{model_folder}/lstm_crf.m"
    config_path = f"model_files/{model_folder}/config.conf"
    res_path = f"{res_folder}/{model_folder}.results"
    print("[Info] The model will be saved to: %s.tar.gz" % (model_folder))
    os.makedirs(f"model_files/{model_folder}",
                exist_ok=True)  ## create model files. not raise error if exist
    os.makedirs(res_folder, exist_ok=True)
    no_incre_dev = 0
    for i in tqdm(range(1, epoch + 1), desc="Epoch"):
        epoch_loss = 0
        start_time = time.time()
        model.zero_grad()
        if config.optimizer.lower() == "sgd":
            optimizer = lr_decay(config, optimizer, i)
        for index in np.random.permutation(len(batched_data)):
            processed_batched_data = simple_batching(config,
                                                     batched_data[index])
            model.train()
            loss = model(*processed_batched_data)
            epoch_loss += loss.item()
            loss.backward()
            optimizer.step()
            model.zero_grad()

        end_time = time.time()
        print("Epoch %d: %.5f, Time is %.2fs" %
              (i, epoch_loss, end_time - start_time),
              flush=True)

        model.eval()
        dev_metrics = evaluate_model(config, model, dev_batches, "dev",
                                     dev_insts)
        test_metrics = evaluate_model(config, model, test_batches, "test",
                                      test_insts)
        # print(test_insts.prediction)
        # if dev_metrics[2] > best_dev[0] or (dev_metrics[2] == best_dev[0] and dev_metrics[-1] > best_dev[-1]): # task 1 & task 2
        if np.max(dev_metrics[-1]) > best_dev[-1]:  # task 2
            # if dev_metrics[2] > best_dev[0]: # task 1
            print("saving the best model...")
            no_incre_dev = 0
            best_dev[0] = dev_metrics[2]
            best_dev[-1] = np.max(dev_metrics[-1])
            best_dev[1] = i
            best_test[0] = test_metrics[2]
            best_test[-1] = np.max(test_metrics[-1])
            best_test[1] = i
            torch.save(model.state_dict(), model_path)
            # Save the corresponding config as well.
            f = open(config_path, 'wb')
            pickle.dump(config, f)
            f.close()
            write_results(res_path, test_insts)
        else:
            no_incre_dev += 1
        model.zero_grad()
        if no_incre_dev >= config.max_no_incre:
            print(
                "early stop because there are %d epochs not increasing f1 on dev"
                % no_incre_dev)
            break

    print("Archiving the best Model...")
    with tarfile.open(f"model_files/{model_folder}/{model_folder}.tar.gz",
                      "w:gz") as tar:
        tar.add(f"model_files/{model_folder}",
                arcname=os.path.basename(model_folder))

    print("Finished archiving the models")

    print("The best dev: %.2f" % (best_dev[0]))
    print("The corresponding test: %.2f" % (best_test[0]))
    print("Final testing.")
    model.load_state_dict(torch.load(model_path))
    model.eval()
    evaluate_model(config, model, test_batches, "test", test_insts)
    write_results(res_path, test_insts)