Ejemplo n.º 1
0
def train():
    '''
    finetune function
    '''
    # BertCLS train for classification
    # BertNER train for sequence labeling

    if cfg.task == 'NER':
        tag_to_index = None
        if cfg.use_crf:
            tag_to_index = json.loads(open(cfg.label2id_file).read())
            print(tag_to_index)
            max_val = len(tag_to_index)
            tag_to_index["<START>"] = max_val
            tag_to_index["<STOP>"] = max_val + 1
            number_labels = len(tag_to_index)
        else:
            number_labels = cfg.num_labels

        netwithloss = BertNER(bert_net_cfg,
                              cfg.batch_size,
                              True,
                              num_labels=number_labels,
                              use_crf=cfg.use_crf,
                              tag_to_index=tag_to_index,
                              dropout_prob=0.1)
    elif cfg.task == 'Classification':
        netwithloss = BertCLS(bert_net_cfg,
                              True,
                              num_labels=cfg.num_labels,
                              dropout_prob=0.1,
                              assessment_method=cfg.assessment_method)
    else:
        raise Exception("task error, NER or Classification is supported.")

    dataset = get_dataset(data_file=cfg.data_file, batch_size=cfg.batch_size)
    steps_per_epoch = dataset.get_dataset_size()
    print('steps_per_epoch:', steps_per_epoch)

    # optimizer
    steps_per_epoch = dataset.get_dataset_size()
    if cfg.optimizer == 'AdamWeightDecay':
        lr_schedule = BertLearningRate(
            learning_rate=optimizer_cfg.AdamWeightDecay.learning_rate,
            end_learning_rate=optimizer_cfg.AdamWeightDecay.end_learning_rate,
            warmup_steps=int(steps_per_epoch * cfg.epoch_num * 0.1),
            decay_steps=steps_per_epoch * cfg.epoch_num,
            power=optimizer_cfg.AdamWeightDecay.power)
        params = netwithloss.trainable_params()
        decay_params = list(
            filter(optimizer_cfg.AdamWeightDecay.decay_filter, params))
        other_params = list(
            filter(lambda x: not optimizer_cfg.AdamWeightDecay.decay_filter(x),
                   params))
        group_params = [{
            'params':
            decay_params,
            'weight_decay':
            optimizer_cfg.AdamWeightDecay.weight_decay
        }, {
            'params': other_params,
            'weight_decay': 0.0
        }]
        optimizer = AdamWeightDecay(group_params,
                                    lr_schedule,
                                    eps=optimizer_cfg.AdamWeightDecay.eps)
    elif cfg.optimizer == 'Lamb':
        lr_schedule = BertLearningRate(
            learning_rate=optimizer_cfg.Lamb.learning_rate,
            end_learning_rate=optimizer_cfg.Lamb.end_learning_rate,
            warmup_steps=int(steps_per_epoch * cfg.epoch_num * 0.1),
            decay_steps=steps_per_epoch * cfg.epoch_num,
            power=optimizer_cfg.Lamb.power)
        optimizer = Lamb(netwithloss.trainable_params(),
                         learning_rate=lr_schedule)
    elif cfg.optimizer == 'Momentum':
        optimizer = Momentum(
            netwithloss.trainable_params(),
            learning_rate=optimizer_cfg.Momentum.learning_rate,
            momentum=optimizer_cfg.Momentum.momentum)
    else:
        raise Exception("Optimizer not supported.")

    # load checkpoint into network
    ckpt_config = CheckpointConfig(save_checkpoint_steps=steps_per_epoch,
                                   keep_checkpoint_max=1)
    ckpoint_cb = ModelCheckpoint(prefix=cfg.ckpt_prefix,
                                 directory=cfg.ckpt_dir,
                                 config=ckpt_config)
    param_dict = load_checkpoint(cfg.pre_training_ckpt)
    load_param_into_net(netwithloss, param_dict)

    update_cell = DynamicLossScaleUpdateCell(loss_scale_value=2**32,
                                             scale_factor=2,
                                             scale_window=1000)
    netwithgrads = BertFinetuneCell(netwithloss,
                                    optimizer=optimizer,
                                    scale_update_cell=update_cell)
    model = Model(netwithgrads)
    callbacks = [
        TimeMonitor(dataset.get_dataset_size()),
        LossCallBack(dataset.get_dataset_size()), ckpoint_cb
    ]
    model.train(cfg.epoch_num,
                dataset,
                callbacks=callbacks,
                dataset_sink_mode=True)
Ejemplo n.º 2
0
def run_classifier():
    """run classifier task"""
    parser = argparse.ArgumentParser(description="run classifier")
    parser.add_argument("--device_target",
                        type=str,
                        default="Ascend",
                        choices=["Ascend", "GPU"],
                        help="Device type, default is Ascend")
    parser.add_argument(
        "--assessment_method",
        type=str,
        default="Accuracy",
        choices=["Mcc", "Spearman_correlation", "Accuracy", "F1"],
        help=
        "assessment_method including [Mcc, Spearman_correlation, Accuracy, F1],\
                             default is Accuracy")
    parser.add_argument("--do_train",
                        type=str,
                        default="false",
                        choices=["true", "false"],
                        help="Enable train, default is false")
    parser.add_argument("--do_eval",
                        type=str,
                        default="false",
                        choices=["true", "false"],
                        help="Enable eval, default is false")
    parser.add_argument("--device_id",
                        type=int,
                        default=0,
                        help="Device id, default is 0.")
    parser.add_argument("--epoch_num",
                        type=int,
                        default="1",
                        help="Epoch number, default is 1.")
    parser.add_argument("--num_class",
                        type=int,
                        default="2",
                        help="The number of class, default is 2.")
    parser.add_argument("--train_data_shuffle",
                        type=str,
                        default="true",
                        choices=["true", "false"],
                        help="Enable train data shuffle, default is true")
    parser.add_argument("--eval_data_shuffle",
                        type=str,
                        default="false",
                        choices=["true", "false"],
                        help="Enable eval data shuffle, default is false")
    parser.add_argument("--save_finetune_checkpoint_path",
                        type=str,
                        default="",
                        help="Save checkpoint path")
    parser.add_argument("--load_pretrain_checkpoint_path",
                        type=str,
                        default="",
                        help="Load checkpoint file path")
    parser.add_argument("--load_finetune_checkpoint_path",
                        type=str,
                        default="",
                        help="Load checkpoint file path")
    parser.add_argument("--train_data_file_path",
                        type=str,
                        default="",
                        help="Data path, it is better to use absolute path")
    parser.add_argument("--eval_data_file_path",
                        type=str,
                        default="",
                        help="Data path, it is better to use absolute path")
    parser.add_argument("--schema_file_path",
                        type=str,
                        default="",
                        help="Schema path, it is better to use absolute path")
    args_opt = parser.parse_args()
    epoch_num = args_opt.epoch_num
    assessment_method = args_opt.assessment_method.lower()
    load_pretrain_checkpoint_path = args_opt.load_pretrain_checkpoint_path
    save_finetune_checkpoint_path = args_opt.save_finetune_checkpoint_path
    load_finetune_checkpoint_path = args_opt.load_finetune_checkpoint_path

    if args_opt.do_train.lower() == "false" and args_opt.do_eval.lower(
    ) == "false":
        raise ValueError(
            "At least one of 'do_train' or 'do_eval' must be true")
    if args_opt.do_train.lower(
    ) == "true" and args_opt.train_data_file_path == "":
        raise ValueError(
            "'train_data_file_path' must be set when do finetune task")
    if args_opt.do_eval.lower(
    ) == "true" and args_opt.eval_data_file_path == "":
        raise ValueError(
            "'eval_data_file_path' must be set when do evaluation task")

    target = args_opt.device_target
    if target == "Ascend":
        context.set_context(mode=context.GRAPH_MODE,
                            device_target="Ascend",
                            device_id=args_opt.device_id)
    elif target == "GPU":
        context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
        if bert_net_cfg.compute_type != mstype.float32:
            logger.warning('GPU only support fp32 temporarily, run with fp32.')
            bert_net_cfg.compute_type = mstype.float32
    else:
        raise Exception("Target error, GPU or Ascend is supported.")

    netwithloss = BertCLS(bert_net_cfg,
                          True,
                          num_labels=args_opt.num_class,
                          dropout_prob=0.1,
                          assessment_method=assessment_method)

    if args_opt.do_train.lower() == "true":
        ds = create_classification_dataset(
            batch_size=bert_net_cfg.batch_size,
            repeat_count=1,
            assessment_method=assessment_method,
            data_file_path=args_opt.train_data_file_path,
            schema_file_path=args_opt.schema_file_path,
            do_shuffle=(args_opt.train_data_shuffle.lower() == "true"))
        do_train(ds, netwithloss, load_pretrain_checkpoint_path,
                 save_finetune_checkpoint_path, epoch_num)

        if args_opt.do_eval.lower() == "true":
            if save_finetune_checkpoint_path == "":
                load_finetune_checkpoint_dir = _cur_dir
            else:
                load_finetune_checkpoint_dir = make_directory(
                    save_finetune_checkpoint_path)
            load_finetune_checkpoint_path = LoadNewestCkpt(
                load_finetune_checkpoint_dir, ds.get_dataset_size(), epoch_num,
                "classifier")

    if args_opt.do_eval.lower() == "true":
        ds = create_classification_dataset(
            batch_size=bert_net_cfg.batch_size,
            repeat_count=1,
            assessment_method=assessment_method,
            data_file_path=args_opt.eval_data_file_path,
            schema_file_path=args_opt.schema_file_path,
            do_shuffle=(args_opt.eval_data_shuffle.lower() == "true"))
        do_eval(ds, BertCLS, args_opt.num_class, assessment_method,
                load_finetune_checkpoint_path)
Ejemplo n.º 3
0
def eval():
    '''
    evaluation function
    '''
    if cfg.data_file[-4:] == 'json':
        dataset = None
    else:
        dataset = get_dataset(cfg.data_file, 1)

    if cfg.task == "NER":
        if cfg.use_crf:
            tag_to_index = json.loads(open(cfg.label2id_file).read())
            max_val = len(tag_to_index)
            tag_to_index["<START>"] = max_val
            tag_to_index["<STOP>"] = max_val + 1
            number_labels = len(tag_to_index)
        else:
            number_labels = cfg.num_labels
            tag_to_index = None
        netwithloss = BertNER(bert_net_cfg,
                              1,
                              False,
                              num_labels=number_labels,
                              use_crf=cfg.use_crf,
                              tag_to_index=tag_to_index)
    elif cfg.task == 'Classification':
        netwithloss = BertCLS(bert_net_cfg,
                              False,
                              cfg.num_labels,
                              assessment_method=cfg.assessment_method)
    else:
        raise Exception("task error, NER or Classification is supported.")

    netwithloss.set_train(False)
    param_dict = load_checkpoint(cfg.finetune_ckpt)
    load_param_into_net(netwithloss, param_dict)
    model = Model(netwithloss)

    if cfg.data_file[-4:] == 'json':
        submit(model, cfg.data_file, bert_net_cfg.seq_length)
        # import moxing as mox
        # mox.file.copy_parallel(src_url=cfg.eval_out_file, dst_url=os.path.join(args_opt.train_url, cfg.eval_out_file))
    else:
        callback = F1() if cfg.task == "NER" else Accuracy()
        columns_list = ["input_ids", "input_mask", "segment_ids", "label_ids"]
        for data in dataset.create_dict_iterator():
            input_data = []
            for i in columns_list:
                input_data.append(Tensor(data[i]))
            input_ids, input_mask, token_type_id, label_ids = input_data
            logits = model.predict(input_ids, input_mask, token_type_id,
                                   label_ids)
            callback.update(logits, label_ids)
        print("==============================================================")
        if cfg.task == "NER":
            print("Precision {:.6f} ".format(callback.TP /
                                             (callback.TP + callback.FP)))
            print("Recall {:.6f} ".format(callback.TP /
                                          (callback.TP + callback.FN)))
            print("F1 {:.6f} ".format(
                2 * callback.TP /
                (2 * callback.TP + callback.FP + callback.FN)))
        else:
            print("acc_num {} , total_num {}, accuracy {:.6f}".format(
                callback.acc_num, callback.total_num,
                callback.acc_num / callback.total_num))
        print("==============================================================")