Exemplo n.º 1
0
def do_eval(dataset=None, network=None, num_class=2, assessment_method="accuracy", load_checkpoint_path=""):
    """ do eval """
    if load_checkpoint_path == "":
        raise ValueError("Finetune model missed, evaluation task must load finetune model!")
    net_for_pretraining = network(bert_net_cfg, False, num_class)
    net_for_pretraining.set_train(False)
    param_dict = load_checkpoint(load_checkpoint_path)
    load_param_into_net(net_for_pretraining, param_dict)
    model = Model(net_for_pretraining)

    if assessment_method == "accuracy":
        callback = Accuracy()
    elif assessment_method == "f1":
        callback = F1(False, num_class)
    elif assessment_method == "mcc":
        callback = MCC()
    elif assessment_method == "spearman_correlation":
        callback = Spearman_Correlation()
    else:
        raise ValueError("Assessment method not supported, support: [accuracy, f1, mcc, spearman_correlation]")

    columns_list = ["input_ids", "input_mask", "segment_ids", "label_ids"]
    for data in dataset.create_dict_iterator():
        input_data = []
        for i in columns_list:
            input_data.append(Tensor(data[i]))
        input_ids, input_mask, token_type_id, label_ids = input_data
        logits = model.predict(input_ids, input_mask, token_type_id, label_ids)
        callback.update(logits, label_ids)
    print("==============================================================")
    eval_result_print(assessment_method, callback)
    print("==============================================================")
Exemplo n.º 2
0
def do_eval(dataset=None,
            network=None,
            use_crf="",
            num_class=2,
            assessment_method="accuracy",
            data_file="",
            load_checkpoint_path="",
            vocab_file="",
            label2id_file="",
            tag_to_index=None):
    """ do eval """
    if load_checkpoint_path == "":
        raise ValueError(
            "Finetune model missed, evaluation task must load finetune model!")
    if assessment_method == "clue_benchmark":
        bert_net_cfg.batch_size = 1
    net_for_pretraining = network(bert_net_cfg,
                                  False,
                                  num_class,
                                  use_crf=(use_crf.lower() == "true"),
                                  tag_to_index=tag_to_index)
    net_for_pretraining.set_train(False)
    param_dict = load_checkpoint(load_checkpoint_path)
    load_param_into_net(net_for_pretraining, param_dict)
    model = Model(net_for_pretraining)

    if assessment_method == "clue_benchmark":
        from src.cluener_evaluation import submit
        submit(model=model,
               path=data_file,
               vocab_file=vocab_file,
               use_crf=use_crf,
               label2id_file=label2id_file)
    else:
        if assessment_method == "accuracy":
            callback = Accuracy()
        elif assessment_method == "f1":
            callback = F1((use_crf.lower() == "true"), num_class)
        elif assessment_method == "mcc":
            callback = MCC()
        elif assessment_method == "spearman_correlation":
            callback = Spearman_Correlation()
        else:
            raise ValueError(
                "Assessment method not supported, support: [accuracy, f1, mcc, spearman_correlation]"
            )

        columns_list = ["input_ids", "input_mask", "segment_ids", "label_ids"]
        for data in dataset.create_dict_iterator():
            input_data = []
            for i in columns_list:
                input_data.append(Tensor(data[i]))
            input_ids, input_mask, token_type_id, label_ids = input_data
            logits = model.predict(input_ids, input_mask, token_type_id,
                                   label_ids)
            callback.update(logits, label_ids)
        print("==============================================================")
        eval_result_print(assessment_method, callback)
        print("==============================================================")
Exemplo n.º 3
0
if __name__ == "__main__":
    num_class = 41
    assessment_method = args.assessment_method.lower()
    use_crf = args.use_crf

    if assessment_method == "accuracy":
        callback = Accuracy()
    elif assessment_method == "bf1":
        callback = F1((use_crf.lower() == "true"), num_class)
    elif assessment_method == "mf1":
        callback = F1((use_crf.lower() == "true"), num_labels=num_class, mode="MultiLabel")
    elif assessment_method == "mcc":
        callback = MCC()
    elif assessment_method == "spearman_correlation":
        callback = Spearman_Correlation()
    else:
        raise ValueError("Assessment method not supported, support: [accuracy, f1, mcc, spearman_correlation]")

    file_name = os.listdir(args.label_dir)
    for f in file_name:
        if use_crf.lower() == "true":
            logits = ()
            for j in range(bert_net_cfg.seq_length):
                f_name = f.split('.')[0] + '_' + str(j) + '.bin'
                data_tmp = np.fromfile(os.path.join(args.result_dir, f_name), np.int32)
                data_tmp = data_tmp.reshape(args.batch_size, num_class + 2)
                logits += ((Tensor(data_tmp),),)
            f_name = f.split('.')[0] + '_' + str(bert_net_cfg.seq_length) + '.bin'
            data_tmp = np.fromfile(os.path.join(args.result_dir, f_name), np.int32).tolist()
            data_tmp = Tensor(data_tmp)