コード例 #1
0
ファイル: run_ner.py プロジェクト: mark14wu/bert_demo
def run_ner():
    """run ner task"""
    args_opt = parse_args()
    epoch_num = args_opt.epoch_num
    assessment_method = args_opt.assessment_method.lower()
    load_pretrain_checkpoint_path = args_opt.load_pretrain_checkpoint_path
    save_finetune_checkpoint_path = args_opt.save_finetune_checkpoint_path
    load_finetune_checkpoint_path = args_opt.load_finetune_checkpoint_path
    target = args_opt.device_target
    if target == "Ascend":
        context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args_opt.device_id)
    elif target == "GPU":
        context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
        if bert_net_cfg.compute_type != mstype.float32:
            logger.warning('GPU only support fp32 temporarily, run with fp32.')
            bert_net_cfg.compute_type = mstype.float32
    else:
        raise Exception("Target error, GPU or Ascend is supported.")
    label_list = []
    with open(args_opt.label_file_path) as f:
        for label in f:
            label_list.append(label.strip())
    tag_to_index = convert_labels_to_index(label_list)
    if args_opt.use_crf.lower() == "true":
        max_val = max(tag_to_index.values())
        tag_to_index["<START>"] = max_val + 1
        tag_to_index["<STOP>"] = max_val + 2
        number_labels = len(tag_to_index)
    else:
        number_labels = args_opt.num_class
    if args_opt.do_train.lower() == "true":
        netwithloss = BertNER(bert_net_cfg, args_opt.train_batch_size, True, num_labels=number_labels,
                              use_crf=(args_opt.use_crf.lower() == "true"),
                              tag_to_index=tag_to_index, dropout_prob=0.1)
        ds = create_ner_dataset(batch_size=args_opt.train_batch_size, repeat_count=1,
                                assessment_method=assessment_method, data_file_path=args_opt.train_data_file_path,
                                schema_file_path=args_opt.schema_file_path,
                                do_shuffle=(args_opt.train_data_shuffle.lower() == "true"))
        do_train(ds, netwithloss, load_pretrain_checkpoint_path, save_finetune_checkpoint_path, epoch_num)

        if args_opt.do_eval.lower() == "true":
            if save_finetune_checkpoint_path == "":
                load_finetune_checkpoint_dir = _cur_dir
            else:
                load_finetune_checkpoint_dir = make_directory(save_finetune_checkpoint_path)
            load_finetune_checkpoint_path = LoadNewestCkpt(load_finetune_checkpoint_dir,
                                                           ds.get_dataset_size(), epoch_num, "ner")

    if args_opt.do_eval.lower() == "true":
        ds = create_ner_dataset(batch_size=args_opt.eval_batch_size, repeat_count=1,
                                assessment_method=assessment_method, data_file_path=args_opt.eval_data_file_path,
                                schema_file_path=args_opt.schema_file_path,
                                do_shuffle=(args_opt.eval_data_shuffle.lower() == "true"))
        do_eval(ds, BertNER, args_opt.use_crf, number_labels, assessment_method,
                args_opt.eval_data_file_path, load_finetune_checkpoint_path, args_opt.vocab_file_path,
                args_opt.label_file_path, tag_to_index, args_opt.eval_batch_size)
コード例 #2
0
ファイル: run_ner.py プロジェクト: egnlife/mindspore
def run_ner():
    """run ner task"""
    parser = argparse.ArgumentParser(description="run classifier")
    parser.add_argument("--device_target",
                        type=str,
                        default="Ascend",
                        choices=["Ascend", "GPU"],
                        help="Device type, default is Ascend")
    parser.add_argument(
        "--assessment_method",
        type=str,
        default="F1",
        choices=["F1", "clue_benchmark"],
        help="assessment_method include: [F1, clue_benchmark], default is F1")
    parser.add_argument("--do_train",
                        type=str,
                        default="false",
                        choices=["true", "false"],
                        help="Eable train, default is false")
    parser.add_argument("--do_eval",
                        type=str,
                        default="false",
                        choices=["true", "false"],
                        help="Eable eval, default is false")
    parser.add_argument("--use_crf",
                        type=str,
                        default="false",
                        choices=["true", "false"],
                        help="Use crf, default is false")
    parser.add_argument("--device_id",
                        type=int,
                        default=0,
                        help="Device id, default is 0.")
    parser.add_argument("--epoch_num",
                        type=int,
                        default="1",
                        help="Epoch number, default is 1.")
    parser.add_argument("--num_class",
                        type=int,
                        default="2",
                        help="The number of class, default is 2.")
    parser.add_argument("--train_data_shuffle",
                        type=str,
                        default="true",
                        choices=["true", "false"],
                        help="Enable train data shuffle, default is true")
    parser.add_argument("--eval_data_shuffle",
                        type=str,
                        default="false",
                        choices=["true", "false"],
                        help="Enable eval data shuffle, default is false")
    parser.add_argument("--vocab_file_path",
                        type=str,
                        default="",
                        help="Vocab file path, used in clue benchmark")
    parser.add_argument("--label2id_file_path",
                        type=str,
                        default="",
                        help="label2id file path, used in clue benchmark")
    parser.add_argument("--save_finetune_checkpoint_path",
                        type=str,
                        default="",
                        help="Save checkpoint path")
    parser.add_argument("--load_pretrain_checkpoint_path",
                        type=str,
                        default="",
                        help="Load checkpoint file path")
    parser.add_argument("--load_finetune_checkpoint_path",
                        type=str,
                        default="",
                        help="Load checkpoint file path")
    parser.add_argument("--train_data_file_path",
                        type=str,
                        default="",
                        help="Data path, it is better to use absolute path")
    parser.add_argument("--eval_data_file_path",
                        type=str,
                        default="",
                        help="Data path, it is better to use absolute path")
    parser.add_argument("--schema_file_path",
                        type=str,
                        default="",
                        help="Schema path, it is better to use absolute path")
    args_opt = parser.parse_args()
    epoch_num = args_opt.epoch_num
    assessment_method = args_opt.assessment_method.lower()
    load_pretrain_checkpoint_path = args_opt.load_pretrain_checkpoint_path
    save_finetune_checkpoint_path = args_opt.save_finetune_checkpoint_path
    load_finetune_checkpoint_path = args_opt.load_finetune_checkpoint_path

    if args_opt.do_train.lower() == "false" and args_opt.do_eval.lower(
    ) == "false":
        raise ValueError(
            "At least one of 'do_train' or 'do_eval' must be true")
    if args_opt.do_train.lower(
    ) == "true" and args_opt.train_data_file_path == "":
        raise ValueError(
            "'train_data_file_path' must be set when do finetune task")
    if args_opt.do_eval.lower(
    ) == "true" and args_opt.eval_data_file_path == "":
        raise ValueError(
            "'eval_data_file_path' must be set when do evaluation task")
    if args_opt.assessment_method.lower(
    ) == "clue_benchmark" and args_opt.vocab_file_path == "":
        raise ValueError("'vocab_file_path' must be set to do clue benchmark")
    if args_opt.use_crf.lower(
    ) == "true" and args_opt.label2id_file_path == "":
        raise ValueError("'label2id_file_path' must be set to use crf")
    if args_opt.assessment_method.lower(
    ) == "clue_benchmark" and args_opt.label2id_file_path == "":
        raise ValueError(
            "'label2id_file_path' must be set to do clue benchmark")

    target = args_opt.device_target
    if target == "Ascend":
        context.set_context(mode=context.GRAPH_MODE,
                            device_target="Ascend",
                            device_id=args_opt.device_id)
    elif target == "GPU":
        context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
        if bert_net_cfg.compute_type != mstype.float32:
            logger.warning('GPU only support fp32 temporarily, run with fp32.')
            bert_net_cfg.compute_type = mstype.float32
    else:
        raise Exception("Target error, GPU or Ascend is supported.")

    tag_to_index = None
    if args_opt.use_crf.lower() == "true":
        with open(args_opt.label2id_file_path) as json_file:
            tag_to_index = json.load(json_file)
        max_val = max(tag_to_index.values())
        tag_to_index["<START>"] = max_val + 1
        tag_to_index["<STOP>"] = max_val + 2
        number_labels = len(tag_to_index)
    else:
        number_labels = args_opt.num_class
    netwithloss = BertNER(bert_net_cfg,
                          True,
                          num_labels=number_labels,
                          use_crf=(args_opt.use_crf.lower() == "true"),
                          tag_to_index=tag_to_index,
                          dropout_prob=0.1)
    if args_opt.do_train.lower() == "true":
        ds = create_ner_dataset(
            batch_size=bert_net_cfg.batch_size,
            repeat_count=1,
            assessment_method=assessment_method,
            data_file_path=args_opt.train_data_file_path,
            schema_file_path=args_opt.schema_file_path,
            do_shuffle=(args_opt.train_data_shuffle.lower() == "true"))
        do_train(ds, netwithloss, load_pretrain_checkpoint_path,
                 save_finetune_checkpoint_path, epoch_num)

        if args_opt.do_eval.lower() == "true":
            if save_finetune_checkpoint_path == "":
                load_finetune_checkpoint_dir = _cur_dir
            else:
                load_finetune_checkpoint_dir = make_directory(
                    save_finetune_checkpoint_path)
            load_finetune_checkpoint_path = LoadNewestCkpt(
                load_finetune_checkpoint_dir, ds.get_dataset_size(), epoch_num,
                "ner")

    if args_opt.do_eval.lower() == "true":
        ds = create_ner_dataset(
            batch_size=bert_net_cfg.batch_size,
            repeat_count=1,
            assessment_method=assessment_method,
            data_file_path=args_opt.eval_data_file_path,
            schema_file_path=args_opt.schema_file_path,
            do_shuffle=(args_opt.eval_data_shuffle.lower() == "true"))
        do_eval(ds, BertNER, args_opt.use_crf, number_labels,
                assessment_method, args_opt.eval_data_file_path,
                load_finetune_checkpoint_path, args_opt.vocab_file_path,
                args_opt.label2id_file_path, tag_to_index)