Exemple #1
0
def do_train(args):
    paddle.enable_static()
    if args.is_distributed:
        fleet.init(is_collective=True)
        assert args.device != "xpu", "xpu doesn't support distributed training"
        places = [paddle.set_device("gpu")] if \
                 args.device == "gpu" else paddle.static.cpu_places()
        trainer_count = len(places)
    else:
        if args.device == "gpu":
            places = paddle.static.cuda_places()
        elif args.device == "xpu":
            places = paddle.static.xpu_places()
            paddle.set_device("xpu")
        else:
            places = paddle.static.cpu_places()
            paddle.set_device("cpu")
        trainer_count = len(places)

    # Set seed for CE
    random_seed = eval(str(args.random_seed))
    if random_seed is not None:
        paddle.seed(random_seed)

    # Define data loader
    (train_loader), (eval_loader) = reader.create_data_loader(args,
                                                              places=places)

    train_program = paddle.static.Program()
    startup_program = paddle.static.Program()
    with paddle.static.program_guard(train_program, startup_program):
        src_word = paddle.static.data(name="src_word",
                                      shape=[None, None],
                                      dtype=args.input_dtype)
        trg_word = paddle.static.data(name="trg_word",
                                      shape=[None, None],
                                      dtype=args.input_dtype)
        lbl_word = paddle.static.data(name="lbl_word",
                                      shape=[None, None, 1],
                                      dtype=args.input_dtype)

        # Define model
        transformer = TransformerModel(src_vocab_size=args.src_vocab_size,
                                       trg_vocab_size=args.trg_vocab_size,
                                       max_length=args.max_length + 1,
                                       num_encoder_layers=args.n_layer,
                                       num_decoder_layers=args.n_layer,
                                       n_head=args.n_head,
                                       d_model=args.d_model,
                                       d_inner_hid=args.d_inner_hid,
                                       dropout=args.dropout,
                                       weight_sharing=args.weight_sharing,
                                       bos_id=args.bos_idx,
                                       eos_id=args.eos_idx)
        # Define loss
        criterion = CrossEntropyCriterion(args.label_smooth_eps, args.bos_idx)

        logits = transformer(src_word=src_word, trg_word=trg_word)

        sum_cost, avg_cost, token_num = criterion(logits, lbl_word)

        scheduler = paddle.optimizer.lr.NoamDecay(args.d_model,
                                                  args.warmup_steps,
                                                  args.learning_rate,
                                                  last_epoch=0)

        # Define optimizer
        optimizer = paddle.optimizer.Adam(learning_rate=scheduler,
                                          beta1=args.beta1,
                                          beta2=args.beta2,
                                          epsilon=float(args.eps),
                                          parameters=transformer.parameters())

        if args.is_distributed:
            build_strategy = paddle.static.BuildStrategy()
            exec_strategy = paddle.static.ExecutionStrategy()
            dist_strategy = fleet.DistributedStrategy()
            dist_strategy.build_strategy = build_strategy
            dist_strategy.execution_strategy = exec_strategy
            dist_strategy.fuse_grad_size_in_MB = 16

            if args.use_amp:
                dist_strategy.amp = True
                dist_strategy.amp_configs = {
                    'custom_white_list': ['softmax', 'layer_norm'],
                    'init_loss_scaling': args.scale_loss,
                    'custom_black_list': ['lookup_table_v2'],
                    'use_pure_fp16': args.use_pure_fp16
                }

            optimizer = fleet.distributed_optimizer(optimizer,
                                                    strategy=dist_strategy)
        else:
            if args.use_amp:
                amp_list = paddle.static.amp.AutoMixedPrecisionLists(
                    custom_white_list=['softmax', 'layer_norm'],
                    custom_black_list=['lookup_table_v2'])
                optimizer = paddle.static.amp.decorate(
                    optimizer,
                    amp_list,
                    init_loss_scaling=args.scale_loss,
                    use_dynamic_loss_scaling=True,
                    use_pure_fp16=args.use_pure_fp16)
        optimizer.minimize(avg_cost)

    if args.is_distributed:
        exe = paddle.static.Executor(places[0])
    else:
        exe = paddle.static.Executor()
        build_strategy = paddle.static.BuildStrategy()
        exec_strategy = paddle.static.ExecutionStrategy()

        compiled_train_program = paddle.static.CompiledProgram(
            train_program).with_data_parallel(loss_name=avg_cost.name,
                                              build_strategy=build_strategy,
                                              exec_strategy=exec_strategy)
    exe.run(startup_program)

    if args.use_amp:
        optimizer.amp_init(places[0])

    # the best cross-entropy value with label smoothing
    loss_normalizer = -(
        (1. - args.label_smooth_eps) * np.log((1. - args.label_smooth_eps)) +
        args.label_smooth_eps * np.log(args.label_smooth_eps /
                                       (args.trg_vocab_size - 1) + 1e-20))

    step_idx = 0

    # For benchmark
    reader_cost_avg = AverageStatistical()
    batch_cost_avg = AverageStatistical()
    batch_ips_avg = AverageStatistical()

    for pass_id in range(args.epoch):
        batch_id = 0
        batch_start = time.time()
        pass_start_time = batch_start
        for data in train_loader:
            # NOTE: used for benchmark and use None as default.
            if args.max_iter and step_idx == args.max_iter:
                break
            if trainer_count == 1:
                data = [data]
            train_reader_cost = time.time() - batch_start

            if args.is_distributed:
                outs = exe.run(train_program,
                               feed=[{
                                   'src_word': data[i][0],
                                   'trg_word': data[i][1],
                                   'lbl_word': data[i][2],
                               } for i in range(trainer_count)],
                               fetch_list=[sum_cost.name, token_num.name])
                train_batch_cost = time.time() - batch_start
                batch_ips_avg.record(train_batch_cost,
                                     np.asarray(outs[1]).sum())
            else:
                outs = exe.run(compiled_train_program,
                               feed=[{
                                   'src_word': data[i][0],
                                   'trg_word': data[i][1],
                                   'lbl_word': data[i][2],
                               } for i in range(trainer_count)],
                               fetch_list=[sum_cost.name, token_num.name])
                train_batch_cost = time.time() - batch_start
                batch_ips_avg.record(train_batch_cost,
                                     np.asarray(outs[1]).sum() / trainer_count)
            scheduler.step()

            reader_cost_avg.record(train_reader_cost)
            batch_cost_avg.record(train_batch_cost)

            # Profile for model benchmark
            if args.profiler_options is not None:
                profiler.add_profiler_step(args.profiler_options)

            if step_idx % args.print_step == 0 and (
                    args.benchmark or
                (args.is_distributed and dist.get_rank() == 0)
                    or not args.is_distributed):
                sum_cost_val, token_num_val = np.array(outs[0]), np.array(
                    outs[1])
                # Sum the cost from multi-devices
                total_sum_cost = sum_cost_val.sum()
                total_token_num = token_num_val.sum()
                total_avg_cost = total_sum_cost / total_token_num

                if step_idx == 0:
                    logging.info(
                        "step_idx: %d, epoch: %d, batch: %d, avg loss: %f, "
                        "normalized loss: %f, ppl: %f" %
                        (step_idx, pass_id, batch_id, total_avg_cost,
                         total_avg_cost - loss_normalizer,
                         np.exp([min(total_avg_cost, 100)])))
                else:
                    train_avg_batch_cost = args.print_step / batch_cost_avg.get_total_time(
                    )
                    logging.info(
                        "step_idx: %d, epoch: %d, batch: %d, avg loss: %f, "
                        "normalized loss: %f, ppl: %f, avg_speed: %.2f step/s, "
                        "batch_cost: %.5f sec, reader_cost: %.5f sec, tokens: %d, "
                        "ips: %.5f words/sec" %
                        (step_idx, pass_id, batch_id, total_avg_cost,
                         total_avg_cost - loss_normalizer,
                         np.exp([min(total_avg_cost, 100)]),
                         train_avg_batch_cost, batch_cost_avg.get_average(),
                         reader_cost_avg.get_average(),
                         batch_ips_avg.get_total_cnt(),
                         batch_ips_avg.get_average_per_sec()))
                reader_cost_avg.reset()
                batch_cost_avg.reset()
                batch_ips_avg.reset()

            if step_idx % args.save_step == 0 and step_idx != 0:
                if args.save_model and dist.get_rank() == 0:
                    model_path = os.path.join(args.save_model,
                                              "step_" + str(step_idx),
                                              "transformer")
                    paddle.static.save(train_program, model_path)

            batch_id += 1
            step_idx += 1
            batch_start = time.time()

        # NOTE: used for benchmark and use None as default.
        if args.max_iter and step_idx == args.max_iter:
            break

    if args.save_model and dist.get_rank() == 0:
        model_path = os.path.join(args.save_model, "step_final", "transformer")
        paddle.static.save(train_program, model_path)

    paddle.disable_static()
Exemple #2
0
def do_train(args):
    if args.device == "gpu":
        rank = dist.get_rank()
        trainer_count = dist.get_world_size()
    else:
        rank = 0
        trainer_count = 1
        paddle.set_device("cpu")

    if trainer_count > 1:
        dist.init_parallel_env()

    # Set seed for CE
    if args.seed is not None:
        set_seed(args.seed)

    benchmark_model = MODEL_REGISTRY[args.model]()
    benchmark_optimizer = OPTIMIZER_REGISTRY[args.optimizer]()

    # Define data loader
    train_loader, eval_loader = benchmark_model.create_data_loader(args)

    if args.max_steps is None or (args.max_steps is not None and
                                  args.max_steps < 0):
        args.max_steps = len(train_loader) * args.epoch

    # Define model
    model = benchmark_model.build_model(args)

    if args.lr_scheduler is not None:
        benchmark_lr_scheduler = LR_SCHEDULER_REGISTRY[args.lr_scheduler]()
        lr = benchmark_lr_scheduler.build_scheculer(args)
    else:
        lr = args.learning_rate

    optimizer = benchmark_optimizer.build_optimizer(args, lr, model)

    # for amp training
    if args.use_amp:
        scaler = paddle.amp.GradScaler(
            enable=True, init_loss_scaling=args.scale_loss)
        model = paddle.amp.decorate(
            models=model, level=args.amp_level, save_dtype='float32')

    # for distributed training
    if trainer_count > 1:
        model = paddle.DataParallel(model)

    step_id = 1

    # For benchmark
    reader_cost_avg = AverageStatistical()
    batch_cost_avg = AverageStatistical()
    batch_ips_avg = AverageStatistical()

    # Train loop
    for pass_id in range(args.epoch):
        epoch_start = time.time()

        batch_id = 0
        batch_start = time.time()
        for input_data in train_loader:
            train_reader_cost = time.time() - batch_start

            if args.use_amp:
                with paddle.amp.auto_cast(
                        custom_black_list=args.custom_black_list
                        if amp_level == 'O2' else {},
                        level=amp_level):
                    loss, sample_per_cards = benchmark_model.forward(
                        model, args, input_data)

                scaled = scaler.scale(loss)
                scaled.backward()

                scaler.minimize(optimizer, scaled)
                if 'set_to_zero' in inspect.getfullargspec(
                        optimizer.clear_grad).args:
                    optimizer.clear_grad(set_to_zero=False)
                else:
                    optimizer.clear_grad()
            else:
                loss, sample_per_cards = benchmark_model.forward(model, args,
                                                                 input_data)

                loss.backward()

                optimizer.step()
                optimizer.clear_grad()

            train_batch_cost = time.time() - batch_start
            reader_cost_avg.record(train_reader_cost)
            batch_cost_avg.record(train_batch_cost)
            batch_ips_avg.record(train_batch_cost, sample_per_cards)

            if args.profiler_options is not None:
                profiler.add_profiler_step(args.profiler_options)

            if step_id % args.logging_steps == 0:
                total_avg_loss = loss.numpy()

                benchmark_model.logger(
                    args,
                    step_id=step_id,
                    pass_id=pass_id,
                    batch_id=batch_id,
                    loss=total_avg_loss,
                    batch_cost=batch_cost_avg.get_average(),
                    reader_cost=reader_cost_avg.get_average(),
                    num_samples=sample_per_cards,
                    ips=batch_ips_avg.get_average_per_sec())

                reader_cost_avg.reset()
                batch_cost_avg.reset()
                batch_ips_avg.reset()

            if args.max_steps and step_id == args.max_steps:
                if args.save_model and rank == 0:
                    model_dir = args.save_model
                    if not os.path.exists(model_dir):
                        os.makedirs(model_dir)
                    paddle.save(model.state_dict(),
                                os.path.join(model_dir, "model.pdparams"))
                    paddle.save(optimizer.state_dict(),
                                os.path.join(model_dir, "model.pdopt"))
                return
            batch_id += 1
            step_id += 1
            if args.lr_scheduler is not None and not args.scheduler_update_by_epoch:
                lr.step()
            batch_start = time.time()

        if args.lr_scheduler is not None and args.scheduler_update_by_epoch:
            lr.step()

        train_epoch_cost = time.time() - epoch_start
        logger.info("train epoch: %d, epoch_cost: %.5f s" %
                    (pass_id, train_epoch_cost))
def do_train(args):
    # Initialize the paddle and paddle fleet execute enviroment
    paddle.enable_static()
    place = paddle.set_device(args.device)
    fleet.init(is_collective=True)

    worker_num = fleet.worker_num()
    worker_index = fleet.worker_index()

    # Create the random seed for the worker
    set_seed(args.seed)
    worker_init = WorkerInitObj(args.seed + worker_index)

    # Define the input data in the static mode
    main_program = paddle.static.default_main_program()
    startup_program = paddle.static.default_startup_program()

    data_holders = create_data_holder(args)

    [
        input_ids, segment_ids, input_mask, masked_lm_positions,
        masked_lm_labels, next_sentence_labels, masked_lm_scale
    ] = data_holders

    # Define the model structure in static mode
    args.model_type = args.model_type.lower()
    model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
    tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path)
    config = model_class.pretrained_init_configuration[args.model_name_or_path]
    if config["vocab_size"] % 8 != 0:
        config["vocab_size"] += 8 - (config["vocab_size"] % 8)
    model = BertForPretraining(BertModel(**config))
    criterion = BertPretrainingCriterion(model.bert.config["vocab_size"])
    prediction_scores, seq_relationship_score = model(
        input_ids=input_ids,
        token_type_ids=segment_ids,
        attention_mask=input_mask,
        masked_positions=masked_lm_positions)
    loss = criterion(prediction_scores, seq_relationship_score,
                     masked_lm_labels, next_sentence_labels, masked_lm_scale)

    # Define the dynamic learing_reate scheduler and optimizer
    num_training_steps = args.max_steps if args.max_steps > 0 else len(
        train_data_loader) * args.num_train_epochs

    lr_scheduler = LinearDecayWithWarmup(args.learning_rate, num_training_steps,
                                         args.warmup_steps)

    # Generate parameter names needed to perform weight decay.
    # All bias and LayerNorm parameters are excluded.
    decay_params = [
        p.name for n, p in model.named_parameters()
        if not any(nd in n for nd in ["bias", "norm"])
    ]
    optimizer = paddle.optimizer.AdamW(
        learning_rate=lr_scheduler,
        epsilon=args.adam_epsilon,
        parameters=model.parameters(),
        weight_decay=args.weight_decay,
        apply_decay_param_fun=lambda x: x in decay_params,
        multi_precision=args.use_pure_fp16)

    # Use the fleet api to compile the distributed optimizer
    optimizer = dist_optimizer(args, optimizer)
    optimizer.minimize(loss)

    # Define the Executor for running the static model
    exe = paddle.static.Executor(place)
    exe.run(startup_program)
    state_dict = model.state_dict()

    # Use the state dict to update the parameter
    reset_state_dict = reset_program_state_dict(model, state_dict)
    paddle.static.set_program_state(main_program, reset_state_dict)
    if args.use_amp:
        optimizer.amp_init(place)

    pool = ThreadPoolExecutor(1)
    global_step = 0
    tic_train = time.time()
    epoch = 0
    while True:
        files = [
            os.path.join(args.input_dir, f) for f in os.listdir(args.input_dir)
            if os.path.isfile(os.path.join(args.input_dir, f)) and "training" in
            f
        ]
        files.sort()
        num_files = len(files)
        random.Random(args.seed + epoch).shuffle(files)
        f_start_id = 0

        # Select one file for each worker and create the DataLoader for the file
        data_file = select_dataset_file_for_each_worker(
            files, f_start_id, worker_num, worker_index)
        train_data_loader, _ = create_pretraining_dataset(
            data_file, args.max_predictions_per_seq, args, data_holders,
            worker_init, paddle.static.cuda_places())

        for f_id in range(f_start_id + 1, len(files)):
            data_file = select_dataset_file_for_each_worker(
                files, f_id, worker_num, worker_index)
            dataset_future = pool.submit(create_pretraining_dataset, data_file,
                                         args.max_predictions_per_seq, args,
                                         data_holders, worker_init,
                                         paddle.static.cuda_places())

            train_cost_avg = TimeCostAverage()
            reader_cost_avg = TimeCostAverage()
            total_samples = 0
            batch_start = time.time()
            for step, batch in enumerate(train_data_loader):
                train_reader_cost = time.time() - batch_start
                reader_cost_avg.record(train_reader_cost)
                global_step += 1
                train_start = time.time()
                loss_return = exe.run(main_program,
                                      feed=batch,
                                      fetch_list=[loss])
                total_samples += args.batch_size
                # In the new 2.0 api, must call this function to change the learning_rate
                lr_scheduler.step()
                train_run_cost = time.time() - batch_start
                train_cost_avg.record(train_run_cost)

                # Profile for model benchmark
                if args.profiler_options is not None:
                    profiler.add_profiler_step(args.profiler_options)

                if global_step % args.logging_steps == 0:
                    print(
                        "tobal step: %d, epoch: %d, batch: %d, loss: %f, "
                        "avg_reader_cost: %.5f sec, avg_batch_cost: %.5f sec, avg_samples: %.5f, ips: %.5f sequences/sec"
                        % (global_step, epoch, step, loss_return[0],
                           reader_cost_avg.get_average(),
                           train_cost_avg.get_average(), total_samples /
                           args.logging_steps, args.batch_size / (
                               reader_cost_avg.get_average() +
                               train_cost_avg.get_average())))
                    total_samples = 0
                    train_cost_avg.reset()
                    reader_cost_avg.reset()
                if global_step % args.save_steps == 0:
                    if worker_index == 0:
                        output_dir = os.path.join(args.output_dir,
                                                  "model_%d" % global_step)
                        if not os.path.exists(output_dir):
                            os.makedirs(output_dir)
                        model.save_model_config(output_dir)
                        paddle.static.save(main_program,
                                           os.path.join(output_dir,
                                                        "model_state"))
                        tokenizer.save_pretrained(output_dir)
                if global_step >= args.max_steps:
                    reader_start = time.time()
                    del train_data_loader
                    return
                batch_start = time.time()
            del train_data_loader
            train_data_loader, data_file = dataset_future.result(timeout=None)
        epoch += 1
def do_train(args):
    paddle.set_device(args.device)
    nranks = paddle.distributed.get_world_size()
    strategy = fleet.DistributedStrategy()
    strategy.hybrid_configs = {
        "dp_degree": args.dp_degree,
        "mp_degree": args.mp_degree,
        "pp_degree": args.pp_degree,
        "sharding_degree": args.sharding_degree
    }

    accumulate_steps = args.local_batch_size // args.micro_batch_size
    strategy.pipeline_configs = {
        "accumulate_steps": accumulate_steps,
        "micro_batch_size": args.micro_batch_size
    }

    # set control in tensor parallel
    strategy.tensor_parallel_configs = {"tensor_init_seed": args.seed}

    fleet.init(is_collective=True, strategy=strategy)

    # obtain rank message of hybrid parallel
    hcg = fleet.get_hybrid_communicate_group()
    global_rank = hcg.get_global_rank()
    mp_rank = hcg.get_model_parallel_rank()
    pp_rank = hcg.get_stage_id()
    dp_rank = hcg.get_data_parallel_rank()
    sharding_rank = hcg.get_sharding_parallel_rank()

    # sharding stage2/3 not support hybrid parallel
    if args.sharding_stage in [2, 3]:
        assert args.dp_degree == args.mp_degree == args.pp_degree == 1, "sharding stage2/3 will support hybrid parallel later"

    sharding_size = hcg.get_sharding_parallel_world_size()
    data_world_rank = dp_rank * sharding_size + sharding_rank
    data_world_size = args.dp_degree * args.sharding_degree
    local_rank = int(os.getenv("PADDLE_RANK_IN_NODE", 0))

    # seed control in hybrid parallel
    set_hyrbid_parallel_seed(args.seed, data_world_rank, mp_rank, pp_rank)

    default_global_tokens_num = args.global_batch_size * args.max_seq_len

    model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
    tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path)

    # Define log writer
    log_writer_path = os.path.join(
        args.output_dir, "train_log",
        "{}_globalbsz_{}_pure_fp16_{}_recompute_{}_card_{}".format(
            args.model_name_or_path, args.global_batch_size, args.use_pure_fp16,
            False, global_rank).lower())

    if os.path.exists(log_writer_path):
        import shutil
        shutil.rmtree(log_writer_path)

    log_writer = LogWriter(log_writer_path)

    pretrained_models_list = list(
        model_class.pretrained_init_configuration.keys())

    if args.model_name_or_path in pretrained_models_list:
        model_config = model_class.pretrained_init_configuration[
            args.model_name_or_path]
        model_config["hidden_dropout_prob"] = args.hidden_dropout_prob
        model_config[
            "attention_probs_dropout_prob"] = args.attention_probs_dropout_prob

        model_config['num_partitions'] = args.mp_degree
        model_config['use_recompute'] = args.use_recompute
        if args.pp_degree == 1:
            model = GPTForPretraining(GPTModel(**model_config))
        else:
            model_config['topology'] = hcg.topology()
            model = GPTForPretrainingPipe(**model_config)
    else:
        model = GPTForPretraining.from_pretrained(
            args.model_name_or_path,
            hidden_dropout_prob=args.hidden_dropout_prob,
            attention_probs_dropout_prob=args.attention_probs_dropout_prob)

    # Create the critrion for the gpt model
    criterion = GPTPretrainingCriterion()

    if args.decay_steps is None:
        args.decay_steps = args.max_steps
    warmup_step = args.warmup_rate * args.decay_steps

    lr_scheduler = None

    if args.lr_decay_style == "none":
        lr_scheduler = None
    elif args.lr_decay_style == "cosine":
        lr_scheduler = lr.CosineAnnealingWithWarmupDecay(
            max_lr=args.max_lr,
            min_lr=args.min_lr,
            warmup_step=warmup_step,
            decay_step=args.decay_steps)

    clip = None
    if args.grad_clip > 0:
        clip = paddle.nn.ClipGradByGlobalNorm(clip_norm=args.grad_clip)

    # Generate parameter names needed to perform weight decay.
    # All bias and LayerNorm parameters are excluded.
    decay_params = [
        p.name for n, p in model.named_parameters()
        if not any(nd in n for nd in ["bias", "norm"])
    ]

    if args.sharding_stage == 1 and args.sharding_degree > 1:
        optimizer = DygraphShardingOptimizer(
            hcg=fleet.get_hybrid_communicate_group(),
            user_defined_strategy=strategy,
            params=model.parameters(),
            inner_optimizer_class=paddle.optimizer.AdamW,
            learning_rate=lr_scheduler
            if lr_scheduler is not None else args.max_lr,
            beta1=args.adam_beta1,
            beta2=args.adam_beta2,
            epsilon=args.adam_epsilon,
            weight_decay=args.weight_decay,
            grad_clip=clip,
            apply_decay_param_fun=lambda x: x in decay_params)
    else:
        optimizer = paddle.optimizer.AdamW(
            learning_rate=lr_scheduler
            if lr_scheduler is not None else args.max_lr,
            beta1=args.adam_beta1,
            beta2=args.adam_beta2,
            epsilon=args.adam_epsilon,
            parameters=model.parameters(),
            weight_decay=args.weight_decay,
            grad_clip=clip,
            apply_decay_param_fun=lambda x: x in decay_params,
            # TODO: remove 'multi_precision' in definition of optimizer
            # and add it to 'paddle.amp.decorate'
            multi_precision=args.use_pure_fp16)

    if args.use_pure_fp16:
        scaler = paddle.amp.GradScaler(init_loss_scaling=args.scale_loss)
        # level O2 means converting the network to FP16
        if args.sharding_stage not in [2, 3]:
            scaler = fleet.distributed_scaler(scaler)
        model = paddle.amp.decorate(
            models=model, level='O2', save_dtype='float32')

    # wrap sharding stage2/3 and add collective group
    # TODO(Baibaifan): combine ShardingStage1/2/3 and fleet.distributed_model in feature
    if args.sharding_stage in [2, 3]:
        scaler = scaler if args.use_pure_fp16 else None
        model, optimizer, scaler = wrap_sharding_2_3(model, optimizer, scaler,
                                                     args.sharding_offload)

    elif paddle.distributed.get_world_size() > 1:
        model = fleet.distributed_model(model)
        optimizer = fleet.distributed_optimizer(optimizer)

    if args.model_name_or_path not in pretrained_models_list:
        logger.info("Try to load checkpoint from %s " % args.model_name_or_path)
        opt_path = os.path.join(args.model_name_or_path, "model_state.pdopt")
        if os.path.exists(opt_path):
            opt_dict = paddle.load(opt_path)
            optimizer.set_state_dict(opt_dict)
        else:
            logger.warning("No optimizer checkpoint file found in %s." %
                           opt_path)

    global_step = 0
    tic_train = time.time()
    for epoch in range(args.num_train_epochs):
        files = get_train_data_file(args)
        files.sort()
        num_files = len(files)
        for f_id in range(num_files):
            data_file = files[f_id]
            train_data_loader, valid_data_loader, test_data_loader = create_pretrained_dataset(
                args, [data_file],
                local_rank=local_rank,
                data_world_size=data_world_size,
                data_world_rank=data_world_rank,
                eos_id=tokenizer.eos_token_id)
            # Bug fix, if not call valid_data_loader, the enumerate will call valid_data_loader
            # many times. and start a new random dataloader.
            valid_data_loader = valid_data_loader()
            test_data_loader = test_data_loader()

            # time count
            train_reader_cost = 0.0
            train_run_cost = 0.0
            reader_start = time.time()
            for step, batch in enumerate(train_data_loader()):
                train_reader_cost += time.time() - reader_start
                train_start = time.time()

                global_step += 1
                tokens, loss_mask, position_ids, labels = batch

                loss_mask.stop_gradient = True
                labels.stop_gradient = True
                position_ids.stop_gradient = True

                if args.pp_degree == 1:
                    # In ParallelMode of DataParallel, 'no_sync' can be used for improving
                    # performance of model by gradient accumulation.
                    loss = 0.0
                    for i in range(accumulate_steps):
                        start_index = i * args.micro_batch_size
                        end_index = start_index + args.micro_batch_size
                        with paddle.amp.auto_cast(
                                args.use_pure_fp16,
                                custom_black_list=[
                                    "reduce_sum",
                                    "c_softmax_with_cross_entropy",
                                    "elementwise_div"
                                ],
                                level='O2'):
                            preds = model(
                                tokens[start_index:end_index, :],
                                position_ids[start_index:end_index, :])
                            loss_mbs = criterion(
                                preds, labels[start_index:end_index, :],
                                loss_mask[start_index:end_index, :])
                        loss_mbs = loss_mbs / accumulate_steps
                        if args.use_pure_fp16:
                            scaler.scale(loss_mbs).backward()
                        else:
                            loss_mbs.backward()
                        loss = loss + loss_mbs

                    if args.use_pure_fp16:
                        if args.sharding_stage in [2, 3]:
                            scaler.step(optimizer)
                            scaler.update()
                        else:
                            scaler.minimize(optimizer, loss)
                    else:
                        optimizer.step()

                    if lr_scheduler is not None:
                        lr_scheduler.step()

                    optimizer.clear_grad()

                else:
                    data = [(tokens, position_ids), (labels, loss_mask)]
                    with paddle.amp.auto_cast(
                            args.use_pure_fp16,
                            custom_black_list=[
                                "reduce_sum", "c_softmax_with_cross_entropy",
                                "elementwise_div"
                            ],
                            level='O2'):
                        loss = model.train_batch(
                            data,
                            optimizer=optimizer,
                            lr_scheduler=lr_scheduler,
                            scaler=scaler if args.use_pure_fp16 else None)

                # Sync for profile time, delete it may be a little faster
                paddle.device.cuda.synchronize()
                train_run_cost += time.time() - train_start
                # Profile for model benchmark
                profiler.add_profiler_step(args.profiler_options)

                if global_step % args.logging_freq == 0:
                    avg_loss = loss.numpy()
                    speed = args.logging_freq / (
                        train_reader_cost + train_run_cost)
                    avg_reader_cost = train_reader_cost / args.logging_freq

                    logger.info(
                        "global step %d, epoch: %d, batch: %d, loss: %.9f, avg_reader_cost: %.5f sec, avg_batch_cost: %.5f sec, speed: %.2f step/s, ips: %.0f tokens/s, ips_per_card: %.0f tokens/s, learning rate: %.5e"
                        % (global_step, epoch, step, avg_loss, avg_reader_cost,
                           1. / speed, speed, speed * default_global_tokens_num,
                           speed * default_global_tokens_num / nranks,
                           optimizer.get_lr()))
                    log_writer.add_scalar("loss", float(loss), global_step)
                    log_writer.add_scalar("learning_rate",
                                          optimizer.get_lr(), global_step)

                    tic_train = time.time()
                    train_reader_cost = 0.0
                    train_run_cost = 0.0

                if args.check_accuracy:
                    if global_step >= args.max_steps:
                        return
                    else:
                        continue

                if global_step % args.eval_freq == 0:
                    # Since the valid data broardcast to all devices, we do evaluate on all device.
                    run_evaluate(args, valid_data_loader, model, criterion,
                                 args.eval_iters, log_writer, global_step,
                                 epoch, "valid")

                # TODO: 1. merge paramters while saving model. 2. ensure that the model is saved and loaded correctly
                # only dp_rank = 0 save model
                if (global_step % args.save_steps == 0 or
                        global_step >= args.max_steps) and dp_rank == 0:

                    model_to_save = model._layers if paddle.distributed.get_world_size(
                    ) > 1 and args.sharding_stage not in [2, 3] else model
                    output_dir = os.path.join(args.output_dir,
                                              "step_%d" % global_step)
                    os.makedirs(output_dir, exist_ok=True)

                    logger.info("Save model to %s" % output_dir)

                    if args.pp_degree > 1:
                        if mp_rank == 0 and sharding_rank == 0 and pp_rank == 0:
                            tokenizer.save_pretrained(output_dir)
                        model_to_save.save_state_dict(output_dir)
                        paddle.save(
                            optimizer.state_dict(),
                            os.path.join(
                                output_dir,
                                "model_state_mp_{:0>2d}_sharding_{:0>2d}_pp_{:0>2d}.pdopt".
                                format(mp_rank, sharding_rank, pp_rank)))
                    else:
                        if args.sharding_stage == 3:
                            # If parameter need to convert to cpu, please add convert2cpu=True
                            model_to_save.get_all_parameters(convert2cpu=False)
                        if mp_rank == 0 and sharding_rank == 0:
                            tokenizer.save_pretrained(output_dir)
                        model_to_save.save_pretrained(output_dir)
                        paddle.save(
                            optimizer.state_dict(),
                            os.path.join(
                                output_dir,
                                "model_state_mp_{:0>2d}_sharding_{:0>2d}.pdopt".
                                format(mp_rank, sharding_rank)))

                if global_step >= args.max_steps:
                    run_evaluate(args, test_data_loader, model, criterion,
                                 args.test_iters, log_writer, global_step,
                                 epoch, "test")
                    logger.info("The training process is complete.")
                    del train_data_loader
                    return

                reader_start = time.time()

            del train_data_loader
Exemple #5
0
def do_train(args):
    paddle.set_device(args.device)
    if paddle.distributed.get_world_size() > 1:
        paddle.distributed.init_parallel_env()

    set_seed(args)
    global final_res

    args.task_name = args.task_name.lower()
    metric_class = METRIC_CLASSES[args.task_name]
    model_class, tokenizer_class = XLNetForSequenceClassification, XLNetTokenizer

    tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path)

    if args.task_name == "mnli":
        train_data_loader, dev_data_loader_matched, dev_data_loader_mismatched, train_ds, dev_ds_matched, dev_ds_mismatched = create_data_loader(
            args, tokenizer)
    else:
        train_data_loader, dev_data_loader, train_ds, dev_ds = create_data_loader(
            args, tokenizer)

    num_classes = 1 if train_ds.label_list is None else len(
        train_ds.label_list)
    model = XLNetForSequenceClassification.from_pretrained(
        args.model_name_or_path, num_classes=num_classes)

    if paddle.distributed.get_world_size() > 1:
        model = paddle.DataParallel(model)

    if args.max_steps > 0:
        num_training_steps = args.max_steps
        num_train_epochs = ceil(num_training_steps / len(train_data_loader))
    else:
        num_training_steps = len(train_data_loader) * args.num_train_epochs
        num_train_epochs = args.num_train_epochs

    warmup = args.warmup_steps if args.warmup_steps > 0 else args.warmup_proportion
    lr_scheduler = LinearDecayWithWarmup(args.learning_rate,
                                         num_training_steps, warmup)

    clip = paddle.nn.ClipGradByGlobalNorm(clip_norm=args.max_grad_norm)
    # Generate parameter names needed to perform weight decay.
    # All bias and LayerNorm parameters are excluded.
    decay_params = [
        p.name for n, p in model.named_parameters()
        if not any(nd in n for nd in ["bias", "layer_norm"])
    ]
    optimizer = paddle.optimizer.AdamW(
        learning_rate=lr_scheduler,
        beta1=0.9,
        beta2=0.999,
        epsilon=args.adam_epsilon,
        parameters=model.parameters(),
        grad_clip=clip,
        weight_decay=args.weight_decay,
        apply_decay_param_fun=lambda x: x in decay_params)

    loss_fct = paddle.nn.loss.CrossEntropyLoss(
    ) if train_ds.label_list else paddle.nn.loss.MSELoss()

    metric = metric_class()

    global_step = 0
    model.train()

    train_reader_cost = 0.0
    train_run_cost = 0.0
    reader_start = time.time()
    for epoch in range(num_train_epochs):
        for step, batch in enumerate(train_data_loader):
            train_reader_cost += time.time() - reader_start
            train_start = time.time()

            global_step += 1
            input_ids, token_type_ids, attention_mask, labels = batch
            logits = model(input_ids, token_type_ids, attention_mask)
            loss = loss_fct(logits, labels)
            loss.backward()
            optimizer.step()
            lr_scheduler.step()
            optimizer.clear_grad()

            train_run_cost += time.time() - train_start
            # Profile for model benchmark
            profiler.add_profiler_step(args.profiler_options)

            if global_step % args.logging_steps == 0:
                speed = args.logging_steps / (train_reader_cost +
                                              train_run_cost)
                avg_reader_cost = train_reader_cost / args.logging_steps
                print(
                    "global step %d/%d, epoch: %d, batch: %d, rank_id: %s, loss: %f, lr: %.10f, speed: %.4f step/s, avg_reader_cost: %.4f sec, avg_batch_cost: %.4f sec, avg_samples: %d, avg_ips: %.4f sequences/sec"
                    % (
                        global_step,
                        num_training_steps,
                        epoch,
                        step,
                        paddle.distributed.get_rank(),
                        loss,
                        optimizer.get_lr(),
                        speed,
                        avg_reader_cost,
                        1.0 / speed,
                        args.batch_size,
                        speed * args.batch_size,
                    ))
                train_reader_cost = 0.0
                train_run_cost = 0.0

            if global_step % args.save_steps == 0 or global_step == num_training_steps:
                tic_eval = time.time()
                if args.task_name == "mnli":
                    print("matched ", end="")
                    evaluate(model, loss_fct, metric, dev_data_loader_matched)
                    final_res1 = "matched " + final_res
                    print("mismatched ", end="")
                    evaluate(model, loss_fct, metric,
                             dev_data_loader_mismatched)
                    final_res2 = "mismatched " + final_res
                    final_res = final_res1 + "\r\n" + final_res2
                    print("eval done total : %s s" % (time.time() - tic_eval))
                else:
                    evaluate(model, loss_fct, metric, dev_data_loader)
                    print("eval done total : %s s" % (time.time() - tic_eval))
                if (not paddle.distributed.get_world_size() > 1
                    ) or paddle.distributed.get_rank() == 0:
                    output_dir = os.path.join(
                        args.output_dir,
                        "%s_ft_model_%d" % (args.task_name, global_step))
                    if not os.path.exists(output_dir):
                        os.makedirs(output_dir)
                    # Need better way to get inner model of DataParallel
                    model_to_save = model._layers if isinstance(
                        model, paddle.DataParallel) else model
                    model_to_save.save_pretrained(output_dir)
                    tokenizer.save_pretrained(output_dir)
                if global_step == num_training_steps:
                    print(final_res)
                    exit(0)

            reader_start = time.time()
Exemple #6
0
 def on_batch_end(self, mode, step=None, logs=None):
     if mode == 'train':
         profiler.add_profiler_step(self.profiler_options)
def do_train(args):
    # Initialize the paddle and paddle fleet execute environment
    paddle.enable_static()
    fleet.init(is_collective=True)

    # Create the random seed for the worker
    random.seed(args.seed)
    np.random.seed(args.seed)
    paddle.seed(args.seed)
    get_rng_state_tracker().add('global_seed', args.seed)
    get_rng_state_tracker().add('local_seed',
                                args.seed + fleet.worker_index() + 2021)

    if args.use_amp and args.amp_level == "O2":
        assert (args.mp_degree == 1 and args.pp_degree == 1
                ), "When amp level is O2, mp_degree and pp_degree should be 1."
        assert (args.use_sharding == False
                ), "When amp level is O2, use_sharding should be False."

    assert args.device in [
        "cpu", "gpu", "xpu"
    ], "Invalid device! Available device should be cpu, gpu, or xpu."
    place = paddle.set_device(args.device)

    worker_num = fleet.worker_num()
    worker_index = fleet.worker_index()
    local_rank = 0 if fleet.local_rank() is None else int(fleet.local_rank())

    topo = Topology(
        device_rank=worker_index,
        world_size=worker_num,
        dp_degree=args.dp_degree,
        pp_degree=args.pp_degree,
        sharding_degree=args.sharding_degree,
        mp_degree=args.mp_degree)

    logger.info("The topo of hybrid parallelism:\n{}".format(topo))

    dist_strategy = dist_optimizer(args, topo)

    # Create log write, train results show on last card of pipeline.
    if topo.is_last:
        log_writer_path = os.path.join(
            args.output_dir, "train_log",
            "{}_globalbsz_{}_amp_{}_recompute_{}_card_{}".format(
                args.model_name_or_path, args.global_batch_size, args.use_amp,
                args.use_recompute, worker_index).lower())
        if os.path.exists(log_writer_path):
            import shutil
            shutil.rmtree(log_writer_path)
        log_writer = LogWriter(log_writer_path)

    # Define the input data in the static mode

    model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
    pretrained_models_list = list(
        model_class.pretrained_init_configuration.keys())

    data_file = get_train_data_file(args)
    main_program = paddle.static.default_main_program()
    startup_program = paddle.static.default_startup_program()
    with paddle.static.program_guard(main_program, startup_program):
        with paddle.utils.unique_name.guard():
            with paddle.static.device_guard('gpu:0'):
                data_holders = create_data_holder(args)
                [tokens, loss_mask, position_ids, labels] = data_holders

                tokenizer = tokenizer_class.from_pretrained(
                    args.model_name_or_path)
                eos_id = tokenizer.eos_token_id

                train_data_loader, valid_data_loader, test_data_loader = create_pretrained_dataset(
                    args,
                    data_file,
                    local_rank=local_rank,
                    data_world_size=topo.data_info.size,
                    data_world_rank=topo.data_info.rank,
                    eos_id=eos_id,
                    max_seq_len=args.max_seq_len,
                    places=paddle.static.cuda_places(),
                    data_holders=data_holders,
                    pipeline_mode=False, )

                if args.model_name_or_path in pretrained_models_list:
                    model_config = model_class.pretrained_init_configuration[
                        args.model_name_or_path]

                    model_config[
                        "hidden_dropout_prob"] = args.hidden_dropout_prob
                    model_config[
                        "attention_probs_dropout_prob"] = args.attention_probs_dropout_prob
                    model_config["topo"] = topo

                    model = guard(f'gpu:{args.pp_degree -1}')(
                        GPTForPretraining)(guard(f'gpu:0')(GPTModel)(
                            **model_config))
                else:
                    model, _ = GPTForPretraining.from_pretrained(
                        args.model_name_or_path,
                        hidden_dropout_prob=args.hidden_dropout_prob,
                        attention_probs_dropout_prob=args.
                        attention_probs_dropout_prob,
                        topo=topo)
                # Create the model for the gpt pretrain
                preds = model(tokens, position_ids)

                criterion = guard(f'gpu:{args.pp_degree -1}')(
                    GPTPretrainingCriterion)(topo)
                loss = criterion(preds, labels, loss_mask)

            # Create the learning_rate sheduler and optimizer
            if args.decay_steps is None:
                args.decay_steps = args.max_steps
            warmup_step = args.warmup_rate * args.decay_steps

            # TODO @ZHUI Use paddle network to support lr scheduler
            lr_scheduler = lr.CosineAnnealingWithWarmupDecay(
                max_lr=args.max_lr,
                min_lr=args.min_lr,
                warmup_step=warmup_step,
                decay_step=args.decay_steps)

            clip = None
            if args.grad_clip > 0:
                clip = paddle.fluid.clip.GradientClipByGlobalNorm(
                    clip_norm=args.grad_clip)

            decay_param = [
                p.name for n, p in model.named_parameters()
                if not any(nd in n for nd in ["bias", "norm"])
            ]
            optimizer = paddle.optimizer.AdamW(
                learning_rate=lr_scheduler,
                beta1=args.adam_beta1,
                beta2=args.adam_beta2,
                epsilon=args.adam_epsilon,
                grad_clip=clip,
                weight_decay=args.weight_decay,
                apply_decay_param_fun=lambda x: x in decay_param)
            # alias
            optimizer.apply_optimize = optimizer._apply_optimize

            if args.use_recompute:
                dist_strategy.recompute = True
                dist_strategy.recompute_configs = {
                    "checkpoints": model.gpt.checkpoints
                }

            # Use the fleet api to compile the distributed optimizer
            optimizer = fleet.distributed_optimizer(
                optimizer, strategy=dist_strategy)

            optimizer.minimize(loss)
            logger.info(f'final strategy: {fleet._final_strategy()}')
            logger.info("The training meta optimizer is/are %s" %
                        fleet._get_applied_meta_list())

    program_desc_dir = os.path.join(args.output_dir, "program_desc")
    if not os.path.isdir(program_desc_dir):
        os.mkdir(program_desc_dir)

    with open(program_desc_dir + "/main_program.txt.%d" % worker_index,
              'w') as f:
        f.write(str(main_program))

    with open(program_desc_dir + "/startup_program.txt.%d" % worker_index,
              'w') as f:
        f.write(str(startup_program))

    # Define the Executor for running the static model
    exe = paddle.static.Executor(place)
    exe.run(startup_program)
    test_program = main_program.clone(for_test=True)

    if args.use_amp and args.amp_level == "O2":
        optimizer.amp_init(place)

    if args.model_name_or_path not in pretrained_models_list:
        logger.info("Try to load checkpoint from %s " % args.model_name_or_path)
        dygrah_path = os.path.join(args.model_name_or_path,
                                   "model_state.pdparams")
        static_path = os.path.join(args.model_name_or_path, "static_vars")

        flag_loaded = False
        if os.path.exists(static_path):
            if args.mp_degree > 1:
                logger.warning("MP should init with dygraph params")
            else:
                logger.info("Loading parameters from %s" % static_path)
                paddle.static.load(main_program, static_path, exe)
                flag_loaded = True

        if not flag_loaded and os.path.exists(dygrah_path):
            if args.sharding_degree > 1:
                logger.warning("Sharding should init with static vars")
            else:
                logger.info("Loading parameters from %s" % dygrah_path)
                init_static_with_params(
                    model,
                    paddle.load(
                        dygrah_path, return_numpy=True),
                    topo,
                    main_program)
                flag_loaded = True

        if not flag_loaded:
            logger.error("No checkpoint load.")

    global_step = 0
    tic_train = time.time()
    epoch = 0
    learning_rate = main_program.global_block().vars["learning_rate_0"]
    while True:
        fetchs = []
        if topo.is_last:
            fetchs = [loss, learning_rate]

        # Bug fix, if not call valid_data_loader, the enumerate will call valid_data_loader
        # many times. and start a new random dataloader.
        valid_data_loader = valid_data_loader()
        test_data_loader = test_data_loader()

        train_reader_cost = 0.0
        train_run_cost = 0.0
        reader_start = time.time()
        for step, batch in enumerate(train_data_loader()):
            train_reader_cost += time.time() - reader_start
            train_start = time.time()

            global_step += 1

            ret = exe.run(main_program,
                          feed=batch,
                          fetch_list=fetchs,
                          use_program_cache=True)
            # In the new 2.0 api, must call this function to change the learning_rate
            lr_scheduler.step()
            train_run_cost += time.time() - train_start

            # Profile for model benchmark
            profiler.add_profiler_step(args.profiler_options)

            if global_step % args.logging_freq == 0:
                if topo.is_last:
                    loss_return, lr_return = ret
                    #speed = args.logging_freq / (time.time() - tic_train)
                    speed = args.logging_freq / (
                        train_reader_cost + train_run_cost)
                    avg_reader_cost = train_reader_cost / args.logging_freq
                    logger.info(
                        "global step %d, epoch: %d, batch: %d, loss: %.9f, avg_reader_cost: %.5f sec, avg_batch_cost: %.5f sec, speed: %.2f steps/s, ips: %.0f tokens/s, ips_per_card: %.0f tokens/s, learning rate: %.5e"
                        % (global_step, epoch, step, loss_return[0],
                           avg_reader_cost, 1. / speed, speed,
                           speed * args.global_batch_size * args.max_seq_len,
                           speed * args.global_batch_size * args.max_seq_len /
                           worker_num, lr_return[0]))
                    log_writer.add_scalar("loss", loss_return[0], global_step)
                    log_writer.add_scalar("learning_rate", lr_return[0],
                                          global_step)
                tic_train = time.time()
                train_reader_cost = 0.0
                train_run_cost = 0.0

            if args.check_accuracy:
                if global_step >= args.max_steps:
                    return
                else:
                    continue

            if global_step % args.eval_freq == 0:
                # TODO, check the input data of validation
                eval_fetch = []
                if topo.is_last:
                    eval_fetch = [loss]

                run_evaluate(valid_data_loader, exe, test_program,
                             args.eval_iters, log_writer, global_step, args,
                             epoch, topo.is_last, eval_fetch, "valid")
                tic_train = time.time()

            if global_step % args.save_steps == 0 or global_step >= args.max_steps:
                output_dir = os.path.join(args.output_dir,
                                          "model_%d" % global_step)
                logger.debug("saving models to {}".format(output_dir))
                save_persistables(exe,
                                  os.path.join(output_dir, "static_vars"),
                                  main_program)
                if global_step <= args.save_steps:
                    model.init_config["init_args"][0].init_config.pop("topo",
                                                                      None)
                model.save_pretrained(output_dir)
                tokenizer.save_pretrained(output_dir)
                tic_train = time.time()

            if global_step >= args.max_steps:
                eval_fetch = []
                if topo.is_last:
                    eval_fetch = [loss]

                run_evaluate(test_data_loader, exe, test_program,
                             args.test_iters, log_writer, global_step, args,
                             epoch, topo.is_last, eval_fetch, "test")
                del train_data_loader
                return
            reader_start = time.time()

        epoch += 1
Exemple #8
0
def do_train(args):
    paddle.set_device(args.device)
    if paddle.distributed.get_world_size() > 1:
        paddle.distributed.init_parallel_env()

    set_seed(args)
    worker_init = WorkerInitObj(args.seed + paddle.distributed.get_rank())

    args.model_type = args.model_type.lower()
    base_class, model_class, criterion_class, tokenizer_class = MODEL_CLASSES[
        args.model_type]

    tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path)

    pretrained_models_list = list(
        model_class.pretrained_init_configuration.keys())
    if args.model_name_or_path in pretrained_models_list:
        model = model_class(
            base_class(**model_class.pretrained_init_configuration[
                args.model_name_or_path]))
    else:
        model = model_class.from_pretrained(args.model_name_or_path)
    criterion = criterion_class(
        getattr(model, model_class.base_model_prefix).config["vocab_size"])
    # decorate @to_static for benchmark, skip it by default.
    if args.to_static:
        specs = create_input_specs()
        model = paddle.jit.to_static(model, input_spec=specs)
        logger.info("Successfully to apply @to_static with specs: {}".format(
            specs))

    if paddle.distributed.get_world_size() > 1:
        model = paddle.DataParallel(model)

    # If use default last_epoch, lr of the first iteration is 0.
    # Use `last_epoch = 0` to be consistent with nv bert.
    num_training_steps = args.max_steps if args.max_steps > 0 else len(
        train_data_loader) * args.num_train_epochs

    lr_scheduler = LinearDecayWithWarmup(
        args.learning_rate, num_training_steps, args.warmup_steps, last_epoch=0)

    # Generate parameter names needed to perform weight decay.
    # All bias and LayerNorm parameters are excluded.
    decay_params = [
        p.name for n, p in model.named_parameters()
        if not any(nd in n for nd in ["bias", "norm"])
    ]
    optimizer = paddle.optimizer.AdamW(
        learning_rate=lr_scheduler,
        epsilon=args.adam_epsilon,
        parameters=model.parameters(),
        weight_decay=args.weight_decay,
        apply_decay_param_fun=lambda x: x in decay_params)
    if args.use_amp:
        scaler = paddle.amp.GradScaler(init_loss_scaling=args.scale_loss)

    pool = ThreadPoolExecutor(1)
    global_step = 0
    tic_train = time.time()
    for epoch in range(args.num_train_epochs):
        files = [
            os.path.join(args.input_dir, f) for f in os.listdir(args.input_dir)
            if os.path.isfile(os.path.join(args.input_dir, f)) and "train" in f
        ]
        files.sort()
        num_files = len(files)
        random.Random(args.seed + epoch).shuffle(files)
        f_start_id = 0

        shared_file_list = {}

        if paddle.distributed.get_world_size() > num_files:
            remainder = paddle.distributed.get_world_size() % num_files
            data_file = files[(
                f_start_id * paddle.distributed.get_world_size() +
                paddle.distributed.get_rank() + remainder * f_start_id) %
                              num_files]
        else:
            data_file = files[(f_start_id * paddle.distributed.get_world_size()
                               + paddle.distributed.get_rank()) % num_files]

        previous_file = data_file

        train_data_loader, _ = create_pretraining_dataset(
            data_file, args.max_predictions_per_seq, shared_file_list, args,
            worker_init)

        # TODO(guosheng): better way to process single file
        single_file = True if f_start_id + 1 == len(files) else False

        for f_id in range(f_start_id, len(files)):
            if not single_file and f_id == f_start_id:
                continue
            if paddle.distributed.get_world_size() > num_files:
                data_file = files[(
                    f_id * paddle.distributed.get_world_size() +
                    paddle.distributed.get_rank() + remainder * f_id) %
                                  num_files]
            else:
                data_file = files[(f_id * paddle.distributed.get_world_size() +
                                   paddle.distributed.get_rank()) % num_files]

            previous_file = data_file
            dataset_future = pool.submit(create_pretraining_dataset, data_file,
                                         args.max_predictions_per_seq,
                                         shared_file_list, args, worker_init)
            train_cost_avg = TimeCostAverage()
            reader_cost_avg = TimeCostAverage()
            total_samples = 0
            batch_start = time.time()
            for step, batch in enumerate(train_data_loader):
                train_reader_cost = time.time() - batch_start
                reader_cost_avg.record(train_reader_cost)
                global_step += 1
                (input_ids, segment_ids, input_mask, masked_lm_positions,
                 masked_lm_labels, next_sentence_labels,
                 masked_lm_scale) = batch
                with paddle.amp.auto_cast(
                        args.use_amp,
                        custom_white_list=["layer_norm", "softmax", "gelu"]):
                    prediction_scores, seq_relationship_score = model(
                        input_ids=input_ids,
                        token_type_ids=segment_ids,
                        attention_mask=input_mask,
                        masked_positions=masked_lm_positions)
                    loss = criterion(prediction_scores, seq_relationship_score,
                                     masked_lm_labels, next_sentence_labels,
                                     masked_lm_scale)
                if args.use_amp:
                    scaler.scale(loss).backward()
                    scaler.minimize(optimizer, loss)
                else:
                    loss.backward()
                    optimizer.step()
                lr_scheduler.step()
                optimizer.clear_grad()
                total_samples += args.batch_size
                train_run_cost = time.time() - batch_start
                train_cost_avg.record(train_run_cost)

                # Profile for model benchmark
                if args.profiler_options is not None:
                    profiler.add_profiler_step(args.profiler_options)

                if global_step % args.logging_steps == 0:
                    if paddle.distributed.get_rank() == 0:
                        logger.info(
                            "global step: %d, epoch: %d, batch: %d, loss: %f, "
                            "avg_reader_cost: %.5f sec, avg_batch_cost: %.5f sec, avg_samples: %.5f, ips: %.5f sequences/sec"
                            % (global_step, epoch, step, loss,
                               reader_cost_avg.get_average(),
                               train_cost_avg.get_average(), total_samples /
                               args.logging_steps, total_samples / (
                                   args.logging_steps *
                                   train_cost_avg.get_average())))
                    total_samples = 0
                    train_cost_avg.reset()
                    reader_cost_avg.reset()
                if global_step % args.save_steps == 0:
                    if paddle.distributed.get_rank() == 0:
                        output_dir = os.path.join(args.output_dir,
                                                  "model_%d" % global_step)
                        if not os.path.exists(output_dir):
                            os.makedirs(output_dir)
                        # need better way to get inner model of DataParallel
                        model_to_save = model._layers if isinstance(
                            model, paddle.DataParallel) else model
                        model_to_save.save_pretrained(output_dir)
                        tokenizer.save_pretrained(output_dir)
                        paddle.save(
                            optimizer.state_dict(),
                            os.path.join(output_dir, "model_state.pdopt"))
                if global_step >= args.max_steps:
                    del train_data_loader
                    return
                batch_start = time.time()

            del train_data_loader
            train_data_loader, data_file = dataset_future.result(timeout=None)
Exemple #9
0
def do_train(args):
    paddle.set_device(args.device)
    if paddle.distributed.get_world_size() > 1:
        paddle.distributed.init_parallel_env()

    worker_index = paddle.distributed.get_rank()
    worker_num = paddle.distributed.get_world_size()
    local_rank = int(os.getenv("PADDLE_RANK_IN_NODE", 0))
    set_seed(args)
    # Now, we only support data parallel in dygraph mode for now.
    topo = Topology(
        device_rank=worker_index, world_size=worker_num, dp_degree=worker_num)

    default_global_batch_size = topo.data_info.size * args.micro_batch_size
    default_global_tokens_num = default_global_batch_size * args.max_seq_len

    model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
    tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path)

    # Define log writer
    log_writer_path = os.path.join(
        args.output_dir, "train_log",
        "{}_globalbsz_{}_amp_{}_recompute_{}_card_{}".format(
            args.model_name_or_path, args.micro_batch_size *
            topo.data_info.size, False, False, worker_index).lower())
    if os.path.exists(log_writer_path):
        import shutil
        shutil.rmtree(log_writer_path)
    log_writer = LogWriter(log_writer_path)

    pretrained_models_list = list(
        model_class.pretrained_init_configuration.keys())
    if args.model_name_or_path in pretrained_models_list:
        model_config = model_class.pretrained_init_configuration[
            args.model_name_or_path]
        model_config["hidden_dropout_prob"] = args.hidden_dropout_prob
        model_config[
            "attention_probs_dropout_prob"] = args.attention_probs_dropout_prob
        model = GPTForPretraining(GPTModel(**model_config))
    else:
        model = GPTForPretraining.from_pretrained(
            args.model_name_or_path,
            hidden_dropout_prob=args.hidden_dropout_prob,
            attention_probs_dropout_prob=args.attention_probs_dropout_prob)

    # Create the critrion for the gpt model
    criterion = GPTPretrainingCriterion()

    if paddle.distributed.get_world_size() > 1:
        model = paddle.DataParallel(model)

    if args.decay_steps is None:
        args.decay_steps = args.max_steps
    warmup_step = args.warmup_rate * args.decay_steps

    lr_scheduler = None

    if args.lr_decay_style == "none":
        lr_scheduler = None
    elif args.lr_decay_style == "cosine":
        lr_scheduler = lr.CosineAnnealingWithWarmupDecay(
            max_lr=args.max_lr,
            min_lr=args.min_lr,
            warmup_step=warmup_step,
            decay_step=args.decay_steps)

    clip = None
    if args.grad_clip > 0:
        clip = paddle.nn.ClipGradByGlobalNorm(clip_norm=args.grad_clip)

    # Generate parameter names needed to perform weight decay.
    # All bias and LayerNorm parameters are excluded.
    decay_params = [
        p.name for n, p in model.named_parameters()
        if not any(nd in n for nd in ["bias", "norm"])
    ]
    optimizer = paddle.optimizer.AdamW(
        learning_rate=lr_scheduler if lr_scheduler is not None else args.max_lr,
        beta1=args.adam_beta1,
        beta2=args.adam_beta2,
        epsilon=args.adam_epsilon,
        parameters=model.parameters(),
        weight_decay=args.weight_decay,
        grad_clip=clip,
        apply_decay_param_fun=lambda x: x in decay_params)

    if args.use_amp:
        scaler = paddle.amp.GradScaler(init_loss_scaling=args.scale_loss)

    if args.model_name_or_path not in pretrained_models_list:
        logger.info("Try to load checkpoint from %s " % args.model_name_or_path)
        opt_path = os.path.join(args.model_name_or_path, "model_state.pdopt")
        if os.path.exists(opt_path):
            opt_dict = paddle.load(opt_path)
            optimizer.set_state_dict(opt_dict)
        else:
            logger.warning("No optimizer checkpoint file found in %s." %
                           opt_path)

    global_step = 0
    epoch = 0
    tic_train = time.time()
    while True:
        files = get_train_data_file(args)
        files.sort()
        num_files = len(files)
        for f_id in range(num_files):
            data_file = files[f_id]
            train_data_loader, valid_data_loader, test_data_loader = create_pretrained_dataset(
                args, [data_file],
                local_rank=local_rank,
                data_world_size=topo.data_info.size,
                data_world_rank=topo.data_info.rank,
                eos_id=tokenizer.eos_token_id)
            # Bug fix, if not call valid_data_loader, the enumerate will call valid_data_loader
            # many times. and start a new random dataloader.
            valid_data_loader = valid_data_loader()
            test_data_loader = test_data_loader()

            # time count
            train_reader_cost = 0.0
            train_run_cost = 0.0
            reader_start = time.time()
            for step, batch in enumerate(train_data_loader()):
                train_reader_cost += time.time() - reader_start
                train_start = time.time()

                global_step += 1
                tokens, loss_mask, attention_mask, position_ids, labels = batch
                loss_mask.stop_gradient = True
                attention_mask.stop_gradient = True
                with paddle.amp.auto_cast(
                        args.use_amp,
                        custom_white_list=["layer_norm", "softmax", "gelu"],
                        custom_black_list=[
                            "reduce_sum", "c_softmax_with_cross_entropy",
                            "c_embedding"
                        ]):

                    preds = model(tokens, position_ids, attention_mask)
                    loss = criterion(preds, labels, loss_mask)

                if args.use_amp:
                    scaler.scale(loss).backward()
                    scaler.minimize(optimizer, loss)
                else:
                    loss.backward()
                    optimizer.step()

                if lr_scheduler is not None:
                    lr_scheduler.step()
                optimizer.clear_grad()

                loss_numpy = loss.numpy()
                train_run_cost += time.time() - train_start

                # Profile for model benchmark
                profiler.add_profiler_step(args.profiler_options)

                if global_step % args.logging_freq == 0:
                    speed = args.logging_freq / (
                        train_reader_cost + train_run_cost)
                    avg_reader_cost = train_reader_cost / args.logging_freq
                    logger.info(
                        "global step %d, epoch: %d, batch: %d, loss: %.9f, avg_reader_cost: %.5f sec, avg_batch_cost: %.5f sec, speed: %.2f step/s, ips: %.0f tokens/s, ips_per_card: %.0f tokens/s, learning rate: %.5e"
                        %
                        (global_step, epoch, step, loss_numpy, avg_reader_cost,
                         1. / speed, speed, speed * default_global_tokens_num,
                         speed * default_global_tokens_num / worker_num,
                         optimizer.get_lr()))
                    log_writer.add_scalar("loss", loss_numpy, global_step)
                    log_writer.add_scalar("learning_rate",
                                          optimizer.get_lr(), global_step)

                    tic_train = time.time()
                    train_reader_cost = 0.0
                    train_run_cost = 0.0

                if args.check_accuracy:
                    if global_step >= args.max_steps:
                        return
                    else:
                        continue

                if global_step % args.eval_freq == 0:
                    # Since the valid data broardcast to all devices, we do evaluate on all device.
                    run_evaluate(valid_data_loader, model, criterion,
                                 args.eval_iters, log_writer, global_step,
                                 epoch, "valid")

                if global_step % args.save_steps == 0 or global_step >= args.max_steps:
                    if worker_index == 0:
                        output_dir = os.path.join(args.output_dir,
                                                  "model_%d" % global_step)
                        if not os.path.exists(output_dir):
                            os.makedirs(output_dir)
                        # Need better way to get inner model of DataParallel
                        model_to_save = model._layers if isinstance(
                            model, paddle.DataParallel) else model
                        logger.info("Save model to %s" % output_dir)
                        model_to_save.save_pretrained(output_dir)
                        tokenizer.save_pretrained(output_dir)
                        paddle.save(optimizer.state_dict(),
                                    os.path.join(output_dir,
                                                 "model_state.pdopt"))

                if global_step >= args.max_steps:
                    run_evaluate(test_data_loader, model, criterion,
                                 args.test_iters, log_writer, global_step,
                                 epoch, "test")
                    logger.info("The training process is complete.")
                    del train_data_loader
                    return

                reader_start = time.time()

            del train_data_loader