예제 #1
0
def do_eval(args):
    paddle.set_device(args.device)
    if paddle.distributed.get_world_size() > 1:
        paddle.distributed.init_parallel_env()

    set_seed(args)

    args.task_name = args.task_name.lower()
    metric_class = METRIC_CLASSES[args.task_name]

    dev_ds = load_dataset('clue', args.task_name, splits='dev')

    tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)
    trans_func = partial(
        convert_example,
        label_list=dev_ds.label_list,
        tokenizer=tokenizer,
        max_seq_length=args.max_seq_length)

    dev_ds = dev_ds.map(trans_func, lazy=True)
    dev_batch_sampler = paddle.io.BatchSampler(
        dev_ds, batch_size=args.batch_size, shuffle=False)

    batchify_fn = DataCollatorWithPadding(tokenizer)

    dev_data_loader = DataLoader(
        dataset=dev_ds,
        batch_sampler=dev_batch_sampler,
        collate_fn=batchify_fn,
        num_workers=0,
        return_list=True)

    num_classes = 1 if dev_ds.label_list == None else len(dev_ds.label_list)

    model = AutoModelForSequenceClassification.from_pretrained(
        args.model_name_or_path, num_classes=num_classes)
    if paddle.distributed.get_world_size() > 1:
        model = paddle.DataParallel(model)

    metric = metric_class()
    best_acc = 0.0
    global_step = 0
    tic_train = time.time()
    model.eval()
    metric.reset()
    for batch in dev_data_loader:
        labels = batch.pop("labels")
        logits = model(**batch)
        correct = metric.compute(logits, labels)
        metric.update(correct)
    res = metric.accumulate()
    print("acc: %s\n, " % (res), end='')
예제 #2
0
def do_predict(args):
    paddle.set_device(args.device)
    args.task_name = args.task_name.lower()

    train_ds, test_ds = load_dataset(
        'clue', args.task_name, splits=('train', 'test'))
    tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)

    trans_func = partial(
        convert_example,
        tokenizer=tokenizer,
        label_list=train_ds.label_list,
        max_seq_length=args.max_seq_length,
        is_test=True)

    batchify_fn = DataCollatorWithPadding(tokenizer)

    test_ds = test_ds.map(trans_func, lazy=True)
    test_batch_sampler = paddle.io.BatchSampler(
        test_ds, batch_size=args.batch_size, shuffle=False)
    test_data_loader = DataLoader(
        dataset=test_ds,
        batch_sampler=test_batch_sampler,
        collate_fn=batchify_fn,
        num_workers=0,
        return_list=True)

    num_classes = 1 if train_ds.label_list == None else len(train_ds.label_list)

    model = AutoModelForSequenceClassification.from_pretrained(
        args.model_name_or_path, num_classes=num_classes)

    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)
    if args.task_name == 'ocnli':
        args.task_name = 'ocnli_50k'
    f = open(
        os.path.join(args.output_dir, args.task_name + "_predict.json"), 'w')

    for step, batch in enumerate(test_data_loader):
        with paddle.no_grad():
            logits = model(**batch)
        preds = paddle.argmax(logits, axis=1)
        for idx, pred in enumerate(preds):
            j = json.dumps({"id": idx, "label": train_ds.label_list[pred]})
            f.write(j + "\n")
예제 #3
0
def main():
    parser = PdArgumentParser(
        (ModelArguments, DataArguments, TrainingArguments))
    model_args, data_args, training_args = parser.parse_args_into_dataclasses()

    paddle.set_device(training_args.device)
    data_args.dataset = data_args.dataset.strip()

    if data_args.dataset in ALL_DATASETS:
        # if you custom you hyper-parameters in yaml config, it will overwrite all args.
        config = ALL_DATASETS[data_args.dataset]
        for args in (model_args, data_args, training_args):
            for arg in vars(args):
                if arg in config.keys():
                    setattr(args, arg, config[arg])

        training_args.per_device_train_batch_size = config["batch_size"]
        training_args.per_device_eval_batch_size = config["batch_size"]

    # Log model and data config
    training_args.print_config(model_args, "Model")
    training_args.print_config(data_args, "Data")

    dataset_config = data_args.dataset.split(" ")
    raw_datasets = load_dataset(
        dataset_config[0],
        None if len(dataset_config) <= 1 else dataset_config[1],
        cache_dir=model_args.cache_dir)

    label_list = getattr(raw_datasets['train'], "label_list", None)
    data_args.label_list = label_list

    # Define tokenizer, model, loss function. 
    tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path)
    model = AutoModelForQuestionAnswering.from_pretrained(
        model_args.model_name_or_path)

    loss_fct = CrossEntropyLossForSQuAD()

    # Preprocessing the datasets.
    # Preprocessing is slighlty different for training and evaluation.
    column_names = raw_datasets["train"].column_names

    column_names = raw_datasets["validation"].column_names

    train_dataset = raw_datasets["train"]
    # Create train feature from dataset
    with training_args.main_process_first(
            desc="train dataset map pre-processing"):
        # Dataset pre-process
        train_dataset = train_dataset.map(
            partial(
                prepare_train_features, tokenizer=tokenizer, args=data_args),
            batched=True,
            num_proc=4,
            remove_columns=column_names,
            load_from_cache_file=not data_args.overwrite_cache,
            desc="Running tokenizer on train dataset", )
    eval_examples = raw_datasets["validation"]
    with training_args.main_process_first(
            desc="evaluate dataset map pre-processing"):
        eval_dataset = eval_examples.map(
            partial(
                prepare_validation_features,
                tokenizer=tokenizer,
                args=data_args),
            batched=True,
            num_proc=4,
            remove_columns=column_names,
            load_from_cache_file=not data_args.overwrite_cache,
            desc="Running tokenizer on validation dataset", )

    # Define data collector
    data_collator = DataCollatorWithPadding(tokenizer)

    # Post-processing:
    def post_processing_function(examples, features, predictions, stage="eval"):
        # Post-processing: we match the start logits and end logits to answers in the original context.
        predictions, all_nbest_json, scores_diff_json = compute_prediction(
            examples=examples,
            features=features,
            predictions=predictions,
            n_best_size=data_args.n_best_size,
            max_answer_length=data_args.max_answer_length,
            null_score_diff_threshold=data_args.null_score_diff_threshold, )

        references = [{
            "id": ex["id"],
            "answers": ex["answers"]
        } for ex in examples]
        return EvalPrediction(predictions=predictions, label_ids=references)

    trainer = QuestionAnsweringTrainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        eval_examples=eval_examples,
        data_collator=data_collator,
        post_process_function=post_processing_function,
        tokenizer=tokenizer)

    output_dir = os.path.join(model_args.model_name_or_path, "compress")
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    prune = True
    compress_config = CompressConfig(quantization_config=PTQConfig(
        algo_list=['hist', 'mse'], batch_size_list=[4, 8, 16]))
    trainer.compress(
        data_args.dataset,
        output_dir,
        pruning=prune,
        quantization=True,
        compress_config=compress_config)
예제 #4
0
def main():
    parser = PdArgumentParser(
        (ModelArguments, DataArguments, TrainingArguments))
    model_args, data_args, training_args = parser.parse_args_into_dataclasses()

    paddle.set_device(training_args.device)

    data_args.dataset = data_args.dataset.strip()

    if data_args.dataset in ALL_DATASETS:
        # if you custom you hyper-parameters in yaml config, it will overwrite all args.
        config = ALL_DATASETS[data_args.dataset]
        logger.info("Over-writing training config by yaml config!")
        for args in (model_args, data_args, training_args):
            for arg in vars(args):
                if arg in config.keys():
                    setattr(args, arg, config[arg])

        training_args.per_device_train_batch_size = config["batch_size"]
        training_args.per_device_eval_batch_size = config["batch_size"]

    # Log model and data config
    training_args.print_config(model_args, "Model")
    training_args.print_config(data_args, "Data")

    dataset_config = data_args.dataset.split(" ")
    raw_datasets = load_dataset(
        dataset_config[0],
        None if len(dataset_config) <= 1 else dataset_config[1],
        splits=("train", "dev", "test"))

    data_args.label_list = getattr(raw_datasets['train'], "label_list", None)
    num_classes = 1 if raw_datasets["train"].label_list == None else len(
        raw_datasets['train'].label_list)

    criterion = paddle.nn.CrossEntropyLoss()
    # Define tokenizer, model, loss function.
    tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path)
    model = AutoModelForSequenceClassification.from_pretrained(
        model_args.model_name_or_path, num_classes=num_classes)

    # Define dataset pre-process function
    if "clue" in data_args.dataset:
        trans_fn = partial(clue_trans_fn, tokenizer=tokenizer, args=data_args)
    else:
        trans_fn = partial(seq_trans_fn, tokenizer=tokenizer, args=data_args)

    # Define data collector
    data_collator = DataCollatorWithPadding(tokenizer)

    train_dataset = raw_datasets["train"].map(trans_fn)
    eval_dataset = raw_datasets["dev"].map(trans_fn)

    trainer = Trainer(model=model,
                      args=training_args,
                      data_collator=data_collator,
                      train_dataset=train_dataset,
                      eval_dataset=eval_dataset,
                      tokenizer=tokenizer,
                      criterion=criterion)

    output_dir = os.path.join(model_args.model_name_or_path, "compress")
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    compress_config = CompressConfig(quantization_config=PTQConfig(
        algo_list=['hist', 'mse'], batch_size_list=[4, 8, 16]))

    trainer.compress(data_args.dataset,
                     output_dir,
                     pruning=True,
                     quantization=True,
                     compress_config=compress_config)
예제 #5
0
def do_train(args):
    paddle.set_device(args.device)
    if paddle.distributed.get_world_size() > 1:
        paddle.distributed.init_parallel_env()

    set_seed(args)

    args.task_name = args.task_name.lower()

    sentence1_key, sentence2_key = task_to_keys[args.task_name]

    metric_class = METRIC_CLASSES[args.task_name]
    args.model_type = args.model_type.lower()
    model_class, tokenizer_class = MODEL_CLASSES[args.model_type]

    train_ds = load_dataset('glue', args.task_name, split="train")
    columns = train_ds.column_names
    is_regression = args.task_name == "stsb"
    label_list = None
    if not is_regression:
        label_list = train_ds.features["label"].names
        num_classes = len(label_list)
    else:
        num_classes = 1
    tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path)

    def preprocess_function(examples):
        # Tokenize the texts
        texts = ((examples[sentence1_key], ) if sentence2_key is None else
                 (examples[sentence1_key], examples[sentence2_key]))
        result = tokenizer(*texts, max_seq_len=args.max_seq_length)
        if "label" in examples:
            # In all cases, rename the column to labels because the model will expect that.
            result["labels"] = examples["label"]
        return result

    train_ds = train_ds.map(preprocess_function,
                            batched=True,
                            remove_columns=columns)
    train_batch_sampler = paddle.io.DistributedBatchSampler(
        train_ds, batch_size=args.batch_size, shuffle=True)
    batchify_fn = DataCollatorWithPadding(tokenizer)
    train_data_loader = DataLoader(
        dataset=train_ds,
        batch_sampler=train_batch_sampler,
        collate_fn=batchify_fn,
        num_workers=0,
        return_list=True)
    if args.task_name == "mnli":
        dev_ds_matched, dev_ds_mismatched = load_dataset(
            'glue',
            args.task_name,
            split=["validation_matched", "validation_mismatched"])

        dev_ds_matched = dev_ds_matched.map(preprocess_function,
                                            batched=True,
                                            remove_columns=columns)
        dev_ds_mismatched = dev_ds_mismatched.map(preprocess_function,
                                                  batched=True,
                                                  remove_columns=columns)
        dev_batch_sampler_matched = paddle.io.BatchSampler(
            dev_ds_matched, batch_size=args.batch_size, shuffle=False)
        dev_data_loader_matched = DataLoader(
            dataset=dev_ds_matched,
            batch_sampler=dev_batch_sampler_matched,
            collate_fn=batchify_fn,
            num_workers=0,
            return_list=True)
        dev_batch_sampler_mismatched = paddle.io.BatchSampler(
            dev_ds_mismatched, batch_size=args.batch_size, shuffle=False)
        dev_data_loader_mismatched = DataLoader(
            dataset=dev_ds_mismatched,
            batch_sampler=dev_batch_sampler_mismatched,
            collate_fn=batchify_fn,
            num_workers=0,
            return_list=True)
    else:
        dev_ds = load_dataset('glue', args.task_name, split='validation')
        dev_ds = dev_ds.map(preprocess_function,
                            batched=True,
                            remove_columns=columns)
        dev_batch_sampler = paddle.io.BatchSampler(
            dev_ds, batch_size=args.batch_size, shuffle=False)
        dev_data_loader = DataLoader(
            dataset=dev_ds,
            batch_sampler=dev_batch_sampler,
            collate_fn=batchify_fn,
            num_workers=0,
            return_list=True)

    model = model_class.from_pretrained(
        args.model_name_or_path, num_classes=num_classes)
    if paddle.distributed.get_world_size() > 1:
        model = paddle.DataParallel(model)

    num_training_steps = args.max_steps if args.max_steps > 0 else (
        len(train_data_loader) * args.num_train_epochs)
    warmup = args.warmup_steps if args.warmup_steps > 0 else args.warmup_proportion

    lr_scheduler = LinearDecayWithWarmup(args.learning_rate, num_training_steps,
                                         warmup)

    # Generate parameter names needed to perform weight decay.
    # All bias and LayerNorm parameters are excluded.
    decay_params = [
        p.name for n, p in model.named_parameters()
        if not any(nd in n for nd in ["bias", "norm"])
    ]
    optimizer = paddle.optimizer.AdamW(
        learning_rate=lr_scheduler,
        beta1=0.9,
        beta2=0.999,
        epsilon=args.adam_epsilon,
        parameters=model.parameters(),
        weight_decay=args.weight_decay,
        apply_decay_param_fun=lambda x: x in decay_params)

    loss_fct = paddle.nn.loss.CrossEntropyLoss(
    ) if not is_regression else paddle.nn.loss.MSELoss()

    metric = metric_class()
    if args.use_amp:
        scaler = paddle.amp.GradScaler(init_loss_scaling=args.scale_loss)

    global_step = 0
    tic_train = time.time()
    for epoch in range(args.num_train_epochs):
        for step, batch in enumerate(train_data_loader):
            global_step += 1
            with paddle.amp.auto_cast(
                    args.use_amp,
                    custom_white_list=["layer_norm", "softmax", "gelu"]):
                logits = model(batch['input_ids'], batch['token_type_ids'])
                loss = loss_fct(logits, batch['labels'])
            if args.use_amp:
                scaler.scale(loss).backward()
                scaler.minimize(optimizer, loss)
            else:
                loss.backward()
                optimizer.step()
            lr_scheduler.step()
            optimizer.clear_grad()
            if global_step % args.logging_steps == 0:
                print(
                    "global step %d/%d, epoch: %d, batch: %d, rank_id: %s, loss: %f, lr: %.10f, speed: %.4f step/s"
                    % (global_step, num_training_steps, epoch, step,
                       paddle.distributed.get_rank(), loss, optimizer.get_lr(),
                       args.logging_steps / (time.time() - tic_train)))
                tic_train = time.time()
            if global_step % args.save_steps == 0 or global_step == num_training_steps:
                tic_eval = time.time()
                if args.task_name == "mnli":
                    evaluate(model, loss_fct, metric, dev_data_loader_matched)
                    evaluate(model, loss_fct, metric,
                             dev_data_loader_mismatched)
                    print("eval done total : %s s" % (time.time() - tic_eval))
                else:
                    evaluate(model, loss_fct, metric, dev_data_loader)
                    print("eval done total : %s s" % (time.time() - tic_eval))
                if paddle.distributed.get_rank() == 0:
                    output_dir = os.path.join(args.output_dir,
                                              "%s_ft_model_%d.pdparams" %
                                              (args.task_name, global_step))
                    if not os.path.exists(output_dir):
                        os.makedirs(output_dir)
                    # Need better way to get inner model of DataParallel
                    model_to_save = model._layers if isinstance(
                        model, paddle.DataParallel) else model
                    model_to_save.save_pretrained(output_dir)
                    tokenizer.save_pretrained(output_dir)
            if global_step >= num_training_steps:
                return
예제 #6
0
def main(args):
    # For memory saving when using FasterGeneration:
    # If environment variable `PPFG_QKV_MEM_OPT` is set and the weights of q/k/v
    # is fused, it will try to delete the original unfused weights. Note the
    # rollback to original model would not be guarantee anymore when the faster
    # model failed if the original weights are deleted.
    os.environ["PPFG_QKV_MEM_OPT"] = "1"
    if args.use_fp16:
        paddle.set_default_dtype("float16")
    enable_ft_para()
    # TODO(guosheng): Maybe device can be set in `enable_ft_para`
    paddle.set_device("gpu:" + str(get_ft_para_conf().rank))

    if args.profile:
        UnifiedTransformerLMHeadModel.generate = profile(args.batch_size)(
            UnifiedTransformerLMHeadModel.generate)
    tokenizer = UnifiedTransformerTokenizer.from_pretrained("plato-xl")
    model = UnifiedTransformerLMHeadModel.from_pretrained(
        "plato-xl", load_state_as_np=True)
    model.eval()

    history = [
        "hi , Mary ! What do you usually like to do in your spare time ?",
        "well , I spend a lot of time watching movies .",
        "what a confidence ! I always watch a lot of movies , too ."
        "oh really , Frank ? What kind of movies do you like ?"
    ]
    inputs = [history] * args.batch_size
    inputs = list(
        map(
            lambda history: tokenizer.dialogue_encode(
                history=history,
                add_start_token_as_response=True,
                return_length=True,
                return_role_ids=args.use_role,
                position_style=args.position_style), inputs))
    collator = DataCollatorWithPadding(tokenizer)
    data = collator(inputs)

    outputs, _ = model.generate(
        input_ids=data['input_ids'],
        token_type_ids=data['token_type_ids'],
        position_ids=data['position_ids'],
        attention_mask=data['attention_mask'].cast(
            "float32"),  # TODO(guosheng): remove this cast
        role_ids=data.get('role_ids', None),
        seq_len=data['seq_len'],
        max_length=args.max_out_len,
        min_length=args.min_out_len,
        decode_strategy='sampling',
        top_k=args.topk,
        top_p=args.topp,
        temperature=args.temperature,
        num_return_sequences=args.num_return_sequences,
        use_faster=True,
        use_fp16_decoding=args.use_fp16)

    # Only make the first process to output.
    if get_ft_para_conf().rank == 0:
        for i in range(len(outputs)):
            result = postprocess_response(outputs[i].numpy(), tokenizer)
            print("Result:", result)
예제 #7
0
def run(args):
    if args.do_train:
        assert args.batch_size % args.gradient_accumulation_steps == 0, \
            "Please make sure argmument `batch_size` must be divisible by `gradient_accumulation_steps`."
    paddle.set_device(args.device)
    if paddle.distributed.get_world_size() > 1:
        paddle.distributed.init_parallel_env()
    rank = paddle.distributed.get_rank()

    tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)

    set_seed(args)

    train_examples, dev_examples, test_examples = load_dataset(
        'clue', 'cmrc2018', split=["train", "validation", "test"])

    column_names = train_examples.column_names
    if rank == 0:
        if os.path.exists(args.model_name_or_path):
            logger.info("init checkpoint from %s" % args.model_name_or_path)

    model = AutoModelForQuestionAnswering.from_pretrained(
        args.model_name_or_path)

    if paddle.distributed.get_world_size() > 1:
        model = paddle.DataParallel(model)

    def prepare_train_features(examples):
        # Tokenize our examples with truncation and maybe padding, but keep the overflows using a stride. This results
        # in one example possible giving several features when a context is long, each of those features having a
        # context that overlaps a bit the context of the previous feature.
        # NOTE: Almost the same functionality as HuggingFace's prepare_train_features function. The main difference is
        # that HugggingFace uses ArrowTable as basic data structure, while we use list of dictionary instead.
        contexts = examples['context']
        questions = examples['question']

        tokenized_examples = tokenizer(questions,
                                       contexts,
                                       stride=args.doc_stride,
                                       max_seq_len=args.max_seq_length)

        # Since one example might give us several features if it has a long context, we need a map from a feature to
        # its corresponding example. This key gives us just that.
        sample_mapping = tokenized_examples.pop("overflow_to_sample")
        # The offset mappings will give us a map from token to character position in the original context. This will
        # help us compute the start_positions and end_positions.
        offset_mapping = tokenized_examples.pop("offset_mapping")

        # Let's label those examples!
        tokenized_examples["start_positions"] = []
        tokenized_examples["end_positions"] = []

        for i, offsets in enumerate(offset_mapping):
            # We will label impossible answers with the index of the CLS token.
            input_ids = tokenized_examples["input_ids"][i]
            cls_index = input_ids.index(tokenizer.cls_token_id)

            # Grab the sequence corresponding to that example (to know what is the context and what is the question).
            sequence_ids = tokenized_examples['token_type_ids'][i]

            # One example can give several spans, this is the index of the example containing this span of text.
            sample_index = sample_mapping[i]
            answers = examples['answers'][sample_index]
            # If no answers are given, set the cls_index as answer.
            if len(answers["answer_start"]) == 0:
                tokenized_examples["start_positions"].append(cls_index)
                tokenized_examples["end_positions"].append(cls_index)
            else:
                # Start/end character index of the answer in the text.
                start_char = answers["answer_start"][0]
                end_char = start_char + len(answers["text"][0])

                # Start token index of the current span in the text.
                token_start_index = 0
                while sequence_ids[token_start_index] != 1:
                    token_start_index += 1

                # End token index of the current span in the text.
                token_end_index = len(input_ids) - 1
                while sequence_ids[token_end_index] != 1:
                    token_end_index -= 1
                token_end_index -= 1

                # Detect if the answer is out of the span (in which case this feature is labeled with the CLS index).
                if not (offsets[token_start_index][0] <= start_char
                        and offsets[token_end_index][1] >= end_char):
                    tokenized_examples["start_positions"].append(cls_index)
                    tokenized_examples["end_positions"].append(cls_index)
                else:
                    # Otherwise move the token_start_index and token_end_index to the two ends of the answer.
                    # Note: we could go after the last offset if the answer is the last word (edge case).
                    while token_start_index < len(offsets) and offsets[
                            token_start_index][0] <= start_char:
                        token_start_index += 1
                    tokenized_examples["start_positions"].append(
                        token_start_index - 1)
                    while offsets[token_end_index][1] >= end_char:
                        token_end_index -= 1
                    tokenized_examples["end_positions"].append(
                        token_end_index + 1)

        return tokenized_examples

    def prepare_validation_features(examples):
        # Tokenize our examples with truncation and maybe padding, but keep the overflows using a stride. This results
        # in one example possible giving several features when a context is long, each of those features having a
        # context that overlaps a bit the context of the previous feature.
        #NOTE: Almost the same functionality as HuggingFace's prepare_train_features function. The main difference is
        # that HuggingFace uses ArrowTable as basic data structure, while we use list of dictionary instead.
        contexts = examples['context']
        questions = examples['question']

        tokenized_examples = tokenizer(questions,
                                       contexts,
                                       stride=args.doc_stride,
                                       max_seq_len=args.max_seq_length,
                                       return_attention_mask=True)

        # Since one example might give us several features if it has a long context, we need a map from a feature to
        # its corresponding example. This key gives us just that.
        sample_mapping = tokenized_examples.pop("overflow_to_sample")

        # For evaluation, we will need to convert our predictions to substrings of the context, so we keep the
        # corresponding example_id and we will store the offset mappings.
        tokenized_examples["example_id"] = []

        for i in range(len(tokenized_examples["input_ids"])):
            # Grab the sequence corresponding to that example (to know what is the context and what is the question).
            sequence_ids = tokenized_examples['token_type_ids'][i]
            context_index = 1

            # One example can give several spans, this is the index of the example containing this span of text.
            sample_index = sample_mapping[i]
            tokenized_examples["example_id"].append(
                examples["id"][sample_index])

            # Set to None the offset_mapping that are not part of the context so it's easy to determine if a token
            # position is part of the context or not.
            tokenized_examples["offset_mapping"][i] = [
                (o if sequence_ids[k] == context_index
                 and k != len(sequence_ids) - 1 else None)
                for k, o in enumerate(tokenized_examples["offset_mapping"][i])
            ]

        return tokenized_examples

    if args.do_train:
        args.batch_size = int(args.batch_size /
                              args.gradient_accumulation_steps)

        with main_process_first(desc="train dataset map pre-processing"):
            train_ds = train_examples.map(
                prepare_train_features,
                batched=True,
                remove_columns=column_names,
                load_from_cache_file=not args.overwrite_cache,
                num_proc=args.num_proc,
                desc="Running tokenizer on train dataset")
        train_batch_sampler = paddle.io.DistributedBatchSampler(
            train_ds, batch_size=args.batch_size, shuffle=True)

        batchify_fn = DataCollatorWithPadding(tokenizer)
        train_data_loader = DataLoader(dataset=train_ds,
                                       batch_sampler=train_batch_sampler,
                                       collate_fn=batchify_fn,
                                       return_list=True)

        with main_process_first(desc="evaluate dataset map pre-processing"):
            dev_ds = dev_examples.map(
                prepare_validation_features,
                batched=True,
                remove_columns=column_names,
                num_proc=args.num_proc,
                load_from_cache_file=args.overwrite_cache,
                desc="Running tokenizer on validation dataset")
        dev_ds_for_model = dev_ds.remove_columns(
            ["example_id", "offset_mapping", "attention_mask"])
        dev_batch_sampler = paddle.io.BatchSampler(
            dev_ds, batch_size=args.eval_batch_size, shuffle=False)

        dev_data_loader = DataLoader(dataset=dev_ds_for_model,
                                     batch_sampler=dev_batch_sampler,
                                     collate_fn=batchify_fn,
                                     return_list=True)

        num_training_steps = int(
            args.max_steps /
            args.gradient_accumulation_steps) if args.max_steps >= 0 else int(
                len(train_data_loader) * args.num_train_epochs /
                args.gradient_accumulation_steps)

        warmup = args.warmup_steps if args.warmup_steps > 0 else args.warmup_proportion
        lr_scheduler = LinearDecayWithWarmup(args.learning_rate,
                                             num_training_steps, warmup)

        # Generate parameter names needed to perform weight decay.
        # All bias and LayerNorm parameters are excluded.
        decay_params = [
            p.name for n, p in model.named_parameters()
            if not any(nd in n for nd in ["bias", "norm"])
        ]
        optimizer = paddle.optimizer.AdamW(
            learning_rate=lr_scheduler,
            epsilon=args.adam_epsilon,
            parameters=model.parameters(),
            weight_decay=args.weight_decay,
            apply_decay_param_fun=lambda x: x in decay_params)
        criterion = CrossEntropyLossForSQuAD()
        best_res = (0.0, 0.0)
        global_step = 0
        tic_train = time.time()
        for epoch in range(args.num_train_epochs):
            for step, batch in enumerate(train_data_loader):
                start_positions = batch.pop("start_positions")
                end_positions = batch.pop("end_positions")
                logits = model(**batch)
                loss = criterion(logits, (start_positions, end_positions))
                if args.gradient_accumulation_steps > 1:
                    loss = loss / args.gradient_accumulation_steps
                loss.backward()
                if (step + 1) % args.gradient_accumulation_steps == 0:
                    global_step += 1
                    optimizer.step()
                    lr_scheduler.step()
                    optimizer.clear_grad()

                    if global_step % args.logging_steps == 0:
                        logger.info(
                            "global step %d/%d, epoch: %d, batch: %d, loss: %f, speed: %.2f step/s"
                            % (global_step, num_training_steps, epoch,
                               step + 1, loss, args.logging_steps /
                               (time.time() - tic_train)))
                        tic_train = time.time()
                    if global_step >= num_training_steps:
                        logger.info("best_result: %.2f/%.2f" %
                                    (best_res[0], best_res[1]))
                        return
            em, f1 = evaluate(model, dev_examples, dev_ds, dev_data_loader,
                              args)
            if paddle.distributed.get_rank() == 0 and em > best_res[0]:
                best_res = (em, f1)
                if args.save_best_model:
                    output_dir = args.output_dir
                    if not os.path.exists(output_dir):
                        os.makedirs(output_dir)
                    # need better way to get inner model of DataParallel
                    model_to_save = model._layers if isinstance(
                        model, paddle.DataParallel) else model
                    model_to_save.save_pretrained(output_dir)
                    tokenizer.save_pretrained(output_dir)
        logger.info("best_result: %.2f/%.2f" % (best_res[0], best_res[1]))

    if args.do_predict and rank == 0:
        test_ds = test_examples.map(prepare_validation_features,
                                    batched=True,
                                    remove_columns=column_names,
                                    num_proc=args.num_proc)
        test_ds_for_model = test_ds.remove_columns(
            ["example_id", "offset_mapping", "attention_mask"])
        dev_batchify_fn = DataCollatorWithPadding(tokenizer)

        test_batch_sampler = paddle.io.BatchSampler(
            test_ds_for_model, batch_size=args.eval_batch_size, shuffle=False)

        batchify_fn = DataCollatorWithPadding(tokenizer)
        test_data_loader = DataLoader(dataset=test_ds_for_model,
                                      batch_sampler=test_batch_sampler,
                                      collate_fn=batchify_fn,
                                      return_list=True)

        evaluate(model,
                 test_examples,
                 test_ds,
                 test_data_loader,
                 args,
                 do_eval=False)
def main():
    parser = PdArgumentParser(
        (ModelArguments, DataArguments, TrainingArguments))
    model_args, data_args, training_args = parser.parse_args_into_dataclasses()

    # Log model and data config
    training_args.print_config(model_args, "Model")
    training_args.print_config(data_args, "Data")

    paddle.set_device(training_args.device)

    # Log on each process the small summary:
    logger.warning(
        f"Process rank: {training_args.local_rank}, device: {training_args.device}, world_size: {training_args.world_size}, "
        +
        f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
    )

    # Detecting last checkpoint.
    last_checkpoint = None
    if os.path.isdir(
            training_args.output_dir
    ) and training_args.do_train and not training_args.overwrite_output_dir:
        last_checkpoint = get_last_checkpoint(training_args.output_dir)
        if last_checkpoint is None and len(os.listdir(
                training_args.output_dir)) > 0:
            raise ValueError(
                f"Output directory ({training_args.output_dir}) already exists and is not empty. "
                "Use --overwrite_output_dir to overcome.")
        elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
            logger.info(
                f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
                "the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
            )

    data_args.dataset = data_args.dataset.strip()

    dataset_config = data_args.dataset.split(" ")
    print(dataset_config)
    raw_datasets = load_dataset(
        dataset_config[0],
        name=None if len(dataset_config) <= 1 else dataset_config[1],
        splits=('train', 'dev'))

    data_args.label_list = getattr(raw_datasets['train'], "label_list", None)
    num_classes = 1 if raw_datasets["train"].label_list == None else len(
        raw_datasets['train'].label_list)

    # Define tokenizer, model, loss function.
    tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path)
    model = AutoModelForSequenceClassification.from_pretrained(
        model_args.model_name_or_path, num_classes=num_classes)
    criterion = nn.loss.CrossEntropyLoss(
    ) if data_args.label_list else nn.loss.MSELoss()

    # Define dataset pre-process function
    trans_fn = partial(clue_trans_fn, tokenizer=tokenizer, args=data_args)

    # Define data collector
    data_collator = DataCollatorWithPadding(tokenizer)

    # Dataset pre-process
    if training_args.do_train:
        train_dataset = raw_datasets["train"].map(trans_fn)
    if training_args.do_eval:
        eval_dataset = raw_datasets["dev"].map(trans_fn)
    if training_args.do_predict:
        test_dataset = raw_datasets["test"].map(trans_fn)

    # Define the metrics of tasks.
    def compute_metrics(p):
        preds = p.predictions[0] if isinstance(p.predictions,
                                               tuple) else p.predictions

        preds = paddle.to_tensor(preds)
        label = paddle.to_tensor(p.label_ids)

        probs = F.softmax(preds, axis=1)
        metric = Accuracy()
        metric.reset()
        result = metric.compute(preds, label)
        metric.update(result)
        accu = metric.accumulate()
        metric.reset()
        return {"accuracy": accu}

    trainer = Trainer(
        model=model,
        criterion=criterion,
        args=training_args,
        data_collator=data_collator,
        train_dataset=train_dataset if training_args.do_train else None,
        eval_dataset=eval_dataset if training_args.do_eval else None,
        tokenizer=tokenizer,
        compute_metrics=compute_metrics,
    )

    checkpoint = None
    if training_args.resume_from_checkpoint is not None:
        checkpoint = training_args.resume_from_checkpoint
    elif last_checkpoint is not None:
        checkpoint = last_checkpoint

    # Training
    if training_args.do_train:
        train_result = trainer.train(resume_from_checkpoint=checkpoint)
        metrics = train_result.metrics
        trainer.save_model()
        trainer.log_metrics("train", metrics)
        trainer.save_metrics("train", metrics)
        trainer.save_state()

    # Evaluate and tests model
    if training_args.do_eval:
        eval_metrics = trainer.evaluate()
        trainer.log_metrics("eval", eval_metrics)

    if training_args.do_predict:
        test_ret = trainer.predict(test_dataset)
        trainer.log_metrics("test", test_ret.metrics)
        if test_ret.label_ids is None:
            paddle.save(
                test_ret.predictions,
                os.path.join(training_args.output_dir,
                             "test_results.pdtensor"),
            )

    # export inference model
    if training_args.do_export:
        # You can also load from certain checkpoint
        # trainer.load_state_dict_from_checkpoint("/path/to/checkpoint/")
        input_spec = [
            paddle.static.InputSpec(shape=[None, None],
                                    dtype="int64"),  # input_ids
            paddle.static.InputSpec(shape=[None, None],
                                    dtype="int64")  # segment_ids
        ]
        if model_args.export_model_dir is None:
            model_args.export_model_dir = os.path.join(
                training_args.output_dir, "export")
        paddlenlp.transformers.export_model(model=trainer.model,
                                            input_spec=input_spec,
                                            path=model_args.export_model_dir)
예제 #9
0
def do_train(args):
    assert args.batch_size % args.gradient_accumulation_steps == 0, \
        "Please make sure argmument `batch_size` must be divisible by `gradient_accumulation_steps`."
    paddle.set_device(args.device)
    if paddle.distributed.get_world_size() > 1:
        paddle.distributed.init_parallel_env()

    set_seed(args)

    args.task_name = args.task_name.lower()
    metric_class = METRIC_CLASSES[args.task_name]

    args.batch_size = int(args.batch_size / args.gradient_accumulation_steps)
    train_ds, dev_ds = load_dataset(
        'clue', args.task_name, splits=('train', 'dev'))

    tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)

    trans_func = partial(
        convert_example,
        label_list=train_ds.label_list,
        tokenizer=tokenizer,
        max_seq_length=args.max_seq_length)

    train_ds = train_ds.map(trans_func, lazy=True)

    train_batch_sampler = paddle.io.DistributedBatchSampler(
        train_ds, batch_size=args.batch_size, shuffle=True)

    dev_ds = dev_ds.map(trans_func, lazy=True)
    dev_batch_sampler = paddle.io.BatchSampler(
        dev_ds, batch_size=args.batch_size, shuffle=False)

    batchify_fn = DataCollatorWithPadding(tokenizer)

    train_data_loader = DataLoader(
        dataset=train_ds,
        batch_sampler=train_batch_sampler,
        collate_fn=batchify_fn,
        num_workers=0,
        return_list=True)
    dev_data_loader = DataLoader(
        dataset=dev_ds,
        batch_sampler=dev_batch_sampler,
        collate_fn=batchify_fn,
        num_workers=0,
        return_list=True)

    num_classes = 1 if train_ds.label_list == None else len(train_ds.label_list)
    model = AutoModelForSequenceClassification.from_pretrained(
        args.model_name_or_path, num_classes=num_classes)

    if args.dropout != 0.1:
        update_model_dropout(model, args.dropout)

    if paddle.distributed.get_world_size() > 1:
        model = paddle.DataParallel(model)

    if args.max_steps > 0:
        num_training_steps = args.max_steps / args.gradient_accumulation_steps
        num_train_epochs = math.ceil(num_training_steps /
                                     len(train_data_loader))
    else:
        num_training_steps = len(
            train_data_loader
        ) * args.num_train_epochs / args.gradient_accumulation_steps
        num_train_epochs = args.num_train_epochs

    warmup = args.warmup_steps if args.warmup_steps > 0 else args.warmup_proportion

    lr_scheduler = LinearDecayWithWarmup(args.learning_rate, num_training_steps,
                                         warmup)

    # Generate parameter names needed to perform weight decay.
    # All bias and LayerNorm parameters are excluded.
    decay_params = [
        p.name for n, p in model.named_parameters()
        if not any(nd in n for nd in ["bias", "norm"])
    ]
    optimizer = paddle.optimizer.AdamW(
        learning_rate=lr_scheduler,
        beta1=0.9,
        beta2=0.999,
        epsilon=args.adam_epsilon,
        parameters=model.parameters(),
        weight_decay=args.weight_decay,
        apply_decay_param_fun=lambda x: x in decay_params,
        grad_clip=nn.ClipGradByGlobalNorm(args.max_grad_norm))

    loss_fct = paddle.nn.loss.CrossEntropyLoss(
    ) if train_ds.label_list else paddle.nn.loss.MSELoss()

    metric = metric_class()
    best_acc = 0.0
    global_step = 0
    tic_train = time.time()
    for epoch in range(num_train_epochs):
        for step, batch in enumerate(train_data_loader):
            labels = batch.pop("labels")
            logits = model(**batch)
            loss = loss_fct(logits, labels)
            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps
            loss.backward()
            if (step + 1) % args.gradient_accumulation_steps == 0:
                global_step += 1
                optimizer.step()
                lr_scheduler.step()
                optimizer.clear_grad()
                if global_step % args.logging_steps == 0:
                    print(
                        "global step %d/%d, epoch: %d, batch: %d, rank_id: %s, loss: %f, lr: %.10f, speed: %.4f step/s"
                        % (global_step, num_training_steps, epoch, step,
                           paddle.distributed.get_rank(), loss,
                           optimizer.get_lr(),
                           args.logging_steps / (time.time() - tic_train)))
                    tic_train = time.time()
                if global_step % args.save_steps == 0 or global_step == num_training_steps:
                    tic_eval = time.time()
                    acc = evaluate(model, loss_fct, metric, dev_data_loader)
                    print("eval done total : %s s" % (time.time() - tic_eval))
                    if acc > best_acc:
                        best_acc = acc
                        output_dir = args.output_dir
                        if not os.path.exists(output_dir):
                            os.makedirs(output_dir)
                        # Need better way to get inner model of DataParallel
                        model_to_save = model._layers if isinstance(
                            model, paddle.DataParallel) else model
                        model_to_save.save_pretrained(output_dir)
                        tokenizer.save_pretrained(output_dir)
                if global_step >= num_training_steps:
                    print("best_acc: ", best_acc)
                    return
    print("best_acc: ", best_acc)
예제 #10
0
def run(args):
    paddle.set_device(args.device)
    if paddle.distributed.get_world_size() > 1:
        paddle.distributed.init_parallel_env()
    rank = paddle.distributed.get_rank()
    args.model_type = args.model_type.lower()
    model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
    tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path)

    if args.version_2_with_negative:
        train_examples = load_dataset('squad_v2', split='train')
        dev_examples = load_dataset('squad_v2', split='validation')
    else:
        train_examples = load_dataset('squad', split='train')
        dev_examples = load_dataset('squad', split='validation')
    set_seed(args)
    if rank == 0:
        if os.path.exists(args.model_name_or_path):
            print("init checkpoint from %s" % args.model_name_or_path)

    model = model_class.from_pretrained(args.model_name_or_path)
    column_names = train_examples.column_names
    if paddle.distributed.get_world_size() > 1:
        model = paddle.DataParallel(model)

    if args.do_train:
        train_ds = train_examples.map(partial(prepare_train_features,
                                              tokenizer=tokenizer,
                                              args=args),
                                      batched=True,
                                      remove_columns=column_names,
                                      num_proc=4)
        train_batch_sampler = paddle.io.DistributedBatchSampler(
            train_ds, batch_size=args.batch_size, shuffle=True)
        train_batchify_fn = DataCollatorWithPadding(tokenizer)

        train_data_loader = DataLoader(dataset=train_ds,
                                       batch_sampler=train_batch_sampler,
                                       collate_fn=train_batchify_fn,
                                       return_list=True)

        num_training_steps = args.max_steps if args.max_steps > 0 else len(
            train_data_loader) * args.num_train_epochs
        num_train_epochs = math.ceil(num_training_steps /
                                     len(train_data_loader))

        lr_scheduler = LinearDecayWithWarmup(args.learning_rate,
                                             num_training_steps,
                                             args.warmup_proportion)

        # Generate parameter names needed to perform weight decay.
        # All bias and LayerNorm parameters are excluded.
        decay_params = [
            p.name for n, p in model.named_parameters()
            if not any(nd in n for nd in ["bias", "norm"])
        ]
        optimizer = paddle.optimizer.AdamW(
            learning_rate=lr_scheduler,
            epsilon=args.adam_epsilon,
            parameters=model.parameters(),
            weight_decay=args.weight_decay,
            apply_decay_param_fun=lambda x: x in decay_params)
        criterion = CrossEntropyLossForSQuAD()

        global_step = 0
        tic_train = time.time()

        for epoch in range(num_train_epochs):
            for step, batch in enumerate(train_data_loader):
                global_step += 1
                logits = model(input_ids=batch['input_ids'],
                               token_type_ids=batch['token_type_ids'],
                               attention_mask=batch['attention_mask'])
                loss = criterion(
                    logits, (batch['start_positions'], batch['end_positions']))
                if global_step % args.logging_steps == 0:
                    print(
                        "global step %d, epoch: %d, batch: %d, loss: %f, speed: %.2f step/s"
                        % (global_step, epoch + 1, step + 1, loss,
                           args.logging_steps / (time.time() - tic_train)))
                    tic_train = time.time()
                loss.backward()
                optimizer.step()
                lr_scheduler.step()
                optimizer.clear_grad()

                if global_step % args.save_steps == 0 or global_step == num_training_steps:
                    if rank == 0:
                        output_dir = os.path.join(args.output_dir,
                                                  "model_%d" % global_step)
                        if not os.path.exists(output_dir):
                            os.makedirs(output_dir)
                        # need better way to get inner model of DataParallel
                        model_to_save = model._layers if isinstance(
                            model, paddle.DataParallel) else model
                        model_to_save.save_pretrained(output_dir)
                        tokenizer.save_pretrained(output_dir)
                        print('Saving checkpoint to:', output_dir)
                    if global_step == num_training_steps:
                        break

    if args.do_predict and rank == 0:
        dev_ds = dev_examples.map(partial(prepare_validation_features,
                                          tokenizer=tokenizer,
                                          args=args),
                                  batched=True,
                                  remove_columns=column_names,
                                  num_proc=4)
        dev_batch_sampler = paddle.io.BatchSampler(dev_ds,
                                                   batch_size=args.batch_size,
                                                   shuffle=False)
        dev_ds_for_model = dev_ds.remove_columns(
            ["example_id", "offset_mapping"])
        dev_batchify_fn = DataCollatorWithPadding(tokenizer)

        dev_data_loader = DataLoader(dataset=dev_ds_for_model,
                                     batch_sampler=dev_batch_sampler,
                                     collate_fn=dev_batchify_fn,
                                     return_list=True)

        evaluate(model, dev_data_loader, dev_examples, dev_ds, args)
예제 #11
0
def main():
    paddle.seed(42)
    args = parse_args()

    args.task_name = args.task_name.lower()

    predictor = Predictor.create_predictor(args)
    tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)

    if args.task_name == "msra_ner":

        def ner_trans_fn(example,
                         tokenizer,
                         max_seq_length=128,
                         no_entity_id=0):
            return tokenize_and_align_labels(example,
                                             tokenizer=tokenizer,
                                             no_entity_id=no_entity_id,
                                             max_seq_len=max_seq_length)

        trans_fn = partial(ner_trans_fn,
                           tokenizer=tokenizer,
                           max_seq_length=args.max_seq_length)
        dev_ds = load_dataset("msra_ner", split="test")
        label_list = dev_ds.features['ner_tags'].feature.names
        args.label_list = label_list

        column_names = dev_ds.column_names
        dev_ds = dev_ds.map(trans_fn, remove_columns=column_names)
        batchify_fn = DataCollatorForTokenClassification(tokenizer)
        outputs = predictor.predict(dev_ds, tokenizer, batchify_fn, args)
    elif args.task_name == "cmrc2018":
        dev_example = load_dataset("cmrc2018", split="validation")
        column_names = dev_example.column_names
        dev_ds = dev_example.map(
            partial(prepare_validation_features,
                    tokenizer=tokenizer,
                    doc_stride=128,
                    max_seq_length=args.max_seq_length),
            batched=True,
            num_proc=4,
            remove_columns=column_names,
            load_from_cache_file=True,
            desc="Running tokenizer on validation dataset",
        )

        batchify_fn = DataCollatorWithPadding(tokenizer)
        outputs = predictor.predict(dev_ds, tokenizer, batchify_fn, args,
                                    dev_example)
    else:
        dev_ds = ppnlp_load_dataset('clue', args.task_name, splits='dev')

        trans_func = partial(convert_example,
                             label_list=dev_ds.label_list,
                             tokenizer=tokenizer,
                             max_seq_length=args.max_seq_length,
                             is_test=False)
        dev_ds = dev_ds.map(trans_func, lazy=False)
        batchify_fn = DataCollatorWithPadding(tokenizer)

        outputs = predictor.predict(dev_ds, tokenizer, batchify_fn, args)
예제 #12
0
def main():
    parser = PdArgumentParser(
        (ModelArguments, DataArguments, TrainingArguments))
    model_args, data_args, training_args = parser.parse_args_into_dataclasses()

    # Log model and data config
    training_args.print_config(model_args, "Model")
    training_args.print_config(data_args, "Data")

    paddle.set_device(training_args.device)
    # Log on each process the small summary:
    logger.warning(
        f"Process rank: {training_args.local_rank}, device: {training_args.device}, world_size: {training_args.world_size}, "
        +
        f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
    )

    # Detecting last checkpoint.
    last_checkpoint = None
    if os.path.isdir(
            training_args.output_dir
    ) and training_args.do_train and not training_args.overwrite_output_dir:
        last_checkpoint = get_last_checkpoint(training_args.output_dir)
        if last_checkpoint is None and len(os.listdir(
                training_args.output_dir)) > 0:
            raise ValueError(
                f"Output directory ({training_args.output_dir}) already exists and is not empty. "
                "Use --overwrite_output_dir to overcome.")
        elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
            logger.info(
                f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
                "the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
            )

    # set_seed(args)
    data_args.dataset = data_args.dataset.strip()

    if data_args.dataset in ALL_DATASETS:
        # if you custom you hyper-parameters in yaml config, it will overwrite all args.
        config = ALL_DATASETS[data_args.dataset]
        for args in (model_args, data_args, training_args):
            for arg in vars(args):
                if arg in config.keys():
                    setattr(args, arg, config[arg])

        training_args.per_device_train_batch_size = config["batch_size"]
        training_args.per_device_eval_batch_size = config["batch_size"]

    dataset_config = data_args.dataset.split(" ")
    raw_datasets = load_dataset(
        dataset_config[0],
        None if len(dataset_config) <= 1 else dataset_config[1],
        cache_dir=model_args.cache_dir)

    label_list = getattr(raw_datasets['train'], "label_list", None)
    data_args.label_list = label_list

    # Define tokenizer, model, loss function.
    tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path)
    model = AutoModelForQuestionAnswering.from_pretrained(
        model_args.model_name_or_path)

    loss_fct = CrossEntropyLossForSQuAD()

    # Preprocessing the datasets.
    # Preprocessing is slighlty different for training and evaluation.
    if training_args.do_train:
        column_names = raw_datasets["train"].column_names
    elif training_args.do_eval:
        column_names = raw_datasets["validation"].column_names
    else:
        column_names = raw_datasets["test"].column_names

    if training_args.do_train:
        train_dataset = raw_datasets["train"]
        # Create train feature from dataset
        with training_args.main_process_first(
                desc="train dataset map pre-processing"):
            # Dataset pre-process
            train_dataset = train_dataset.map(
                partial(prepare_train_features,
                        tokenizer=tokenizer,
                        args=data_args),
                batched=True,
                num_proc=4,
                remove_columns=column_names,
                load_from_cache_file=not data_args.overwrite_cache,
                desc="Running tokenizer on train dataset",
            )

    if training_args.do_eval:
        eval_examples = raw_datasets["validation"]
        with training_args.main_process_first(
                desc="evaluate dataset map pre-processing"):
            eval_dataset = eval_examples.map(
                partial(prepare_validation_features,
                        tokenizer=tokenizer,
                        args=data_args),
                batched=True,
                num_proc=4,
                remove_columns=column_names,
                load_from_cache_file=not data_args.overwrite_cache,
                desc="Running tokenizer on validation dataset",
            )
    if training_args.do_predict:
        predict_examples = raw_datasets["test"]
        with training_args.main_process_first(
                desc="test dataset map pre-processing"):
            predict_dataset = predict_examples.map(
                partial(prepare_validation_features,
                        tokenizer=tokenizer,
                        args=data_args),
                batched=True,
                num_proc=4,
                remove_columns=column_names,
                load_from_cache_file=not data_args.overwrite_cache,
                desc="Running tokenizer on prediction dataset",
            )

    # Define data collector
    data_collator = DataCollatorWithPadding(tokenizer)

    # Post-processing:
    def post_processing_function(examples,
                                 features,
                                 predictions,
                                 stage="eval"):
        # Post-processing: we match the start logits and end logits to answers in the original context.
        predictions, all_nbest_json, scores_diff_json = compute_prediction(
            examples=examples,
            features=features,
            predictions=predictions,
            n_best_size=data_args.n_best_size,
            max_answer_length=data_args.max_answer_length,
            null_score_diff_threshold=data_args.null_score_diff_threshold,
        )

        # # Format the result to the format the metric expects.
        # formatted_predictions = [{
        #     "id": k,
        #     "prediction_text": v
        # } for k, v in predictions.items()]

        references = [{
            "id": ex["id"],
            "answers": ex["answers"]
        } for ex in examples]
        return EvalPrediction(predictions=predictions, label_ids=references)

    def compute_metrics(p: EvalPrediction):
        ret = squad_evaluate(examples=p.label_ids,
                             preds=p.predictions,
                             is_whitespace_splited=False)
        return dict(ret)
        # return metric.compute(predictions=p.predictions, references=p.label_ids)

    trainer = QuestionAnsweringTrainer(
        model=model,
        criterion=loss_fct,
        args=training_args,
        train_dataset=train_dataset if training_args.do_train else None,
        eval_dataset=eval_dataset if training_args.do_eval else None,
        eval_examples=eval_examples if training_args.do_eval else None,
        data_collator=data_collator,
        post_process_function=post_processing_function,
        tokenizer=tokenizer,
        compute_metrics=compute_metrics,
    )

    checkpoint = None
    if training_args.resume_from_checkpoint is not None:
        checkpoint = training_args.resume_from_checkpoint
    elif last_checkpoint is not None:
        checkpoint = last_checkpoint

    if training_args.do_train:
        # Training
        train_result = trainer.train(resume_from_checkpoint=checkpoint)
        metrics = train_result.metrics
        trainer.save_model()
        trainer.log_metrics("train", metrics)
        trainer.save_metrics("train", metrics)
        trainer.save_state()

    # model.set_state_dict(paddle.load("tmp/model_state.pdparams"))

    # Evaluate and tests model
    if training_args.do_eval:
        eval_metrics = trainer.evaluate()
        trainer.log_metrics("eval", eval_metrics)

    if training_args.do_predict:
        test_ret = trainer.predict(predict_dataset, predict_examples)
        trainer.log_metrics("predict", test_ret.metrics)

        if test_ret.label_ids is None:
            paddle.save(
                test_ret.predictions,
                os.path.join(training_args.output_dir,
                             "test_results.pdtensor"),
            )

    # export inference model
    if training_args.do_export:
        # You can also load from certain checkpoint
        # trainer.load_state_dict_from_checkpoint("/path/to/checkpoint/")
        input_spec = [
            paddle.static.InputSpec(shape=[None, None],
                                    dtype="int64"),  # input_ids
            paddle.static.InputSpec(shape=[None, None],
                                    dtype="int64")  # segment_ids
        ]

        if model_args.export_model_dir is None:
            model_args.export_model_dir = os.path.join(
                training_args.output_dir, "export")
        paddlenlp.transformers.export_model(model=trainer.model,
                                            input_spec=input_spec,
                                            path=model_args.export_model_dir)