def do_train(args):
    # Initialization for the parallel enviroment
    paddle.set_device(args.device)
    if paddle.distributed.get_world_size() > 1:
        paddle.distributed.init_parallel_env()

    worker_index = paddle.distributed.get_rank()
    worker_num = paddle.distributed.get_world_size()

    # Set the random seed for the training process
    set_seed(args)
    worker_init = WorkerInitObj(args.seed + worker_index)

    # Get the model class and tokenizer class
    args.model_type = args.model_type.lower()
    model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
    tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path)

    # Define the pretrain model and metric
    pretrained_models_list = list(
        model_class.pretrained_init_configuration.keys())
    if args.model_name_or_path in pretrained_models_list:
        model = BigBirdForPretraining(
            BigBirdModel(**model_class.pretrained_init_configuration[
                args.model_name_or_path]))
    else:
        model = BigBirdForPretraining.from_pretrained(args.model_name_or_path)
    # Get bigbird config for generate random attention mask
    config = getattr(model, BigBirdForPretraining.base_model_prefix).config
    criterion = BigBirdPretrainingCriterion(config["vocab_size"], args.use_nsp)
    if worker_num > 1:
        model = paddle.DataParallel(model)

    # Define learing_rate scheduler and optimizer
    lr_scheduler = LinearDecayWithWarmup(args.learning_rate, args.max_steps,
                                         args.warmup_steps)

    # Generate parameter names needed to perform weight decay.
    # All bias and LayerNorm parameters are excluded.
    decay_params = [
        p.name for n, p in model.named_parameters()
        if not any(nd in n for nd in ["bias", "norm"])
    ]
    optimizer = paddle.optimizer.AdamW(
        learning_rate=lr_scheduler,
        epsilon=args.adam_epsilon,
        parameters=model.parameters(),
        weight_decay=args.weight_decay,
        apply_decay_param_fun=lambda x: x in decay_params)

    global_step = 0
    tic_train = time.time()
    for epoch in range(args.epochs):
        files = [
            os.path.join(args.input_dir, f) for f in os.listdir(args.input_dir)
        ]
        files.sort()
        num_files = len(files)
        for f_id in range(num_files):
            train_data_loader = create_dataloader(
                files[f_id], tokenizer, worker_init, args.batch_size,
                args.max_encoder_length, args.max_pred_length, config)
            for step, batch in enumerate(train_data_loader):
                global_step += 1
                (input_ids, segment_ids, masked_lm_positions, masked_lm_ids,
                 masked_lm_weights, next_sentence_labels,
                 masked_lm_scale) = batch[:7]
                rand_mask_idx_list = batch[7:]

                prediction_scores, seq_relationship_score = model(
                    input_ids=input_ids,
                    token_type_ids=segment_ids,
                    rand_mask_idx_list=rand_mask_idx_list,
                    masked_positions=masked_lm_positions)
                loss = criterion(prediction_scores, seq_relationship_score,
                                 masked_lm_ids, next_sentence_labels,
                                 masked_lm_scale, masked_lm_weights)
                if global_step % args.logging_steps == 0 and worker_index == 0:
                    logger.info(
                        "global step %d, epoch: %d, lr: %.10f, loss: %f, speed: %.2f step/s"
                        % (global_step, epoch, optimizer.get_lr(), loss,
                           args.logging_steps / (time.time() - tic_train)))
                    tic_train = time.time()
                loss.backward()
                optimizer.step()
                lr_scheduler.step()
                optimizer.clear_grad()
                if global_step % args.save_steps == 0:
                    if worker_index == 0:
                        output_dir = os.path.join(args.output_dir,
                                                  "model_%d" % global_step)
                        if not os.path.exists(output_dir):
                            os.makedirs(output_dir)
                        # Need better way to get inner model of DataParallel
                        model_to_save = model._layers if isinstance(
                            model, paddle.DataParallel) else model
                        model_to_save.save_pretrained(output_dir)
                        tokenizer.save_pretrained(output_dir)
                        paddle.save(
                            optimizer.state_dict(),
                            os.path.join(output_dir, "model_state.pdopt"))
                if global_step >= args.max_steps:
                    del train_data_loader
                    return
            del train_data_loader
Exemple #2
0
 def setUp(self):
     self.config['vocab_size'] = 1024
     self.config['use_nsp'] = False
     self.criterion = BigBirdPretrainingCriterion(**self.config)
     self.np_criterion = NpBigBirdPretrainingCriterion(**self.config)