Esempio n. 1
0
def do_train():
    parser = PdArgumentParser(
        (ModelArguments, DataTrainingArguments, TrainingArguments))
    model_args, data_args, training_args = parser.parse_args_into_dataclasses()

    paddle.set_device(training_args.device)
    if paddle.distributed.get_world_size() > 1:
        paddle.distributed.init_parallel_env()

    # Log on each process the small summary:
    logger.warning(
        f"Process rank: {training_args.local_rank}, device: {training_args.device}, world_size: {training_args.world_size}, "
        +
        f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
    )

    # Detecting last checkpoint.
    last_checkpoint = None
    if os.path.isdir(
            training_args.output_dir
    ) and training_args.do_train and not training_args.overwrite_output_dir:
        last_checkpoint = get_last_checkpoint(training_args.output_dir)
        if last_checkpoint is None and len(os.listdir(
                training_args.output_dir)) > 0:
            raise ValueError(
                f"Output directory ({training_args.output_dir}) already exists and is not empty. "
                "Use --overwrite_output_dir to overcome.")
        elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
            logger.info(
                f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
                "the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
            )

    # set_seed(args)
    data_args.dataset = data_args.dataset.strip()
    if data_args.dataset not in ALL_DATASETS:
        raise ValueError("Not found dataset {}".format(data_args.dataset))

    # Use yaml config to rewrite all args.
    config = ALL_DATASETS[data_args.dataset]
    for args in (model_args, data_args, training_args):
        for arg in vars(args):
            if arg in config.keys():
                setattr(args, arg, config[arg])

    training_args.per_device_train_batch_size = config["batch_size"]
    training_args.per_device_eval_batch_size = config["batch_size"]

    dataset_config = data_args.dataset.split(" ")
    raw_datasets = load_dataset(
        dataset_config[0],
        None if len(dataset_config) <= 1 else dataset_config[1],
    )

    data_args.label_list = getattr(raw_datasets['train'], "label_list", None)
    num_classes = 1 if raw_datasets["train"].label_list == None else len(
        raw_datasets['train'].label_list)

    # Define tokenizer, model, loss function.
    tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path)
    model = AutoModelForSequenceClassification.from_pretrained(
        model_args.model_name_or_path, num_classes=num_classes)
    loss_fct = nn.loss.CrossEntropyLoss(
    ) if data_args.label_list else nn.loss.MSELoss()

    # Define dataset pre-process function
    if "clue" in data_args.dataset:
        trans_fn = partial(clue_trans_fn, tokenizer=tokenizer, args=data_args)
    else:
        trans_fn = partial(seq_trans_fn, tokenizer=tokenizer, args=data_args)

    # Define data collector
    batchify_fn = defaut_collator(tokenizer, data_args)

    # Dataset pre-process
    train_dataset = raw_datasets["train"].map(trans_fn)
    eval_dataset = raw_datasets["dev"].map(trans_fn)
    test_dataset = raw_datasets["test"].map(trans_fn)

    # Define the metrics of tasks.
    def compute_metrics(p):
        preds = p.predictions[0] if isinstance(p.predictions,
                                               tuple) else p.predictions

        preds = paddle.to_tensor(preds)
        label = paddle.to_tensor(p.label_ids)

        probs = F.softmax(preds, axis=1)
        metric = Accuracy()
        metric.reset()
        result = metric.compute(preds, label)
        metric.update(result)
        accu = metric.accumulate()
        metric.reset()
        return {"accuracy": accu}

    trainer = Trainer(
        model=model,
        criterion=loss_fct,
        args=training_args,
        data_collator=batchify_fn,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        tokenizer=tokenizer,
        compute_metrics=compute_metrics,
    )

    # Log model and data config
    trainer.print_config(model_args, "Model")
    trainer.print_config(data_args, "Data")

    checkpoint = None
    if training_args.resume_from_checkpoint is not None:
        checkpoint = training_args.resume_from_checkpoint
    elif last_checkpoint is not None:
        checkpoint = last_checkpoint

    # Training
    train_result = trainer.train(resume_from_checkpoint=checkpoint)
    metrics = train_result.metrics
    trainer.save_model()  # Saves the tokenizer too for easy upload
    trainer.log_metrics("train", metrics)
    trainer.save_metrics("train", metrics)
    trainer.save_state()

    # Evaluate and tests model
    eval_metrics = trainer.evaluate()
    trainer.log_metrics("eval", eval_metrics)

    test_ret = trainer.predict(test_dataset)
    trainer.log_metrics("test", test_ret.metrics)
    if test_ret.label_ids is None:
        paddle.save(
            test_ret.predictions,
            os.path.join(training_args.output_dir, "test_results.pdtensor"),
        )

    # export inference model
    input_spec = [
        paddle.static.InputSpec(shape=[None, None],
                                dtype="int64"),  # input_ids
        paddle.static.InputSpec(shape=[None, None],
                                dtype="int64")  # segment_ids
    ]
    trainer.export_model(input_spec=input_spec,
                         load_best_model=True,
                         output_dir=model_args.export_model_dir)
Esempio n. 2
0
def do_train():
    parser = PdArgumentParser(
        (ModelArguments, DataTrainingArguments, TrainingArguments))
    model_args, data_args, training_args = parser.parse_args_into_dataclasses()

    paddle.set_device(training_args.device)
    if paddle.distributed.get_world_size() > 1:
        paddle.distributed.init_parallel_env()

    # Log on each process the small summary:
    logger.warning(
        f"Process rank: {training_args.local_rank}, device: {training_args.device}, world_size: {training_args.world_size}, "
        +
        f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
    )

    # Detecting last checkpoint.
    last_checkpoint = None
    if os.path.isdir(
            training_args.output_dir
    ) and training_args.do_train and not training_args.overwrite_output_dir:
        last_checkpoint = get_last_checkpoint(training_args.output_dir)
        if last_checkpoint is None and len(os.listdir(
                training_args.output_dir)) > 0:
            raise ValueError(
                f"Output directory ({training_args.output_dir}) already exists and is not empty. "
                "Use --overwrite_output_dir to overcome.")
        elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
            logger.info(
                f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
                "the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
            )

    # set_seed(args)
    data_args.dataset = data_args.dataset.strip()
    if data_args.dataset not in ALL_DATASETS:
        raise ValueError("Not found dataset {}".format(data_args.dataset))

    # Use yaml config to rewrite all args.
    config = ALL_DATASETS[data_args.dataset]
    for args in (model_args, data_args, training_args):
        for arg in vars(args):
            if arg in config.keys():
                setattr(args, arg, config[arg])

    training_args.per_device_train_batch_size = config["batch_size"]
    training_args.per_device_eval_batch_size = config["batch_size"]

    dataset_config = data_args.dataset.split(" ")
    all_ds = load_dataset(
        dataset_config[0],
        None if len(dataset_config) <= 1 else dataset_config[1],
    )

    label_list = getattr(all_ds['train'], "label_list", None)
    data_args.label_list = label_list
    data_args.ignore_label = -100
    data_args.no_entity_id = len(data_args.label_list) - 1

    num_classes = 1 if all_ds["train"].label_list == None else len(
        all_ds['train'].label_list)

    # Define tokenizer, model, loss function.
    tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path)
    model = AutoModelForTokenClassification.from_pretrained(
        model_args.model_name_or_path, num_classes=num_classes)

    class criterion(nn.Layer):
        def __init__(self):
            super(criterion, self).__init__()
            self.loss_fn = paddle.nn.loss.CrossEntropyLoss(
                ignore_index=data_args.ignore_label)

        def forward(self, *args, **kwargs):
            return paddle.mean(self.loss_fn(*args, **kwargs))

    loss_fct = criterion()

    # Define dataset pre-process function
    trans_fn = partial(ner_trans_fn, tokenizer=tokenizer, args=data_args)

    # Define data collector
    batchify_fn = ner_collator(tokenizer, data_args)

    # Dataset pre-process
    train_dataset = all_ds["train"].map(trans_fn)
    eval_dataset = all_ds["dev"].map(trans_fn)
    test_dataset = all_ds["test"].map(trans_fn)

    # Define the metrics of tasks.
    # Metrics
    metric = load_metric("seqeval")

    def compute_metrics(p):
        predictions, labels = p
        predictions = np.argmax(predictions, axis=2)

        # Remove ignored index (special tokens)
        true_predictions = [[
            label_list[p] for (p, l) in zip(prediction, label) if l != -100
        ] for prediction, label in zip(predictions, labels)]
        true_labels = [[
            label_list[l] for (p, l) in zip(prediction, label) if l != -100
        ] for prediction, label in zip(predictions, labels)]
        results = metric.compute(predictions=true_predictions,
                                 references=true_labels)
        return {
            "precision": results["overall_precision"],
            "recall": results["overall_recall"],
            "f1": results["overall_f1"],
            "accuracy": results["overall_accuracy"],
        }

    trainer = Trainer(
        model=model,
        criterion=loss_fct,
        args=training_args,
        data_collator=batchify_fn,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        tokenizer=tokenizer,
        compute_metrics=compute_metrics,
    )

    # Log model and data config
    trainer.print_config(model_args, "Model")
    trainer.print_config(data_args, "Data")

    checkpoint = None
    if training_args.resume_from_checkpoint is not None:
        checkpoint = training_args.resume_from_checkpoint
    elif last_checkpoint is not None:
        checkpoint = last_checkpoint

    # Training
    train_result = trainer.train(resume_from_checkpoint=checkpoint)
    metrics = train_result.metrics
    trainer.save_model()  # Saves the tokenizer too for easy upload
    trainer.log_metrics("train", metrics)
    trainer.save_metrics("train", metrics)
    trainer.save_state()

    # Evaluate and tests model
    eval_metrics = trainer.evaluate()
    trainer.log_metrics("eval", eval_metrics)

    test_ret = trainer.predict(test_dataset)
    trainer.log_metrics("test", test_ret.metrics)
    if test_ret.label_ids is None:
        paddle.save(
            test_ret.predictions,
            os.path.join(training_args.output_dir, "test_results.pdtensor"),
        )

    # export inference model
    input_spec = [
        paddle.static.InputSpec(shape=[None, None],
                                dtype="int64"),  # input_ids
        paddle.static.InputSpec(shape=[None, None],
                                dtype="int64")  # segment_ids
    ]
    trainer.export_model(input_spec=input_spec,
                         load_best_model=True,
                         output_dir=model_args.export_model_dir)
Esempio n. 3
0
def main():
    parser = PdArgumentParser(
        (ModelArguments, DataArguments, TrainingArguments))
    model_args, data_args, training_args = parser.parse_args_into_dataclasses()

    paddle.set_device(training_args.device)

    data_args.dataset = data_args.dataset.strip()

    if data_args.dataset in ALL_DATASETS:
        # if you custom you hyper-parameters in yaml config, it will overwrite all args.
        config = ALL_DATASETS[data_args.dataset]
        logger.info("Over-writing training config by yaml config!")
        for args in (model_args, data_args, training_args):
            for arg in vars(args):
                if arg in config.keys():
                    setattr(args, arg, config[arg])

        training_args.per_device_train_batch_size = config["batch_size"]
        training_args.per_device_eval_batch_size = config["batch_size"]

    # Log model and data config
    training_args.print_config(model_args, "Model")
    training_args.print_config(data_args, "Data")

    dataset_config = data_args.dataset.split(" ")
    raw_datasets = load_dataset(
        dataset_config[0],
        None if len(dataset_config) <= 1 else dataset_config[1],
        splits=("train", "dev", "test"))

    data_args.label_list = getattr(raw_datasets['train'], "label_list", None)
    num_classes = 1 if raw_datasets["train"].label_list == None else len(
        raw_datasets['train'].label_list)

    criterion = paddle.nn.CrossEntropyLoss()
    # Define tokenizer, model, loss function.
    tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path)
    model = AutoModelForSequenceClassification.from_pretrained(
        model_args.model_name_or_path, num_classes=num_classes)

    # Define dataset pre-process function
    if "clue" in data_args.dataset:
        trans_fn = partial(clue_trans_fn, tokenizer=tokenizer, args=data_args)
    else:
        trans_fn = partial(seq_trans_fn, tokenizer=tokenizer, args=data_args)

    # Define data collector
    data_collator = DataCollatorWithPadding(tokenizer)

    train_dataset = raw_datasets["train"].map(trans_fn)
    eval_dataset = raw_datasets["dev"].map(trans_fn)

    trainer = Trainer(model=model,
                      args=training_args,
                      data_collator=data_collator,
                      train_dataset=train_dataset,
                      eval_dataset=eval_dataset,
                      tokenizer=tokenizer,
                      criterion=criterion)

    output_dir = os.path.join(model_args.model_name_or_path, "compress")
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    compress_config = CompressConfig(quantization_config=PTQConfig(
        algo_list=['hist', 'mse'], batch_size_list=[4, 8, 16]))

    trainer.compress(data_args.dataset,
                     output_dir,
                     pruning=True,
                     quantization=True,
                     compress_config=compress_config)
def main():
    parser = PdArgumentParser(
        (ModelArguments, DataArguments, TrainingArguments))
    model_args, data_args, training_args = parser.parse_args_into_dataclasses()

    # Log model and data config
    training_args.print_config(model_args, "Model")
    training_args.print_config(data_args, "Data")

    paddle.set_device(training_args.device)

    # Log on each process the small summary:
    logger.warning(
        f"Process rank: {training_args.local_rank}, device: {training_args.device}, world_size: {training_args.world_size}, "
        +
        f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
    )

    # Detecting last checkpoint.
    last_checkpoint = None
    if os.path.isdir(
            training_args.output_dir
    ) and training_args.do_train and not training_args.overwrite_output_dir:
        last_checkpoint = get_last_checkpoint(training_args.output_dir)
        if last_checkpoint is None and len(os.listdir(
                training_args.output_dir)) > 0:
            raise ValueError(
                f"Output directory ({training_args.output_dir}) already exists and is not empty. "
                "Use --overwrite_output_dir to overcome.")
        elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
            logger.info(
                f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
                "the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
            )

    data_args.dataset = data_args.dataset.strip()

    dataset_config = data_args.dataset.split(" ")
    print(dataset_config)
    raw_datasets = load_dataset(
        dataset_config[0],
        name=None if len(dataset_config) <= 1 else dataset_config[1],
        splits=('train', 'dev'))

    data_args.label_list = getattr(raw_datasets['train'], "label_list", None)
    num_classes = 1 if raw_datasets["train"].label_list == None else len(
        raw_datasets['train'].label_list)

    # Define tokenizer, model, loss function.
    tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path)
    model = AutoModelForSequenceClassification.from_pretrained(
        model_args.model_name_or_path, num_classes=num_classes)
    criterion = nn.loss.CrossEntropyLoss(
    ) if data_args.label_list else nn.loss.MSELoss()

    # Define dataset pre-process function
    trans_fn = partial(clue_trans_fn, tokenizer=tokenizer, args=data_args)

    # Define data collector
    data_collator = DataCollatorWithPadding(tokenizer)

    # Dataset pre-process
    if training_args.do_train:
        train_dataset = raw_datasets["train"].map(trans_fn)
    if training_args.do_eval:
        eval_dataset = raw_datasets["dev"].map(trans_fn)
    if training_args.do_predict:
        test_dataset = raw_datasets["test"].map(trans_fn)

    # Define the metrics of tasks.
    def compute_metrics(p):
        preds = p.predictions[0] if isinstance(p.predictions,
                                               tuple) else p.predictions

        preds = paddle.to_tensor(preds)
        label = paddle.to_tensor(p.label_ids)

        probs = F.softmax(preds, axis=1)
        metric = Accuracy()
        metric.reset()
        result = metric.compute(preds, label)
        metric.update(result)
        accu = metric.accumulate()
        metric.reset()
        return {"accuracy": accu}

    trainer = Trainer(
        model=model,
        criterion=criterion,
        args=training_args,
        data_collator=data_collator,
        train_dataset=train_dataset if training_args.do_train else None,
        eval_dataset=eval_dataset if training_args.do_eval else None,
        tokenizer=tokenizer,
        compute_metrics=compute_metrics,
    )

    checkpoint = None
    if training_args.resume_from_checkpoint is not None:
        checkpoint = training_args.resume_from_checkpoint
    elif last_checkpoint is not None:
        checkpoint = last_checkpoint

    # Training
    if training_args.do_train:
        train_result = trainer.train(resume_from_checkpoint=checkpoint)
        metrics = train_result.metrics
        trainer.save_model()
        trainer.log_metrics("train", metrics)
        trainer.save_metrics("train", metrics)
        trainer.save_state()

    # Evaluate and tests model
    if training_args.do_eval:
        eval_metrics = trainer.evaluate()
        trainer.log_metrics("eval", eval_metrics)

    if training_args.do_predict:
        test_ret = trainer.predict(test_dataset)
        trainer.log_metrics("test", test_ret.metrics)
        if test_ret.label_ids is None:
            paddle.save(
                test_ret.predictions,
                os.path.join(training_args.output_dir,
                             "test_results.pdtensor"),
            )

    # export inference model
    if training_args.do_export:
        # You can also load from certain checkpoint
        # trainer.load_state_dict_from_checkpoint("/path/to/checkpoint/")
        input_spec = [
            paddle.static.InputSpec(shape=[None, None],
                                    dtype="int64"),  # input_ids
            paddle.static.InputSpec(shape=[None, None],
                                    dtype="int64")  # segment_ids
        ]
        if model_args.export_model_dir is None:
            model_args.export_model_dir = os.path.join(
                training_args.output_dir, "export")
        paddlenlp.transformers.export_model(model=trainer.model,
                                            input_spec=input_spec,
                                            path=model_args.export_model_dir)
def main():
    parser = PdArgumentParser(
        (ModelArguments, DataArguments, TrainingArguments))
    model_args, data_args, training_args = parser.parse_args_into_dataclasses()

    paddle.set_device(training_args.device)

    data_args.dataset = data_args.dataset.strip()
    if data_args.dataset not in ALL_DATASETS:
        raise ValueError("Not found dataset {}".format(data_args.dataset))

    if data_args.dataset in ALL_DATASETS:
        # if you custom you hyper-parameters in yaml config, it will overwrite all args.
        config = ALL_DATASETS[data_args.dataset]
        for args in (model_args, data_args, training_args):
            for arg in vars(args):
                if arg in config.keys():
                    setattr(args, arg, config[arg])

        training_args.per_device_train_batch_size = config["batch_size"]
        training_args.per_device_eval_batch_size = config["batch_size"]

    # Log model and data config
    training_args.print_config(model_args, "Model")
    training_args.print_config(data_args, "Data")

    dataset_config = data_args.dataset.split(" ")
    raw_datasets = load_dataset(
        dataset_config[0],
        None if len(dataset_config) <= 1 else dataset_config[1],
    )

    label_list = raw_datasets['train'].features['ner_tags'].feature.names
    data_args.label_list = label_list
    data_args.ignore_label = -100

    data_args.no_entity_id = 0
    num_classes = 1 if label_list == None else len(label_list)

    # Define tokenizer, model, loss function.
    tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path)
    model = AutoModelForTokenClassification.from_pretrained(
        model_args.model_name_or_path, num_classes=num_classes)

    class criterion(nn.Layer):
        def __init__(self):
            super(criterion, self).__init__()
            self.loss_fn = paddle.nn.loss.CrossEntropyLoss(
                ignore_index=data_args.ignore_label)

        def forward(self, *args, **kwargs):
            return paddle.mean(self.loss_fn(*args, **kwargs))

    loss_fct = criterion()

    # Define dataset pre-process function
    trans_fn = partial(ner_trans_fn, tokenizer=tokenizer, args=data_args)

    # Define data collector
    data_collator = DataCollatorForTokenClassification(
        tokenizer, label_pad_token_id=data_args.ignore_label)

    column_names = raw_datasets["train"].column_names

    # Dataset pre-process
    train_dataset = raw_datasets["train"].map(trans_fn,
                                              remove_columns=column_names)
    train_dataset.label_list = label_list

    eval_dataset = raw_datasets["test"].map(trans_fn,
                                            remove_columns=column_names)

    trainer = Trainer(model=model,
                      criterion=loss_fct,
                      args=training_args,
                      data_collator=data_collator,
                      train_dataset=train_dataset,
                      eval_dataset=eval_dataset,
                      tokenizer=tokenizer)

    output_dir = os.path.join(model_args.model_name_or_path, "compress")

    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    compress_config = CompressConfig(quantization_config=PTQConfig(
        algo_list=['hist', 'mse'], batch_size_list=[4, 8, 16]))

    trainer.compress(data_args.dataset,
                     output_dir,
                     pruning=True,
                     quantization=True,
                     compress_config=compress_config)
Esempio n. 6
0
def main():
    parser = PdArgumentParser(
        (ModelArguments, DataArguments, TrainingArguments))
    model_args, data_args, training_args = parser.parse_args_into_dataclasses()

    training_args.eval_iters = 10
    training_args.test_iters = training_args.eval_iters * 10

    # Log model and data config
    training_args.print_config(model_args, "Model")
    training_args.print_config(data_args, "Data")

    paddle.set_device(training_args.device)

    # Log on each process the small summary:
    logger.warning(
        f"Process rank: {training_args.local_rank}, device: {training_args.device}, world_size: {training_args.world_size}, "
        +
        f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
    )

    # Detecting last checkpoint.
    last_checkpoint = None
    if os.path.isdir(
            training_args.output_dir
    ) and training_args.do_train and not training_args.overwrite_output_dir:
        last_checkpoint = get_last_checkpoint(training_args.output_dir)
        if last_checkpoint is None and len(os.listdir(
                training_args.output_dir)) > 1:
            raise ValueError(
                f"Output directory ({training_args.output_dir}) already exists and is not empty. "
                "Use --overwrite_output_dir to overcome.")
        elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
            logger.info(
                f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
                "the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
            )

    model_class, tokenizer_class = MODEL_CLASSES['ernie-health']

    # Loads or initialize a model.
    pretrained_models = list(
        tokenizer_class.pretrained_init_configuration.keys())

    if model_args.model_name_or_path in pretrained_models:
        tokenizer = tokenizer_class.from_pretrained(
            model_args.model_name_or_path)
        generator = ElectraGenerator(
            ElectraModel(**model_class.pretrained_init_configuration[
                model_args.model_name_or_path + "-generator"]))
        discriminator = ErnieHealthDiscriminator(
            ElectraModel(**model_class.pretrained_init_configuration[
                model_args.model_name_or_path + "-discriminator"]))
        model = model_class(generator, discriminator)
    else:
        raise ValueError("Only support %s" % (", ".join(pretrained_models)))

    # Loads dataset.
    tic_load_data = time.time()
    logger.info("start load data : %s" %
                (time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())))

    train_dataset = MedicalCorpus(data_path=data_args.input_dir,
                                  tokenizer=tokenizer)
    logger.info("load data done, total : %s s" % (time.time() - tic_load_data))

    # Reads data and generates mini-batches.
    data_collator = DataCollatorForErnieHealth(
        tokenizer=tokenizer,
        max_seq_length=data_args.max_seq_length,
        mlm_prob=data_args.masked_lm_prob,
        return_dict=True)

    class CriterionWrapper(paddle.nn.Layer):
        """
        """
        def __init__(self):
            """CriterionWrapper
            """
            super(CriterionWrapper, self).__init__()
            self.criterion = ErnieHealthPretrainingCriterion(
                getattr(
                    model.generator,
                    ElectraGenerator.base_model_prefix).config["vocab_size"],
                model.gen_weight)

        def forward(self, output, labels):
            """forward function

            Args:
                output (tuple): generator_logits, logits_rtd, logits_mts, logits_csp, disc_labels, mask
                labels (tuple): generator_labels

            Returns:
                Tensor: final loss.
            """
            generator_logits, logits_rtd, logits_mts, logits_csp, disc_labels, masks = output
            generator_labels = labels

            loss, gen_loss, rtd_loss, mts_loss, csp_loss = self.criterion(
                generator_logits, generator_labels, logits_rtd, logits_mts,
                logits_csp, disc_labels, masks)

            return loss

    trainer = Trainer(
        model=model,
        criterion=CriterionWrapper(),
        args=training_args,
        data_collator=data_collator,
        train_dataset=train_dataset if training_args.do_train else None,
        eval_dataset=None,
        tokenizer=tokenizer,
    )

    checkpoint = None
    if training_args.resume_from_checkpoint is not None:
        checkpoint = training_args.resume_from_checkpoint
    elif last_checkpoint is not None:
        checkpoint = last_checkpoint

    # Training
    if training_args.do_train:
        train_result = trainer.train(resume_from_checkpoint=checkpoint)
        metrics = train_result.metrics
        trainer.save_model()
        trainer.log_metrics("train", metrics)
        trainer.save_metrics("train", metrics)
        trainer.save_state()