def do_train(args):
    # Initialize the paddle and paddle fleet execute environment
    paddle.enable_static()
    fleet.init(is_collective=True)

    # Create the random seed for the worker
    random.seed(args.seed)
    np.random.seed(args.seed)
    paddle.seed(args.seed)
    get_rng_state_tracker().add('global_seed', args.seed)
    get_rng_state_tracker().add('local_seed',
                                args.seed + fleet.worker_index() + 2021)

    assert args.device in [
        "cpu", "gpu", "xpu"
    ], "Invalid device! Available device should be cpu, gpu, or xpu."
    place = paddle.set_device(args.device)

    worker_num = fleet.worker_num()
    worker_index = fleet.worker_index()
    assert args.dp_degree * args.sharding_degree * args.mp_degree * args.pp_degree == worker_num, \
        "The product of degree num should be equal to worker_num."

    topo = Topology(device_rank=worker_index,
                    world_size=worker_num,
                    dp_degree=args.dp_degree,
                    pp_degree=args.pp_degree,
                    sharding_degree=args.sharding_degree,
                    mp_degree=args.mp_degree)

    logger.info("The topo of hybrid parallelism:\n{}".format(topo))

    dist_strategy = dist_optimizer(args, topo)

    # Create log write, train results show on last card of pipeline.
    if topo.is_last:
        log_writer_path = os.path.join(
            args.output_dir, "train_log",
            "{}_globalbsz_{}_amp_{}_recompute_{}_card_{}".format(
                args.model_name_or_path, args.global_batch_size, args.use_amp,
                args.use_recompute, worker_index).lower())
        # if os.path.exists(log_writer_path):
        #     shutil.rmtree(log_writer_path)
        log_writer = LogWriter(log_writer_path)

    # Define the input data in the static mode
    base_class, model_class, criterion_class, tokenizer_class = MODEL_CLASSES[
        args.model_type]
    pretrained_models_list = list(
        model_class.pretrained_init_configuration.keys())

    # load config in checkpoint
    global_step = 0
    consumed_samples = 0
    checkpoint_dir = os.path.join(args.output_dir, "model_last")
    if os.path.exists(checkpoint_dir):
        if os.path.isfile(os.path.join(checkpoint_dir, "./config.yml")):
            with open(os.path.join(checkpoint_dir, "./config.yml"), "r") as f:
                step_config = yaml.load(f, Loader=yaml.FullLoader)
                assert step_config[
                    "global_batch_size"] == args.global_batch_size, "Please ensure checkpoint global batch size is the same. Folder: {}".format(
                        checkpoint_dir)
                consumed_samples = step_config["consumed_samples"]
                global_step = step_config["global_step"]

    data_file = get_train_data_file(args)
    main_program = paddle.static.default_main_program()
    startup_program = paddle.static.default_startup_program()
    with paddle.static.program_guard(main_program, startup_program):
        data_holders = create_data_holder(args)
        # 0. input_ids,
        # 1. segment_ids,
        # 2. input_mask,
        # 3. masked_lm_positions,
        # 4. masked_lm_labels,
        # 5. next_sentence_labels

        [
            input_ids, segment_ids, input_mask, masked_lm_positions,
            masked_lm_labels, next_sentence_labels
        ] = data_holders

        tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path)

        train_data_loader, valid_data_loader, test_data_loader = create_pretrained_dataset(
            args,
            data_file,
            tokenizer,
            data_world_size=topo.data_info.size,
            data_world_rank=topo.data_info.rank,
            max_seq_len=args.max_seq_len,
            places=paddle.static.cuda_places(),
            data_holders=data_holders,
            current_step=global_step)
        fleet.init(is_collective=True)

        if args.model_name_or_path in pretrained_models_list:
            model_config = model_class.pretrained_init_configuration[
                args.model_name_or_path]
            if model_config["vocab_size"] % 8 != 0:
                model_config["vocab_size"] += 8 - (model_config["vocab_size"] %
                                                   8)
            model_config["hidden_dropout_prob"] = args.hidden_dropout_prob
            model_config[
                "attention_probs_dropout_prob"] = args.attention_probs_dropout_prob
            model = model_class(base_class(**model_config))
        else:
            model, _ = model_class.from_pretrained(
                args.model_name_or_path,
                hidden_dropout_prob=args.hidden_dropout_prob,
                attention_probs_dropout_prob=args.attention_probs_dropout_prob,
            )

        # Create the model for the gpt pretrain
        prediction_scores, seq_relationship_score = model(
            input_ids=input_ids,
            token_type_ids=segment_ids,
            position_ids=None,
            attention_mask=input_mask,
            masked_positions=masked_lm_positions)

        criterion = criterion_class(with_nsp_loss=args.binary_head)
        if args.binary_head:
            lm_loss, sop_loss = criterion(prediction_scores,
                                          seq_relationship_score,
                                          masked_lm_labels,
                                          next_sentence_labels)
            loss = lm_loss + sop_loss
        else:
            loss = criterion(prediction_scores, seq_relationship_score,
                             masked_lm_labels)

        # Create the learning_rate sheduler and optimizer
        if args.decay_steps is None:
            args.decay_steps = args.max_steps

        # lr_scheduler = CosineAnnealingWithWarmupDecay(
        #     max_lr=args.max_lr,
        #     min_lr=args.min_lr,
        #     warmup_step=args.warmup_rate * args.max_steps,
        #     decay_step=args.decay_steps, last_epoch=global_step)

        lr_scheduler = LinearDecayWithWarmup(args.max_lr,
                                             args.max_steps,
                                             args.warmup_rate,
                                             last_epoch=global_step)

        clip = None
        if args.grad_clip > 0:
            clip = paddle.fluid.clip.GradientClipByGlobalNorm(
                clip_norm=args.grad_clip)

        decay_param = [
            p.name for n, p in model.named_parameters()
            if not any(nd in n for nd in ["bias", "norm"])
        ]
        logger.info("Using paddle.optimizer.AdamW.")
        optimizer = paddle.optimizer.AdamW(
            learning_rate=lr_scheduler,
            beta1=args.adam_beta1,
            beta2=args.adam_beta2,
            epsilon=args.adam_epsilon,
            grad_clip=clip,
            weight_decay=args.weight_decay,
            apply_decay_param_fun=lambda x: x in decay_param)
        # alias
        optimizer.apply_optimize = optimizer._apply_optimize

        # if args.use_recompute:
        #     dist_strategy.recompute = True
        #     dist_strategy.recompute_configs = {
        #         "checkpoints": model.bert.checkpoints
        #     }

        # Use the fleet api to compile the distributed optimizer
        optimizer = fleet.distributed_optimizer(optimizer,
                                                strategy=dist_strategy)

        optimizer.minimize(loss)
        logger.info(f'final strategy: {fleet._final_strategy()}')
        logger.info("The training meta optimizer is/are %s" %
                    fleet._get_applied_meta_list())

    program_desc_dir = os.path.join(args.output_dir, "program_desc")
    if not os.path.isdir(program_desc_dir):
        os.mkdir(program_desc_dir)

    with open(program_desc_dir + "/main_program.txt.%d" % worker_index,
              'w') as f:
        f.write(str(main_program))

    with open(program_desc_dir + "/startup_program.txt.%d" % worker_index,
              'w') as f:
        f.write(str(startup_program))

    # Define the Executor for running the static model
    exe = paddle.static.Executor(place)
    exe.run(startup_program)

    test_program = main_program.clone(for_test=True)

    if args.model_name_or_path not in pretrained_models_list:
        logger.info("Try to load checkpoint from %s " %
                    args.model_name_or_path)
        dygrah_path = os.path.join(args.model_name_or_path,
                                   "model_state.pdparams")
        static_path = os.path.join(args.model_name_or_path, "static_vars")

        flag_loaded = False
        if os.path.exists(static_path):
            if args.mp_degree > 1:
                logger.warning("MP should init with dygraph params")
            else:
                logger.info("Loading parameters from %s" % static_path)
                paddle.static.load(main_program, static_path, exe)
                flag_loaded = True

        if not flag_loaded and os.path.exists(dygrah_path):
            if args.sharding_degree > 1:
                logger.warning("Sharding should init with static vars")
            else:
                logger.info("Loading parameters from %s" % dygrah_path)
                init_static_with_params(
                    model, paddle.load(dygrah_path, return_numpy=True), topo,
                    main_program)
                flag_loaded = True

        if not flag_loaded:
            logger.error("No checkpoint load.")

    # load checkpoint vars
    if os.path.exists(checkpoint_dir):
        if os.path.isfile(os.path.join(checkpoint_dir, "./config.yml")):
            paddle.static.load(main_program,
                               os.path.join(checkpoint_dir, "static_vars"),
                               exe)

    fetch_loss_vars = collections.OrderedDict()
    fetch_other_vars = collections.OrderedDict()
    fetch_loss_vars["loss"] = loss
    if args.binary_head:
        fetch_loss_vars["lm_loss"] = lm_loss
        fetch_loss_vars["sop_loss"] = sop_loss

    fetch_other_vars["learning_rate"] = main_program.global_block(
    ).vars["learning_rate_0"]

    additional_vars = collections.OrderedDict()
    if args.use_amp:
        for key in ["loss_scaling", "num_good_steps", "num_bad_steps"]:
            additional_vars[key] = main_program.global_block().vars[key + "_0"]

    tic_train = time.time()
    while True:
        fetchs = []
        fetchs_keys = []
        if topo.is_last:
            fetchs = list(fetch_loss_vars.values()) + list(
                fetch_other_vars.values()) + list(additional_vars.values())
            fetchs_keys = list(fetch_loss_vars.keys()) + list(
                fetch_other_vars.keys()) + list(additional_vars.keys())

        # Bug fix, if not call valid_data_loader, the enumerate will call valid_data_loader
        # many times. and start a new random dataloader.
        valid_data_loader = valid_data_loader()
        test_data_loader = test_data_loader()

        for step, batch in enumerate(train_data_loader()):
            ret = exe.run(main_program,
                          feed=batch,
                          fetch_list=fetchs,
                          use_program_cache=True)
            # Skip for accumulate_steps in global step
            if (step + 1) % args.accumulate_steps != 0:
                continue
            global_step += 1
            # In the new 2.0 api, must call this function to change the learning_rate
            lr_scheduler.step()

            if global_step % args.logging_freq == 0:
                if topo.is_last:
                    res = collections.defaultdict(float)
                    for k, v in zip(fetchs_keys, ret):
                        res[k] = v[0]

                    speed = args.logging_freq / (time.time() - tic_train)

                    loss_info = "loss: %.6f, lm_loss: %.6f, sop_loss: %.6f"

                    loss_info = ", ".join([
                        "{}: {:.6f}".format(k, res[k])
                        for k in fetch_loss_vars.keys()
                    ])

                    common_loginfo = "global step %d, %s, speed: %.2f steps/s, ips: %.2f seqs/s, learning rate: %.5e" % (
                        global_step, loss_info, speed,
                        speed * args.global_batch_size, res["learning_rate"])
                    additional_loginfo = ", ".join([
                        "{}: {}".format(k, res[k])
                        for k in additional_vars.keys()
                    ])
                    if additional_loginfo:
                        common_loginfo += ", " + additional_loginfo
                    logger.info(common_loginfo)
                    for k, v in res.items():
                        log_writer.add_scalar(k, v, global_step)

                tic_train = time.time()

            #if args.check_accuracy:
            #    if global_step >= args.max_steps:
            #        return
            #    else:
            #        continue

            if global_step % args.eval_freq == 0:
                # TODO, check the input data of validation
                eval_fetch = collections.OrderedDict()
                if topo.is_last:
                    eval_fetch["loss"] = loss
                    if args.binary_head:
                        eval_fetch["lm_loss"] = lm_loss
                        eval_fetch["sop_loss"] = sop_loss

                run_evaluate(valid_data_loader, exe, test_program,
                             args.eval_iters, log_writer, global_step, args,
                             topo.is_last, eval_fetch, "valid")
                tic_train = time.time()

            if global_step % args.save_steps == 0 or global_step >= args.max_steps:
                output_dir = os.path.join(args.output_dir,
                                          "model_%d" % global_step)
                logger.debug("saving models to {}".format(output_dir))
                save_persistables(exe, os.path.join(output_dir, "static_vars"),
                                  main_program)
                if global_step == args.save_steps:
                    model.init_config["init_args"][0].init_config.pop(
                        "topo", None)
                model.save_pretrained(output_dir)
                tokenizer.save_pretrained(output_dir)
                tic_train = time.time()

            if global_step % args.checkpoint_steps == 0:
                output_dir = os.path.join(args.output_dir, "model_last")
                if worker_index == 0:
                    if not os.path.exists(output_dir):
                        os.mkdir(output_dir)
                    output_dir_bak = os.path.join(args.output_dir,
                                                  "model_last_bak")
                    if os.path.exists(output_dir):
                        if os.path.exists(output_dir_bak):
                            shutil.rmtree(output_dir_bak)
                        shutil.move(output_dir, output_dir_bak)
                        os.mkdir(output_dir)

                    step_config = {
                        "model_name": args.model_name_or_path,
                        "global_step": global_step,
                        "global_batch_size": args.global_batch_size,
                        "consumed_samples":
                        global_step * args.global_batch_size,
                    }

                    with open(os.path.join(output_dir, "config.yml"),
                              "w") as f:
                        yaml.dump(step_config,
                                  f,
                                  encoding='utf-8',
                                  allow_unicode=True)

                fleet.barrier_worker()

                logger.debug("saving models to {}".format(output_dir))
                if args.sharding_degree <= 1:
                    # Save on the first worker by default.
                    if worker_index == 0:
                        paddle.static.save(
                            main_program,
                            os.path.join(output_dir, "static_vars"))
                else:
                    # Use save_persistables in sharding, but more slower
                    save_persistables(exe,
                                      os.path.join(output_dir, "static_vars"),
                                      main_program)

            if global_step >= args.max_steps:
                eval_fetch = collections.OrderedDict()
                if topo.is_last:
                    eval_fetch["loss"] = loss
                    if args.binary_head:
                        eval_fetch["lm_loss"] = lm_loss
                        eval_fetch["sop_loss"] = sop_loss

                run_evaluate(test_data_loader, exe, test_program,
                             args.test_iters, log_writer, global_step, args,
                             topo.is_last, eval_fetch, "test")
                del train_data_loader
                return
def do_train(args):
    # Initialize the paddle and paddle fleet execute environment
    paddle.enable_static()
    fleet.init(is_collective=True)

    # Create the random seed for the worker
    random.seed(args.seed)
    np.random.seed(args.seed)
    paddle.seed(args.seed)
    get_rng_state_tracker().add('global_seed', args.seed)
    get_rng_state_tracker().add('local_seed',
                                args.seed + fleet.worker_index() + 2021)

    assert args.device in [
        "cpu", "gpu", "xpu"
    ], "Invalid device! Available device should be cpu, gpu, or xpu."
    place = paddle.set_device(args.device)

    worker_num = fleet.worker_num()
    worker_index = fleet.worker_index()

    topo = Topology(device_rank=worker_index,
                    world_size=worker_num,
                    dp_degree=args.dp_degree,
                    pp_degree=args.pp_degree,
                    sharding_degree=args.sharding_degree,
                    mp_degree=args.mp_degree)

    logger.info("The topo of hybrid parallelism:\n{}".format(topo))

    dist_strategy = dist_optimizer(args, topo)

    # Create log write, train results show on last card of pipeline.
    if topo.is_last:
        log_writer_path = os.path.join(
            args.output_dir, "train_log",
            "{}_globalbsz_{}_amp_{}_recompute_{}_card_{}".format(
                args.model_name_or_path, args.global_batch_size, args.use_amp,
                args.use_recompute, worker_index).lower())
        if os.path.exists(log_writer_path):
            import shutil
            shutil.rmtree(log_writer_path)
        log_writer = LogWriter(log_writer_path)

    # Define the input data in the static mode

    model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
    pretrained_models_list = list(
        model_class.pretrained_init_configuration.keys())

    data_file = get_train_data_file(args)
    main_program = paddle.static.default_main_program()
    startup_program = paddle.static.default_startup_program()
    with paddle.static.program_guard(main_program, startup_program):
        with paddle.utils.unique_name.guard():
            with paddle.static.device_guard('gpu:0'):
                data_holders = create_data_holder(args)
                [tokens, loss_mask, attention_mask, position_ids,
                 labels] = data_holders

                tokenizer = tokenizer_class.from_pretrained(
                    args.model_name_or_path)
                eos_id = tokenizer.eos_token_id

                train_data_loader, valid_data_loader, test_data_loader = create_pretrained_dataset(
                    args,
                    data_file,
                    data_world_size=topo.data_info.size,
                    data_world_rank=topo.data_info.rank,
                    eos_id=eos_id,
                    max_seq_len=args.max_seq_len,
                    places=paddle.static.cuda_places(),
                    data_holders=data_holders,
                    pipeline_mode=False,
                )

                if args.model_name_or_path in pretrained_models_list:
                    model_config = model_class.pretrained_init_configuration[
                        args.model_name_or_path]

                    model_config[
                        "hidden_dropout_prob"] = args.hidden_dropout_prob
                    model_config[
                        "attention_probs_dropout_prob"] = args.attention_probs_dropout_prob
                    model_config["topo"] = topo

                    model = guard(f'gpu:{args.pp_degree -1}')(
                        GPTForPretraining)(
                            guard(f'gpu:0')(GPTModel)(**model_config))
                else:
                    model, _ = GPTForPretraining.from_pretrained(
                        args.model_name_or_path,
                        hidden_dropout_prob=args.hidden_dropout_prob,
                        attention_probs_dropout_prob=args.
                        attention_probs_dropout_prob,
                        topo=topo)

                # Create the model for the gpt pretrain
                preds = model(tokens, position_ids, attention_mask)

                criterion = guard(f'gpu:{args.pp_degree -1}')(
                    GPTPretrainingCriterion)(topo)
                loss = criterion(preds, labels, loss_mask)

            # Create the learning_rate sheduler and optimizer
            if args.decay_steps is None:
                args.decay_steps = args.max_steps
            warmup_step = args.warmup_rate * args.decay_steps

            # TODO @ZHUI Use paddle network to support lr scheduler
            lr_scheduler = lr.CosineAnnealingWithWarmupDecay(
                max_lr=args.max_lr,
                min_lr=args.min_lr,
                warmup_step=warmup_step,
                decay_step=args.decay_steps)

            clip = None
            if args.grad_clip > 0:
                clip = paddle.fluid.clip.GradientClipByGlobalNorm(
                    clip_norm=args.grad_clip)

            decay_param = [
                p.name for n, p in model.named_parameters()
                if not any(nd in n for nd in ["bias", "norm"])
            ]

            optimizer = paddle.optimizer.AdamW(
                learning_rate=lr_scheduler,
                beta1=args.adam_beta1,
                beta2=args.adam_beta2,
                epsilon=args.adam_epsilon,
                grad_clip=clip,
                weight_decay=args.weight_decay,
                apply_decay_param_fun=lambda x: x in decay_param)
            # alias
            optimizer.apply_optimize = optimizer._apply_optimize

            if args.use_recompute:
                dist_strategy.recompute = True
                dist_strategy.recompute_configs = {
                    "checkpoints": model.gpt.checkpoints
                }

            # Use the fleet api to compile the distributed optimizer
            optimizer = fleet.distributed_optimizer(optimizer,
                                                    strategy=dist_strategy)

            optimizer.minimize(loss)
            logger.info(f'final strategy: {fleet._final_strategy()}')
            logger.info("The training meta optimizer is/are %s" %
                        fleet._get_applied_meta_list())

    program_desc_dir = os.path.join(args.output_dir, "program_desc")
    if not os.path.isdir(program_desc_dir):
        os.mkdir(program_desc_dir)

    with open(program_desc_dir + "/main_program.txt.%d" % worker_index,
              'w') as f:
        f.write(str(main_program))

    with open(program_desc_dir + "/startup_program.txt.%d" % worker_index,
              'w') as f:
        f.write(str(startup_program))

    # Define the Executor for running the static model
    exe = paddle.static.Executor(place)
    exe.run(startup_program)
    test_program = main_program.clone(for_test=True)

    if args.model_name_or_path not in pretrained_models_list:
        logger.info("Try to load checkpoint from %s " %
                    args.model_name_or_path)
        dygrah_path = os.path.join(args.model_name_or_path,
                                   "model_state.pdparams")
        static_path = os.path.join(args.model_name_or_path, "static_vars")

        flag_loaded = False
        if os.path.exists(static_path):
            if args.mp_degree > 1:
                logger.warning("MP should init with dygraph params")
            else:
                logger.info("Loading parameters from %s" % static_path)
                paddle.static.load(main_program, static_path, exe)
                flag_loaded = True

        if not flag_loaded and os.path.exists(dygrah_path):
            if args.sharding_degree > 1:
                logger.warning("Sharding should init with static vars")
            else:
                logger.info("Loading parameters from %s" % dygrah_path)
                init_static_with_params(
                    model, paddle.load(dygrah_path, return_numpy=True), topo,
                    main_program)
                flag_loaded = True

        if not flag_loaded:
            logger.error("No checkpoint load.")

    global_step = 0
    tic_train = time.time()
    epoch = 0
    learning_rate = main_program.global_block().vars["learning_rate_0"]
    while True:
        fetchs = []
        if topo.is_last:
            fetchs = [loss, learning_rate]

        # Bug fix, if not call valid_data_loader, the enumerate will call valid_data_loader
        # many times. and start a new random dataloader.
        valid_data_loader = valid_data_loader()
        test_data_loader = test_data_loader()

        for step, batch in enumerate(train_data_loader()):
            global_step += 1
            ret = exe.run(main_program,
                          feed=batch,
                          fetch_list=fetchs,
                          use_program_cache=True)
            # In the new 2.0 api, must call this function to change the learning_rate
            lr_scheduler.step()

            if global_step % args.logging_freq == 0:
                if topo.is_last:
                    loss_return, lr_return = ret
                    speed = args.logging_freq / (time.time() - tic_train)
                    logger.info(
                        "global step %d, epoch: %d, batch: %d, loss: %.9f, speed: %.2f steps/s, ips: %.0f tokens/s, learning rate: %.5e"
                        % (global_step, epoch, step, loss_return[0], speed,
                           speed * args.global_batch_size * args.max_seq_len,
                           lr_return[0]))
                    log_writer.add_scalar("loss", loss_return[0], global_step)
                    log_writer.add_scalar("learning_rate", lr_return[0],
                                          global_step)
                tic_train = time.time()

            if args.check_accuracy:
                if global_step >= args.max_steps:
                    return
                else:
                    continue

            if global_step % args.eval_freq == 0:
                # TODO, check the input data of validation
                eval_fetch = []
                if topo.is_last:
                    eval_fetch = [loss]

                run_evaluate(valid_data_loader, exe, test_program,
                             args.eval_iters, log_writer, global_step, args,
                             epoch, topo.is_last, eval_fetch, "valid")
                tic_train = time.time()

            if global_step % args.save_steps == 0 or global_step >= args.max_steps:
                output_dir = os.path.join(args.output_dir,
                                          "model_%d" % global_step)
                logger.debug("saving models to {}".format(output_dir))
                save_persistables(exe, os.path.join(output_dir, "static_vars"),
                                  main_program)
                if global_step == args.save_steps:
                    model.init_config["init_args"][0].init_config.pop(
                        "topo", None)
                model.save_pretrained(output_dir)
                tokenizer.save_pretrained(output_dir)
                tic_train = time.time()

            if global_step >= args.max_steps:
                eval_fetch = []
                if topo.is_last:
                    eval_fetch = [loss]

                run_evaluate(test_data_loader, exe, test_program,
                             args.test_iters, log_writer, global_step, args,
                             epoch, topo.is_last, eval_fetch, "test")
                del train_data_loader
                return
        epoch += 1
Example #3
0
def train(args):
    log.info("pretraining start")
    profile = False

    place = fluid.CUDAPlace(int(os.environ.get('FLAGS_selected_gpus', 0)))

    # set seed
    random.seed(args.seed)
    np.random.seed(args.seed)
    paddle.seed(args.seed)
    get_rng_state_tracker().add('global_seed', args.seed)
    get_rng_state_tracker().add('local_seed',
                                args.seed + fleet.worker_index() + 2021)

    # define execution strategy
    exec_strategy = fluid.ExecutionStrategy()
    exec_strategy.num_threads = 2
    exec_strategy.num_iteration_per_drop_scope = 1

    # define distribution strategy
    dist_strategy = fleet.DistributedStrategy()
    dist_strategy.execution_strategy = exec_strategy
    dist_strategy.nccl_comm_num = 3
    if args.use_recompute:
        log.info("using recompute.")
    dist_strategy.recompute = args.use_recompute
    dist_strategy.sharding = args.use_sharding
    dist_strategy.pipeline = args.num_pp > 1

    # define topology structure for dp/pp/mp
    topo = Topology(rank=fleet.worker_index(),
                    world_size=fleet.worker_num(),
                    dp=args.num_dp,
                    pp=args.num_pp,
                    sharding=args.num_sharding,
                    mp=args.num_mp)

    is_last = False
    if topo.pp.rank == (topo.pp.size - 1):
        is_last = True

    dp_sharding_rank = topo.dp.rank * topo.sharding.size + topo.sharding.rank
    dp_worldsize = topo.dp.size * topo.sharding.size
    bsz_per_dp = args.global_bsz // dp_worldsize

    micro_bsz = args.micro_bsz
    assert args.global_bsz % micro_bsz == 0, f"cannot do gradient accumulate, globa_bsz: {args.bsz} micro_bsz: {micro_bsz}"
    acc_steps = bsz_per_dp // micro_bsz

    # sharding \ model parallel \ pipeline
    assert dist_strategy.sharding == True
    dist_strategy.sharding_configs = {
        "segment_broadcast_MB": 32,
        "sharding_degree": args.num_sharding,
        "mp_degree": args.num_mp,
        "pp_degree": args.num_pp,
        "dp_degree": args.num_dp,
        "optimize_offload": True,
    }
    dist_strategy.pipeline_configs = {
        "schedule_mode": "1F1B",
        "micro_batch_size": micro_bsz,
        "accumulate_steps": acc_steps,
    }
    log.info(
        f"using globa_bsz: {args.global_bsz} micro_bsz: {micro_bsz}, acc_steps: {acc_steps}"
    )

    dist_strategy.amp = args.use_amp
    dist_strategy.amp_configs = {
        "custom_white_list": ['softmax', 'layer_norm', 'gelu'],
        "init_loss_scaling": 32768,
        "decr_every_n_nan_or_inf": 2,
        "incr_every_n_steps": 1000,
        "incr_ratio": 2.0,
        "use_dynamic_loss_scaling": True,
        "decr_ratio": 0.5,
        "use_pure_fp16": False,
        "use_fp16_guard": False,
    }

    dist_strategy.lamb = args.use_lamb
    dist_strategy.lamb_configs = {
        'lamb_weight_decay':
        0.01,
        'exclude_from_weight_decay':
        ['layer_norm_bias', 'layer_norm_scale', '.b_0']
    }

    train_program = fluid.Program()
    startup_program = fluid.Program()
    with fluid.program_guard(train_program, startup_program):
        with fluid.unique_name.guard():
            graph_vars = create_model(args, 'train', micro_bsz,
                                      dp_sharding_rank, dp_worldsize, topo)
            data_loader = graph_vars['data_loader']
            for op in train_program.global_block().ops:
                if op.type == 'fill_constant':
                    op._set_attr(
                        'op_device', "gpu:0"
                    )  # XXX: hack: https://github.com/PaddlePaddle/Paddle/blob/develop/python/paddle/fluid/layers/tensor.py#L1376

            if args.use_recompute:
                dist_strategy.recompute_configs = {
                    "checkpoints": graph_vars['checkpoints'],
                    # "enable_offload": args.use_offload,
                    # "checkpoint_shape": [micro_bsz, args.max_seq_len, 4096],
                }

            log.debug("base lr: {}".format(args.learning_rate))
            scheduled_lr = linear_warmup_decay(
                learning_rate=args.learning_rate,
                warmup_steps=args.warmup_steps,
                num_train_steps=args.num_train_steps)

            clip_norm_thres = 1.0
            if paddlenlp.ops.optimizer._jit_compile():
                optimizer = paddlenlp.ops.optimizer.AdamwOptimizer(
                    learning_rate=scheduled_lr,
                    grad_clip=fluid.clip.GradientClipByGlobalNorm(
                        clip_norm=clip_norm_thres),
                    weight_decay=args.weight_decay,
                    apply_decay_param_fun=apply_weight_decay_fun)
            else:
                optimizer = fluid.optimizer.Adam(
                    learning_rate=scheduled_lr,
                    grad_clip=fluid.clip.GradientClipByGlobalNorm(
                        clip_norm=clip_norm_thres),
                    #multi_precision=True,
                    #weight_decay=args.weight_decay, # merge this pr to use weight_decay: https://github.com/PaddlePaddle/Paddle/pull/29248
                    #exclude_from_weight_decay_fn=exclude_from_weight_decay
                )

            optimizer = fleet.distributed_optimizer(optimizer, dist_strategy)
            log.info(f"using dist strategy: {dist_strategy}")

            optimizer.minimize(graph_vars['total_loss'])

            final_strategy = fleet._final_strategy()
            applied_meta_list = fleet._get_applied_meta_list()
            log.info("final strategy: {}".format(final_strategy))
            log.info("applied_meta_list: {}".format(applied_meta_list))

    program_desc_dir = os.path.join(args.output_dir, "program_desc")
    if not os.path.isdir(program_desc_dir):
        os.mkdir(program_desc_dir)

    with open(
            program_desc_dir + "/main_program.txt.%d" %
        (int(os.environ.get('FLAGS_selected_gpus', 0))), 'w') as f:
        f.write(str(train_program))

    with open(
            program_desc_dir + "/startup_program.txt.%d" %
        (int(os.environ.get('FLAGS_selected_gpus', 0))), 'w') as f:
        f.write(str(startup_program))

    exe = fluid.Executor(place)
    exe.run(startup_program)

    optimizer.amp_init(place)

    #save_path = os.path.join(args.output_dir, 'step_0')
    #log.debug("saving models to {}".format(save_path))
    #save_persistables(exe, save_path, train_program)

    if args.init_checkpoint and args.init_checkpoint != "":
        log.info(' ')
        log.info(
            '############################WARNING############################')
        log.info(
            '####### using ini_checkpoint, not init_pretraining_params ####')
        log.info(
            '## meaning hyper param e.g. lr will inherit from checkpoint ##')
        log.info(
            '###############################################################')
        init_checkpoint(exe, args.init_checkpoint, train_program)
        log.info(' ')

    output_dir = args.output_dir
    save_steps = args.save_steps
    total_time = 0
    cost_vals, lm_losses, sop_accs = [], [], []
    global_steps = args.global_steps + 1
    steps = 0
    log_path = 'train_log/node-%d' % fleet.worker_index()
    start_time = time.time()
    with LogWriter(os.path.join(args.output_dir, log_path)) as swriter:
        data_loader.start()
        while True:
            #if steps < global_steps:
            #    steps += 1
            #    continue
            if not is_last:
                fetch_list = []
            else:
                fetch_list = [
                    graph_vars['total_loss'], graph_vars['mean_mask_lm_loss'],
                    scheduled_lr
                ]
                if args.use_sop:
                    fetch_list.extend(
                        [graph_vars['sop_acc'], graph_vars['sop_loss']])
                if args.use_amp:
                    loss_scaling = train_program.global_block(
                    ).vars['loss_scaling_0']
                    fetch_list.append(loss_scaling)

            ret = exe.run(train_program, fetch_list=fetch_list
                          )  # run one mini-batch(=acc_steps micro-batch)
            #use_program_cache=True)

            steps += 1

            if is_last:
                if args.use_sop and args.use_amp:
                    cost_val, lm_loss, lr, sop_acc, sop_loss, loss_scaling_0 = ret
                elif args.use_sop:
                    cost_val, lm_loss, lr, sop_acc, sop_loss = ret
                elif args.use_amp:
                    cost_val, lm_loss, lr, loss_scaling_0 = ret
                else:
                    cost_val, lm_loss, lr = ret
                cost_vals.append(cost_val[0])
                lm_losses.append(lm_loss[0])
                if args.use_sop:
                    sop_accs.append(sop_acc[0])

                if steps > 0 and (steps % args.log_steps) == 0:
                    end_time = time.time()
                    total_time = end_time - start_time
                    cost_val = np.mean(cost_vals)
                    lm_loss = np.mean(lm_losses)
                    swriter.add_scalar('loss/total_loss', cost_val, steps)
                    swriter.add_scalar('loss/mlm_loss', lm_loss, steps)
                    swriter.add_scalar('lr/scheduled_lr', lr[0], steps)

                    if args.use_sop:
                        sop_acc = np.mean(sop_accs)
                        swriter.add_scalar('loss/sop_loss', sop_loss, steps)
                        swriter.add_scalar('train/sop_acc', sop_acc, steps)
                    else:
                        sop_acc = 0.0

                    if args.use_amp:
                        swriter.add_scalar('lr/loss_scaling',
                                           loss_scaling_0[0], steps)
                    else:
                        loss_scaling_0 = [0.0]

                    log.info(
                        "worker_index: %d, step: %d, cost: %f, "
                        "mlm loss: %f, sentence order acc: %f, "
                        "speed: %f steps/s, "
                        "speed: %f samples/s, "
                        "speed: %f tokens/s, "
                        "learning rate: %.3e, loss_scalings: %f" %
                        (fleet.worker_index(), steps, cost_val, lm_loss,
                         sop_acc, args.log_steps / total_time,
                         args.log_steps * args.global_bsz / total_time,
                         args.log_steps * args.global_bsz * args.max_seq_len /
                         total_time, lr[0], loss_scaling_0[0]))

                    cost_vals, lm_losses, sop_accs = [], [], []
                    start_time = time.time()

            # TODO: add evaluation
            if steps > 0 and args.eval_steps > 0 and steps % args.eval_steps == 0:
                pass

            if steps > 0 and args.save_steps > 0 and steps % args.save_steps == 0:
                if args.use_hybrid_dp and fleet.worker_index() > 8:
                    continue
                save_path = os.path.join(output_dir, 'step_' + str(steps))
                log.debug("saving models to {}".format(save_path))
                save_persistables(exe, save_path, train_program)

            if steps == args.num_train_steps:
                if args.use_hybrid_dp and fleet.worker_index() > 8:
                    continue
                save_path = os.path.join(output_dir,
                                         'final_step_' + str(steps))
                save_persistables(exe, save_path, train_program)
                log.debug("saving final models to {}".format(save_path))
                log.debug("end of training, total steps: {}".format(steps))