예제 #1
0
def main():
    parser = PdArgumentParser(
        (ModelArguments, DataArguments, TrainingArguments))
    model_args, data_args, training_args = parser.parse_args_into_dataclasses()

    # Log model and data config
    training_args.print_config(model_args, "Model")
    training_args.print_config(data_args, "Data")

    paddle.set_device(training_args.device)

    # Log on each process the small summary:
    logger.warning(
        f"Process rank: {training_args.local_rank}, device: {training_args.device}, world_size: {training_args.world_size}, "
        +
        f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
    )

    # Detecting last checkpoint.
    last_checkpoint = None
    if os.path.isdir(
            training_args.output_dir
    ) and training_args.do_train and not training_args.overwrite_output_dir:
        last_checkpoint = get_last_checkpoint(training_args.output_dir)
        if last_checkpoint is None and len(os.listdir(
                training_args.output_dir)) > 0:
            raise ValueError(
                f"Output directory ({training_args.output_dir}) already exists and is not empty. "
                "Use --overwrite_output_dir to overcome.")
        elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
            logger.info(
                f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
                "the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
            )

    data_args.dataset = data_args.dataset.strip()

    if data_args.dataset in ALL_DATASETS:
        # if you custom you hyper-parameters in yaml config, it will overwrite all args.
        config = ALL_DATASETS[data_args.dataset]
        logger.info("Over-writing training config by yaml config!")
        for args in (model_args, data_args, training_args):
            for arg in vars(args):
                if arg in config.keys():
                    setattr(args, arg, config[arg])

        training_args.per_device_train_batch_size = config["batch_size"]
        training_args.per_device_eval_batch_size = config["batch_size"]

    dataset_config = data_args.dataset.split(" ")
    raw_datasets = load_dataset(
        dataset_config[0],
        None if len(dataset_config) <= 1 else dataset_config[1],
    )

    data_args.label_list = getattr(raw_datasets['train'], "label_list", None)
    num_classes = 1 if raw_datasets["train"].label_list == None else len(
        raw_datasets['train'].label_list)

    # Define tokenizer, model, loss function.
    tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path)
    model = AutoModelForSequenceClassification.from_pretrained(
        model_args.model_name_or_path, num_classes=num_classes)
    criterion = nn.loss.CrossEntropyLoss(
    ) if data_args.label_list else nn.loss.MSELoss()

    # Define dataset pre-process function
    if "clue" in data_args.dataset:
        trans_fn = partial(clue_trans_fn, tokenizer=tokenizer, args=data_args)
    else:
        trans_fn = partial(seq_trans_fn, tokenizer=tokenizer, args=data_args)

    # Define data collector
    data_collator = DataCollatorWithPadding(tokenizer)

    # Dataset pre-process
    if training_args.do_train:
        train_dataset = raw_datasets["train"].map(trans_fn)
    if training_args.do_eval:
        eval_dataset = raw_datasets["dev"].map(trans_fn)
    if training_args.do_predict:
        test_dataset = raw_datasets["test"].map(trans_fn)

    # Define the metrics of tasks.
    def compute_metrics(p):
        preds = p.predictions[0] if isinstance(p.predictions,
                                               tuple) else p.predictions

        preds = paddle.to_tensor(preds)
        label = paddle.to_tensor(p.label_ids)

        probs = F.softmax(preds, axis=1)
        metric = Accuracy()
        metric.reset()
        result = metric.compute(preds, label)
        metric.update(result)
        accu = metric.accumulate()
        metric.reset()
        return {"accuracy": accu}

    trainer = Trainer(
        model=model,
        criterion=criterion,
        args=training_args,
        data_collator=data_collator,
        train_dataset=train_dataset if training_args.do_train else None,
        eval_dataset=eval_dataset if training_args.do_eval else None,
        tokenizer=tokenizer,
        compute_metrics=compute_metrics,
    )

    checkpoint = None
    if training_args.resume_from_checkpoint is not None:
        checkpoint = training_args.resume_from_checkpoint
    elif last_checkpoint is not None:
        checkpoint = last_checkpoint

    # Training
    if training_args.do_train:
        train_result = trainer.train(resume_from_checkpoint=checkpoint)
        metrics = train_result.metrics
        trainer.save_model()
        trainer.log_metrics("train", metrics)
        trainer.save_metrics("train", metrics)
        trainer.save_state()

    # Evaluate and tests model
    if training_args.do_eval:
        eval_metrics = trainer.evaluate()
        trainer.log_metrics("eval", eval_metrics)

    if training_args.do_predict:
        test_ret = trainer.predict(test_dataset)
        trainer.log_metrics("test", test_ret.metrics)
        if test_ret.label_ids is None:
            paddle.save(
                test_ret.predictions,
                os.path.join(training_args.output_dir,
                             "test_results.pdtensor"),
            )

    # export inference model
    if training_args.do_export:
        # You can also load from certain checkpoint
        # trainer.load_state_dict_from_checkpoint("/path/to/checkpoint/")
        input_spec = [
            paddle.static.InputSpec(shape=[None, None],
                                    dtype="int64"),  # input_ids
            paddle.static.InputSpec(shape=[None, None],
                                    dtype="int64")  # segment_ids
        ]
        if model_args.export_model_dir is None:
            model_args.export_model_dir = os.path.join(
                training_args.output_dir, "export")
        paddlenlp.transformers.export_model(model=trainer.model,
                                            input_spec=input_spec,
                                            path=model_args.export_model_dir)
예제 #2
0
def do_train(args):
    # Set the paddle execute enviroment
    paddle.enable_static()
    place = paddle.set_device(args.device)
    set_seed(args)

    # Create the main_program for the training and dev_program for the validation
    main_program = paddle.static.default_main_program()
    startup_program = paddle.static.default_startup_program()
    dev_program = paddle.static.Program()

    # Get the configuration of tokenizer and model
    args.task_name = args.task_name.lower()
    args.model_type = args.model_type.lower()
    metric_class = METRIC_CLASSES[args.task_name]
    model_class, tokenizer_class = MODEL_CLASSES[args.model_type]

    # Create the tokenizer and dataset
    tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path)
    train_ds = load_dataset('glue', args.task_name, splits="train")

    trans_func = partial(convert_example,
                         tokenizer=tokenizer,
                         label_list=train_ds.label_list,
                         max_seq_length=args.max_seq_length)

    train_ds = train_ds.map(trans_func, lazy=True)

    batchify_fn = lambda samples, fn=Tuple(
        Pad(axis=0, pad_val=tokenizer.pad_token_id),  # input
        Pad(axis=0, pad_val=tokenizer.pad_token_type_id),  # token_type 
        Stack(dtype="int64" if train_ds.label_list else "float32")  # label
    ): fn(samples)

    train_batch_sampler = paddle.io.BatchSampler(train_ds,
                                                 batch_size=args.batch_size,
                                                 shuffle=True)

    feed_list_name = []

    # Define the input data and create the train/dev data_loader
    with paddle.static.program_guard(main_program, startup_program):
        [input_ids, token_type_ids,
         labels] = create_data_holder(args.task_name)

    train_data_loader = DataLoader(
        dataset=train_ds,
        feed_list=[input_ids, token_type_ids, labels],
        batch_sampler=train_batch_sampler,
        collate_fn=batchify_fn,
        num_workers=0,
        return_list=False)

    if args.task_name == "mnli":
        dev_ds_matched, dev_ds_mismatched = load_dataset(
            'glue', args.task_name, splits=["dev_matched", "dev_mismatched"])

        dev_ds_matched = dev_ds_matched.map(trans_func, lazy=True)
        dev_ds_mismatched = dev_ds_mismatched.map(trans_func, lazy=True)
        dev_batch_sampler_matched = paddle.io.BatchSampler(
            dev_ds_matched, batch_size=args.batch_size, shuffle=False)
        dev_data_loader_matched = DataLoader(
            dataset=dev_ds_matched,
            batch_sampler=dev_batch_sampler_matched,
            collate_fn=batchify_fn,
            feed_list=[input_ids, token_type_ids, labels],
            num_workers=0,
            return_list=False)
        dev_batch_sampler_mismatched = paddle.io.BatchSampler(
            dev_ds_mismatched, batch_size=args.batch_size, shuffle=False)
        dev_data_loader_mismatched = DataLoader(
            dataset=dev_ds_mismatched,
            batch_sampler=dev_batch_sampler_mismatched,
            collate_fn=batchify_fn,
            num_workers=0,
            feed_list=[input_ids, token_type_ids, labels],
            return_list=False)
    else:
        dev_ds = load_dataset('glue', args.task_name, splits='dev')
        dev_ds = dev_ds.map(trans_func, lazy=True)
        dev_batch_sampler = paddle.io.BatchSampler(dev_ds,
                                                   batch_size=args.batch_size,
                                                   shuffle=False)
        dev_data_loader = DataLoader(
            dataset=dev_ds,
            batch_sampler=dev_batch_sampler,
            collate_fn=batchify_fn,
            num_workers=0,
            feed_list=[input_ids, token_type_ids, labels],
            return_list=False)

    # Create the training-forward program, and clone it for the validation
    with paddle.static.program_guard(main_program, startup_program):
        num_class = 1 if train_ds.label_list is None else len(
            train_ds.label_list)
        model, pretrained_state_dict = model_class.from_pretrained(
            args.model_name_or_path, num_classes=num_class)
        loss_fct = paddle.nn.loss.CrossEntropyLoss(
        ) if train_ds.label_list else paddle.nn.loss.MSELoss()
        logits = model(input_ids, token_type_ids)
        loss = loss_fct(logits, labels)
        dev_program = main_program.clone(for_test=True)

    # Create the training-backward program, this pass will not be
    # executed in the validation
    num_training_steps = args.max_steps if args.max_steps > 0 else len(
        train_data_loader) * args.num_train_epochs
    with paddle.static.program_guard(main_program, startup_program):
        lr_scheduler = LinearDecayWithWarmup(args.learning_rate,
                                             num_training_steps,
                                             args.warmup_steps)
        # Generate parameter names needed to perform weight decay.
        # All bias and LayerNorm parameters are excluded.
        decay_params = [
            p.name for n, p in model.named_parameters()
            if not any(nd in n for nd in ["bias", "norm"])
        ]
        optimizer = paddle.optimizer.AdamW(
            learning_rate=lr_scheduler,
            epsilon=args.adam_epsilon,
            parameters=model.parameters(),
            weight_decay=args.weight_decay,
            apply_decay_param_fun=lambda x: x in decay_params)
        optimizer.minimize(loss)

    # Create the metric pass for the validation
    with paddle.static.program_guard(dev_program, startup_program):
        metric = metric_class()
        correct = metric.compute(logits, labels)

    # Initialize the fine-tuning parameter, we will load the parameters in
    # pre-training model. And initialize the parameter which not in pre-training model
    # by the normal distribution.
    exe = paddle.static.Executor(place)
    exe.run(startup_program)
    state_dict = model.state_dict()
    reset_state_dict = reset_program_state_dict(args, model, state_dict,
                                                pretrained_state_dict)
    paddle.static.set_program_state(main_program, reset_state_dict)

    global_step = 0
    tic_train = time.time()
    for epoch in range(args.num_train_epochs):
        for step, batch in enumerate(train_data_loader):
            global_step += 1
            loss_return = exe.run(main_program, feed=batch, fetch_list=[loss])
            if global_step % args.logging_steps == 0:
                logger.info(
                    "global step %d, epoch: %d, batch: %d, loss: %f, speed: %.2f step/s"
                    % (global_step, epoch, step, loss_return[0],
                       args.logging_steps / (time.time() - tic_train)))
                tic_train = time.time()
            lr_scheduler.step()
            if global_step % args.save_steps == 0:
                # Validation pass, record the loss and metric
                if args.task_name == "mnli":
                    evaluate(exe, metric, loss, correct, dev_program,
                             dev_data_loader_matched)
                    evaluate(exe, metric, loss, correct, dev_program,
                             dev_data_loader_mismatched)
                else:
                    evaluate(exe, metric, loss, correct, dev_program,
                             dev_data_loader)
                output_dir = os.path.join(args.output_dir,
                                          "model_%d" % global_step)
                if not os.path.exists(output_dir):
                    os.makedirs(output_dir)
                paddle.fluid.io.save_params(exe, output_dir)
                tokenizer.save_pretrained(output_dir)
            if global_step >= num_training_steps:
                return
예제 #3
0
                    "global step %d, epoch: %d, batch: %d, loss: %f, speed: %.2f step/s"
                    % (global_step, epoch, step, avg_loss, args.logging_steps /
                       (time.time() - tic_train)))
                tic_train = time.time()
            avg_loss.backward()
            optimizer.step()
            lr_scheduler.step()
            optimizer.clear_grad()
            if global_step % args.save_steps == 0 or global_step == num_training_steps:
                if paddle.distributed.get_rank() == 0:
                    if args.dataset == "peoples_daily_ner":
                        evaluate(model, loss_fct, metric, dev_data_loader,
                                 label_num, "valid")
                    evaluate(model, loss_fct, metric, test_data_loader,
                             label_num, "test")

                    paddle.save(
                        model.state_dict(),
                        os.path.join(args.output_dir,
                                     "model_%d.pdparams" % global_step))
            if global_step >= num_training_steps:
                return


if __name__ == "__main__":
    args = parser.parse_args()
    for arg in vars(args):
        logger.info('{:20}:{}'.format(arg, getattr(args, arg)))

    do_train(args)
예제 #4
0
def do_train(args):
    # Initialization for the parallel enviroment
    paddle.set_device(args.device)
    if paddle.distributed.get_world_size() > 1:
        paddle.distributed.init_parallel_env()

    worker_index = paddle.distributed.get_rank()
    worker_num = paddle.distributed.get_world_size()

    # Set the random seed for the training process
    set_seed(args)
    worker_init = WorkerInitObj(args.seed + worker_index)

    # Get the model class and tokenizer class
    args.model_type = args.model_type.lower()
    model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
    tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path)

    # Define the pretrain model and metric
    pretrained_models_list = list(
        model_class.pretrained_init_configuration.keys())
    if args.model_name_or_path in pretrained_models_list:
        model = BigBirdForPretraining(
            BigBirdModel(**model_class.pretrained_init_configuration[
                args.model_name_or_path]))
    else:
        model = BigBirdForPretraining.from_pretrained(args.model_name_or_path)
    # Get bigbird config for generate random attention mask
    config = getattr(model, BigBirdForPretraining.base_model_prefix).config
    criterion = BigBirdPretrainingCriterion(config["vocab_size"], args.use_nsp)
    if worker_num > 1:
        model = paddle.DataParallel(model)

    # Define learing_rate scheduler and optimizer
    lr_scheduler = LinearDecayWithWarmup(args.learning_rate, args.max_steps,
                                         args.warmup_steps)

    # Generate parameter names needed to perform weight decay.
    # All bias and LayerNorm parameters are excluded.
    decay_params = [
        p.name for n, p in model.named_parameters()
        if not any(nd in n for nd in ["bias", "norm"])
    ]
    optimizer = paddle.optimizer.AdamW(
        learning_rate=lr_scheduler,
        epsilon=args.adam_epsilon,
        parameters=model.parameters(),
        weight_decay=args.weight_decay,
        apply_decay_param_fun=lambda x: x in decay_params)

    global_step = 0
    tic_train = time.time()
    for epoch in range(args.epochs):
        files = [
            os.path.join(args.input_dir, f) for f in os.listdir(args.input_dir)
        ]
        files.sort()
        num_files = len(files)
        for f_id in range(num_files):
            train_data_loader = create_dataloader(files[f_id], tokenizer,
                                                  worker_init, args.batch_size,
                                                  args.max_encoder_length,
                                                  args.max_pred_length, config)
            for step, batch in enumerate(train_data_loader):
                global_step += 1
                (input_ids, segment_ids, masked_lm_positions, masked_lm_ids,
                 masked_lm_weights, next_sentence_labels,
                 masked_lm_scale) = batch[:7]
                rand_mask_idx_list = batch[7:]

                prediction_scores, seq_relationship_score = model(
                    input_ids=input_ids,
                    token_type_ids=segment_ids,
                    rand_mask_idx_list=rand_mask_idx_list,
                    masked_positions=masked_lm_positions)
                loss = criterion(prediction_scores, seq_relationship_score,
                                 masked_lm_ids, next_sentence_labels,
                                 masked_lm_scale, masked_lm_weights)
                if global_step % args.logging_steps == 0 and worker_index == 0:
                    logger.info(
                        "global step %d, epoch: %d, lr: %.10f, loss: %f, speed: %.2f step/s"
                        % (global_step, epoch, optimizer.get_lr(), loss,
                           args.logging_steps / (time.time() - tic_train)))
                    tic_train = time.time()
                loss.backward()
                optimizer.step()
                lr_scheduler.step()
                optimizer.clear_grad()
                if global_step % args.save_steps == 0:
                    if worker_index == 0:
                        output_dir = os.path.join(args.output_dir,
                                                  "model_%d" % global_step)
                        if not os.path.exists(output_dir):
                            os.makedirs(output_dir)
                        # Need better way to get inner model of DataParallel
                        model_to_save = model._layers if isinstance(
                            model, paddle.DataParallel) else model
                        model_to_save.save_pretrained(output_dir)
                        tokenizer.save_pretrained(output_dir)
                        paddle.save(
                            optimizer.state_dict(),
                            os.path.join(output_dir, "model_state.pdopt"))
                if global_step >= args.max_steps:
                    del train_data_loader
                    return
            del train_data_loader
예제 #5
0
def main():
    parser = PdArgumentParser(
        (ModelArguments, DataArguments, TrainingArguments))
    model_args, data_args, training_args = parser.parse_args_into_dataclasses()

    training_args.eval_iters = 10
    training_args.test_iters = training_args.eval_iters * 10

    # Log model and data config
    training_args.print_config(model_args, "Model")
    training_args.print_config(data_args, "Data")

    paddle.set_device(training_args.device)

    # Log on each process the small summary:
    logger.warning(
        f"Process rank: {training_args.local_rank}, device: {training_args.device}, world_size: {training_args.world_size}, "
        +
        f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
    )

    # Detecting last checkpoint.
    last_checkpoint = None
    if os.path.isdir(
            training_args.output_dir
    ) and training_args.do_train and not training_args.overwrite_output_dir:
        last_checkpoint = get_last_checkpoint(training_args.output_dir)
        if last_checkpoint is None and len(os.listdir(
                training_args.output_dir)) > 1:
            raise ValueError(
                f"Output directory ({training_args.output_dir}) already exists and is not empty. "
                "Use --overwrite_output_dir to overcome.")
        elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
            logger.info(
                f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
                "the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
            )

    model_class, tokenizer_class = MODEL_CLASSES['ernie-health']

    # Loads or initialize a model.
    pretrained_models = list(
        tokenizer_class.pretrained_init_configuration.keys())

    if model_args.model_name_or_path in pretrained_models:
        tokenizer = tokenizer_class.from_pretrained(
            model_args.model_name_or_path)
        generator = ElectraGenerator(
            ElectraModel(**model_class.pretrained_init_configuration[
                model_args.model_name_or_path + "-generator"]))
        discriminator = ErnieHealthDiscriminator(
            ElectraModel(**model_class.pretrained_init_configuration[
                model_args.model_name_or_path + "-discriminator"]))
        model = model_class(generator, discriminator)
    else:
        raise ValueError("Only support %s" % (", ".join(pretrained_models)))

    # Loads dataset.
    tic_load_data = time.time()
    logger.info("start load data : %s" %
                (time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())))

    train_dataset = MedicalCorpus(data_path=data_args.input_dir,
                                  tokenizer=tokenizer)
    logger.info("load data done, total : %s s" % (time.time() - tic_load_data))

    # Reads data and generates mini-batches.
    data_collator = DataCollatorForErnieHealth(
        tokenizer=tokenizer,
        max_seq_length=data_args.max_seq_length,
        mlm_prob=data_args.masked_lm_prob,
        return_dict=True)

    class CriterionWrapper(paddle.nn.Layer):
        """
        """
        def __init__(self):
            """CriterionWrapper
            """
            super(CriterionWrapper, self).__init__()
            self.criterion = ErnieHealthPretrainingCriterion(
                getattr(
                    model.generator,
                    ElectraGenerator.base_model_prefix).config["vocab_size"],
                model.gen_weight)

        def forward(self, output, labels):
            """forward function

            Args:
                output (tuple): generator_logits, logits_rtd, logits_mts, logits_csp, disc_labels, mask
                labels (tuple): generator_labels

            Returns:
                Tensor: final loss.
            """
            generator_logits, logits_rtd, logits_mts, logits_csp, disc_labels, masks = output
            generator_labels = labels

            loss, gen_loss, rtd_loss, mts_loss, csp_loss = self.criterion(
                generator_logits, generator_labels, logits_rtd, logits_mts,
                logits_csp, disc_labels, masks)

            return loss

    trainer = Trainer(
        model=model,
        criterion=CriterionWrapper(),
        args=training_args,
        data_collator=data_collator,
        train_dataset=train_dataset if training_args.do_train else None,
        eval_dataset=None,
        tokenizer=tokenizer,
    )

    checkpoint = None
    if training_args.resume_from_checkpoint is not None:
        checkpoint = training_args.resume_from_checkpoint
    elif last_checkpoint is not None:
        checkpoint = last_checkpoint

    # Training
    if training_args.do_train:
        train_result = trainer.train(resume_from_checkpoint=checkpoint)
        metrics = train_result.metrics
        trainer.save_model()
        trainer.log_metrics("train", metrics)
        trainer.save_metrics("train", metrics)
        trainer.save_state()
예제 #6
0
def create_pretrained_dataset(args,
                              input_path,
                              local_rank,
                              data_world_rank,
                              data_world_size,
                              eos_id,
                              worker_init=None,
                              max_seq_len=1024,
                              places=None,
                              data_holders=None):
    device_world_size = paddle.distributed.get_world_size()
    device_world_rank = paddle.distributed.get_rank()

    logger.info(
        "The distributed run, total device num:{}, distinct dataflow num:{}.".
        format(device_world_size, data_world_size))

    process_datas = np.load(input_path, mmap_mode="r+", allow_pickle=True)
    # All documment ids, extend as 1-D array.
    sample_ids = process_datas["ids"]
    # The len(sample_lens) num of docs
    # The sum(sample_lens) should equal len(sample_ids)
    sample_lens = process_datas["lens"]

    splits = get_train_valid_test_split_(args.split, len(sample_lens))
    assert len(sample_lens) >= splits[
        -1], "The document nums should larger than max of splits, but %s < %s" % (
            len(sample_lens), splits[-1])

    def build_dataset(index, name, num_samples):
        dataset = GPTDataset(
            file_path=input_path,
            build_data_file=local_rank == 0,
            name="gpt_" + name,
            max_seq_len=max_seq_len,
            num_samples=num_samples,
            documents=np.arange(splits[index], splits[index + 1]),
            sample_ids=sample_ids,
            sample_lens=sample_lens,
            eos_id=eos_id,
            seed=args.seed)

        batch_sampler = DistributedBatchSampler(
            dataset,
            batch_size=args.local_batch_size,
            num_replicas=data_world_size,
            rank=data_world_rank,
            shuffle=False,
            drop_last=True)

        data_loader = DataLoader(
            dataset=dataset,
            places=places,
            feed_list=data_holders,
            batch_sampler=batch_sampler,
            num_workers=0,
            worker_init_fn=worker_init,
            # collate_fn=Tuple(Stack(), Stack(), Stack(), Stack(), Stack()),
            collate_fn=Tuple(Stack(), Stack(), Stack()),
            return_list=False)
        return data_loader

    # Note, data should be broardcast to all devices.
    # for train, valid, test, the distinct data num is data_world_size
    train_data_loader = build_dataset(0, "train", args.local_batch_size *
                                      args.max_steps * data_world_size)

    valid_data_loader = build_dataset(1, "valid", args.local_batch_size *
                                      (args.max_steps // args.eval_freq + 1) *
                                      args.eval_iters * data_world_size)
    test_data_loader = build_dataset(2, "test", args.local_batch_size *
                                     args.test_iters * data_world_size)

    return train_data_loader, valid_data_loader, test_data_loader
예제 #7
0
def run(args):
    if args.do_train:
        assert args.batch_size % args.gradient_accumulation_steps == 0, \
            "Please make sure argmument `batch_size` must be divisible by `gradient_accumulation_steps`."
    max_seq_length = args.max_seq_length
    max_num_choices = 4

    def preprocess_function(examples, do_predict=False):
        def _truncate_seq_tuple(tokens_a, tokens_b, tokens_c, max_length):
            """Truncates a sequence tuple in place to the maximum length."""
            # This is a simple heuristic which will always truncate the longer
            # sequence one token at a time. This makes more sense than
            # truncating an equal percent of tokens from each, since if one
            # sequence is very short then each token that's truncated likely
            # contains more information than a longer sequence.
            while True:
                total_length = len(tokens_a) + len(tokens_b) + len(tokens_c)
                if total_length <= max_length:
                    break
                if len(tokens_a) >= len(tokens_b) and len(tokens_a) >= len(
                        tokens_c):
                    tokens_a.pop()
                elif len(tokens_b) >= len(tokens_a) and len(tokens_b) >= len(
                        tokens_c):
                    tokens_b.pop()
                else:
                    tokens_c.pop()

        num_examples = len(examples.data["question"])
        if do_predict:
            result = {"input_ids": [], "token_type_ids": []}
        else:
            result = {"input_ids": [], "token_type_ids": [], "labels": []}
        for idx in range(num_examples):
            text = '\n'.join(examples.data["context"][idx]).lower()
            question = examples.data["question"][idx].lower()
            choice_list = examples.data["choice"][idx]
            choice_list = [choice.lower()
                           for choice in choice_list][:max_num_choices]
            if not do_predict:
                answer = examples.data["answer"][idx].lower()
                label = choice_list.index(answer)

            tokens_t = tokenizer.tokenize(text)
            tokens_q = tokenizer.tokenize(question)

            tokens_t_list = []
            tokens_c_list = []

            # Pad each new example for axis=1, [batch_size, num_choices, seq_len]
            while len(choice_list) < max_num_choices:
                choice_list.append('无效答案')

            for choice in choice_list:
                tokens_c = tokenizer.tokenize(choice.lower())
                _truncate_seq_tuple(tokens_t, tokens_q, tokens_c,
                                    max_seq_length - 4)

                tokens_c = tokens_q + ["[SEP]"] + tokens_c
                tokens_t_list.append(tokens_t)
                tokens_c_list.append(tokens_c)

            new_data = tokenizer(
                tokens_t_list,
                text_pair=tokens_c_list,
                is_split_into_words=True)

            # Pad each new example for axis=2 of [batch_size, num_choices, seq_len],
            # because length of each choice could be different.
            input_ids = Pad(
                axis=0, pad_val=tokenizer.pad_token_id)(new_data["input_ids"])
            token_type_ids = Pad(
                axis=0,
                pad_val=tokenizer.pad_token_id)(new_data["token_type_ids"])

            # Final shape of input_ids: [batch_size, num_choices, seq_len]
            result["input_ids"].append(input_ids)
            result["token_type_ids"].append(token_type_ids)
            if not do_predict:
                result["labels"].append([label])
            if (idx + 1) % 1000 == 0:
                logger.info("%d samples have been processed." % (idx + 1))
        return result

    paddle.set_device(args.device)
    set_seed(args)

    if paddle.distributed.get_world_size() > 1:
        paddle.distributed.init_parallel_env()

    tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)
    model = AutoModelForMultipleChoice.from_pretrained(
        args.model_name_or_path, num_choices=max_num_choices)

    if paddle.distributed.get_world_size() > 1:
        model = paddle.DataParallel(model)

    train_ds, dev_ds, test_ds = load_dataset(
        "clue", "c3", split=["train", "validation", "test"])

    if args.do_train:
        args.batch_size = int(args.batch_size /
                              args.gradient_accumulation_steps)
        column_names = train_ds.column_names
        with main_process_first(desc="train dataset map pre-processing"):
            train_ds = train_ds.map(
                preprocess_function,
                batched=True,
                batch_size=len(train_ds),
                num_proc=args.num_proc,
                remove_columns=column_names,
                load_from_cache_file=not args.overwrite_cache,
                desc="Running tokenizer on train dataset")
        batchify_fn = lambda samples, fn=Dict({
            'input_ids': Pad(axis=1, pad_val=tokenizer.pad_token_id),  # input
            'token_type_ids': Pad(axis=1, pad_val=tokenizer.pad_token_type_id),  # segment
            'labels': Stack(dtype="int64")  # label
        }): fn(samples)

        train_batch_sampler = paddle.io.DistributedBatchSampler(
            train_ds, batch_size=args.batch_size, shuffle=True)
        train_data_loader = paddle.io.DataLoader(
            dataset=train_ds,
            batch_sampler=train_batch_sampler,
            collate_fn=batchify_fn,
            num_workers=0,
            return_list=True)
        with main_process_first(desc="evaluate dataset map pre-processing"):
            dev_ds = dev_ds.map(preprocess_function,
                                batched=True,
                                batch_size=len(dev_ds),
                                remove_columns=column_names,
                                num_proc=args.num_proc,
                                load_from_cache_file=args.overwrite_cache,
                                desc="Running tokenizer on validation dataset")
        dev_batch_sampler = paddle.io.BatchSampler(
            dev_ds, batch_size=args.eval_batch_size, shuffle=False)
        dev_data_loader = paddle.io.DataLoader(
            dataset=dev_ds,
            batch_sampler=dev_batch_sampler,
            collate_fn=batchify_fn,
            return_list=True)

        num_training_steps = int(
            args.max_steps /
            args.gradient_accumulation_steps) if args.max_steps >= 0 else int(
                len(train_data_loader) * args.num_train_epochs /
                args.gradient_accumulation_steps)

        warmup = args.warmup_steps if args.warmup_steps > 0 else args.warmup_proportion
        lr_scheduler = LinearDecayWithWarmup(args.learning_rate,
                                             num_training_steps, warmup)

        # Generate parameter names needed to perform weight decay.
        # All bias and LayerNorm parameters are excluded.
        decay_params = [
            p.name for n, p in model.named_parameters()
            if not any(nd in n for nd in ["bias", "norm"])
        ]
        grad_clip = paddle.nn.ClipGradByGlobalNorm(args.max_grad_norm)
        optimizer = paddle.optimizer.AdamW(
            learning_rate=lr_scheduler,
            parameters=model.parameters(),
            weight_decay=args.weight_decay,
            apply_decay_param_fun=lambda x: x in decay_params,
            grad_clip=grad_clip)
        loss_fct = paddle.nn.loss.CrossEntropyLoss()
        metric = paddle.metric.Accuracy()
        model.train()
        global_step = 0
        best_acc = 0.0
        tic_train = time.time()
        for epoch in range(args.num_train_epochs):
            for step, batch in enumerate(train_data_loader):
                input_ids, segment_ids, label = batch
                logits = model(input_ids=input_ids, token_type_ids=segment_ids)
                loss = loss_fct(logits, label)
                if args.gradient_accumulation_steps > 1:
                    loss = loss / args.gradient_accumulation_steps
                loss.backward()
                if (step + 1) % args.gradient_accumulation_steps == 0:
                    global_step += 1
                    optimizer.step()
                    lr_scheduler.step()
                    optimizer.clear_grad()
                    if global_step % args.logging_steps == 0:
                        logger.info(
                            "global step %d/%d, epoch: %d, batch: %d, rank_id: %s, loss: %f, lr: %.10f, speed: %.4f step/s"
                            % (global_step, num_training_steps, epoch, step + 1,
                               paddle.distributed.get_rank(), loss,
                               optimizer.get_lr(),
                               args.logging_steps / (time.time() - tic_train)))
                        tic_train = time.time()
                if global_step >= num_training_steps:
                    logger.info("best_result: %.2f" % (best_acc * 100))
                    return
            tic_eval = time.time()
            acc = evaluation(model, loss_fct, dev_data_loader, metric)
            logger.info("eval acc: %.5f, eval done total : %s s" %
                        (acc, time.time() - tic_eval))
            if paddle.distributed.get_rank() == 0 and acc > best_acc:
                best_acc = acc
                if args.save_best_model:
                    model_to_save = model._layers if isinstance(
                        model, paddle.DataParallel) else model
                    if not os.path.exists(args.output_dir):
                        os.makedirs(args.output_dir)
                    model_to_save.save_pretrained(args.output_dir)
                    tokenizer.save_pretrained(args.output_dir)

        logger.info("best_result: %.2f" % (best_acc * 100))

    if args.do_predict:
        column_names = test_ds.column_names
        test_ds = test_ds.map(partial(
            preprocess_function, do_predict=True),
                              batched=True,
                              batch_size=len(test_ds),
                              remove_columns=column_names,
                              num_proc=args.num_proc)
        # Serveral samples have more than four choices.
        test_batch_sampler = paddle.io.BatchSampler(
            test_ds, batch_size=1, shuffle=False)

        batchify_fn = lambda samples, fn=Dict({
            'input_ids': Pad(axis=1, pad_val=tokenizer.pad_token_id),  # input
            'token_type_ids': Pad(axis=1, pad_val=tokenizer.pad_token_type_id),  # segment
        }): fn(samples)

        test_data_loader = paddle.io.DataLoader(
            dataset=test_ds,
            batch_sampler=test_batch_sampler,
            collate_fn=batchify_fn,
            return_list=True)

        if not os.path.exists(args.output_dir):
            os.makedirs(args.output_dir)

        f = open(os.path.join(args.output_dir, "c311_predict.json"), 'w')
        result = {}
        idx = 0
        for step, batch in enumerate(test_data_loader):
            input_ids, segment_ids = batch
            with paddle.no_grad():
                logits = model(input_ids, segment_ids)
            preds = paddle.argmax(logits, axis=1).numpy().tolist()
            for pred in preds:
                result[str(idx)] = pred
                j = json.dumps({"id": idx, "label": pred})
                f.write(j + "\n")
                idx += 1
예제 #8
0
def do_train(args):
    paddle.set_device(args.device)
    strategy = fleet.DistributedStrategy()
    strategy.hybrid_configs = {
        "dp_degree": args.dp_degree,
        "mp_degree": args.mp_degree,
        "pp_degree": args.pp_degree
    }

    accumulate_steps = args.local_batch_size // args.micro_batch_size
    strategy.pipeline_configs = {
        "accumulate_steps": accumulate_steps,
        "micro_batch_size": args.micro_batch_size
    }

    fleet.init(is_collective=True, strategy=strategy)

    # obtain rank message of hybrid parallel
    hcg = fleet.get_hybrid_communicate_group()
    global_rank = hcg.get_global_rank()
    mp_rank = hcg.get_model_parallel_rank()
    pp_rank = hcg.get_stage_id()
    dp_rank = hcg.get_data_parallel_rank()
    local_rank = int(os.getenv("PADDLE_RANK_IN_NODE", 0))

    # seed control in hybrid parallel
    set_hyrbid_parallel_seed(args.seed, dp_rank, mp_rank, pp_rank)

    default_global_tokens_num = args.global_batch_size * args.max_seq_len

    model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
    tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path)

    # Define log writer
    log_writer_path = os.path.join(
        args.output_dir, "train_log",
        "{}_globalbsz_{}_pure_fp16_{}_recompute_{}_card_{}".format(
            args.model_name_or_path, args.global_batch_size,
            args.use_pure_fp16, False, global_rank).lower())

    if os.path.exists(log_writer_path):
        import shutil
        shutil.rmtree(log_writer_path)

    log_writer = LogWriter(log_writer_path)

    pretrained_models_list = list(
        model_class.pretrained_init_configuration.keys())

    if args.model_name_or_path in pretrained_models_list:
        model_config = model_class.pretrained_init_configuration[
            args.model_name_or_path]
        model_config["hidden_dropout_prob"] = args.hidden_dropout_prob
        model_config[
            "attention_probs_dropout_prob"] = args.attention_probs_dropout_prob

        model_config['num_partitions'] = args.mp_degree

        # MOE config
        initialize_model_and_expert_group(hcg)

        model_config['expert_mode'] = args.expert_mode
        model_config['hcg'] = hcg
        model_config['num_experts'] = args.num_experts
        model_config['top_k'] = args.top_k
        if args.expert_mode:
            model_config['gate'] = args.gate

        if args.pp_degree == 1:
            model_config["recompute_interval"] = 1 if args.use_recompute else 0
            model_config["recompute_partition"] = args.recompute_partition
            model_config["recompute_offload"] = args.recompute_offload
            if args.use_recompute and args.recompute_partition:
                raise Exception(
                    "when use_recompute is True, recompute_partition must be False in MoE."
                )

            model = GPTForPretraining(GPTModel(**model_config))
        else:
            model_config['topology'] = hcg.topology()
            model_config["recompute_interval"] = 1 if args.use_recompute else 0
            model = GPTForPretrainingPipe(**model_config)
    else:
        model = GPTForPretraining.from_pretrained(
            args.model_name_or_path,
            hidden_dropout_prob=args.hidden_dropout_prob,
            attention_probs_dropout_prob=args.attention_probs_dropout_prob)

    # Create the critrion for the gpt model
    criterion = GPTPretrainingCriterion()

    if args.decay_steps is None:
        args.decay_steps = args.max_steps
    warmup_step = args.warmup_rate * args.decay_steps

    lr_scheduler = None

    if args.lr_decay_style == "none":
        lr_scheduler = None
    elif args.lr_decay_style == "cosine":
        lr_scheduler = lr.CosineAnnealingWithWarmupDecay(
            max_lr=args.max_lr,
            min_lr=args.min_lr,
            warmup_step=warmup_step,
            decay_step=args.decay_steps)


# Generate parameter names needed to perform weight decay.
# All bias and LayerNorm parameters are excluded.
    if args.use_pure_fp16:
        scaler = paddle.amp.GradScaler(init_loss_scaling=args.scale_loss)
        scaler = fleet.distributed_scaler(scaler)
        scaler._unscale = MethodType(unscale_method, scaler)
        model = paddle.amp.decorate(models=model,
                                    optimizers=None,
                                    level='O2',
                                    save_dtype='float32')

    opt_fused_tensors, decay_fused_tensors, reduce_fused_tensors, gate_fused_tensors, \
        expert_fusion_names = parameters_classify(model)
    decay_params = [p.name for p in decay_fused_tensors]

    clip = None
    if args.grad_clip > 0:
        is_expert_param_fun = lambda param: param.name in expert_fusion_names
        clip = moe.ClipGradByGlobalNorm(clip_norm=args.grad_clip, \
                                        is_expert_param_func = is_expert_param_fun, \
                                        moe_group = hcg.get_expert_parallel_group())

    optimizer = AdamW(
        learning_rate=lr_scheduler
        if lr_scheduler is not None else args.max_lr,
        beta1=args.adam_beta1,
        beta2=args.adam_beta2,
        epsilon=args.adam_epsilon,
        parameters=opt_fused_tensors,
        weight_decay=args.weight_decay,
        grad_clip=clip,
        apply_decay_param_fun=lambda x: x in decay_params,  #decay_params,
        multi_precision=args.use_pure_fp16)

    if paddle.distributed.get_world_size() > 1 and args.resume_dir is None:
        print(">> initialize....")
        initialize_mp_dp_parameters(model, hcg)

    #in order to restore reader.
    pass_num = 0
    file_id = 0
    start_epoch = 0
    args.resume_dir = None if len(args.resume_dir) <= 0 else args.resume_dir

    if args.resume_dir is not None:
        global_step, loss_scale, data_meta = load_checkpoint(
            args, model, optimizer, lr_scheduler, tokenizer, dp_rank, mp_rank,
            pp_rank)
        pass_num = data_meta["pass_num"]
        file_id = data_meta["file_id"]
        start_epoch = data_meta["start_epoch"]

    if args.use_pure_fp16:
        scaler = paddle.amp.GradScaler(
            init_loss_scaling=loss_scale if args.
            resume_dir is not None else args.scale_loss)
        scaler = fleet.distributed_scaler(scaler)
        scaler._unscale = MethodType(unscale_method, scaler)
        model, optimizer = paddle.amp.decorate(models=model,
                                               optimizers=optimizer,
                                               level='O2',
                                               save_dtype='float32')

    if args.model_name_or_path not in pretrained_models_list:
        logger.info("Try to load checkpoint from %s " %
                    args.model_name_or_path)
        opt_path = os.path.join(args.model_name_or_path, "model_state.pdopt")
        if os.path.exists(opt_path):
            opt_dict = paddle.load(opt_path)
            optimizer.set_state_dict(opt_dict)
        else:
            logger.warning("No optimizer checkpoint file found in %s." %
                           opt_path)

    global_step = 0 if args.resume_dir is None else global_step
    timers = get_timers()
    tic_train = time.time()
    for epoch in range(start_epoch, args.num_train_epochs):
        files = [
            os.path.join(args.input_dir, f) for f in os.listdir(args.input_dir)
            if (os.path.isfile(os.path.join(args.input_dir, f))
                and "npz_" not in str(f))
        ]
        files.sort()
        num_files = len(files)
        for f_id in range(file_id, num_files):
            data_file = files[f_id]
            train_data_loader, valid_data_loader, test_data_loader = create_pretrained_dataset(
                args,
                data_file,
                local_rank=local_rank,
                data_world_size=args.dp_degree,
                data_world_rank=dp_rank,
                eos_id=tokenizer.eos_token_id)

            # Bug fix, if not call valid_data_loader, the enumerate will call valid_data_loader
            # many times. and start a new random dataloader.
            valid_data_loader = valid_data_loader()
            test_data_loader = test_data_loader()

            for step, batch in enumerate(train_data_loader()):
                # to remove the train data that has been studyed.
                if step < global_step - pass_num: continue

                global_step += 1
                tokens, loss_mask, labels = batch

                loss_mask.stop_gradient = True
                labels.stop_gradient = True

                loss = 0.0
                for i in range(accumulate_steps):
                    start_index = i * args.micro_batch_size
                    end_index = start_index + args.micro_batch_size
                    timers('forward-compute').start()
                    with paddle.amp.auto_cast(
                            args.use_pure_fp16,
                            custom_black_list=[
                                "reduce_sum",
                                "c_softmax_with_cross_entropy",
                                "elementwise_div",
                            ],
                            level='O2'):
                        preds = model(tokens[start_index:end_index, :])
                        loss_mbs = criterion(
                            preds, labels[start_index:end_index, :],
                            loss_mask[start_index:end_index, :])
                    timers('forward-compute').stop()

                    if args.gate != "naive" and args.balance_loss_weight:
                        aux_loss_list = [
                            l.moe_mlp.gate.get_loss(clear=False)
                            for l in model.gpt.decoder.layers
                            if hasattr(l.moe_mlp, "gate")
                        ]
                        bal_loss = paddle.concat(aux_loss_list)
                        if bal_loss.dtype == paddle.float16:
                            bal_loss = paddle.cast(bal_loss,
                                                   dtype=paddle.float32)
                        bal_loss = bal_loss.mean()
                        loss_mbs += bal_loss * args.balance_loss_weight
                    loss_mbs = loss_mbs / accumulate_steps

                    timers('backward-compute').start()
                    if args.use_pure_fp16:
                        scaler.scale(loss_mbs).backward()
                    else:
                        loss_mbs.backward()
                    timers('backward-compute').stop()
                    loss = loss + loss_mbs

                timers('backward-params-all-reduce').start()
                all_reduce_parameters(gate_fused_tensors,
                                      hcg.get_expert_parallel_group())
                all_reduce_parameters(reduce_fused_tensors,
                                      hcg.get_data_parallel_group())
                timers('backward-params-all-reduce').stop()

                if args.use_pure_fp16:
                    scaler.minimize(optimizer, loss)
                else:
                    optimizer.step()
                learning_rate = optimizer.get_lr()
                if lr_scheduler is not None:
                    lr_scheduler.step()
                optimizer.clear_grad()

                if global_step % args.logging_freq == 0:
                    avg_loss = loss.numpy()
                    speed = args.logging_freq / (time.time() - tic_train)
                    if args.gate != "naive" and args.balance_loss_weight:
                        bal_loss = bal_loss.numpy()
                        avg_loss -= bal_loss
                    else:
                        bal_loss = -1
                    logger.info(
                        "global step %d, epoch: %d, batch: %d, loss: %.9f, bal_loss: %.9f, speed: %.2f step/s, ips: %.0f tokens/s, learning rate: %.5e"
                        % (global_step, epoch, step, avg_loss, bal_loss, speed,
                           speed * default_global_tokens_num, learning_rate))
                    log_writer.add_scalar("loss", float(loss), global_step)
                    log_writer.add_scalar("learning_rate", learning_rate,
                                          global_step)

                    tic_train = time.time()
                    timer_log(args.logging_freq)

                if (global_step % args.save_steps == 0
                        or global_step >= args.max_steps):
                    loss_scale = scaler._scale if args.use_pure_fp16 else None
                    save_checkpoint(args, global_step, model, optimizer,
                                    lr_scheduler, tokenizer, loss_scale,
                                    dp_rank, mp_rank, pp_rank, pass_num,
                                    file_id, epoch)
                    print(
                        "save checkpoint for step_{} successfully...loss_scale = {}"
                        .format(global_step, loss_scale))

                if global_step % args.eval_freq == 0:
                    # Since the valid data broardcast to all devices, we do evaluate on all device.
                    run_evaluate(args, valid_data_loader, model, criterion,
                                 args.eval_iters, log_writer, global_step,
                                 epoch, "valid")

                if global_step >= args.max_steps:
                    run_evaluate(args, test_data_loader, model, criterion,
                                 args.test_iters, log_writer, global_step,
                                 epoch, "test")
                    logger.info("The training process is complete.")
                    del train_data_loader
                    return

            # to record sum of the length of train_data_loader that has been read.
            pass_num += len(train_data_loader())
            del train_data_loader
예제 #9
0
def do_train(args):
    paddle.set_device("gpu" if args.n_gpu else "cpu")
    if paddle.distributed.get_world_size() > 1:
        paddle.distributed.init_parallel_env()

    set_seed(args)

    args.task_name = args.task_name.lower()
    dataset_class, metric_class = TASK_CLASSES[args.task_name]
    args.model_type = args.model_type.lower()
    model_class, tokenizer_class = MODEL_CLASSES[args.model_type]

    train_ds = dataset_class.get_datasets(['train'])

    tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path)
    trans_func = partial(convert_example,
                         tokenizer=tokenizer,
                         label_list=train_ds.get_labels(),
                         max_seq_length=args.max_seq_length)
    train_ds = train_ds.apply(trans_func, lazy=True)
    train_batch_sampler = paddle.io.DistributedBatchSampler(
        train_ds, batch_size=args.batch_size, shuffle=True)
    batchify_fn = lambda samples, fn=Tuple(
        Pad(axis=0, pad_val=tokenizer.pad_token_id),  # input
        Pad(axis=0, pad_val=tokenizer.pad_token_id),  # segment
        Stack(),  # length
        Stack(dtype="int64" if train_ds.get_labels() else "float32")  # label
    ): [data for i, data in enumerate(fn(samples)) if i != 2]
    train_data_loader = DataLoader(dataset=train_ds,
                                   batch_sampler=train_batch_sampler,
                                   collate_fn=batchify_fn,
                                   num_workers=0,
                                   return_list=True)
    if args.task_name == "mnli":
        dev_dataset_matched, dev_dataset_mismatched = dataset_class.get_datasets(
            ["dev_matched", "dev_mismatched"])
        dev_dataset_matched = dev_dataset_matched.apply(trans_func, lazy=True)
        dev_dataset_mismatched = dev_dataset_mismatched.apply(trans_func,
                                                              lazy=True)
        dev_batch_sampler_matched = paddle.io.BatchSampler(
            dev_dataset_matched, batch_size=args.batch_size, shuffle=False)
        dev_data_loader_matched = DataLoader(
            dataset=dev_dataset_matched,
            batch_sampler=dev_batch_sampler_matched,
            collate_fn=batchify_fn,
            num_workers=0,
            return_list=True)
        dev_batch_sampler_mismatched = paddle.io.BatchSampler(
            dev_dataset_mismatched, batch_size=args.batch_size, shuffle=False)
        dev_data_loader_mismatched = DataLoader(
            dataset=dev_dataset_mismatched,
            batch_sampler=dev_batch_sampler_mismatched,
            collate_fn=batchify_fn,
            num_workers=0,
            return_list=True)
    else:
        dev_dataset = dataset_class.get_datasets(["dev"])
        dev_dataset = dev_dataset.apply(trans_func, lazy=True)
        dev_batch_sampler = paddle.io.BatchSampler(dev_dataset,
                                                   batch_size=args.batch_size,
                                                   shuffle=False)
        dev_data_loader = DataLoader(dataset=dev_dataset,
                                     batch_sampler=dev_batch_sampler,
                                     collate_fn=batchify_fn,
                                     num_workers=0,
                                     return_list=True)

    num_labels = 1 if train_ds.get_labels() == None else len(
        train_ds.get_labels())

    model = model_class.from_pretrained(args.model_name_or_path,
                                        num_classes=num_labels)
    if paddle.distributed.get_world_size() > 1:
        model = paddle.DataParallel(model)

    # Step1: Initialize a dictionary to save the weights from the origin BERT model.
    origin_weights = {}
    for name, param in model.named_parameters():
        origin_weights[name] = param

    # Step2: Convert origin model to supernet.
    sp_config = supernet(expand_ratio=args.width_mult_list)
    model = Convert(sp_config).convert(model)
    # Use weights saved in the dictionary to initialize supernet.
    utils.set_state_dict(model, origin_weights)
    del origin_weights

    # Step3: Define teacher model.
    teacher_model = model_class.from_pretrained(args.model_name_or_path,
                                                num_classes=num_labels)

    # Step4: Config about distillation.
    mapping_layers = ['bert.embeddings']
    for idx in range(model.bert.config['num_hidden_layers']):
        mapping_layers.append('bert.encoder.layers.{}'.format(idx))

    default_distill_config = {
        'lambda_distill': 0.1,
        'teacher_model': teacher_model,
        'mapping_layers': mapping_layers,
    }
    distill_config = DistillConfig(**default_distill_config)

    # Step5: Config in supernet training.
    ofa_model = OFA(model,
                    distill_config=distill_config,
                    elastic_order=['width'])

    criterion = paddle.nn.loss.CrossEntropyLoss() if train_ds.get_labels(
    ) else paddle.nn.loss.MSELoss()

    metric = metric_class()

    if args.task_name == "mnli":
        dev_data_loader = (dev_data_loader_matched, dev_data_loader_mismatched)

    # Step6: Calculate the importance of neurons and head,
    # and then reorder them according to the importance.
    head_importance, neuron_importance = utils.compute_neuron_head_importance(
        args.task_name,
        ofa_model.model,
        dev_data_loader,
        loss_fct=criterion,
        num_layers=model.bert.config['num_hidden_layers'],
        num_heads=model.bert.config['num_attention_heads'])
    reorder_neuron_head(ofa_model.model, head_importance, neuron_importance)

    num_training_steps = args.max_steps if args.max_steps > 0 else len(
        train_data_loader) * args.num_train_epochs

    lr_scheduler = LinearDecayWithWarmup(args.learning_rate,
                                         num_training_steps, args.warmup_steps)

    optimizer = paddle.optimizer.AdamW(
        learning_rate=lr_scheduler,
        epsilon=args.adam_epsilon,
        parameters=ofa_model.model.parameters(),
        weight_decay=args.weight_decay,
        apply_decay_param_fun=lambda x: x in [
            p.name for n, p in ofa_model.model.named_parameters()
            if not any(nd in n for nd in ["bias", "norm"])
        ])

    global_step = 0
    tic_train = time.time()
    for epoch in range(args.num_train_epochs):
        # Step7: Set current epoch and task.
        ofa_model.set_epoch(epoch)
        ofa_model.set_task('width')

        for step, batch in enumerate(train_data_loader):
            global_step += 1
            input_ids, segment_ids, labels = batch

            for width_mult in args.width_mult_list:
                # Step8: Broadcast supernet config from width_mult,
                # and use this config in supernet training.
                net_config = apply_config(ofa_model, width_mult)
                ofa_model.set_net_config(net_config)
                logits, teacher_logits = ofa_model(input_ids,
                                                   segment_ids,
                                                   attention_mask=[None, None])
                rep_loss = ofa_model.calc_distill_loss()
                if args.task_name == 'sts-b':
                    logit_loss = 0.0
                else:
                    logit_loss = soft_cross_entropy(logits,
                                                    teacher_logits.detach())
                loss = rep_loss + args.lambda_logit * logit_loss
                loss.backward()
            optimizer.step()
            lr_scheduler.step()
            ofa_model.model.clear_gradients()

            if global_step % args.logging_steps == 0:
                if (not args.n_gpu > 1) or paddle.distributed.get_rank() == 0:
                    logger.info(
                        "global step %d, epoch: %d, batch: %d, loss: %f, speed: %.2f step/s"
                        % (global_step, epoch, step, loss, args.logging_steps /
                           (time.time() - tic_train)))
                tic_train = time.time()

            if global_step % args.save_steps == 0:
                if args.task_name == "mnli":
                    evaluate(teacher_model,
                             criterion,
                             metric,
                             dev_data_loader_matched,
                             width_mult=100)
                    evaluate(teacher_model,
                             criterion,
                             metric,
                             dev_data_loader_mismatched,
                             width_mult=100)
                else:
                    evaluate(teacher_model,
                             criterion,
                             metric,
                             dev_data_loader,
                             width_mult=100)
                for idx, width_mult in enumerate(args.width_mult_list):
                    net_config = apply_config(ofa_model, width_mult)
                    ofa_model.set_net_config(net_config)
                    tic_eval = time.time()
                    if args.task_name == "mnli":
                        acc = evaluate(ofa_model, criterion, metric,
                                       dev_data_loader_matched, width_mult)
                        evaluate(ofa_model, criterion, metric,
                                 dev_data_loader_mismatched, width_mult)
                        print("eval done total : %s s" %
                              (time.time() - tic_eval))
                    else:
                        acc = evaluate(ofa_model, criterion, metric,
                                       dev_data_loader, width_mult)
                        print("eval done total : %s s" %
                              (time.time() - tic_eval))

                    if (not args.n_gpu > 1
                        ) or paddle.distributed.get_rank() == 0:
                        output_dir = os.path.join(args.output_dir,
                                                  "model_%d" % global_step)
                        if not os.path.exists(output_dir):
                            os.makedirs(output_dir)
                        # need better way to get inner model of DataParallel
                        model_to_save = model._layers if isinstance(
                            model, paddle.DataParallel) else model
                        model_to_save.save_pretrained(output_dir)
                        tokenizer.save_pretrained(output_dir)
예제 #10
0
def do_train(args):
    paddle.set_device(args.device)
    if paddle.distributed.get_world_size() > 1:
        paddle.distributed.init_parallel_env()

    train_ds = load_dataset(datafiles=('./data/train.json'))
    tags_to_idx = load_dict("./data/tags.txt")
    labels_to_idx = load_dict("./data/classifier_labels.txt")
    tokenizer = ErnieCtmTokenizer.from_pretrained(args.model_dir)
    trans_func = partial(convert_example,
                         tokenizer=tokenizer,
                         max_seq_len=args.max_seq_len,
                         tags_to_idx=tags_to_idx,
                         labels_to_idx=labels_to_idx)
    train_ds.map(trans_func)

    ignore_label = tags_to_idx["O"]
    batchify_fn = lambda samples, fn=Tuple(
        Pad(axis=0, pad_val=tokenizer.pad_token_id, dtype='int64'
            ),  # input_ids
        Pad(axis=0, pad_val=tokenizer.pad_token_type_id, dtype='int64'
            ),  # token_type_ids
        Stack(dtype='int64'),  # seq_len
        Pad(axis=0, pad_val=ignore_label, dtype='int64'),  # tags
        Stack(dtype='int64'),  # cls_label
    ): fn(samples)

    train_batch_sampler = paddle.io.DistributedBatchSampler(
        train_ds, batch_size=args.batch_size, shuffle=False, drop_last=True)
    train_data_loader = DataLoader(train_ds,
                                   batch_sampler=train_batch_sampler,
                                   num_workers=0,
                                   collate_fn=batchify_fn,
                                   return_list=True)

    model = ErnieCtmWordtagModel.from_pretrained(
        args.model_dir,
        num_cls_label=len(labels_to_idx),
        num_tag=len(tags_to_idx),
        ignore_index=tags_to_idx["O"])

    if paddle.distributed.get_world_size() > 1:
        model = paddle.DataParallel(model)

    num_training_steps = args.max_steps if args.max_steps > 0 else (
        len(train_data_loader) * args.num_train_epochs)
    warmup = args.warmup_steps if args.warmup_steps > 0 else args.warmup_proportion
    lr_scheduler = LinearDecayWithWarmup(args.learning_rate,
                                         num_training_steps, warmup)

    num_train_optimization_steps = len(
        train_ds) / args.batch_size * args.num_train_epochs

    decay_params = [
        p.name for n, p in model.named_parameters()
        if not any(nd in n for nd in ["bias", "norm"])
    ]
    optimizer = paddle.optimizer.AdamW(
        learning_rate=lr_scheduler,
        epsilon=args.adam_epsilon,
        parameters=model.parameters(),
        weight_decay=args.weight_decay,
        apply_decay_param_fun=lambda x: x in decay_params)

    logger.info("Total steps: %s" % num_training_steps)
    logger.info("WarmUp steps: %s" % warmup)

    cls_acc = paddle.metric.Accuracy()
    seq_acc = SequenceAccuracy()
    total_loss = 0

    global_step = 0

    for epoch in range(1, args.num_train_epochs + 1):
        logger.info(f"Epoch {epoch} beginnig")
        start_time = time.time()

        for total_step, batch in enumerate(train_data_loader):
            global_step += 1
            input_ids, token_type_ids, seq_len, tags, cls_label = batch

            outputs = model(input_ids,
                            token_type_ids,
                            lengths=seq_len,
                            tag_labels=tags,
                            cls_label=cls_label)
            loss, seq_logits, cls_logits = outputs[0], outputs[1], outputs[2]
            loss = loss.mean()
            total_loss += loss
            loss.backward()

            optimizer.step()
            optimizer.clear_grad()
            lr_scheduler.step()

            cls_correct = cls_acc.compute(pred=cls_logits.reshape(
                [-1, len(labels_to_idx)]),
                                          label=cls_label.reshape([-1]))
            cls_acc.update(cls_correct)
            seq_correct = seq_acc.compute(pred=seq_logits.reshape(
                [-1, len(tags_to_idx)]),
                                          label=tags.reshape([-1]),
                                          ignore_index=tags_to_idx["O"])
            seq_acc.update(seq_correct)

            if global_step % args.logging_steps == 0 and global_step != 0:
                end_time = time.time()
                speed = float(args.logging_steps) / (end_time - start_time)
                logger.info(
                    "[Training]["
                    "epoch: %s/%s][step: %s/%s] loss: %6f, Classification Accuracy: %6f, Sequence Labeling Accuracy: %6f, speed: %6f"
                    % (epoch, args.num_train_epochs, global_step,
                       num_training_steps, total_loss / args.logging_steps,
                       cls_acc.accumulate(), seq_acc.accumulate(), speed))
                start_time = time.time()
                cls_acc.reset()
                seq_acc.reset()
                total_loss = 0

            if (global_step % args.save_steps == 0
                    or global_step == num_training_steps
                ) and paddle.distributed.get_rank() == 0:
                output_dir = os.path.join(
                    args.output_dir,
                    "ernie_ctm_ft_model_%d.pdparams" % (global_step))
                if not os.path.exists(output_dir):
                    os.makedirs(output_dir)
                model_to_save = model._layers if isinstance(
                    model, paddle.DataParallel) else model
                model_to_save.save_pretrained(output_dir)
                tokenizer.save_pretrained(output_dir)
예제 #11
0
def main():
    parser = PdArgumentParser(
        (ModelArguments, DataArguments, TrainingArguments))
    model_args, data_args, training_args = parser.parse_args_into_dataclasses()

    # Log model and data config
    training_args.print_config(model_args, "Model")
    training_args.print_config(data_args, "Data")

    paddle.set_device(training_args.device)
    # Log on each process the small summary:
    logger.warning(
        f"Process rank: {training_args.local_rank}, device: {training_args.device}, world_size: {training_args.world_size}, "
        +
        f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
    )

    # Detecting last checkpoint.
    last_checkpoint = None
    if os.path.isdir(
            training_args.output_dir
    ) and training_args.do_train and not training_args.overwrite_output_dir:
        last_checkpoint = get_last_checkpoint(training_args.output_dir)
        if last_checkpoint is None and len(os.listdir(
                training_args.output_dir)) > 0:
            raise ValueError(
                f"Output directory ({training_args.output_dir}) already exists and is not empty. "
                "Use --overwrite_output_dir to overcome.")
        elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
            logger.info(
                f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
                "the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
            )

    # set_seed(args)
    data_args.dataset = data_args.dataset.strip()

    if data_args.dataset in ALL_DATASETS:
        # if you custom you hyper-parameters in yaml config, it will overwrite all args.
        config = ALL_DATASETS[data_args.dataset]
        for args in (model_args, data_args, training_args):
            for arg in vars(args):
                if arg in config.keys():
                    setattr(args, arg, config[arg])

        training_args.per_device_train_batch_size = config["batch_size"]
        training_args.per_device_eval_batch_size = config["batch_size"]

    dataset_config = data_args.dataset.split(" ")
    raw_datasets = load_dataset(
        dataset_config[0],
        None if len(dataset_config) <= 1 else dataset_config[1],
        cache_dir=model_args.cache_dir)

    label_list = getattr(raw_datasets['train'], "label_list", None)
    data_args.label_list = label_list

    # Define tokenizer, model, loss function.
    tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path)
    model = AutoModelForQuestionAnswering.from_pretrained(
        model_args.model_name_or_path)

    loss_fct = CrossEntropyLossForSQuAD()

    # Preprocessing the datasets.
    # Preprocessing is slighlty different for training and evaluation.
    if training_args.do_train:
        column_names = raw_datasets["train"].column_names
    elif training_args.do_eval:
        column_names = raw_datasets["validation"].column_names
    else:
        column_names = raw_datasets["test"].column_names

    if training_args.do_train:
        train_dataset = raw_datasets["train"]
        # Create train feature from dataset
        with training_args.main_process_first(
                desc="train dataset map pre-processing"):
            # Dataset pre-process
            train_dataset = train_dataset.map(
                partial(prepare_train_features,
                        tokenizer=tokenizer,
                        args=data_args),
                batched=True,
                num_proc=4,
                remove_columns=column_names,
                load_from_cache_file=not data_args.overwrite_cache,
                desc="Running tokenizer on train dataset",
            )

    if training_args.do_eval:
        eval_examples = raw_datasets["validation"]
        with training_args.main_process_first(
                desc="evaluate dataset map pre-processing"):
            eval_dataset = eval_examples.map(
                partial(prepare_validation_features,
                        tokenizer=tokenizer,
                        args=data_args),
                batched=True,
                num_proc=4,
                remove_columns=column_names,
                load_from_cache_file=not data_args.overwrite_cache,
                desc="Running tokenizer on validation dataset",
            )
    if training_args.do_predict:
        predict_examples = raw_datasets["test"]
        with training_args.main_process_first(
                desc="test dataset map pre-processing"):
            predict_dataset = predict_examples.map(
                partial(prepare_validation_features,
                        tokenizer=tokenizer,
                        args=data_args),
                batched=True,
                num_proc=4,
                remove_columns=column_names,
                load_from_cache_file=not data_args.overwrite_cache,
                desc="Running tokenizer on prediction dataset",
            )

    # Define data collector
    data_collator = DataCollatorWithPadding(tokenizer)

    # Post-processing:
    def post_processing_function(examples,
                                 features,
                                 predictions,
                                 stage="eval"):
        # Post-processing: we match the start logits and end logits to answers in the original context.
        predictions, all_nbest_json, scores_diff_json = compute_prediction(
            examples=examples,
            features=features,
            predictions=predictions,
            n_best_size=data_args.n_best_size,
            max_answer_length=data_args.max_answer_length,
            null_score_diff_threshold=data_args.null_score_diff_threshold,
        )

        # # Format the result to the format the metric expects.
        # formatted_predictions = [{
        #     "id": k,
        #     "prediction_text": v
        # } for k, v in predictions.items()]

        references = [{
            "id": ex["id"],
            "answers": ex["answers"]
        } for ex in examples]
        return EvalPrediction(predictions=predictions, label_ids=references)

    def compute_metrics(p: EvalPrediction):
        ret = squad_evaluate(examples=p.label_ids,
                             preds=p.predictions,
                             is_whitespace_splited=False)
        return dict(ret)
        # return metric.compute(predictions=p.predictions, references=p.label_ids)

    trainer = QuestionAnsweringTrainer(
        model=model,
        criterion=loss_fct,
        args=training_args,
        train_dataset=train_dataset if training_args.do_train else None,
        eval_dataset=eval_dataset if training_args.do_eval else None,
        eval_examples=eval_examples if training_args.do_eval else None,
        data_collator=data_collator,
        post_process_function=post_processing_function,
        tokenizer=tokenizer,
        compute_metrics=compute_metrics,
    )

    checkpoint = None
    if training_args.resume_from_checkpoint is not None:
        checkpoint = training_args.resume_from_checkpoint
    elif last_checkpoint is not None:
        checkpoint = last_checkpoint

    if training_args.do_train:
        # Training
        train_result = trainer.train(resume_from_checkpoint=checkpoint)
        metrics = train_result.metrics
        trainer.save_model()
        trainer.log_metrics("train", metrics)
        trainer.save_metrics("train", metrics)
        trainer.save_state()

    # model.set_state_dict(paddle.load("tmp/model_state.pdparams"))

    # Evaluate and tests model
    if training_args.do_eval:
        eval_metrics = trainer.evaluate()
        trainer.log_metrics("eval", eval_metrics)

    if training_args.do_predict:
        test_ret = trainer.predict(predict_dataset, predict_examples)
        trainer.log_metrics("predict", test_ret.metrics)

        if test_ret.label_ids is None:
            paddle.save(
                test_ret.predictions,
                os.path.join(training_args.output_dir,
                             "test_results.pdtensor"),
            )

    # export inference model
    if training_args.do_export:
        # You can also load from certain checkpoint
        # trainer.load_state_dict_from_checkpoint("/path/to/checkpoint/")
        input_spec = [
            paddle.static.InputSpec(shape=[None, None],
                                    dtype="int64"),  # input_ids
            paddle.static.InputSpec(shape=[None, None],
                                    dtype="int64")  # segment_ids
        ]

        if model_args.export_model_dir is None:
            model_args.export_model_dir = os.path.join(
                training_args.output_dir, "export")
        paddlenlp.transformers.export_model(model=trainer.model,
                                            input_spec=input_spec,
                                            path=model_args.export_model_dir)
예제 #12
0
def do_predict(args):
    place = "gpu"
    paddle.set_device(place)

    # Define model
    transformer = FasterDecoder(src_vocab_size=args.src_vocab_size,
                                trg_vocab_size=args.trg_vocab_size,
                                max_length=args.max_length + 1,
                                num_encoder_layers=args.num_encoder_layers,
                                num_decoder_layers=args.num_decoder_layers,
                                n_head=args.n_head,
                                d_model=args.d_model,
                                d_inner_hid=args.d_inner_hid,
                                dropout=args.dropout,
                                weight_sharing=args.weight_sharing,
                                bos_id=args.bos_idx,
                                eos_id=args.eos_idx,
                                max_out_len=args.max_out_len,
                                decoder_lib=args.decoder_lib,
                                use_fp16_decoder=args.use_fp16_decoder)

    # Load checkpoint.
    transformer.load(
        os.path.join(args.init_from_params, "transformer.pdparams"))
    # Set evaluate mode
    transformer.eval()

    # Generate data randomly
    dec_input = paddle.randn(shape=[args.infer_batch_size, 1, args.d_model],
                             dtype='float32')
    enc_output = paddle.randn(
        shape=[args.infer_batch_size, args.max_length, args.d_model],
        dtype='float32')
    mem_seq_lens = paddle.full(shape=[args.infer_batch_size, 1],
                               fill_value=args.max_length,
                               dtype='int32')
    dtype = 'float32'
    if args.use_fp16_decoder:
        dtype = 'float16'
        dec_input = paddle.cast(dec_input, dtype=dtype)
        enc_output = paddle.cast(enc_output, dtype=dtype)
    self_cache = paddle.zeros(shape=[
        args.num_decoder_layers, 2, 0, args.infer_batch_size, args.d_model
    ],
                              dtype=dtype)
    mem_cache = paddle.zeros(shape=[
        args.num_decoder_layers, 2, args.infer_batch_size, args.max_length,
        args.d_model
    ],
                             dtype=dtype)

    with paddle.no_grad():
        for i in range(100):
            # For warmup.
            if 50 == i:
                start = time.time()
            dec_output, self_cache, mem_cache = transformer.decoder(
                from_tensor=dec_input,
                memory_tensor=enc_output,
                mem_seq_len=mem_seq_lens,
                self_cache=self_cache,
                mem_cache=mem_cache)
        logger.info("Average test time for decoder is %f ms" %
                    ((time.time() - start) / 50 * 1000))
예제 #13
0
def parse_args(MODEL_CLASSES):
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--model_type",
        default=None,
        type=str,
        required=True,
        help="Model type selected in the list: " +
        ", ".join(MODEL_CLASSES.keys()),
    )
    parser.add_argument(
        "--model_name_or_path",
        default=None,
        type=str,
        required=True,
        help="Path to pre-trained model or shortcut name selected in the list: "
        + ", ".join(
            sum([
                list(classes[-1].pretrained_init_configuration.keys())
                for classes in MODEL_CLASSES.values()
            ], [])),
    )

    # Train I/O config
    parser.add_argument(
        "--input_dir",
        default=None,
        type=str,
        required=True,
        help="The input directory where the data will be read from.",
    )
    parser.add_argument(
        "--output_dir",
        default=None,
        type=str,
        required=True,
        help=
        "The output directory where the training logs and checkpoints will be written."
    )
    parser.add_argument("--split",
                        type=str,
                        default='949,50,1',
                        help="Train/valid/test data split.")

    parser.add_argument("--max_seq_len",
                        type=int,
                        default=1024,
                        help="Max sequence length.")
    parser.add_argument(
        "--micro_batch_size",
        default=8,
        type=int,
        help="Batch size per device for one step training.",
    )
    parser.add_argument(
        "--global_batch_size",
        default=None,
        type=int,
        help=
        "Global batch size for all training process. None for not check the size is valid. If we only use data parallelism, it should be device_num * micro_batch_size."
    )

    # Default training config
    parser.add_argument("--weight_decay",
                        default=0.0,
                        type=float,
                        help="Weight decay if we apply some.")
    parser.add_argument("--grad_clip",
                        default=0.0,
                        type=float,
                        help="Grad clip for the parameter.")
    parser.add_argument("--max_lr",
                        default=1e-5,
                        type=float,
                        help="The initial max learning rate for Adam.")
    parser.add_argument("--min_lr",
                        default=5e-5,
                        type=float,
                        help="The initial min learning rate for Adam.")
    parser.add_argument(
        "--warmup_rate",
        default=0.01,
        type=float,
        help="Linear warmup over warmup_steps for learing rate.")

    # Adam optimizer config
    parser.add_argument(
        "--adam_beta1",
        default=0.9,
        type=float,
        help=
        "The beta1 for Adam optimizer. The exponential decay rate for the 1st moment estimates."
    )
    parser.add_argument(
        "--adam_beta2",
        default=0.999,
        type=float,
        help=
        "The bate2 for Adam optimizer. The exponential decay rate for the 2nd moment estimates."
    )
    parser.add_argument("--adam_epsilon",
                        default=1e-8,
                        type=float,
                        help="Epsilon for Adam optimizer.")

    # Training steps config
    parser.add_argument(
        "--num_train_epochs",
        default=1,
        type=int,
        help="Total number of training epochs to perform.",
    )
    parser.add_argument(
        "--max_steps",
        default=500000,
        type=int,
        help=
        "If > 0: set total number of training steps to perform. Override num_train_epochs."
    )
    parser.add_argument("--save_steps",
                        type=int,
                        default=500,
                        help="Save checkpoint every X updates steps.")
    parser.add_argument(
        "--decay_steps",
        default=360000,
        type=int,
        help=
        "The steps use to control the learing rate. If the step > decay_steps, will use the min_lr."
    )
    parser.add_argument("--logging_freq",
                        type=int,
                        default=1,
                        help="Log every X updates steps.")
    parser.add_argument("--eval_freq",
                        type=int,
                        default=500,
                        help="Evaluate for every X updates steps.")
    parser.add_argument("--eval_iters",
                        type=int,
                        default=10,
                        help="Evaluate the model use X steps data.")

    # Config for 4D Parallelism
    parser.add_argument("--use_sharding",
                        type=str2bool,
                        nargs='?',
                        const=False,
                        help="Use sharding Parallelism to training.")
    parser.add_argument(
        "--sharding_degree",
        type=int,
        default=1,
        help="Sharding degree. Share the parameters to many cards.")
    parser.add_argument("--dp_degree",
                        type=int,
                        default=1,
                        help="Data Parallelism degree.")
    parser.add_argument(
        "--mp_degree",
        type=int,
        default=1,
        help=
        "Model Parallelism degree. Spliting the linear layers to many cards.")
    parser.add_argument(
        "--pp_degree",
        type=int,
        default=1,
        help=
        "Pipeline Parallelism degree.  Spliting the the model layers to different parts."
    )
    parser.add_argument("--use_recompute",
                        type=str2bool,
                        nargs='?',
                        const=False,
                        help="Using the recompute to save the memory.")

    # AMP config
    parser.add_argument("--use_amp",
                        type=str2bool,
                        nargs='?',
                        const=False,
                        help="Enable mixed precision training.")
    parser.add_argument(
        "--enable_addto",
        type=str2bool,
        nargs='?',
        const=True,
        help=
        "Whether to enable the addto strategy for gradient accumulation or not. This is only used for AMP training."
    )
    parser.add_argument(
        "--scale_loss",
        type=float,
        default=128,
        help=
        "The value of scale_loss for fp16. This is only used for AMP training."
    )
    parser.add_argument("--hidden_dropout_prob",
                        type=float,
                        default=0.1,
                        help="The hidden dropout prob.")
    parser.add_argument("--attention_probs_dropout_prob",
                        type=float,
                        default=0.1,
                        help="The attention probs dropout prob.")
    # Other config
    parser.add_argument("--seed",
                        type=int,
                        default=1234,
                        help="Random seed for initialization")
    parser.add_argument("--check_accuracy",
                        type=str2bool,
                        nargs='?',
                        const=False,
                        help="Check accuracy for training process.")
    parser.add_argument("--device",
                        type=str,
                        default="gpu",
                        choices=["cpu", "gpu", "xpu"],
                        help="select cpu, gpu, xpu devices.")
    parser.add_argument("--lr_decay_style",
                        type=str,
                        default="cosine",
                        choices=["cosine", "none"],
                        help="Learning rate decay style.")
    args = parser.parse_args()
    args.test_iters = args.eval_iters * 10

    if args.check_accuracy:
        if args.hidden_dropout_prob != 0:
            args.hidden_dropout_prob = .0
            logger.warning(
                "The hidden_dropout_prob should set to 0 for accuracy checking."
            )
        if args.attention_probs_dropout_prob != 0:
            args.attention_probs_dropout_prob = .0
            logger.warning(
                "The attention_probs_dropout_prob should set to 0 for accuracy checking."
            )

    logger.info('{:20}:{}'.format("paddle commit id", paddle.version.commit))
    for arg in vars(args):
        logger.info('{:20}:{}'.format(arg, getattr(args, arg)))

    return args
예제 #14
0
def predict():
    paddle.set_device("gpu" if args.use_gpu else "cpu")

    model = ErnieForGeneration.from_pretrained(args.model_name_or_path)
    if "ernie-tiny" in args.model_name_or_path:
        tokenizer = ErnieTinyTokenizer.from_pretrained(args.model_name_or_path)
    elif "ernie" in args.model_name_or_path:
        tokenizer = ErnieTokenizer.from_pretrained(args.model_name_or_path)
    elif "roberta" in args.model_name_or_path or "rbt" in args.model_name_or_path:
        tokenizer = RobertaTokenizer.from_pretrained(args.model_name_or_path)
    elif "electra" in args.model_name_or_path:
        tokenizer = ElectraTokenizer.from_pretrained(args.model_name_or_path)
    else:
        tokenizer = BertTokenizer.from_pretrained(args.model_name_or_path)

    dev_dataset = Poetry.get_datasets(['dev'])
    attn_id = tokenizer.vocab[
        '[ATTN]'] if '[ATTN]' in tokenizer.vocab else tokenizer.vocab['[MASK]']
    tgt_type_id = model.sent_emb.weight.shape[0] - 1

    trans_func = convert_example(tokenizer=tokenizer,
                                 attn_id=attn_id,
                                 tgt_type_id=tgt_type_id,
                                 max_encode_len=args.max_encode_len,
                                 max_decode_len=args.max_decode_len)

    batchify_fn = lambda samples, fn=Tuple(
        Pad(axis=0, pad_val=tokenizer.pad_token_id),  # src_ids
        Pad(axis=0, pad_val=tokenizer.pad_token_id),  # src_pids
        Pad(axis=0, pad_val=tokenizer.pad_token_id),  # src_sids
        Pad(axis=0, pad_val=tokenizer.pad_token_id),  # tgt_ids
        Pad(axis=0, pad_val=tokenizer.pad_token_id),  # tgt_pids
        Pad(axis=0, pad_val=tokenizer.pad_token_id),  # tgt_sids
        Pad(axis=0, pad_val=tokenizer.pad_token_id),  # attn_ids
        Pad(axis=0, pad_val=tokenizer.pad_token_id),  # tgt_labels
    ): after_padding(fn(samples))

    dev_dataset = dev_dataset.apply(trans_func, lazy=True)
    test_batch_sampler = paddle.io.BatchSampler(dev_dataset,
                                                batch_size=args.batch_size,
                                                shuffle=False)
    data_loader = DataLoader(dataset=dev_dataset,
                             batch_sampler=test_batch_sampler,
                             collate_fn=batchify_fn,
                             num_workers=0,
                             return_list=True)

    if args.init_checkpoint:
        model_state = paddle.load(args.init_checkpoint)
        model.set_state_dict(model_state)

    model.eval()
    vocab = tokenizer.vocab
    eos_id = vocab[tokenizer.sep_token]
    sos_id = vocab[tokenizer.cls_token]
    pad_id = vocab[tokenizer.pad_token]
    unk_id = vocab[tokenizer.unk_token]
    vocab_size = len(vocab)
    evaluated_sentences = []
    evaluated_sentences_ids = []
    logger.info("Predicting...")
    for data in data_loader:
        (src_ids, src_sids, src_pids, _, _, _, _, _, _, _, _,
         raw_tgt_labels) = data  # never use target when infer
        # Use greedy_search_infilling or beam_search_infilling to get predictions
        output_ids = beam_search_infilling(model,
                                           src_ids,
                                           src_sids,
                                           eos_id=eos_id,
                                           sos_id=sos_id,
                                           attn_id=attn_id,
                                           pad_id=pad_id,
                                           unk_id=unk_id,
                                           vocab_size=vocab_size,
                                           max_decode_len=args.max_decode_len,
                                           max_encode_len=args.max_encode_len,
                                           beam_width=args.beam_width,
                                           length_penalty=args.length_penalty,
                                           tgt_type_id=tgt_type_id)

        for source_ids, target_ids, predict_ids in zip(
                src_ids.numpy().tolist(),
                raw_tgt_labels.numpy().tolist(), output_ids.tolist()):
            if eos_id in predict_ids:
                predict_ids = predict_ids[:predict_ids.index(eos_id)]
            source_sentence = ''.join(
                map(post_process,
                    vocab.to_tokens(source_ids[1:source_ids.index(eos_id)])))
            tgt_sentence = ''.join(
                map(post_process,
                    vocab.to_tokens(target_ids[1:target_ids.index(eos_id)])))
            predict_ids = ''.join(
                map(post_process, vocab.to_tokens(predict_ids)))
            print("source :%s\ntarget :%s\npredict:%s\n" %
                  (source_sentence, tgt_sentence, predict_ids))
예제 #15
0
def main():
    parser = PdArgumentParser(
        (ModelArguments, DataArguments, PreTrainingArguments))
    model_args, data_args, training_args = parser.parse_args_into_dataclasses()
    set_seed(training_args)
    paddle.set_device(training_args.device)
    if paddle.distributed.get_world_size() > 1:
        paddle.distributed.init_parallel_env()

    training_args.eval_iters = 10
    training_args.test_iters = training_args.eval_iters * 10

    # Log model and data config
    training_args.print_config(model_args, "Model")
    training_args.print_config(data_args, "Data")

    # Log on each process the small summary:
    logger.warning(
        f"Process rank: {training_args.local_rank}, device: {training_args.device}, world_size: {training_args.world_size}, "
        +
        f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
    )

    # Detecting last checkpoint.
    last_checkpoint = None
    if os.path.isdir(
            training_args.output_dir
    ) and training_args.do_train and not training_args.overwrite_output_dir:
        last_checkpoint = get_last_checkpoint(training_args.output_dir)
        # if last_checkpoint is None and len(
        #         os.listdir(training_args.output_dir)) > 1:
        #     raise ValueError(
        #         f"Output directory ({training_args.output_dir}) already exists and is not empty. "
        #         "Use --overwrite_output_dir to overcome.")
        if last_checkpoint is not None and training_args.resume_from_checkpoint is None:
            logger.info(
                f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
                "the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
            )

    base_class, model_class, criterion_class, tokenizer_class = MODEL_CLASSES[
        model_args.model_type]
    pretrained_models_list = list(
        model_class.pretrained_init_configuration.keys())

    if model_args.model_name_or_path in pretrained_models_list:
        model_config = model_class.pretrained_init_configuration[
            model_args.model_name_or_path]
        model_config["hidden_dropout_prob"] = model_args.hidden_dropout_prob
        model_config[
            "attention_probs_dropout_prob"] = model_args.attention_probs_dropout_prob
        model = model_class(base_class(**model_config))
    else:
        model = model_class.from_pretrained(
            model_args.model_name_or_path,
            hidden_dropout_prob=model_args.hidden_dropout_prob,
            attention_probs_dropout_prob=model_args.
            attention_probs_dropout_prob)

    class CriterionWrapper(paddle.nn.Layer):
        """
        """
        def __init__(self):
            """CriterionWrapper
            """
            super(CriterionWrapper, self).__init__()
            self.criterion = criterion_class()

        def forward(self, output, labels):
            """forward function

            Args:
                output (tuple): prediction_scores, seq_relationship_score
                labels (tuple): masked_lm_labels, next_sentence_labels

            Returns:
                Tensor: final loss.
            """
            prediction_scores, seq_relationship_score = output
            masked_lm_labels, next_sentence_labels = labels

            lm_loss, sop_loss = self.criterion(prediction_scores,
                                               seq_relationship_score,
                                               masked_lm_labels,
                                               next_sentence_labels)

            loss = lm_loss + sop_loss
            return loss

    # Create the learning_rate sheduler and optimizer
    if training_args.decay_steps is None:
        training_args.decay_steps = training_args.max_steps
    warmup_steps = training_args.warmup_ratio * training_args.max_steps

    lr_scheduler = LinearAnnealingWithWarmupDecay(
        training_args.learning_rate,
        training_args.min_learning_rate,
        warmup_step=warmup_steps,
        decay_step=training_args.decay_steps)

    data_file = get_train_data_file(data_args)
    tokenizer = tokenizer_class.from_pretrained(model_args.model_name_or_path)

    train_dataset, eval_dataset, test_dataset, data_collator = create_pretrained_dataset(
        data_args, training_args, data_file, tokenizer)

    trainer = PretrainingTrainer(
        model=model,
        criterion=CriterionWrapper(),
        args=training_args,
        data_collator=data_collator,
        train_dataset=train_dataset if training_args.do_train else None,
        eval_dataset=eval_dataset if training_args.do_eval else None,
        optimizers=(None, lr_scheduler),
        tokenizer=tokenizer,
    )

    checkpoint = None
    if training_args.resume_from_checkpoint is not None:
        checkpoint = training_args.resume_from_checkpoint
    elif last_checkpoint is not None:
        checkpoint = last_checkpoint

    # Training
    if training_args.do_train:
        train_result = trainer.train(resume_from_checkpoint=checkpoint)
        metrics = train_result.metrics
        trainer.save_model()
        trainer.log_metrics("train", metrics)
        trainer.save_metrics("train", metrics)
        trainer.save_state()

    if training_args.do_predict:
        test_ret = trainer.predict(test_dataset)
        trainer.log_metrics("test", test_ret.metrics)
예제 #16
0
def do_train(args):
    paddle.set_device(args.device)
    if paddle.distributed.get_world_size() > 1:
        paddle.distributed.init_parallel_env()

    set_seed(args)

    args.task_name = args.task_name.lower()
    metric_class = METRIC_CLASSES[args.task_name]
    args.model_type = args.model_type.lower()
    model_class, tokenizer_class = MODEL_CLASSES[args.model_type]

    train_ds = load_dataset('glue', args.task_name, splits="train")

    tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path)
    trans_func = partial(
        convert_example,
        tokenizer=tokenizer,
        label_list=train_ds.label_list,
        max_seq_length=args.max_seq_length)
    train_ds = train_ds.map(trans_func, lazy=True)
    train_batch_sampler = paddle.io.DistributedBatchSampler(
        train_ds, batch_size=args.batch_size, shuffle=True)
    batchify_fn = lambda samples, fn=Tuple(
        Pad(axis=0, pad_val=tokenizer.pad_token_id),  # input
        Pad(axis=0, pad_val=tokenizer.pad_token_type_id),  # segment
        Stack(dtype="int64" if train_ds.label_list else "float32")  # label
    ): fn(samples)
    train_data_loader = DataLoader(
        dataset=train_ds,
        batch_sampler=train_batch_sampler,
        collate_fn=batchify_fn,
        num_workers=0,
        return_list=True)
    if args.task_name == "mnli":
        dev_ds_matched, dev_ds_mismatched = load_dataset(
            'glue', args.task_name, splits=["dev_matched", "dev_mismatched"])
        dev_ds_matched = dev_ds_matched.map(trans_func, lazy=True)
        dev_ds_mismatched = dev_ds_mismatched.map(trans_func, lazy=True)
        dev_batch_sampler_matched = paddle.io.BatchSampler(
            dev_ds_matched, batch_size=args.batch_size, shuffle=False)
        dev_data_loader_matched = DataLoader(
            dataset=dev_ds_matched,
            batch_sampler=dev_batch_sampler_matched,
            collate_fn=batchify_fn,
            num_workers=0,
            return_list=True)
        dev_batch_sampler_mismatched = paddle.io.BatchSampler(
            dev_ds_mismatched, batch_size=args.batch_size, shuffle=False)
        dev_data_loader_mismatched = DataLoader(
            dataset=dev_ds_mismatched,
            batch_sampler=dev_batch_sampler_mismatched,
            collate_fn=batchify_fn,
            num_workers=0,
            return_list=True)
    else:
        dev_ds = load_dataset('glue', args.task_name, splits='dev')
        dev_ds = dev_ds.map(trans_func, lazy=True)
        dev_batch_sampler = paddle.io.BatchSampler(
            dev_ds, batch_size=args.batch_size, shuffle=False)
        dev_data_loader = DataLoader(
            dataset=dev_ds,
            batch_sampler=dev_batch_sampler,
            collate_fn=batchify_fn,
            num_workers=0,
            return_list=True)

    num_labels = 1 if train_ds.label_list == None else len(train_ds.label_list)

    # Step1: Initialize the origin BERT model.
    model = model_class.from_pretrained(
        args.model_name_or_path, num_classes=num_labels)
    origin_weights = model.state_dict()

    # Step2: Convert origin model to supernet.
    sp_config = supernet(expand_ratio=args.width_mult_list)
    model = Convert(sp_config).convert(model)
    # Use weights saved in the dictionary to initialize supernet. 
    utils.set_state_dict(model, origin_weights)

    # Step3: Define teacher model.
    teacher_model = model_class.from_pretrained(
        args.model_name_or_path, num_classes=num_labels)
    new_dict = utils.utils.remove_model_fn(teacher_model, origin_weights)
    teacher_model.set_state_dict(new_dict)
    del origin_weights, new_dict

    default_run_config = {'elastic_depth': args.depth_mult_list}
    run_config = RunConfig(**default_run_config)

    # Step4: Config about distillation.
    mapping_layers = ['bert.embeddings']
    for idx in range(model.bert.config['num_hidden_layers']):
        mapping_layers.append('bert.encoder.layers.{}'.format(idx))

    default_distill_config = {
        'lambda_distill': args.lambda_rep,
        'teacher_model': teacher_model,
        'mapping_layers': mapping_layers,
    }
    distill_config = DistillConfig(**default_distill_config)

    # Step5: Config in supernet training.
    ofa_model = OFA(model,
                    run_config=run_config,
                    distill_config=distill_config,
                    elastic_order=['depth'])
    #elastic_order=['width'])

    criterion = paddle.nn.CrossEntropyLoss(
    ) if train_ds.label_list else paddle.nn.MSELoss()

    metric = metric_class()

    if args.task_name == "mnli":
        dev_data_loader = (dev_data_loader_matched, dev_data_loader_mismatched)

    if paddle.distributed.get_world_size() > 1:
        ofa_model.model = paddle.DataParallel(
            ofa_model.model, find_unused_parameters=True)

    if args.max_steps > 0:
        num_training_steps = args.max_steps
        num_train_epochs = math.ceil(num_training_steps /
                                     len(train_data_loader))
    else:
        num_training_steps = len(train_data_loader) * args.num_train_epochs
        num_train_epochs = args.num_train_epochs

    lr_scheduler = LinearDecayWithWarmup(args.learning_rate, num_training_steps,
                                         args.warmup_steps)

    # Generate parameter names needed to perform weight decay.
    # All bias and LayerNorm parameters are excluded.
    decay_params = [
        p.name for n, p in model.named_parameters()
        if not any(nd in n for nd in ["bias", "norm"])
    ]
    optimizer = paddle.optimizer.AdamW(
        learning_rate=lr_scheduler,
        epsilon=args.adam_epsilon,
        parameters=ofa_model.model.parameters(),
        weight_decay=args.weight_decay,
        apply_decay_param_fun=lambda x: x in decay_params)

    global_step = 0
    tic_train = time.time()
    for epoch in range(num_train_epochs):
        # Step6: Set current epoch and task.
        ofa_model.set_epoch(epoch)
        ofa_model.set_task('depth')

        for step, batch in enumerate(train_data_loader):
            global_step += 1
            input_ids, segment_ids, labels = batch

            for depth_mult in args.depth_mult_list:
                for width_mult in args.width_mult_list:
                    # Step7: Broadcast supernet config from width_mult,
                    # and use this config in supernet training.
                    net_config = utils.dynabert_config(ofa_model, width_mult,
                                                       depth_mult)
                    ofa_model.set_net_config(net_config)
                    logits, teacher_logits = ofa_model(
                        input_ids, segment_ids, attention_mask=[None, None])
                    rep_loss = ofa_model.calc_distill_loss()
                    if args.task_name == 'sts-b':
                        logit_loss = 0.0
                    else:
                        logit_loss = soft_cross_entropy(logits,
                                                        teacher_logits.detach())
                    loss = rep_loss + args.lambda_logit * logit_loss
                    loss.backward()
            optimizer.step()
            lr_scheduler.step()
            ofa_model.model.clear_gradients()

            if global_step % args.logging_steps == 0:
                if paddle.distributed.get_rank() == 0:
                    logger.info(
                        "global step %d, epoch: %d, batch: %d, loss: %f, speed: %.2f step/s"
                        % (global_step, epoch, step, loss,
                           args.logging_steps / (time.time() - tic_train)))
                tic_train = time.time()

            if global_step % args.save_steps == 0:
                if args.task_name == "mnli":
                    evaluate(
                        teacher_model,
                        criterion,
                        metric,
                        dev_data_loader_matched,
                        width_mult=100)
                    evaluate(
                        teacher_model,
                        criterion,
                        metric,
                        dev_data_loader_mismatched,
                        width_mult=100)
                else:
                    evaluate(
                        teacher_model,
                        criterion,
                        metric,
                        dev_data_loader,
                        width_mult=100)
                for depth_mult in args.depth_mult_list:
                    for width_mult in args.width_mult_list:
                        net_config = utils.dynabert_config(
                            ofa_model, width_mult, depth_mult)
                        ofa_model.set_net_config(net_config)
                        tic_eval = time.time()
                        if args.task_name == "mnli":
                            acc = evaluate(ofa_model, criterion, metric,
                                           dev_data_loader_matched, width_mult,
                                           depth_mult)
                            evaluate(ofa_model, criterion, metric,
                                     dev_data_loader_mismatched, width_mult,
                                     depth_mult)
                            print("eval done total : %s s" %
                                  (time.time() - tic_eval))
                        else:
                            acc = evaluate(ofa_model, criterion, metric,
                                           dev_data_loader, width_mult,
                                           depth_mult)
                            print("eval done total : %s s" %
                                  (time.time() - tic_eval))

                        if paddle.distributed.get_rank() == 0:
                            output_dir = os.path.join(args.output_dir,
                                                      "model_%d" % global_step)
                            if not os.path.exists(output_dir):
                                os.makedirs(output_dir)
                            # need better way to get inner model of DataParallel
                            model_to_save = model._layers if isinstance(
                                model, paddle.DataParallel) else model
                            model_to_save.save_pretrained(output_dir)
                            tokenizer.save_pretrained(output_dir)
            if global_step >= num_training_steps:
                return
예제 #17
0
def train(args):
    paddle.set_device(args.device)

    trainer_num = paddle.distributed.get_world_size()
    if trainer_num > 1:
        paddle.distributed.init_parallel_env()
    rank = paddle.distributed.get_rank()
    # Create dataset.
    train_ds, test_ds = load_dataset(
        datafiles=(os.path.join(args.data_dir, 'train.tsv'),
                   os.path.join(args.data_dir, 'test.tsv')))

    word_vocab = load_vocab(os.path.join(args.data_dir, 'word.dic'))
    label_vocab = load_vocab(os.path.join(args.data_dir, 'tag.dic'))
    # q2b.dic is used to replace DBC case to SBC case
    normlize_vocab = load_vocab(os.path.join(args.data_dir, 'q2b.dic'))

    trans_func = partial(convert_example,
                         max_seq_len=args.max_seq_len,
                         word_vocab=word_vocab,
                         label_vocab=label_vocab,
                         normlize_vocab=normlize_vocab)
    train_ds.map(trans_func)
    test_ds.map(trans_func)

    batchify_fn = lambda samples, fn=Tuple(
        Pad(axis=0, pad_val=word_vocab.get("[PAD]", 0), dtype='int64'
            ),  # word_ids
        Stack(dtype='int64'),  # length
        Pad(axis=0, pad_val=label_vocab.get("O", 0), dtype='int64'
            ),  # label_ids
    ): fn(samples)

    # Create sampler for dataloader
    train_sampler = paddle.io.DistributedBatchSampler(
        dataset=train_ds,
        batch_size=args.batch_size,
        shuffle=True,
        drop_last=True)
    train_loader = paddle.io.DataLoader(dataset=train_ds,
                                        batch_sampler=train_sampler,
                                        return_list=True,
                                        collate_fn=batchify_fn)

    test_sampler = paddle.io.BatchSampler(dataset=test_ds,
                                          batch_size=args.batch_size,
                                          shuffle=False,
                                          drop_last=False)
    test_loader = paddle.io.DataLoader(dataset=test_ds,
                                       batch_sampler=test_sampler,
                                       return_list=True,
                                       collate_fn=batchify_fn)

    # Define the model netword and its loss
    model = BiGruCrf(args.emb_dim,
                     args.hidden_size,
                     len(word_vocab),
                     len(label_vocab),
                     crf_lr=args.crf_lr)
    # Prepare optimizer, loss and metric evaluator
    optimizer = paddle.optimizer.Adam(learning_rate=args.base_lr,
                                      parameters=model.parameters())
    chunk_evaluator = ChunkEvaluator(label_list=label_vocab.keys(),
                                     suffix=True)

    if args.init_checkpoint:
        if os.path.exists(args.init_checkpoint):
            logger.info("Init checkpoint from %s" % args.init_checkpoint)
            model_dict = paddle.load(args.init_checkpoint)
            model.load_dict(model_dict)
        else:
            logger.info("Cannot init checkpoint from %s which doesn't exist" %
                        args.init_checkpoint)
    logger.info("Start training")
    # Start training
    global_step = 0
    last_step = args.epochs * len(train_loader)
    train_reader_cost = 0.0
    train_run_cost = 0.0
    total_samples = 0
    reader_start = time.time()
    max_f1_score = -1
    for epoch in range(args.epochs):
        for step, batch in enumerate(train_loader):
            train_reader_cost += time.time() - reader_start
            global_step += 1
            token_ids, length, label_ids = batch
            train_start = time.time()
            loss = model(token_ids, length, label_ids)
            avg_loss = paddle.mean(loss)
            train_run_cost += time.time() - train_start
            total_samples += args.batch_size
            if global_step % args.logging_steps == 0:
                logger.info(
                    "global step %d / %d, loss: %f, avg_reader_cost: %.5f sec, avg_batch_cost: %.5f sec, avg_samples: %.5f, ips: %.5f sequences/sec"
                    %
                    (global_step, last_step, avg_loss,
                     train_reader_cost / args.logging_steps,
                     (train_reader_cost + train_run_cost) / args.logging_steps,
                     total_samples / args.logging_steps, total_samples /
                     (train_reader_cost + train_run_cost)))
                train_reader_cost = 0.0
                train_run_cost = 0.0
                total_samples = 0
            avg_loss.backward()
            optimizer.step()
            optimizer.clear_grad()
            if global_step % args.save_steps == 0 or global_step == last_step:
                if rank == 0:
                    paddle.save(
                        model.state_dict(),
                        os.path.join(args.model_save_dir,
                                     "model_%d.pdparams" % global_step))
                    logger.info("Save %d steps model." % (global_step))
                    if args.do_eval:
                        precision, recall, f1_score = evaluate(
                            model, chunk_evaluator, test_loader)
                        if f1_score > max_f1_score:
                            max_f1_score = f1_score
                            paddle.save(
                                model.state_dict(),
                                os.path.join(args.model_save_dir,
                                             "best_model.pdparams"))
                            logger.info("Save best model.")

            reader_start = time.time()
예제 #18
0
def do_eval(args):
    paddle.set_device(args.device)
    model_class, tokenizer_class = MODEL_CLASSES[args.model_name]
    tokenizer = tokenizer_class.from_pretrained(args.model_name)

    if args.init_checkpoint_path is not None:
        model = GPT2ForPretraining(
            GPT2Model(
                **model_class.pretrained_init_configuration[args.model_name]))

        logger.info("Load model checkpoint from %s" %
                    args.init_checkpoint_path)
        model_dict = paddle.load(os.path.join(args.init_checkpoint_path))
        model.set_dict(model_dict)
    else:
        model = model_class.from_pretrained(args.model_name)

    tic_eval = time.time()
    eval_data_loader = create_eval_dataset(args)
    model.eval()
    total_score = 0
    score_name = "loss" if not args.cloze_eval else "number correct"
    with paddle.no_grad():
        for step, batch in enumerate(eval_data_loader):
            tokens, loss_mask, attention_mask, position_ids, labels = batch
            preds = model(tokens, position_ids, attention_mask)
            if not args.cloze_eval:
                masked_lm_loss = paddle.nn.functional.cross_entropy(
                    preds, labels, reduction="none")
                loss = paddle.sum(masked_lm_loss * loss_mask)
                total_score += loss.numpy() / (args.num_tokenized_tokens - 1)
            else:
                outputs = paddle.argmax(preds, -1)
                acc = paddle.cast(outputs == labels, 'float32')
                acc = paddle.where(paddle.cast(loss_mask, 'bool'), acc,
                                   paddle.ones_like(acc))
                acc = paddle.sum(paddle.prod(acc, -1))
                total_score += acc.numpy()
            if step % args.logging_steps == 0:
                logger.info(
                    "step %d, batch: %d, %s: %f, speed: %.2f step/s" %
                    (step, step, score_name, total_score, args.logging_steps /
                     (time.time() - tic_eval)))
                tic_eval = time.time()

    if not args.cloze_eval:
        total_loss = float(total_score)
        ppl = math.exp(min(20, total_loss))
        token_ratio = (args.num_tokenized_tokens -
                       1) / (args.num_original_tokens - 1)
        adjusted_ppl = math.exp(min(20, total_loss * token_ratio))
        string = ' validation results on {} | '.format(args.eval_path)
        string += 'avg loss: {:.4E} | '.format(total_loss)
        string += 'ppl: {:.4E} | '.format(ppl)
        string += 'adjusted ppl: {:.4E} | '.format(adjusted_ppl)
        string += 'token ratio: {} |'.format(token_ratio)
    else:
        num_correct = float(total_score)
        acc = float(num_correct / args.num_examples)
        string = ' validation results on {} | '.format(args.eval_path)
        string += 'number correct: {:.4E} | '.format(num_correct)
        string += 'total examples: {:.4E} | '.format(args.num_examples)
        string += 'avg accuracy: {:.4E}'.format(acc)
    logger.info(string)
예제 #19
0
def do_train(args):
    assert args.device in [
        "cpu", "gpu", "xpu"
    ], "Invalid device! Available device should be cpu, gpu, or xpu."
    paddle.set_device(args.device)
    if paddle.distributed.get_world_size() > 1:
        paddle.distributed.init_parallel_env()

    worker_index = paddle.distributed.get_rank()
    worker_num = paddle.distributed.get_world_size()
    set_seed(args)
    worker_init = WorkerInitObj(args.seed + paddle.distributed.get_rank())
    model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
    tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path)
    eod_id = tokenizer.command_name_map["eod"].Id

    pretrained_models_list = list(
        model_class.pretrained_init_configuration.keys())
    if args.model_name_or_path in pretrained_models_list:
        model = GPT2ForPretraining(
            GPT2Model(**model_class.pretrained_init_configuration[
                args.model_name_or_path]))
    else:
        model = GPT2ForPretraining.from_pretrained(args.model_name_or_path)

    if args.decay_steps is None:
        args.decay_steps = args.max_steps
    warmup_step = args.warmup_rate * args.decay_steps
    lr_scheduler = lr.CosineAnnealingWithWarmupDecay(
        max_lr=args.max_lr,
        min_lr=args.min_lr,
        warmup_step=warmup_step,
        decay_step=args.decay_steps)

    clip = None
    if args.grad_clip > 0:
        clip = paddle.nn.ClipGradByNorm(clip_norm=args.grad_clip)

    optimizer = paddle.optimizer.AdamW(
        learning_rate=lr_scheduler,
        epsilon=args.adam_epsilon,
        parameters=model.parameters(),
        weight_decay=args.weight_decay,
        grad_clip=clip,
        apply_decay_param_fun=lambda x: x in [
            p.name for n, p in model.named_parameters()
            if not any(nd in n for nd in ["bias", "norm"])
        ])
    if args.model_name_or_path not in pretrained_models_list:
        opt_dict = paddle.load(
            os.path.join(args.model_name_or_path, "model_state.pdopt"))
        optimizer.set_state_dict(opt_dict)

    # creat the critrion for the gpt model
    criterion = GPT2PretrainingCriterion()

    global_step = 0
    tic_train = time.time()
    for epoch in range(args.num_train_epochs):
        files = [
            os.path.join(args.input_dir, f) for f in os.listdir(args.input_dir)
            if (os.path.isfile(os.path.join(args.input_dir, f))
                and "npz_" not in str(f))
        ]
        files.sort()
        num_files = len(files)
        for f_id in range(num_files):
            data_file = files[f_id]
            train_data_loader = create_pretrained_dataset(args,
                                                          data_file,
                                                          worker_init,
                                                          worker_index,
                                                          worker_num,
                                                          eod_id=eod_id)
            for step, batch in enumerate(train_data_loader):
                global_step += 1
                tokens, loss_mask, attention_mask, position_ids, labels = batch

                loss_mask.stop_gradient = True
                attention_mask.stop_gradient = True

                preds = model(tokens, position_ids, attention_mask)
                loss = criterion(preds, labels, loss_mask)

                if global_step % args.logging_steps == 0:
                    if worker_index == 0:
                        logger.info(
                            "global step %d, epoch: %d, lr: %.10f, batch: %d, loss: %f, speed: %.2f step/s"
                            % (global_step, epoch, optimizer.get_lr(), step,
                               loss, args.logging_steps /
                               (time.time() - tic_train)))
                    tic_train = time.time()
                loss.backward()
                optimizer.step()
                lr_scheduler.step()
                optimizer.clear_grad()
                if global_step % args.save_steps == 0 or global_step >= args.max_steps:
                    if worker_index == 0:
                        output_dir = os.path.join(args.output_dir,
                                                  "model_%d" % global_step)
                        if not os.path.exists(output_dir):
                            os.makedirs(output_dir)
                        # need better way to get inner model of DataParallel
                        model_to_save = model._layers if isinstance(
                            model, paddle.DataParallel) else model
                        logger.info("Save model to %s" % output_dir)
                        model_to_save.save_pretrained(output_dir)
                        tokenizer.save_pretrained(output_dir)
                        paddle.save(
                            optimizer.state_dict(),
                            os.path.join(output_dir, "model_state.pdopt"))
                if global_step >= args.max_steps:
                    logger.info("The training process is complete.")
                    del train_data_loader
                    return

            del train_data_loader
예제 #20
0
def create_pretrained_dataset(args,
                              input_path,
                              local_rank,
                              data_world_rank,
                              data_world_size,
                              eos_id,
                              worker_init=None,
                              max_seq_len=1024,
                              places=None,
                              data_holders=None):
    if local_rank == 0:
        start_time = time.time()
        print('> compiling dataset index builder ...')
        from data_tools.dataset_utils import compile_helper
        compile_helper()
        print(
            '>>> done with dataset index builder. Compilation time: {:.3f} '
            'seconds'.format(time.time() - start_time),
            flush=True)

    device_world_size = paddle.distributed.get_world_size()
    device_world_rank = paddle.distributed.get_rank()

    if device_world_size > 1 and local_rank != 0:
        while True:
            try:
                import data_tools.helpers as helpers
                break
            except Exception as e:
                print("> wait for helpers to be compiled!")
                time.sleep(1)

    logger.info(
        "The distributed run, total device num:{}, distinct dataflow num:{}.".
        format(device_world_size, data_world_size))

    assert len(input_path) == 1, "GPT only support one dataset for now."

    input_prefix = input_path[0]

    if os.path.isfile(input_prefix + "_ids.npz"):
        logger.warning(
            "You are using compatible dataset, please make new dataset as the readme!"
        )
        process_datas = np.load(
            input_prefix + "_ids.npz", mmap_mode="r+", allow_pickle=True)
        sample_ids = process_datas["ids"]
        sample_lens = process_datas["lens"].astype("int32")
    else:
        for suffix in ["_ids.npy", "_idx.npz"]:
            if not os.path.isfile(input_prefix + suffix):
                raise ValueError("File Not found, %s" % (path + suffix))

        sample_ids = np.load(
            input_prefix + "_ids.npy", mmap_mode="r", allow_pickle=True)
        # All documment ids, extend as 1-D array.

        process_datas = np.load(input_prefix + "_idx.npz")
        # The len(sample_lens) num of docs
        # The sum(sample_lens) should equal len(sample_ids)
        sample_lens = process_datas["lens"]

    splits = get_train_valid_test_split_(args.split, len(sample_lens))
    assert len(sample_lens) >= splits[
        -1], "The document nums should larger than max of splits, but %s < %s" % (
            len(sample_lens), splits[-1])

    def build_dataset(index, name, num_samples):
        dataset = GPTDataset(
            file_path=input_prefix,
            build_data_file=local_rank == 0,
            name="gpt_" + name,
            max_seq_len=max_seq_len,
            num_samples=num_samples,
            documents=np.arange(splits[index], splits[index + 1]),
            sample_ids=sample_ids,
            sample_lens=sample_lens,
            eos_id=eos_id,
            seed=args.seed)

        batch_sampler = DistributedBatchSampler(
            dataset,
            batch_size=args.local_batch_size,
            num_replicas=data_world_size,
            rank=data_world_rank,
            shuffle=False,
            drop_last=True)

        data_loader = DataLoader(
            dataset=dataset,
            places=places,
            feed_list=data_holders,
            batch_sampler=batch_sampler,
            num_workers=1,
            worker_init_fn=worker_init,
            # collate_fn=Tuple(Stack(), Stack(), Stack(), Stack(), Stack()),
            collate_fn=Tuple(Stack(), Stack(), Stack(), Stack()),
            return_list=False)
        return data_loader

    # Note, data should be broardcast to all devices.
    # for train, valid, test, the distinct data num is data_world_size
    train_data_loader = build_dataset(0, "train", args.local_batch_size *
                                      args.max_steps * data_world_size)

    valid_data_loader = build_dataset(1, "valid", args.local_batch_size *
                                      (args.max_steps // args.eval_freq + 1) *
                                      args.eval_iters * data_world_size)
    test_data_loader = build_dataset(2, "test", args.local_batch_size *
                                     args.test_iters * data_world_size)

    return train_data_loader, valid_data_loader, test_data_loader
예제 #21
0
    def preprocess_function(examples, do_predict=False):
        def _truncate_seq_tuple(tokens_a, tokens_b, tokens_c, max_length):
            """Truncates a sequence tuple in place to the maximum length."""
            # This is a simple heuristic which will always truncate the longer
            # sequence one token at a time. This makes more sense than
            # truncating an equal percent of tokens from each, since if one
            # sequence is very short then each token that's truncated likely
            # contains more information than a longer sequence.
            while True:
                total_length = len(tokens_a) + len(tokens_b) + len(tokens_c)
                if total_length <= max_length:
                    break
                if len(tokens_a) >= len(tokens_b) and len(tokens_a) >= len(
                        tokens_c):
                    tokens_a.pop()
                elif len(tokens_b) >= len(tokens_a) and len(tokens_b) >= len(
                        tokens_c):
                    tokens_b.pop()
                else:
                    tokens_c.pop()

        num_examples = len(examples.data["question"])
        if do_predict:
            result = {"input_ids": [], "token_type_ids": []}
        else:
            result = {"input_ids": [], "token_type_ids": [], "labels": []}
        for idx in range(num_examples):
            text = '\n'.join(examples.data["context"][idx]).lower()
            question = examples.data["question"][idx].lower()
            choice_list = examples.data["choice"][idx]
            choice_list = [choice.lower()
                           for choice in choice_list][:max_num_choices]
            if not do_predict:
                answer = examples.data["answer"][idx].lower()
                label = choice_list.index(answer)

            tokens_t = tokenizer.tokenize(text)
            tokens_q = tokenizer.tokenize(question)

            tokens_t_list = []
            tokens_c_list = []

            # Pad each new example for axis=1, [batch_size, num_choices, seq_len]
            while len(choice_list) < max_num_choices:
                choice_list.append('无效答案')

            for choice in choice_list:
                tokens_c = tokenizer.tokenize(choice.lower())
                _truncate_seq_tuple(tokens_t, tokens_q, tokens_c,
                                    max_seq_length - 4)

                tokens_c = tokens_q + ["[SEP]"] + tokens_c
                tokens_t_list.append(tokens_t)
                tokens_c_list.append(tokens_c)

            new_data = tokenizer(
                tokens_t_list,
                text_pair=tokens_c_list,
                is_split_into_words=True)

            # Pad each new example for axis=2 of [batch_size, num_choices, seq_len],
            # because length of each choice could be different.
            input_ids = Pad(
                axis=0, pad_val=tokenizer.pad_token_id)(new_data["input_ids"])
            token_type_ids = Pad(
                axis=0,
                pad_val=tokenizer.pad_token_id)(new_data["token_type_ids"])

            # Final shape of input_ids: [batch_size, num_choices, seq_len]
            result["input_ids"].append(input_ids)
            result["token_type_ids"].append(token_type_ids)
            if not do_predict:
                result["labels"].append([label])
            if (idx + 1) % 1000 == 0:
                logger.info("%d samples have been processed." % (idx + 1))
        return result
예제 #22
0
    def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
        """
        Creates an instance of `PretrainedModel`. Model weights are loaded
        by specifying name of a built-in pretrained model, or a community contributed model,
        or a local file directory path.

        Args:
            pretrained_model_name_or_path (str): Name of pretrained model or dir path
                to load from. The string can be:

                - Name of a built-in pretrained model
                - Name of a community-contributed pretrained model.
                - Local directory path which contains model weights file("model_state.pdparams")
                  and model config file ("model_config.json").
            *args (tuple): Position arguments for model `__init__`. If provided,
                use these as position argument values for model initialization.
            **kwargs (dict): Keyword arguments for model `__init__`. If provided,
                use these to update pre-defined keyword argument values for model
                initialization. If the keyword is in `__init__` argument names of
                base model, update argument values of the base model; else update
                argument values of derived model.
            load_state_as_np (bool, optional): The weights read in can be choosed
                to place on CPU or GPU though the model is on the default device.
                If `True`, load the model weights as `numpy.ndarray` on CPU.
                Otherwise, weights would be loaded as tensors on the default
                device. Note that if on GPU, the latter would creates extra
                temporary tensors in addition to the model weights, which
                doubles the memory usage . Thus it is suggested to use `True`
                for big models on GPU. Default to `False`.

        Returns:
            PretrainedModel: An instance of `PretrainedModel`.

        Example:
            .. code-block::

                from paddlenlp.transformers import BertForSequenceClassification

                # Name of built-in pretrained model
                model = BertForSequenceClassification.from_pretrained('bert-base-uncased')

                # Name of community-contributed pretrained model
                model = BertForSequenceClassification.from_pretrained('yingyibiao/bert-base-uncased-sst-2-finetuned')

                # Load from local directory path
                model = BertForSequenceClassification.from_pretrained('./my_bert/')
        """
        pretrained_models = list(cls.pretrained_init_configuration.keys())
        resource_files = {}
        init_configuration = {}
        load_state_as_np = kwargs.pop("load_state_as_np", False)

        # From built-in pretrained models
        if pretrained_model_name_or_path in pretrained_models:
            for file_id, map_list in cls.pretrained_resource_files_map.items():
                resource_files[file_id] = map_list[
                    pretrained_model_name_or_path]
            init_configuration = copy.deepcopy(
                cls.
                pretrained_init_configuration[pretrained_model_name_or_path])
        # From local dir path
        elif os.path.isdir(pretrained_model_name_or_path):
            for file_id, file_name in cls.resource_files_names.items():
                full_file_name = os.path.join(pretrained_model_name_or_path,
                                              file_name)
                resource_files[file_id] = full_file_name
            resource_files["model_config_file"] = os.path.join(
                pretrained_model_name_or_path, cls.model_config_file)
        else:
            # Assuming from community-contributed pretrained models
            for file_id, file_name in cls.resource_files_names.items():
                full_file_name = os.path.join(COMMUNITY_MODEL_PREFIX,
                                              pretrained_model_name_or_path,
                                              file_name)
                resource_files[file_id] = full_file_name
            resource_files["model_config_file"] = os.path.join(
                COMMUNITY_MODEL_PREFIX, pretrained_model_name_or_path,
                cls.model_config_file)

        default_root = os.path.join(MODEL_HOME, pretrained_model_name_or_path)
        resolved_resource_files = {}
        for file_id, file_path in resource_files.items():
            if file_path is None or os.path.isfile(file_path):
                resolved_resource_files[file_id] = file_path
                continue
            path = os.path.join(default_root, file_path.split('/')[-1])
            if os.path.exists(path):
                logger.info("Already cached %s" % path)
                resolved_resource_files[file_id] = path
            else:
                logger.info("Downloading %s and saved to %s" %
                            (file_path, default_root))
                try:
                    resolved_resource_files[file_id] = get_path_from_url(
                        file_path, default_root)
                except RuntimeError as err:
                    logger.error(err)
                    raise RuntimeError(
                        f"Can't load weights for '{pretrained_model_name_or_path}'.\n"
                        f"Please make sure that '{pretrained_model_name_or_path}' is:\n"
                        "- a correct model-identifier of built-in pretrained models,\n"
                        "- or a correct model-identifier of community-contributed pretrained models,\n"
                        "- or the correct path to a directory containing relevant modeling files(model_weights and model_config).\n"
                    )

        # Prepare model initialization kwargs
        # Did we saved some inputs and kwargs to reload ?
        model_config_file = resolved_resource_files.pop(
            "model_config_file", None)
        if model_config_file is not None:
            with io.open(model_config_file, encoding="utf-8") as f:
                init_kwargs = json.load(f)
        else:
            init_kwargs = init_configuration
        # position args are stored in kwargs, maybe better not include
        init_args = init_kwargs.pop("init_args", ())
        # class name corresponds to this configuration
        init_class = init_kwargs.pop("init_class",
                                     cls.base_model_class.__name__)
        # Check if the loaded config matches the current model class's __init__
        # arguments. If not match, the loaded config is for the base model class.
        if init_class == cls.base_model_class.__name__:
            base_args = init_args
            base_kwargs = init_kwargs
            derived_args = ()
            derived_kwargs = {}
            base_arg_index = None
        else:  # extract config for base model
            derived_args = list(init_args)
            derived_kwargs = init_kwargs
            base_arg = None
            for i, arg in enumerate(init_args):
                if isinstance(arg, dict) and "init_class" in arg:
                    assert arg.pop(
                        "init_class") == cls.base_model_class.__name__, (
                            "pretrained base model should be {}").format(
                                cls.base_model_class.__name__)
                    base_arg_index = i
                    base_arg = arg
                    break
            for arg_name, arg in init_kwargs.items():
                if isinstance(arg, dict) and "init_class" in arg:
                    assert arg.pop(
                        "init_class") == cls.base_model_class.__name__, (
                            "pretrained base model should be {}").format(
                                cls.base_model_class.__name__)
                    base_arg_index = arg_name
                    base_arg = arg
                    break

            base_args = base_arg.pop("init_args", ())
            base_kwargs = base_arg

        if cls == cls.base_model_class:
            # Update with newly provided args and kwargs for base model
            base_args = base_args if not args else args
            base_kwargs.update(kwargs)
            model = cls(*base_args, **base_kwargs)
        else:
            # Update with newly provided args and kwargs for derived model
            base_parameters_dict = inspect.signature(
                cls.base_model_class.__init__).parameters
            for k, v in kwargs.items():
                if k in base_parameters_dict:
                    base_kwargs[k] = v
            base_model = cls.base_model_class(*base_args, **base_kwargs)
            if base_arg_index is not None:
                derived_args[base_arg_index] = base_model
            else:
                derived_args = (base_model, )  # assume at the first position
            derived_args = derived_args if not args else args
            derived_parameters_dict = inspect.signature(
                cls.__init__).parameters
            for k, v in kwargs.items():
                if k in derived_parameters_dict:
                    derived_kwargs[k] = v
            model = cls(*derived_args, **derived_kwargs)

        # Maybe need more ways to load resources.
        weight_path = resolved_resource_files["model_state"]
        assert weight_path.endswith(
            ".pdparams"), "suffix of weight must be .pdparams"

        # NOTE: Allow to load partial model for model parallel.
        # TODO(guosheng): To make model loading for the model parallel automatic,
        # maybe we should make rank 0 worker load weights of the full model on
        # CPU, then split weights into multiple parts and pickle separately.
        # The other workers wait util pickle finish and then load the corresponding
        # partial weights. Also we can directly use separate weight files for
        # simplicity.
        state_dict = paddle.load(weight_path, return_numpy=load_state_as_np)

        # Make sure we are able to load base models as well as derived models
        # (with heads)
        start_prefix = ""
        model_to_load = model
        state_to_load = state_dict
        unexpected_keys = []
        missing_keys = []
        if not hasattr(model, cls.base_model_prefix) and any(
                s.startswith(cls.base_model_prefix)
                for s in state_dict.keys()):
            # base model
            state_to_load = {}
            start_prefix = cls.base_model_prefix + "."
            for k, v in state_dict.items():
                if k.startswith(cls.base_model_prefix):
                    state_to_load[k[len(start_prefix):]] = v
                else:
                    unexpected_keys.append(k)
        if hasattr(model, cls.base_model_prefix) and not any(
                s.startswith(cls.base_model_prefix)
                for s in state_dict.keys()):
            # derived model (base model with heads)
            model_to_load = getattr(model, cls.base_model_prefix)
            for k in model.state_dict().keys():
                if not k.startswith(cls.base_model_prefix):
                    missing_keys.append(k)
        if len(missing_keys) > 0:
            logger.info(
                "Weights of {} not initialized from pretrained model: {}".
                format(model.__class__.__name__, missing_keys))
        if len(unexpected_keys) > 0:
            logger.info(
                "Weights from pretrained model not used in {}: {}".format(
                    model.__class__.__name__, unexpected_keys))
        # Allow the float16 model to load float32 weights, which decreases memory
        # usage in model loading stage and is useful to big models.
        dtype_prefix_len = len("paddle.")  # paddle.float16
        for k, v in model_to_load.state_dict().items():
            if not isinstance(v, np.ndarray):
                dtype = str(v.dtype)[dtype_prefix_len:]
            # TODO(guosheng): add warnings for unmatched dtypes
            if k in state_to_load:
                state_to_load[k] = state_to_load[k].astype(dtype)
        # For model parallel if FasterGeneration
        state_to_load = ft_decoding.get_ft_para_conf().fit_partial_model(
            model_to_load, state_to_load)
        if paddle.in_dynamic_mode():
            model_to_load.set_state_dict(state_to_load)
            return model
        return model, state_to_load
예제 #23
0
def do_train(args):
    # Initialize the paddle and paddle fleet execute environment
    paddle.enable_static()
    fleet.init(is_collective=True)

    # Create the random seed for the worker
    random.seed(args.seed)
    np.random.seed(args.seed)
    paddle.seed(args.seed)
    get_rng_state_tracker().add('global_seed', args.seed)
    get_rng_state_tracker().add('local_seed',
                                args.seed + fleet.worker_index() + 2021)

    assert args.device in [
        "cpu", "gpu", "xpu"
    ], "Invalid device! Available device should be cpu, gpu, or xpu."
    place = paddle.set_device(args.device)

    worker_num = fleet.worker_num()
    worker_index = fleet.worker_index()
    assert args.dp_degree * args.sharding_degree * args.mp_degree * args.pp_degree == worker_num, \
        "The product of degree num should be equal to worker_num."

    topo = Topology(device_rank=worker_index,
                    world_size=worker_num,
                    dp_degree=args.dp_degree,
                    pp_degree=args.pp_degree,
                    sharding_degree=args.sharding_degree,
                    mp_degree=args.mp_degree)

    logger.info("The topo of hybrid parallelism:\n{}".format(topo))

    dist_strategy = dist_optimizer(args, topo)

    # Create log write, train results show on last card of pipeline.
    if topo.is_last:
        log_writer_path = os.path.join(
            args.output_dir, "train_log",
            "{}_globalbsz_{}_amp_{}_recompute_{}_card_{}".format(
                args.model_name_or_path, args.global_batch_size, args.use_amp,
                args.use_recompute, worker_index).lower())
        # if os.path.exists(log_writer_path):
        #     shutil.rmtree(log_writer_path)
        log_writer = LogWriter(log_writer_path)

    # Define the input data in the static mode
    base_class, model_class, criterion_class, tokenizer_class = MODEL_CLASSES[
        args.model_type]
    pretrained_models_list = list(
        model_class.pretrained_init_configuration.keys())

    # load config in checkpoint
    global_step = 0
    consumed_samples = 0
    checkpoint_dir = os.path.join(args.output_dir, "model_last")
    if os.path.exists(checkpoint_dir):
        if os.path.isfile(os.path.join(checkpoint_dir, "./config.yml")):
            with open(os.path.join(checkpoint_dir, "./config.yml"), "r") as f:
                step_config = yaml.load(f, Loader=yaml.FullLoader)
                assert step_config[
                    "global_batch_size"] == args.global_batch_size, "Please ensure checkpoint global batch size is the same. Folder: {}".format(
                        checkpoint_dir)
                consumed_samples = step_config["consumed_samples"]
                global_step = step_config["global_step"]

    data_file = get_train_data_file(args)
    main_program = paddle.static.default_main_program()
    startup_program = paddle.static.default_startup_program()
    with paddle.static.program_guard(main_program, startup_program):
        data_holders = create_data_holder(args)
        # 0. input_ids,
        # 1. segment_ids,
        # 2. input_mask,
        # 3. masked_lm_positions,
        # 4. masked_lm_labels,
        # 5. next_sentence_labels

        [
            input_ids, segment_ids, input_mask, masked_lm_positions,
            masked_lm_labels, next_sentence_labels
        ] = data_holders

        tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path)

        train_data_loader, valid_data_loader, test_data_loader = create_pretrained_dataset(
            args,
            data_file,
            tokenizer,
            data_world_size=topo.data_info.size,
            data_world_rank=topo.data_info.rank,
            max_seq_len=args.max_seq_len,
            places=paddle.static.cuda_places(),
            data_holders=data_holders,
            current_step=global_step)
        fleet.init(is_collective=True)

        if args.model_name_or_path in pretrained_models_list:
            model_config = model_class.pretrained_init_configuration[
                args.model_name_or_path]
            if model_config["vocab_size"] % 8 != 0:
                model_config["vocab_size"] += 8 - (model_config["vocab_size"] %
                                                   8)
            model_config["hidden_dropout_prob"] = args.hidden_dropout_prob
            model_config[
                "attention_probs_dropout_prob"] = args.attention_probs_dropout_prob
            model = model_class(base_class(**model_config))
        else:
            model, _ = model_class.from_pretrained(
                args.model_name_or_path,
                hidden_dropout_prob=args.hidden_dropout_prob,
                attention_probs_dropout_prob=args.attention_probs_dropout_prob,
            )

        # Create the model for the gpt pretrain
        prediction_scores, seq_relationship_score = model(
            input_ids=input_ids,
            token_type_ids=segment_ids,
            position_ids=None,
            attention_mask=input_mask,
            masked_positions=masked_lm_positions)

        criterion = criterion_class()
        lm_loss, sop_loss = criterion(prediction_scores,
                                      seq_relationship_score, masked_lm_labels,
                                      next_sentence_labels)
        loss = lm_loss + sop_loss

        # Create the learning_rate sheduler and optimizer
        if args.decay_steps is None:
            args.decay_steps = args.max_steps

        # lr_scheduler = CosineAnnealingWithWarmupDecay(
        #     max_lr=args.max_lr,
        #     min_lr=args.min_lr,
        #     warmup_step=args.warmup_rate * args.max_steps,
        #     decay_step=args.decay_steps, last_epoch=global_step)

        lr_scheduler = LinearDecayWithWarmup(args.max_lr,
                                             args.max_steps,
                                             args.warmup_rate,
                                             last_epoch=global_step)

        clip = None
        if args.grad_clip > 0:
            clip = paddle.fluid.clip.GradientClipByGlobalNorm(
                clip_norm=args.grad_clip)

        decay_param = [
            p.name for n, p in model.named_parameters()
            if not any(nd in n for nd in ["bias", "norm"])
        ]
        logger.info("Using paddle.optimizer.AdamW.")
        optimizer = paddle.optimizer.AdamW(
            learning_rate=lr_scheduler,
            beta1=args.adam_beta1,
            beta2=args.adam_beta2,
            epsilon=args.adam_epsilon,
            grad_clip=clip,
            weight_decay=args.weight_decay,
            apply_decay_param_fun=lambda x: x in decay_param)
        # alias
        optimizer.apply_optimize = optimizer._apply_optimize

        # if args.use_recompute:
        #     dist_strategy.recompute = True
        #     dist_strategy.recompute_configs = {
        #         "checkpoints": model.bert.checkpoints
        #     }

        # Use the fleet api to compile the distributed optimizer
        optimizer = fleet.distributed_optimizer(optimizer,
                                                strategy=dist_strategy)

        optimizer.minimize(loss)
        logger.info(f'final strategy: {fleet._final_strategy()}')
        logger.info("The training meta optimizer is/are %s" %
                    fleet._get_applied_meta_list())

    program_desc_dir = os.path.join(args.output_dir, "program_desc")
    if not os.path.isdir(program_desc_dir):
        os.mkdir(program_desc_dir)

    with open(program_desc_dir + "/main_program.txt.%d" % worker_index,
              'w') as f:
        f.write(str(main_program))

    with open(program_desc_dir + "/startup_program.txt.%d" % worker_index,
              'w') as f:
        f.write(str(startup_program))

    # Define the Executor for running the static model
    exe = paddle.static.Executor(place)
    exe.run(startup_program)

    test_program = main_program.clone(for_test=True)

    if args.model_name_or_path not in pretrained_models_list:
        logger.info("Try to load checkpoint from %s " %
                    args.model_name_or_path)
        dygrah_path = os.path.join(args.model_name_or_path,
                                   "model_state.pdparams")
        static_path = os.path.join(args.model_name_or_path, "static_vars")

        flag_loaded = False
        if os.path.exists(static_path):
            if args.mp_degree > 1:
                logger.warning("MP should init with dygraph params")
            else:
                logger.info("Loading parameters from %s" % static_path)
                paddle.static.load(main_program, static_path, exe)
                flag_loaded = True

        if not flag_loaded and os.path.exists(dygrah_path):
            if args.sharding_degree > 1:
                logger.warning("Sharding should init with static vars")
            else:
                logger.info("Loading parameters from %s" % dygrah_path)
                init_static_with_params(
                    model, paddle.load(dygrah_path, return_numpy=True), topo,
                    main_program)
                flag_loaded = True

        if not flag_loaded:
            logger.error("No checkpoint load.")

    # load checkpoint vars
    if os.path.exists(checkpoint_dir):
        if os.path.isfile(os.path.join(checkpoint_dir, "./config.yml")):
            paddle.static.load(main_program,
                               os.path.join(checkpoint_dir, "static_vars"),
                               exe)

    fetch_vars = collections.OrderedDict()
    fetch_vars["loss"] = loss
    fetch_vars["lm_loss"] = lm_loss
    fetch_vars["sop_loss"] = sop_loss
    fetch_vars["learning_rate"] = main_program.global_block(
    ).vars["learning_rate_0"]

    additional_vars = collections.OrderedDict()
    if args.use_amp:
        for key in ["loss_scaling", "num_good_steps", "num_bad_steps"]:
            additional_vars[key] = main_program.global_block().vars[key + "_0"]

    tic_train = time.time()
    while True:
        fetchs = []
        fetchs_keys = []
        if topo.is_last:
            fetchs = list(fetch_vars.values()) + list(additional_vars.values())
            fetchs_keys = list(fetch_vars.keys()) + list(
                additional_vars.keys())

        # Bug fix, if not call valid_data_loader, the enumerate will call valid_data_loader
        # many times. and start a new random dataloader.
        valid_data_loader = valid_data_loader()
        test_data_loader = test_data_loader()

        for step, batch in enumerate(train_data_loader()):
            ret = exe.run(main_program,
                          feed=batch,
                          fetch_list=fetchs,
                          use_program_cache=True)
            # Skip for accumulate_steps in global step
            if (step + 1) % args.accumulate_steps != 0:
                continue
            global_step += 1
            # In the new 2.0 api, must call this function to change the learning_rate
            lr_scheduler.step()

            if global_step % args.logging_freq == 0:
                if topo.is_last:
                    res = {}
                    for k, v in zip(fetchs_keys, ret):
                        res[k] = v[0]

                    speed = args.logging_freq / (time.time() - tic_train)
                    common_loginfo = "global step %d, loss: %.9f, lm_loss: %.6f, sop_loss: %.6f, speed: %.2f steps/s, ips: %.2f seqs/s, learning rate: %.5e" % (
                        global_step, res["loss"], res["lm_loss"],
                        res["sop_loss"], speed, speed * args.global_batch_size,
                        res["learning_rate"])
                    additional_loginfo = ", ".join([
                        "{}: {}".format(k, res[k])
                        for k in additional_vars.keys()
                    ])
                    if additional_loginfo:
                        common_loginfo += ", " + additional_loginfo
                    logger.info(common_loginfo)
                    for k, v in res.items():
                        log_writer.add_scalar(k, v, global_step)

                tic_train = time.time()

            #if args.check_accuracy:
            #    if global_step >= args.max_steps:
            #        return
            #    else:
            #        continue

            if global_step % args.eval_freq == 0:
                # TODO, check the input data of validation
                eval_fetch = []
                if topo.is_last:
                    eval_fetch = [loss, lm_loss, sop_loss]

                run_evaluate(valid_data_loader, exe, test_program,
                             args.eval_iters, log_writer, global_step, args,
                             topo.is_last, eval_fetch, "valid")
                tic_train = time.time()

            if global_step % args.save_steps == 0 or global_step >= args.max_steps:
                output_dir = os.path.join(args.output_dir,
                                          "model_%d" % global_step)
                logger.debug("saving models to {}".format(output_dir))
                save_persistables(exe, os.path.join(output_dir, "static_vars"),
                                  main_program)
                if global_step == args.save_steps:
                    model.init_config["init_args"][0].init_config.pop(
                        "topo", None)
                model.save_pretrained(output_dir)
                tokenizer.save_pretrained(output_dir)
                tic_train = time.time()

            if global_step % args.checkpoint_steps == 0:
                output_dir = os.path.join(args.output_dir, "model_last")
                if worker_index == 0:
                    if not os.path.exists(output_dir):
                        os.mkdir(output_dir)
                    output_dir_bak = os.path.join(args.output_dir,
                                                  "model_last_bak")
                    if os.path.exists(output_dir):
                        if os.path.exists(output_dir_bak):
                            shutil.rmtree(output_dir_bak)
                        shutil.move(output_dir, output_dir_bak)
                        os.mkdir(output_dir)

                    step_config = {
                        "model_name": args.model_name_or_path,
                        "global_step": global_step,
                        "global_batch_size": args.global_batch_size,
                        "consumed_samples":
                        global_step * args.global_batch_size,
                    }

                    with open(os.path.join(output_dir, "config.yml"),
                              "w") as f:
                        yaml.dump(step_config,
                                  f,
                                  encoding='utf-8',
                                  allow_unicode=True)

                fleet.barrier_worker()

                logger.debug("saving models to {}".format(output_dir))
                if args.sharding_degree <= 1:
                    # Save on the first worker by default.
                    if worker_index == 0:
                        paddle.static.save(
                            main_program,
                            os.path.join(output_dir, "static_vars"))
                else:
                    # Use save_persistables in sharding, but more slower
                    save_persistables(exe,
                                      os.path.join(output_dir, "static_vars"),
                                      main_program)

            if global_step >= args.max_steps:
                eval_fetch = []
                if topo.is_last:
                    eval_fetch = [loss, lm_loss, sop_loss]

                run_evaluate(test_data_loader, exe, test_program,
                             args.test_iters, log_writer, global_step, args,
                             topo.is_last, eval_fetch, "test")
                del train_data_loader
                return
예제 #24
0
def train():
    paddle.set_device("gpu" if args.n_gpu else "cpu")
    if paddle.distributed.get_world_size() > 1:
        paddle.distributed.init_parallel_env()

    model = ErnieForGeneration.from_pretrained(args.model_name_or_path)
    if "ernie-tiny" in args.model_name_or_path:
        tokenizer = ErnieTinyTokenizer.from_pretrained(args.model_name_or_path)
    elif "ernie" in args.model_name_or_path:
        tokenizer = ErnieTokenizer.from_pretrained(args.model_name_or_path)
    elif "roberta" in args.model_name_or_path or "rbt" in args.model_name_or_path:
        tokenizer = RobertaTokenizer.from_pretrained(args.model_name_or_path)
    elif "electra" in args.model_name_or_path:
        tokenizer = ElectraTokenizer.from_pretrained(args.model_name_or_path)
    else:
        tokenizer = BertTokenizer.from_pretrained(args.model_name_or_path)
    if args.init_checkpoint:
        model_state = paddle.load(args.init_checkpoint)
        model.set_state_dict(model_state)

    train_dataset, dev_dataset = Poetry.get_datasets(['train', 'dev'])
    attn_id = tokenizer.vocab[
        '[ATTN]'] if '[ATTN]' in tokenizer.vocab else tokenizer.vocab['[MASK]']
    tgt_type_id = model.sent_emb.weight.shape[0] - 1

    trans_func = convert_example(tokenizer=tokenizer,
                                 attn_id=attn_id,
                                 tgt_type_id=tgt_type_id,
                                 max_encode_len=args.max_encode_len,
                                 max_decode_len=args.max_decode_len,
                                 noise_prob=args.noise_prob,
                                 use_random_noice=args.use_random_noice)

    train_dataset = train_dataset.apply(trans_func, lazy=True)
    train_batch_sampler = paddle.io.DistributedBatchSampler(
        train_dataset, batch_size=args.batch_size, shuffle=True)
    batchify_fn = lambda samples, fn=Tuple(
        Pad(axis=0, pad_val=tokenizer.pad_token_id),  # src_ids
        Pad(axis=0, pad_val=tokenizer.pad_token_id),  # src_pids
        Pad(axis=0, pad_val=tokenizer.pad_token_id),  # src_sids
        Pad(axis=0, pad_val=tokenizer.pad_token_id),  # tgt_ids
        Pad(axis=0, pad_val=tokenizer.pad_token_id),  # tgt_pids
        Pad(axis=0, pad_val=tokenizer.pad_token_id),  # tgt_sids
        Pad(axis=0, pad_val=tokenizer.pad_token_id),  # attn_ids
        Pad(axis=0, pad_val=tokenizer.pad_token_id),  # tgt_labels
    ): after_padding(fn(samples))
    train_data_loader = DataLoader(dataset=train_dataset,
                                   batch_sampler=train_batch_sampler,
                                   collate_fn=batchify_fn,
                                   num_workers=0,
                                   return_list=True)

    dev_dataset = dev_dataset.apply(trans_func, lazy=True)
    dev_batch_sampler = paddle.io.BatchSampler(dev_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=False)
    dev_data_loader = DataLoader(dataset=dev_dataset,
                                 batch_sampler=dev_batch_sampler,
                                 collate_fn=batchify_fn,
                                 num_workers=0,
                                 return_list=True)

    label_num = model.word_emb.weight.shape[0]
    train_model = StackModel(model)
    if paddle.distributed.get_world_size() > 1:
        # All 'forward' outputs derived from the module parameters using in DataParallel
        # must participate in the calculation of losses and subsequent gradient calculations.
        # So we use StackModel here to make the model only output loss in its 'forward' function.
        train_model = paddle.DataParallel(train_model)

    max_steps = len(train_data_loader) * args.num_epochs

    lr_scheduler = LinearDecayWithWarmup(args.learning_rate, max_steps,
                                         args.warmup_proportion)

    optimizer = paddle.optimizer.AdamW(
        learning_rate=lr_scheduler,
        epsilon=args.adam_epsilon,
        parameters=model.parameters(),
        weight_decay=args.weight_decay,
        grad_clip=nn.ClipGradByGlobalNorm(1.0),
        apply_decay_param_fun=lambda x: x in [
            p.name for n, p in model.named_parameters()
            if not any(nd in n for nd in ["bias", "norm"])
        ])

    rouge1 = Rouge1()
    rouge2 = Rouge2()

    global_step = 1
    tic_train = time.time()
    for epoch in range(args.num_epochs):
        for step, batch in enumerate(train_data_loader, start=1):
            (src_ids, src_sids, src_pids, tgt_ids, tgt_sids, tgt_pids,
             attn_ids, mask_src_2_src, mask_tgt_2_srctgt,
             mask_attn_2_srctgtattn, tgt_labels, _) = batch
            # import pdb; pdb.set_trace()
            if args.label_smooth > 0.:
                tgt_labels = nn.functional.label_smooth(
                    nn.functional.one_hot(tgt_labels, label_num),
                    epsilon=args.label_smooth)
            tgt_pos = paddle.nonzero(attn_ids == attn_id)
            loss = train_model(src_ids, src_sids, src_pids, tgt_ids, tgt_sids,
                               tgt_pids, attn_ids, mask_src_2_src,
                               mask_tgt_2_srctgt, mask_attn_2_srctgtattn,
                               tgt_labels, tgt_pos)
            if global_step % args.logging_steps == 0:
                if (not args.n_gpu > 1) or paddle.distributed.get_rank() == 0:
                    logger.info(
                        "global step %d, epoch: %d, batch: %d, loss: %f, speed: %.2f step/s, lr: %.3e"
                        % (global_step, epoch, step, loss, args.logging_steps /
                           (time.time() - tic_train), lr_scheduler.get_lr()))
                tic_train = time.time()

            loss.backward()
            optimizer.step()
            lr_scheduler.step()
            optimizer.clear_gradients()
            if global_step % args.save_steps == 0 and (
                (not args.n_gpu > 1) or paddle.distributed.get_rank() == 0):
                evaluate(model, dev_data_loader, tokenizer, rouge1, rouge2,
                         attn_id, tgt_type_id, args)
                output_dir = os.path.join(args.output_dir,
                                          "model_%d" % global_step)
                if not os.path.exists(output_dir):
                    os.makedirs(output_dir)
                model_to_save = model._layers if isinstance(
                    model, paddle.DataParallel) else model
                model_to_save.save_pretrained(output_dir)
                tokenizer.save_pretrained(output_dir)
            global_step += 1
예제 #25
0
    def _extend_vocab(self, extended_vocab_path, vector_np, pad_vector,
                      unk_vector, keep_extended_vocab_only):
        """
        Constructs index to word list, word to index dict and embedding weight using
        extended vocab.
        """
        logger.info("Start extending vocab.")
        extend_vocab_list = self._read_vocab_list_from_file(
            extended_vocab_path)
        extend_vocab_set = set(extend_vocab_list)
        # update idx_to_word
        self._idx_to_word = extend_vocab_list
        self._word_to_idx = self._construct_word_to_idx(self._idx_to_word)

        # use the Xavier init the embedding
        xavier_scale = np.sqrt(
            6.0 / float(len(self._idx_to_word) + self.embedding_dim))
        embedding_table = np.random.uniform(
            low=-1.0 * xavier_scale,
            high=xavier_scale,
            size=(len(self._idx_to_word),
                  self.embedding_dim)).astype(paddle.get_default_dtype())

        pretrained_idx_to_word = list(vector_np['vocab'])
        pretrained_word_to_idx = self._construct_word_to_idx(
            pretrained_idx_to_word)
        pretrained_embedding_table = np.array(vector_np['embedding'])

        pretrained_vocab_set = set(pretrained_idx_to_word)
        extend_vocab_set = set(self._idx_to_word)
        vocab_intersection = pretrained_vocab_set & extend_vocab_set
        vocab_subtraction = pretrained_vocab_set - extend_vocab_set

        # assignment from pretrained_vocab_embedding to extend_vocab_embedding
        pretrained_vocab_intersect_index = [
            pretrained_word_to_idx[word] for word in vocab_intersection
        ]
        pretrained_vocab_subtract_index = [
            pretrained_word_to_idx[word] for word in vocab_subtraction
        ]
        extend_vocab_intersect_index = [
            self._word_to_idx[word] for word in vocab_intersection
        ]
        embedding_table[
            extend_vocab_intersect_index] = pretrained_embedding_table[
                pretrained_vocab_intersect_index]
        if not keep_extended_vocab_only:
            for idx in pretrained_vocab_subtract_index:
                word = pretrained_idx_to_word[idx]
                self._idx_to_word.append(word)
                self._word_to_idx[word] = len(self._idx_to_word) - 1

            embedding_table = np.append(
                embedding_table,
                pretrained_embedding_table[pretrained_vocab_subtract_index],
                axis=0)

        if self.unknown_token not in extend_vocab_set:
            self._idx_to_word.append(self.unknown_token)
            self._word_to_idx[self.unknown_token] = len(self._idx_to_word) - 1
            embedding_table = np.append(embedding_table, [unk_vector], axis=0)
        else:
            unk_idx = self._word_to_idx[self.unknown_token]
            embedding_table[unk_idx] = unk_vector

        if PAD_TOKEN not in extend_vocab_set:
            self._idx_to_word.append(PAD_TOKEN)
            self._word_to_idx[PAD_TOKEN] = len(self._idx_to_word) - 1
            embedding_table = np.append(embedding_table, [pad_vector], axis=0)
        else:
            embedding_table[self._word_to_idx[PAD_TOKEN]] = pad_vector

        logger.info("Finish extending vocab.")
        return embedding_table
예제 #26
0
def do_train(args):
    paddle.set_device(args.device)
    nranks = paddle.distributed.get_world_size()
    strategy = fleet.DistributedStrategy()
    strategy.hybrid_configs = {
        "dp_degree": args.dp_degree,
        "mp_degree": args.mp_degree,
        "pp_degree": args.pp_degree,
        "sharding_degree": args.sharding_degree
    }

    accumulate_steps = args.local_batch_size // args.micro_batch_size
    strategy.pipeline_configs = {
        "accumulate_steps": accumulate_steps,
        "micro_batch_size": args.micro_batch_size
    }

    # set control in tensor parallel
    strategy.tensor_parallel_configs = {"tensor_init_seed": args.seed}

    fleet.init(is_collective=True, strategy=strategy)

    # obtain rank message of hybrid parallel
    hcg = fleet.get_hybrid_communicate_group()
    global_rank = hcg.get_global_rank()
    mp_rank = hcg.get_model_parallel_rank()
    pp_rank = hcg.get_stage_id()
    dp_rank = hcg.get_data_parallel_rank()
    sharding_rank = hcg.get_sharding_parallel_rank()

    # sharding stage2/3 not support hybrid parallel
    if args.sharding_stage in [2, 3]:
        assert args.dp_degree == args.mp_degree == args.pp_degree == 1, "sharding stage2/3 will support hybrid parallel later"

    sharding_size = hcg.get_sharding_parallel_world_size()
    data_world_rank = dp_rank * sharding_size + sharding_rank
    data_world_size = args.dp_degree * args.sharding_degree
    local_rank = int(os.getenv("PADDLE_RANK_IN_NODE", 0))

    # seed control in hybrid parallel
    set_hyrbid_parallel_seed(args.seed, data_world_rank, mp_rank, pp_rank)

    default_global_tokens_num = args.global_batch_size * args.max_seq_len

    model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
    tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path)

    # Define log writer
    log_writer_path = os.path.join(
        args.output_dir, "train_log",
        "{}_globalbsz_{}_pure_fp16_{}_recompute_{}_card_{}".format(
            args.model_name_or_path, args.global_batch_size,
            args.use_pure_fp16, False, global_rank).lower())

    if os.path.exists(log_writer_path):
        import shutil
        shutil.rmtree(log_writer_path)

    log_writer = LogWriter(log_writer_path)

    pretrained_models_list = list(
        model_class.pretrained_init_configuration.keys())

    if args.model_name_or_path in pretrained_models_list:
        model_config = model_class.pretrained_init_configuration[
            args.model_name_or_path]
        model_config["hidden_dropout_prob"] = args.hidden_dropout_prob
        model_config[
            "attention_probs_dropout_prob"] = args.attention_probs_dropout_prob

        model_config['num_partitions'] = args.mp_degree
        model_config['use_recompute'] = args.use_recompute
        if args.pp_degree == 1:
            model = GPTForPretraining(GPTModel(**model_config))
        else:
            model_config['topology'] = hcg.topology()
            model = GPTForPretrainingPipe(**model_config)
    else:
        model = GPTForPretraining.from_pretrained(
            args.model_name_or_path,
            hidden_dropout_prob=args.hidden_dropout_prob,
            attention_probs_dropout_prob=args.attention_probs_dropout_prob)

    # Create the critrion for the gpt model
    criterion = GPTPretrainingCriterion()

    if args.decay_steps is None:
        args.decay_steps = args.max_steps
    warmup_step = args.warmup_rate * args.decay_steps

    lr_scheduler = None

    if args.lr_decay_style == "none":
        lr_scheduler = None
    elif args.lr_decay_style == "cosine":
        lr_scheduler = lr.CosineAnnealingWithWarmupDecay(
            max_lr=args.max_lr,
            min_lr=args.min_lr,
            warmup_step=warmup_step,
            decay_step=args.decay_steps)

    clip = None
    if args.grad_clip > 0:
        clip = paddle.nn.ClipGradByGlobalNorm(clip_norm=args.grad_clip)

    # Generate parameter names needed to perform weight decay.
    # All bias and LayerNorm parameters are excluded.
    decay_params = [
        p.name for n, p in model.named_parameters()
        if not any(nd in n for nd in ["bias", "norm"])
    ]

    if args.sharding_stage == 1 and args.sharding_degree > 1:
        optimizer = DygraphShardingOptimizer(
            hcg=fleet.get_hybrid_communicate_group(),
            user_defined_strategy=strategy,
            params=model.parameters(),
            inner_optimizer_class=paddle.optimizer.AdamW,
            learning_rate=lr_scheduler
            if lr_scheduler is not None else args.max_lr,
            beta1=args.adam_beta1,
            beta2=args.adam_beta2,
            epsilon=args.adam_epsilon,
            weight_decay=args.weight_decay,
            grad_clip=clip,
            apply_decay_param_fun=lambda x: x in decay_params)
    else:
        optimizer = paddle.optimizer.AdamW(
            learning_rate=lr_scheduler
            if lr_scheduler is not None else args.max_lr,
            beta1=args.adam_beta1,
            beta2=args.adam_beta2,
            epsilon=args.adam_epsilon,
            parameters=model.parameters(),
            weight_decay=args.weight_decay,
            grad_clip=clip,
            apply_decay_param_fun=lambda x: x in decay_params,
            # TODO: remove 'multi_precision' in definition of optimizer
            # and add it to 'paddle.amp.decorate'
            multi_precision=args.use_pure_fp16)

    if args.use_pure_fp16:
        scaler = paddle.amp.GradScaler(init_loss_scaling=args.scale_loss)
        # level O2 means converting the network to FP16
        if args.sharding_stage not in [2, 3]:
            scaler = fleet.distributed_scaler(scaler)
        model = paddle.amp.decorate(models=model,
                                    level='O2',
                                    save_dtype='float32')

    # wrap sharding stage2/3 and add collective group
    # TODO(Baibaifan): combine ShardingStage1/2/3 and fleet.distributed_model in feature
    if args.sharding_stage in [2, 3]:
        scaler = scaler if args.use_pure_fp16 else None
        model, optimizer, scaler = wrap_sharding_2_3(model, optimizer, scaler,
                                                     args.sharding_offload)

    elif paddle.distributed.get_world_size() > 1:
        model = fleet.distributed_model(model)
        optimizer = fleet.distributed_optimizer(optimizer)

    if args.model_name_or_path not in pretrained_models_list:
        logger.info("Try to load checkpoint from %s " %
                    args.model_name_or_path)
        opt_path = os.path.join(args.model_name_or_path, "model_state.pdopt")
        if os.path.exists(opt_path):
            opt_dict = paddle.load(opt_path)
            optimizer.set_state_dict(opt_dict)
        else:
            logger.warning("No optimizer checkpoint file found in %s." %
                           opt_path)

    global_step = 0
    tic_train = time.time()
    for epoch in range(args.num_train_epochs):
        files = get_train_data_file(args)
        files.sort()
        num_files = len(files)
        for f_id in range(num_files):
            data_file = files[f_id]
            train_data_loader, valid_data_loader, test_data_loader = create_pretrained_dataset(
                args, [data_file],
                local_rank=local_rank,
                data_world_size=data_world_size,
                data_world_rank=data_world_rank,
                max_seq_len=args.max_seq_len,
                eos_id=tokenizer.eos_token_id)
            # Bug fix, if not call valid_data_loader, the enumerate will call valid_data_loader
            # many times. and start a new random dataloader.
            valid_data_loader = valid_data_loader()
            test_data_loader = test_data_loader()

            # time count
            train_reader_cost = 0.0
            train_run_cost = 0.0
            reader_start = time.time()
            for step, batch in enumerate(train_data_loader()):
                train_reader_cost += time.time() - reader_start
                train_start = time.time()

                global_step += 1
                tokens, loss_mask, position_ids, labels = batch

                loss_mask.stop_gradient = True
                labels.stop_gradient = True
                position_ids.stop_gradient = True

                if args.pp_degree == 1:
                    # In ParallelMode of DataParallel, 'no_sync' can be used for improving
                    # performance of model by gradient accumulation.
                    loss = 0.0
                    for i in range(accumulate_steps):
                        start_index = i * args.micro_batch_size
                        end_index = start_index + args.micro_batch_size
                        with paddle.amp.auto_cast(
                                args.use_pure_fp16,
                                custom_black_list=[
                                    "reduce_sum",
                                    "c_softmax_with_cross_entropy",
                                    "elementwise_div"
                                ],
                                level='O2'):
                            preds = model(
                                tokens[start_index:end_index, :],
                                position_ids[start_index:end_index, :])
                            loss_mbs = criterion(
                                preds, labels[start_index:end_index, :],
                                loss_mask[start_index:end_index, :])
                        loss_mbs = loss_mbs / accumulate_steps
                        if args.use_pure_fp16:
                            scaler.scale(loss_mbs).backward()
                        else:
                            loss_mbs.backward()
                        loss = loss + loss_mbs

                    if args.use_pure_fp16:
                        if args.sharding_stage in [2, 3]:
                            scaler.step(optimizer)
                            scaler.update()
                        else:
                            scaler.minimize(optimizer, loss)
                    else:
                        optimizer.step()

                    if lr_scheduler is not None:
                        lr_scheduler.step()

                    optimizer.clear_grad()

                else:
                    data = [(tokens, position_ids), (labels, loss_mask)]
                    with paddle.amp.auto_cast(
                            args.use_pure_fp16,
                            custom_black_list=[
                                "reduce_sum", "c_softmax_with_cross_entropy",
                                "elementwise_div"
                            ],
                            level='O2'):
                        loss = model.train_batch(
                            data,
                            optimizer=optimizer,
                            lr_scheduler=lr_scheduler,
                            scaler=scaler if args.use_pure_fp16 else None)

                # Sync for profile time, delete it may be a little faster
                paddle.device.cuda.synchronize()
                train_run_cost += time.time() - train_start
                # Profile for model benchmark
                profiler.add_profiler_step(args.profiler_options)

                if global_step % args.logging_freq == 0:
                    avg_loss = loss.numpy()
                    speed = args.logging_freq / (train_reader_cost +
                                                 train_run_cost)
                    avg_reader_cost = train_reader_cost / args.logging_freq

                    logger.info(
                        "global step %d, epoch: %d, batch: %d, loss: %.9f, avg_reader_cost: %.5f sec, avg_batch_cost: %.5f sec, speed: %.2f step/s, ips_total: %.0f tokens/s, ips: %.0f tokens/s, learning rate: %.5e"
                        %
                        (global_step, epoch, step, avg_loss, avg_reader_cost,
                         1. / speed, speed, speed * default_global_tokens_num,
                         speed * default_global_tokens_num / nranks,
                         optimizer.get_lr()))
                    log_writer.add_scalar("loss", float(loss), global_step)
                    log_writer.add_scalar("learning_rate", optimizer.get_lr(),
                                          global_step)

                    tic_train = time.time()
                    train_reader_cost = 0.0
                    train_run_cost = 0.0

                if args.check_accuracy:
                    if global_step >= args.max_steps:
                        return
                    else:
                        continue

                if global_step % args.eval_freq == 0:
                    # Since the valid data broardcast to all devices, we do evaluate on all device.
                    run_evaluate(args, valid_data_loader, model, criterion,
                                 args.eval_iters, log_writer, global_step,
                                 epoch, "valid")

                # TODO: 1. merge paramters while saving model. 2. ensure that the model is saved and loaded correctly
                # only dp_rank = 0 save model
                if (global_step % args.save_steps == 0
                        or global_step >= args.max_steps) and dp_rank == 0:

                    model_to_save = model._layers if paddle.distributed.get_world_size(
                    ) > 1 and args.sharding_stage not in [2, 3] else model
                    output_dir = os.path.join(args.output_dir,
                                              "step_%d" % global_step)
                    os.makedirs(output_dir, exist_ok=True)

                    logger.info("Save model to %s" % output_dir)

                    if args.pp_degree > 1:
                        if mp_rank == 0 and sharding_rank == 0 and pp_rank == 0:
                            tokenizer.save_pretrained(output_dir)
                        model_to_save.save_state_dict(output_dir)
                        paddle.save(
                            optimizer.state_dict(),
                            os.path.join(
                                output_dir,
                                "model_state_mp_{:0>2d}_sharding_{:0>2d}_pp_{:0>2d}.pdopt"
                                .format(mp_rank, sharding_rank, pp_rank)))
                    else:
                        if args.sharding_stage == 3:
                            # If parameter need to convert to cpu, please add convert2cpu=True
                            model_to_save.get_all_parameters(convert2cpu=False)
                        if mp_rank == 0 and sharding_rank == 0:
                            tokenizer.save_pretrained(output_dir)
                        model_to_save.save_pretrained(output_dir)
                        paddle.save(
                            optimizer.state_dict(),
                            os.path.join(
                                output_dir,
                                "model_state_mp_{:0>2d}_sharding_{:0>2d}.pdopt"
                                .format(mp_rank, sharding_rank)))

                if global_step >= args.max_steps:
                    run_evaluate(args, test_data_loader, model, criterion,
                                 args.test_iters, log_writer, global_step,
                                 epoch, "test")
                    logger.info("The training process is complete.")
                    del train_data_loader
                    return

                reader_start = time.time()

            del train_data_loader