Пример #1
0
def do_eval(dataset=None,
            network=None,
            use_crf="",
            num_class=2,
            assessment_method="accuracy",
            data_file="",
            load_checkpoint_path="",
            vocab_file="",
            label2id_file="",
            tag_to_index=None):
    """ do eval """
    if load_checkpoint_path == "":
        raise ValueError(
            "Finetune model missed, evaluation task must load finetune model!")
    if assessment_method == "clue_benchmark":
        bert_net_cfg.batch_size = 1
    net_for_pretraining = network(bert_net_cfg,
                                  False,
                                  num_class,
                                  use_crf=(use_crf.lower() == "true"),
                                  tag_to_index=tag_to_index)
    net_for_pretraining.set_train(False)
    param_dict = load_checkpoint(load_checkpoint_path)
    load_param_into_net(net_for_pretraining, param_dict)
    model = Model(net_for_pretraining)

    if assessment_method == "clue_benchmark":
        from src.cluener_evaluation import submit
        submit(model=model,
               path=data_file,
               vocab_file=vocab_file,
               use_crf=use_crf,
               label2id_file=label2id_file)
    else:
        if assessment_method == "accuracy":
            callback = Accuracy()
        elif assessment_method == "f1":
            callback = F1((use_crf.lower() == "true"), num_class)
        elif assessment_method == "mcc":
            callback = MCC()
        elif assessment_method == "spearman_correlation":
            callback = Spearman_Correlation()
        else:
            raise ValueError(
                "Assessment method not supported, support: [accuracy, f1, mcc, spearman_correlation]"
            )

        columns_list = ["input_ids", "input_mask", "segment_ids", "label_ids"]
        for data in dataset.create_dict_iterator():
            input_data = []
            for i in columns_list:
                input_data.append(Tensor(data[i]))
            input_ids, input_mask, token_type_id, label_ids = input_data
            logits = model.predict(input_ids, input_mask, token_type_id,
                                   label_ids)
            callback.update(logits, label_ids)
        print("==============================================================")
        eval_result_print(assessment_method, callback)
        print("==============================================================")
Пример #2
0
def do_eval_standalone():
    """
    do eval standalone
    """
    ckpt_file = args_opt.load_td1_ckpt_path
    if ckpt_file == '':
        raise ValueError("Student ckpt file should not be None")
    if args_opt.device_target == "Ascend":
        context.set_context(mode=context.GRAPH_MODE,
                            device_target=args_opt.device_target,
                            device_id=args_opt.device_id)
    elif args_opt.device_target == "GPU":
        context.set_context(mode=context.GRAPH_MODE,
                            device_target=args_opt.device_target)
    else:
        raise Exception("Target error, GPU or Ascend is supported.")
    eval_model = BertModelCLS(td_student_net_cfg,
                              False,
                              task.num_labels,
                              0.0,
                              phase_type="student")
    param_dict = load_checkpoint(ckpt_file)
    new_param_dict = {}
    for key, value in param_dict.items():
        new_key = re.sub('tinybert_', 'bert_', key)
        new_key = re.sub('^bert.', '', new_key)
        new_param_dict[new_key] = value
    load_param_into_net(eval_model, new_param_dict)
    eval_model.set_train(False)

    eval_dataset = create_tinybert_dataset('td',
                                           batch_size=eval_cfg.batch_size,
                                           device_num=1,
                                           rank=0,
                                           do_shuffle="false",
                                           data_dir=args_opt.eval_data_dir,
                                           schema_dir=args_opt.schema_dir)
    print('eval dataset size: ', eval_dataset.get_dataset_size())
    print('eval dataset batch size: ', eval_dataset.get_batch_size())

    callback = Accuracy()
    columns_list = ["input_ids", "input_mask", "segment_ids", "label_ids"]
    for data in eval_dataset.create_dict_iterator(num_epochs=1):
        input_data = []
        for i in columns_list:
            input_data.append(data[i])
        input_ids, input_mask, token_type_id, label_ids = input_data
        logits = eval_model(input_ids, token_type_id, input_mask)
        callback.update(logits[3], label_ids)
    acc = callback.acc_num / callback.total_num
    print("======================================")
    print("============== acc is {}".format(acc))
    print("======================================")
Пример #3
0
def do_eval(dataset=None,
            network=None,
            num_class=2,
            assessment_method="accuracy",
            load_checkpoint_path=""):
    """ do eval """
    if load_checkpoint_path == "":
        raise ValueError(
            "Finetune model missed, evaluation task must load finetune model!")
    net_for_pretraining = network(bert_net_cfg, False, num_class)
    net_for_pretraining.set_train(False)
    param_dict = load_checkpoint(load_checkpoint_path)
    load_param_into_net(net_for_pretraining, param_dict)
    model = Model(net_for_pretraining)

    if assessment_method == "accuracy":
        callback = Accuracy()
    elif assessment_method == "f1":
        callback = F1(False, num_class)
    elif assessment_method == "mcc":
        callback = MCC()
    elif assessment_method == "spearman_correlation":
        callback = Spearman_Correlation()
    else:
        raise ValueError(
            "Assessment method not supported, support: [accuracy, f1, mcc, spearman_correlation]"
        )

    columns_list = ["input_ids", "input_mask", "segment_ids", "label_ids"]
    for data in dataset.create_dict_iterator():
        input_data = []
        for i in columns_list:
            input_data.append(Tensor(data[i]))
        input_ids, input_mask, token_type_id, label_ids = input_data
        logits = model.predict(input_ids, input_mask, token_type_id, label_ids)
        callback.update(logits, label_ids)
    print("==============================================================")
    eval_result_print(assessment_method, callback)
    print("==============================================================")
Пример #4
0
        callback = MCC()
    elif assessment_method == "spearman_correlation":
        callback = Spearman_Correlation()
    else:
        raise ValueError("Assessment method not supported, support: [accuracy, f1, mcc, spearman_correlation]")

    file_name = os.listdir(args.label_dir)
    for f in file_name:
        if use_crf.lower() == "true":
            logits = ()
            for j in range(bert_net_cfg.seq_length):
                f_name = f.split('.')[0] + '_' + str(j) + '.bin'
                data_tmp = np.fromfile(os.path.join(args.result_dir, f_name), np.int32)
                data_tmp = data_tmp.reshape(args.batch_size, num_class + 2)
                logits += ((Tensor(data_tmp),),)
            f_name = f.split('.')[0] + '_' + str(bert_net_cfg.seq_length) + '.bin'
            data_tmp = np.fromfile(os.path.join(args.result_dir, f_name), np.int32).tolist()
            data_tmp = Tensor(data_tmp)
            logits = (logits, data_tmp)
        else:
            f_name = os.path.join(args.result_dir, f.split('.')[0] + '_0.bin')
            logits = np.fromfile(f_name, np.float32).reshape(bert_net_cfg.seq_length * args.batch_size, num_class)
            logits = Tensor(logits)
        label_ids = np.fromfile(os.path.join(args.label_dir, f), np.int32)
        label_ids = Tensor(label_ids.reshape(args.batch_size, bert_net_cfg.seq_length))
        callback.update(logits, label_ids)

    print("==============================================================")
    eval_result_print(assessment_method, callback)
    print("==============================================================")