Exemplo n.º 1
0
def run_ner():
    """run ner task"""
    args_opt = parse_args()
    epoch_num = args_opt.epoch_num
    assessment_method = args_opt.assessment_method.lower()
    load_pretrain_checkpoint_path = args_opt.load_pretrain_checkpoint_path
    save_finetune_checkpoint_path = args_opt.save_finetune_checkpoint_path
    load_finetune_checkpoint_path = args_opt.load_finetune_checkpoint_path
    target = args_opt.device_target
    if target == "Ascend":
        context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args_opt.device_id)
    elif target == "GPU":
        context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
        if bert_net_cfg.compute_type != mstype.float32:
            logger.warning('GPU only support fp32 temporarily, run with fp32.')
            bert_net_cfg.compute_type = mstype.float32
    else:
        raise Exception("Target error, GPU or Ascend is supported.")
    label_list = []
    with open(args_opt.label_file_path) as f:
        for label in f:
            label_list.append(label.strip())
    tag_to_index = convert_labels_to_index(label_list)
    if args_opt.use_crf.lower() == "true":
        max_val = max(tag_to_index.values())
        tag_to_index["<START>"] = max_val + 1
        tag_to_index["<STOP>"] = max_val + 2
        number_labels = len(tag_to_index)
    else:
        number_labels = args_opt.num_class
    if args_opt.do_train.lower() == "true":
        netwithloss = BertNER(bert_net_cfg, args_opt.train_batch_size, True, num_labels=number_labels,
                              use_crf=(args_opt.use_crf.lower() == "true"),
                              tag_to_index=tag_to_index, dropout_prob=0.1)
        ds = create_ner_dataset(batch_size=args_opt.train_batch_size, repeat_count=1,
                                assessment_method=assessment_method, data_file_path=args_opt.train_data_file_path,
                                schema_file_path=args_opt.schema_file_path,
                                do_shuffle=(args_opt.train_data_shuffle.lower() == "true"))
        do_train(ds, netwithloss, load_pretrain_checkpoint_path, save_finetune_checkpoint_path, epoch_num)

        if args_opt.do_eval.lower() == "true":
            if save_finetune_checkpoint_path == "":
                load_finetune_checkpoint_dir = _cur_dir
            else:
                load_finetune_checkpoint_dir = make_directory(save_finetune_checkpoint_path)
            load_finetune_checkpoint_path = LoadNewestCkpt(load_finetune_checkpoint_dir,
                                                           ds.get_dataset_size(), epoch_num, "ner")

    if args_opt.do_eval.lower() == "true":
        ds = create_ner_dataset(batch_size=args_opt.eval_batch_size, repeat_count=1,
                                assessment_method=assessment_method, data_file_path=args_opt.eval_data_file_path,
                                schema_file_path=args_opt.schema_file_path,
                                do_shuffle=(args_opt.eval_data_shuffle.lower() == "true"))
        do_eval(ds, BertNER, args_opt.use_crf, number_labels, assessment_method,
                args_opt.eval_data_file_path, load_finetune_checkpoint_path, args_opt.vocab_file_path,
                args_opt.label_file_path, tag_to_index, args_opt.eval_batch_size)
Exemplo n.º 2
0
                    type=str,
                    default="Ascend",
                    choices=["Ascend", "GPU", "CPU"],
                    help="device target (default: Ascend)")
args = parser.parse_args()

context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
if args.device_target == "Ascend":
    context.set_context(device_id=args.device_id)

label_list = []
with open(args.label_file_path) as f:
    for label in f:
        label_list.append(label.strip())

tag_to_index = convert_labels_to_index(label_list)

if args.use_crf.lower() == "true":
    max_val = max(tag_to_index.values())
    tag_to_index["<START>"] = max_val + 1
    tag_to_index["<STOP>"] = max_val + 2
    number_labels = len(tag_to_index)
else:
    number_labels = args.num_class

if __name__ == "__main__":
    if args.downstream_task == "NER":
        if args.use_crf.lower() == "true":
            net = BertNER(bert_net_cfg,
                          args.batch_size,
                          False,