def run_ner(): """run ner task""" args_opt = parse_args() epoch_num = args_opt.epoch_num assessment_method = args_opt.assessment_method.lower() load_pretrain_checkpoint_path = args_opt.load_pretrain_checkpoint_path save_finetune_checkpoint_path = args_opt.save_finetune_checkpoint_path load_finetune_checkpoint_path = args_opt.load_finetune_checkpoint_path target = args_opt.device_target if target == "Ascend": context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args_opt.device_id) elif target == "GPU": context.set_context(mode=context.GRAPH_MODE, device_target="GPU") if bert_net_cfg.compute_type != mstype.float32: logger.warning('GPU only support fp32 temporarily, run with fp32.') bert_net_cfg.compute_type = mstype.float32 else: raise Exception("Target error, GPU or Ascend is supported.") label_list = [] with open(args_opt.label_file_path) as f: for label in f: label_list.append(label.strip()) tag_to_index = convert_labels_to_index(label_list) if args_opt.use_crf.lower() == "true": max_val = max(tag_to_index.values()) tag_to_index["<START>"] = max_val + 1 tag_to_index["<STOP>"] = max_val + 2 number_labels = len(tag_to_index) else: number_labels = args_opt.num_class if args_opt.do_train.lower() == "true": netwithloss = BertNER(bert_net_cfg, args_opt.train_batch_size, True, num_labels=number_labels, use_crf=(args_opt.use_crf.lower() == "true"), tag_to_index=tag_to_index, dropout_prob=0.1) ds = create_ner_dataset(batch_size=args_opt.train_batch_size, repeat_count=1, assessment_method=assessment_method, data_file_path=args_opt.train_data_file_path, schema_file_path=args_opt.schema_file_path, do_shuffle=(args_opt.train_data_shuffle.lower() == "true")) do_train(ds, netwithloss, load_pretrain_checkpoint_path, save_finetune_checkpoint_path, epoch_num) if args_opt.do_eval.lower() == "true": if save_finetune_checkpoint_path == "": load_finetune_checkpoint_dir = _cur_dir else: load_finetune_checkpoint_dir = make_directory(save_finetune_checkpoint_path) load_finetune_checkpoint_path = LoadNewestCkpt(load_finetune_checkpoint_dir, ds.get_dataset_size(), epoch_num, "ner") if args_opt.do_eval.lower() == "true": ds = create_ner_dataset(batch_size=args_opt.eval_batch_size, repeat_count=1, assessment_method=assessment_method, data_file_path=args_opt.eval_data_file_path, schema_file_path=args_opt.schema_file_path, do_shuffle=(args_opt.eval_data_shuffle.lower() == "true")) do_eval(ds, BertNER, args_opt.use_crf, number_labels, assessment_method, args_opt.eval_data_file_path, load_finetune_checkpoint_path, args_opt.vocab_file_path, args_opt.label_file_path, tag_to_index, args_opt.eval_batch_size)
def run_ner(): """run ner task""" parser = argparse.ArgumentParser(description="run classifier") parser.add_argument("--device_target", type=str, default="Ascend", choices=["Ascend", "GPU"], help="Device type, default is Ascend") parser.add_argument( "--assessment_method", type=str, default="F1", choices=["F1", "clue_benchmark"], help="assessment_method include: [F1, clue_benchmark], default is F1") parser.add_argument("--do_train", type=str, default="false", choices=["true", "false"], help="Eable train, default is false") parser.add_argument("--do_eval", type=str, default="false", choices=["true", "false"], help="Eable eval, default is false") parser.add_argument("--use_crf", type=str, default="false", choices=["true", "false"], help="Use crf, default is false") parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.") parser.add_argument("--epoch_num", type=int, default="1", help="Epoch number, default is 1.") parser.add_argument("--num_class", type=int, default="2", help="The number of class, default is 2.") parser.add_argument("--train_data_shuffle", type=str, default="true", choices=["true", "false"], help="Enable train data shuffle, default is true") parser.add_argument("--eval_data_shuffle", type=str, default="false", choices=["true", "false"], help="Enable eval data shuffle, default is false") parser.add_argument("--vocab_file_path", type=str, default="", help="Vocab file path, used in clue benchmark") parser.add_argument("--label2id_file_path", type=str, default="", help="label2id file path, used in clue benchmark") parser.add_argument("--save_finetune_checkpoint_path", type=str, default="", help="Save checkpoint path") parser.add_argument("--load_pretrain_checkpoint_path", type=str, default="", help="Load checkpoint file path") parser.add_argument("--load_finetune_checkpoint_path", type=str, default="", help="Load checkpoint file path") parser.add_argument("--train_data_file_path", type=str, default="", help="Data path, it is better to use absolute path") parser.add_argument("--eval_data_file_path", type=str, default="", help="Data path, it is better to use absolute path") parser.add_argument("--schema_file_path", type=str, default="", help="Schema path, it is better to use absolute path") args_opt = parser.parse_args() epoch_num = args_opt.epoch_num assessment_method = args_opt.assessment_method.lower() load_pretrain_checkpoint_path = args_opt.load_pretrain_checkpoint_path save_finetune_checkpoint_path = args_opt.save_finetune_checkpoint_path load_finetune_checkpoint_path = args_opt.load_finetune_checkpoint_path if args_opt.do_train.lower() == "false" and args_opt.do_eval.lower( ) == "false": raise ValueError( "At least one of 'do_train' or 'do_eval' must be true") if args_opt.do_train.lower( ) == "true" and args_opt.train_data_file_path == "": raise ValueError( "'train_data_file_path' must be set when do finetune task") if args_opt.do_eval.lower( ) == "true" and args_opt.eval_data_file_path == "": raise ValueError( "'eval_data_file_path' must be set when do evaluation task") if args_opt.assessment_method.lower( ) == "clue_benchmark" and args_opt.vocab_file_path == "": raise ValueError("'vocab_file_path' must be set to do clue benchmark") if args_opt.use_crf.lower( ) == "true" and args_opt.label2id_file_path == "": raise ValueError("'label2id_file_path' must be set to use crf") if args_opt.assessment_method.lower( ) == "clue_benchmark" and args_opt.label2id_file_path == "": raise ValueError( "'label2id_file_path' must be set to do clue benchmark") target = args_opt.device_target if target == "Ascend": context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args_opt.device_id) elif target == "GPU": context.set_context(mode=context.GRAPH_MODE, device_target="GPU") if bert_net_cfg.compute_type != mstype.float32: logger.warning('GPU only support fp32 temporarily, run with fp32.') bert_net_cfg.compute_type = mstype.float32 else: raise Exception("Target error, GPU or Ascend is supported.") tag_to_index = None if args_opt.use_crf.lower() == "true": with open(args_opt.label2id_file_path) as json_file: tag_to_index = json.load(json_file) max_val = max(tag_to_index.values()) tag_to_index["<START>"] = max_val + 1 tag_to_index["<STOP>"] = max_val + 2 number_labels = len(tag_to_index) else: number_labels = args_opt.num_class netwithloss = BertNER(bert_net_cfg, True, num_labels=number_labels, use_crf=(args_opt.use_crf.lower() == "true"), tag_to_index=tag_to_index, dropout_prob=0.1) if args_opt.do_train.lower() == "true": ds = create_ner_dataset( batch_size=bert_net_cfg.batch_size, repeat_count=1, assessment_method=assessment_method, data_file_path=args_opt.train_data_file_path, schema_file_path=args_opt.schema_file_path, do_shuffle=(args_opt.train_data_shuffle.lower() == "true")) do_train(ds, netwithloss, load_pretrain_checkpoint_path, save_finetune_checkpoint_path, epoch_num) if args_opt.do_eval.lower() == "true": if save_finetune_checkpoint_path == "": load_finetune_checkpoint_dir = _cur_dir else: load_finetune_checkpoint_dir = make_directory( save_finetune_checkpoint_path) load_finetune_checkpoint_path = LoadNewestCkpt( load_finetune_checkpoint_dir, ds.get_dataset_size(), epoch_num, "ner") if args_opt.do_eval.lower() == "true": ds = create_ner_dataset( batch_size=bert_net_cfg.batch_size, repeat_count=1, assessment_method=assessment_method, data_file_path=args_opt.eval_data_file_path, schema_file_path=args_opt.schema_file_path, do_shuffle=(args_opt.eval_data_shuffle.lower() == "true")) do_eval(ds, BertNER, args_opt.use_crf, number_labels, assessment_method, args_opt.eval_data_file_path, load_finetune_checkpoint_path, args_opt.vocab_file_path, args_opt.label2id_file_path, tag_to_index)