def do_eval(dataset=None, network=None, use_crf="", num_class=2, assessment_method="accuracy", data_file="", load_checkpoint_path="", vocab_file="", label2id_file="", tag_to_index=None): """ do eval """ if load_checkpoint_path == "": raise ValueError( "Finetune model missed, evaluation task must load finetune model!") if assessment_method == "clue_benchmark": bert_net_cfg.batch_size = 1 net_for_pretraining = network(bert_net_cfg, False, num_class, use_crf=(use_crf.lower() == "true"), tag_to_index=tag_to_index) net_for_pretraining.set_train(False) param_dict = load_checkpoint(load_checkpoint_path) load_param_into_net(net_for_pretraining, param_dict) model = Model(net_for_pretraining) if assessment_method == "clue_benchmark": from src.cluener_evaluation import submit submit(model=model, path=data_file, vocab_file=vocab_file, use_crf=use_crf, label2id_file=label2id_file) else: if assessment_method == "accuracy": callback = Accuracy() elif assessment_method == "f1": callback = F1((use_crf.lower() == "true"), num_class) elif assessment_method == "mcc": callback = MCC() elif assessment_method == "spearman_correlation": callback = Spearman_Correlation() else: raise ValueError( "Assessment method not supported, support: [accuracy, f1, mcc, spearman_correlation]" ) columns_list = ["input_ids", "input_mask", "segment_ids", "label_ids"] for data in dataset.create_dict_iterator(): input_data = [] for i in columns_list: input_data.append(Tensor(data[i])) input_ids, input_mask, token_type_id, label_ids = input_data logits = model.predict(input_ids, input_mask, token_type_id, label_ids) callback.update(logits, label_ids) print("==============================================================") eval_result_print(assessment_method, callback) print("==============================================================")
def do_eval_standalone(): """ do eval standalone """ ckpt_file = args_opt.load_td1_ckpt_path if ckpt_file == '': raise ValueError("Student ckpt file should not be None") if args_opt.device_target == "Ascend": context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=args_opt.device_id) elif args_opt.device_target == "GPU": context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target) else: raise Exception("Target error, GPU or Ascend is supported.") eval_model = BertModelCLS(td_student_net_cfg, False, task.num_labels, 0.0, phase_type="student") param_dict = load_checkpoint(ckpt_file) new_param_dict = {} for key, value in param_dict.items(): new_key = re.sub('tinybert_', 'bert_', key) new_key = re.sub('^bert.', '', new_key) new_param_dict[new_key] = value load_param_into_net(eval_model, new_param_dict) eval_model.set_train(False) eval_dataset = create_tinybert_dataset('td', batch_size=eval_cfg.batch_size, device_num=1, rank=0, do_shuffle="false", data_dir=args_opt.eval_data_dir, schema_dir=args_opt.schema_dir) print('eval dataset size: ', eval_dataset.get_dataset_size()) print('eval dataset batch size: ', eval_dataset.get_batch_size()) callback = Accuracy() columns_list = ["input_ids", "input_mask", "segment_ids", "label_ids"] for data in eval_dataset.create_dict_iterator(num_epochs=1): input_data = [] for i in columns_list: input_data.append(data[i]) input_ids, input_mask, token_type_id, label_ids = input_data logits = eval_model(input_ids, token_type_id, input_mask) callback.update(logits[3], label_ids) acc = callback.acc_num / callback.total_num print("======================================") print("============== acc is {}".format(acc)) print("======================================")
def do_eval(dataset=None, network=None, num_class=2, assessment_method="accuracy", load_checkpoint_path=""): """ do eval """ if load_checkpoint_path == "": raise ValueError( "Finetune model missed, evaluation task must load finetune model!") net_for_pretraining = network(bert_net_cfg, False, num_class) net_for_pretraining.set_train(False) param_dict = load_checkpoint(load_checkpoint_path) load_param_into_net(net_for_pretraining, param_dict) model = Model(net_for_pretraining) if assessment_method == "accuracy": callback = Accuracy() elif assessment_method == "f1": callback = F1(False, num_class) elif assessment_method == "mcc": callback = MCC() elif assessment_method == "spearman_correlation": callback = Spearman_Correlation() else: raise ValueError( "Assessment method not supported, support: [accuracy, f1, mcc, spearman_correlation]" ) columns_list = ["input_ids", "input_mask", "segment_ids", "label_ids"] for data in dataset.create_dict_iterator(): input_data = [] for i in columns_list: input_data.append(Tensor(data[i])) input_ids, input_mask, token_type_id, label_ids = input_data logits = model.predict(input_ids, input_mask, token_type_id, label_ids) callback.update(logits, label_ids) print("==============================================================") eval_result_print(assessment_method, callback) print("==============================================================")
callback = MCC() elif assessment_method == "spearman_correlation": callback = Spearman_Correlation() else: raise ValueError("Assessment method not supported, support: [accuracy, f1, mcc, spearman_correlation]") file_name = os.listdir(args.label_dir) for f in file_name: if use_crf.lower() == "true": logits = () for j in range(bert_net_cfg.seq_length): f_name = f.split('.')[0] + '_' + str(j) + '.bin' data_tmp = np.fromfile(os.path.join(args.result_dir, f_name), np.int32) data_tmp = data_tmp.reshape(args.batch_size, num_class + 2) logits += ((Tensor(data_tmp),),) f_name = f.split('.')[0] + '_' + str(bert_net_cfg.seq_length) + '.bin' data_tmp = np.fromfile(os.path.join(args.result_dir, f_name), np.int32).tolist() data_tmp = Tensor(data_tmp) logits = (logits, data_tmp) else: f_name = os.path.join(args.result_dir, f.split('.')[0] + '_0.bin') logits = np.fromfile(f_name, np.float32).reshape(bert_net_cfg.seq_length * args.batch_size, num_class) logits = Tensor(logits) label_ids = np.fromfile(os.path.join(args.label_dir, f), np.int32) label_ids = Tensor(label_ids.reshape(args.batch_size, bert_net_cfg.seq_length)) callback.update(logits, label_ids) print("==============================================================") eval_result_print(assessment_method, callback) print("==============================================================")