def do_train(args):
    # Initialize the paddle and paddle fleet execute environment
    paddle.enable_static()
    fleet.init(is_collective=True)

    # Create the random seed for the worker
    random.seed(args.seed)
    np.random.seed(args.seed)
    paddle.seed(args.seed)
    get_rng_state_tracker().add('global_seed', args.seed)
    get_rng_state_tracker().add('local_seed',
                                args.seed + fleet.worker_index() + 2021)

    assert args.device in [
        "cpu", "gpu", "xpu"
    ], "Invalid device! Available device should be cpu, gpu, or xpu."
    place = paddle.set_device(args.device)

    worker_num = fleet.worker_num()
    worker_index = fleet.worker_index()
    assert args.dp_degree * args.sharding_degree * args.mp_degree * args.pp_degree == worker_num, \
        "The product of degree num should be equal to worker_num."

    topo = Topology(device_rank=worker_index,
                    world_size=worker_num,
                    dp_degree=args.dp_degree,
                    pp_degree=args.pp_degree,
                    sharding_degree=args.sharding_degree,
                    mp_degree=args.mp_degree)

    logger.info("The topo of hybrid parallelism:\n{}".format(topo))

    dist_strategy = dist_optimizer(args, topo)

    # Create log write, train results show on last card of pipeline.
    if topo.is_last:
        log_writer_path = os.path.join(
            args.output_dir, "train_log",
            "{}_globalbsz_{}_amp_{}_recompute_{}_card_{}".format(
                args.model_name_or_path, args.global_batch_size, args.use_amp,
                args.use_recompute, worker_index).lower())
        # if os.path.exists(log_writer_path):
        #     shutil.rmtree(log_writer_path)
        log_writer = LogWriter(log_writer_path)

    # Define the input data in the static mode
    base_class, model_class, criterion_class, tokenizer_class = MODEL_CLASSES[
        args.model_type]
    pretrained_models_list = list(
        model_class.pretrained_init_configuration.keys())

    # load config in checkpoint
    global_step = 0
    consumed_samples = 0
    checkpoint_dir = os.path.join(args.output_dir, "model_last")
    if os.path.exists(checkpoint_dir):
        if os.path.isfile(os.path.join(checkpoint_dir, "./config.yml")):
            with open(os.path.join(checkpoint_dir, "./config.yml"), "r") as f:
                step_config = yaml.load(f, Loader=yaml.FullLoader)
                assert step_config[
                    "global_batch_size"] == args.global_batch_size, "Please ensure checkpoint global batch size is the same. Folder: {}".format(
                        checkpoint_dir)
                consumed_samples = step_config["consumed_samples"]
                global_step = step_config["global_step"]

    data_file = get_train_data_file(args)
    main_program = paddle.static.default_main_program()
    startup_program = paddle.static.default_startup_program()
    with paddle.static.program_guard(main_program, startup_program):
        data_holders = create_data_holder(args)
        # 0. input_ids,
        # 1. segment_ids,
        # 2. input_mask,
        # 3. masked_lm_positions,
        # 4. masked_lm_labels,
        # 5. next_sentence_labels

        [
            input_ids, segment_ids, input_mask, masked_lm_positions,
            masked_lm_labels, next_sentence_labels
        ] = data_holders

        tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path)

        train_data_loader, valid_data_loader, test_data_loader = create_pretrained_dataset(
            args,
            data_file,
            tokenizer,
            data_world_size=topo.data_info.size,
            data_world_rank=topo.data_info.rank,
            max_seq_len=args.max_seq_len,
            places=paddle.static.cuda_places(),
            data_holders=data_holders,
            current_step=global_step)
        fleet.init(is_collective=True)

        if args.model_name_or_path in pretrained_models_list:
            model_config = model_class.pretrained_init_configuration[
                args.model_name_or_path]
            if model_config["vocab_size"] % 8 != 0:
                model_config["vocab_size"] += 8 - (model_config["vocab_size"] %
                                                   8)
            model_config["hidden_dropout_prob"] = args.hidden_dropout_prob
            model_config[
                "attention_probs_dropout_prob"] = args.attention_probs_dropout_prob
            model = model_class(base_class(**model_config))
        else:
            model, _ = model_class.from_pretrained(
                args.model_name_or_path,
                hidden_dropout_prob=args.hidden_dropout_prob,
                attention_probs_dropout_prob=args.attention_probs_dropout_prob,
            )

        # Create the model for the gpt pretrain
        prediction_scores, seq_relationship_score = model(
            input_ids=input_ids,
            token_type_ids=segment_ids,
            position_ids=None,
            attention_mask=input_mask,
            masked_positions=masked_lm_positions)

        criterion = criterion_class(with_nsp_loss=args.binary_head)
        if args.binary_head:
            lm_loss, sop_loss = criterion(prediction_scores,
                                          seq_relationship_score,
                                          masked_lm_labels,
                                          next_sentence_labels)
            loss = lm_loss + sop_loss
        else:
            loss = criterion(prediction_scores, seq_relationship_score,
                             masked_lm_labels)

        # Create the learning_rate sheduler and optimizer
        if args.decay_steps is None:
            args.decay_steps = args.max_steps

        # lr_scheduler = CosineAnnealingWithWarmupDecay(
        #     max_lr=args.max_lr,
        #     min_lr=args.min_lr,
        #     warmup_step=args.warmup_rate * args.max_steps,
        #     decay_step=args.decay_steps, last_epoch=global_step)

        lr_scheduler = LinearDecayWithWarmup(args.max_lr,
                                             args.max_steps,
                                             args.warmup_rate,
                                             last_epoch=global_step)

        clip = None
        if args.grad_clip > 0:
            clip = paddle.fluid.clip.GradientClipByGlobalNorm(
                clip_norm=args.grad_clip)

        decay_param = [
            p.name for n, p in model.named_parameters()
            if not any(nd in n for nd in ["bias", "norm"])
        ]
        logger.info("Using paddle.optimizer.AdamW.")
        optimizer = paddle.optimizer.AdamW(
            learning_rate=lr_scheduler,
            beta1=args.adam_beta1,
            beta2=args.adam_beta2,
            epsilon=args.adam_epsilon,
            grad_clip=clip,
            weight_decay=args.weight_decay,
            apply_decay_param_fun=lambda x: x in decay_param)
        # alias
        optimizer.apply_optimize = optimizer._apply_optimize

        # if args.use_recompute:
        #     dist_strategy.recompute = True
        #     dist_strategy.recompute_configs = {
        #         "checkpoints": model.bert.checkpoints
        #     }

        # Use the fleet api to compile the distributed optimizer
        optimizer = fleet.distributed_optimizer(optimizer,
                                                strategy=dist_strategy)

        optimizer.minimize(loss)
        logger.info(f'final strategy: {fleet._final_strategy()}')
        logger.info("The training meta optimizer is/are %s" %
                    fleet._get_applied_meta_list())

    program_desc_dir = os.path.join(args.output_dir, "program_desc")
    if not os.path.isdir(program_desc_dir):
        os.mkdir(program_desc_dir)

    with open(program_desc_dir + "/main_program.txt.%d" % worker_index,
              'w') as f:
        f.write(str(main_program))

    with open(program_desc_dir + "/startup_program.txt.%d" % worker_index,
              'w') as f:
        f.write(str(startup_program))

    # Define the Executor for running the static model
    exe = paddle.static.Executor(place)
    exe.run(startup_program)

    test_program = main_program.clone(for_test=True)

    if args.model_name_or_path not in pretrained_models_list:
        logger.info("Try to load checkpoint from %s " %
                    args.model_name_or_path)
        dygrah_path = os.path.join(args.model_name_or_path,
                                   "model_state.pdparams")
        static_path = os.path.join(args.model_name_or_path, "static_vars")

        flag_loaded = False
        if os.path.exists(static_path):
            if args.mp_degree > 1:
                logger.warning("MP should init with dygraph params")
            else:
                logger.info("Loading parameters from %s" % static_path)
                paddle.static.load(main_program, static_path, exe)
                flag_loaded = True

        if not flag_loaded and os.path.exists(dygrah_path):
            if args.sharding_degree > 1:
                logger.warning("Sharding should init with static vars")
            else:
                logger.info("Loading parameters from %s" % dygrah_path)
                init_static_with_params(
                    model, paddle.load(dygrah_path, return_numpy=True), topo,
                    main_program)
                flag_loaded = True

        if not flag_loaded:
            logger.error("No checkpoint load.")

    # load checkpoint vars
    if os.path.exists(checkpoint_dir):
        if os.path.isfile(os.path.join(checkpoint_dir, "./config.yml")):
            paddle.static.load(main_program,
                               os.path.join(checkpoint_dir, "static_vars"),
                               exe)

    fetch_loss_vars = collections.OrderedDict()
    fetch_other_vars = collections.OrderedDict()
    fetch_loss_vars["loss"] = loss
    if args.binary_head:
        fetch_loss_vars["lm_loss"] = lm_loss
        fetch_loss_vars["sop_loss"] = sop_loss

    fetch_other_vars["learning_rate"] = main_program.global_block(
    ).vars["learning_rate_0"]

    additional_vars = collections.OrderedDict()
    if args.use_amp:
        for key in ["loss_scaling", "num_good_steps", "num_bad_steps"]:
            additional_vars[key] = main_program.global_block().vars[key + "_0"]

    tic_train = time.time()
    while True:
        fetchs = []
        fetchs_keys = []
        if topo.is_last:
            fetchs = list(fetch_loss_vars.values()) + list(
                fetch_other_vars.values()) + list(additional_vars.values())
            fetchs_keys = list(fetch_loss_vars.keys()) + list(
                fetch_other_vars.keys()) + list(additional_vars.keys())

        # Bug fix, if not call valid_data_loader, the enumerate will call valid_data_loader
        # many times. and start a new random dataloader.
        valid_data_loader = valid_data_loader()
        test_data_loader = test_data_loader()

        for step, batch in enumerate(train_data_loader()):
            ret = exe.run(main_program,
                          feed=batch,
                          fetch_list=fetchs,
                          use_program_cache=True)
            # Skip for accumulate_steps in global step
            if (step + 1) % args.accumulate_steps != 0:
                continue
            global_step += 1
            # In the new 2.0 api, must call this function to change the learning_rate
            lr_scheduler.step()

            if global_step % args.logging_freq == 0:
                if topo.is_last:
                    res = collections.defaultdict(float)
                    for k, v in zip(fetchs_keys, ret):
                        res[k] = v[0]

                    speed = args.logging_freq / (time.time() - tic_train)

                    loss_info = "loss: %.6f, lm_loss: %.6f, sop_loss: %.6f"

                    loss_info = ", ".join([
                        "{}: {:.6f}".format(k, res[k])
                        for k in fetch_loss_vars.keys()
                    ])

                    common_loginfo = "global step %d, %s, speed: %.2f steps/s, ips: %.2f seqs/s, learning rate: %.5e" % (
                        global_step, loss_info, speed,
                        speed * args.global_batch_size, res["learning_rate"])
                    additional_loginfo = ", ".join([
                        "{}: {}".format(k, res[k])
                        for k in additional_vars.keys()
                    ])
                    if additional_loginfo:
                        common_loginfo += ", " + additional_loginfo
                    logger.info(common_loginfo)
                    for k, v in res.items():
                        log_writer.add_scalar(k, v, global_step)

                tic_train = time.time()

            #if args.check_accuracy:
            #    if global_step >= args.max_steps:
            #        return
            #    else:
            #        continue

            if global_step % args.eval_freq == 0:
                # TODO, check the input data of validation
                eval_fetch = collections.OrderedDict()
                if topo.is_last:
                    eval_fetch["loss"] = loss
                    if args.binary_head:
                        eval_fetch["lm_loss"] = lm_loss
                        eval_fetch["sop_loss"] = sop_loss

                run_evaluate(valid_data_loader, exe, test_program,
                             args.eval_iters, log_writer, global_step, args,
                             topo.is_last, eval_fetch, "valid")
                tic_train = time.time()

            if global_step % args.save_steps == 0 or global_step >= args.max_steps:
                output_dir = os.path.join(args.output_dir,
                                          "model_%d" % global_step)
                logger.debug("saving models to {}".format(output_dir))
                save_persistables(exe, os.path.join(output_dir, "static_vars"),
                                  main_program)
                if global_step == args.save_steps:
                    model.init_config["init_args"][0].init_config.pop(
                        "topo", None)
                model.save_pretrained(output_dir)
                tokenizer.save_pretrained(output_dir)
                tic_train = time.time()

            if global_step % args.checkpoint_steps == 0:
                output_dir = os.path.join(args.output_dir, "model_last")
                if worker_index == 0:
                    if not os.path.exists(output_dir):
                        os.mkdir(output_dir)
                    output_dir_bak = os.path.join(args.output_dir,
                                                  "model_last_bak")
                    if os.path.exists(output_dir):
                        if os.path.exists(output_dir_bak):
                            shutil.rmtree(output_dir_bak)
                        shutil.move(output_dir, output_dir_bak)
                        os.mkdir(output_dir)

                    step_config = {
                        "model_name": args.model_name_or_path,
                        "global_step": global_step,
                        "global_batch_size": args.global_batch_size,
                        "consumed_samples":
                        global_step * args.global_batch_size,
                    }

                    with open(os.path.join(output_dir, "config.yml"),
                              "w") as f:
                        yaml.dump(step_config,
                                  f,
                                  encoding='utf-8',
                                  allow_unicode=True)

                fleet.barrier_worker()

                logger.debug("saving models to {}".format(output_dir))
                if args.sharding_degree <= 1:
                    # Save on the first worker by default.
                    if worker_index == 0:
                        paddle.static.save(
                            main_program,
                            os.path.join(output_dir, "static_vars"))
                else:
                    # Use save_persistables in sharding, but more slower
                    save_persistables(exe,
                                      os.path.join(output_dir, "static_vars"),
                                      main_program)

            if global_step >= args.max_steps:
                eval_fetch = collections.OrderedDict()
                if topo.is_last:
                    eval_fetch["loss"] = loss
                    if args.binary_head:
                        eval_fetch["lm_loss"] = lm_loss
                        eval_fetch["sop_loss"] = sop_loss

                run_evaluate(test_data_loader, exe, test_program,
                             args.test_iters, log_writer, global_step, args,
                             topo.is_last, eval_fetch, "test")
                del train_data_loader
                return
def do_train(args):
    # Initialize the paddle and paddle fleet execute environment
    paddle.enable_static()
    fleet.init(is_collective=True)

    # Create the random seed for the worker
    random.seed(args.seed)
    np.random.seed(args.seed)
    paddle.seed(args.seed)
    get_rng_state_tracker().add('global_seed', args.seed)
    get_rng_state_tracker().add('local_seed',
                                args.seed + fleet.worker_index() + 2021)

    assert args.device in [
        "cpu", "gpu", "xpu"
    ], "Invalid device! Available device should be cpu, gpu, or xpu."
    place = paddle.set_device(args.device)

    worker_num = fleet.worker_num()
    worker_index = fleet.worker_index()

    topo = Topology(device_rank=worker_index,
                    world_size=worker_num,
                    dp_degree=args.dp_degree,
                    pp_degree=args.pp_degree,
                    sharding_degree=args.sharding_degree,
                    mp_degree=args.mp_degree)

    logger.info("The topo of hybrid parallelism:\n{}".format(topo))

    dist_strategy = dist_optimizer(args, topo)

    # Create log write, train results show on last card of pipeline.
    if topo.is_last:
        log_writer_path = os.path.join(
            args.output_dir, "train_log",
            "{}_globalbsz_{}_amp_{}_recompute_{}_card_{}".format(
                args.model_name_or_path, args.global_batch_size, args.use_amp,
                args.use_recompute, worker_index).lower())
        if os.path.exists(log_writer_path):
            import shutil
            shutil.rmtree(log_writer_path)
        log_writer = LogWriter(log_writer_path)

    # Define the input data in the static mode

    model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
    pretrained_models_list = list(
        model_class.pretrained_init_configuration.keys())

    data_file = get_train_data_file(args)
    main_program = paddle.static.default_main_program()
    startup_program = paddle.static.default_startup_program()
    with paddle.static.program_guard(main_program, startup_program):
        with paddle.utils.unique_name.guard():
            with paddle.static.device_guard('gpu:0'):
                data_holders = create_data_holder(args)
                [tokens, loss_mask, attention_mask, position_ids,
                 labels] = data_holders

                tokenizer = tokenizer_class.from_pretrained(
                    args.model_name_or_path)
                eos_id = tokenizer.eos_token_id

                train_data_loader, valid_data_loader, test_data_loader = create_pretrained_dataset(
                    args,
                    data_file,
                    data_world_size=topo.data_info.size,
                    data_world_rank=topo.data_info.rank,
                    eos_id=eos_id,
                    max_seq_len=args.max_seq_len,
                    places=paddle.static.cuda_places(),
                    data_holders=data_holders,
                    pipeline_mode=False,
                )

                if args.model_name_or_path in pretrained_models_list:
                    model_config = model_class.pretrained_init_configuration[
                        args.model_name_or_path]

                    model_config[
                        "hidden_dropout_prob"] = args.hidden_dropout_prob
                    model_config[
                        "attention_probs_dropout_prob"] = args.attention_probs_dropout_prob
                    model_config["topo"] = topo

                    model = guard(f'gpu:{args.pp_degree -1}')(
                        GPTForPretraining)(
                            guard(f'gpu:0')(GPTModel)(**model_config))
                else:
                    model, _ = GPTForPretraining.from_pretrained(
                        args.model_name_or_path,
                        hidden_dropout_prob=args.hidden_dropout_prob,
                        attention_probs_dropout_prob=args.
                        attention_probs_dropout_prob,
                        topo=topo)

                # Create the model for the gpt pretrain
                preds = model(tokens, position_ids, attention_mask)

                criterion = guard(f'gpu:{args.pp_degree -1}')(
                    GPTPretrainingCriterion)(topo)
                loss = criterion(preds, labels, loss_mask)

            # Create the learning_rate sheduler and optimizer
            if args.decay_steps is None:
                args.decay_steps = args.max_steps
            warmup_step = args.warmup_rate * args.decay_steps

            # TODO @ZHUI Use paddle network to support lr scheduler
            lr_scheduler = lr.CosineAnnealingWithWarmupDecay(
                max_lr=args.max_lr,
                min_lr=args.min_lr,
                warmup_step=warmup_step,
                decay_step=args.decay_steps)

            clip = None
            if args.grad_clip > 0:
                clip = paddle.fluid.clip.GradientClipByGlobalNorm(
                    clip_norm=args.grad_clip)

            decay_param = [
                p.name for n, p in model.named_parameters()
                if not any(nd in n for nd in ["bias", "norm"])
            ]

            optimizer = paddle.optimizer.AdamW(
                learning_rate=lr_scheduler,
                beta1=args.adam_beta1,
                beta2=args.adam_beta2,
                epsilon=args.adam_epsilon,
                grad_clip=clip,
                weight_decay=args.weight_decay,
                apply_decay_param_fun=lambda x: x in decay_param)
            # alias
            optimizer.apply_optimize = optimizer._apply_optimize

            if args.use_recompute:
                dist_strategy.recompute = True
                dist_strategy.recompute_configs = {
                    "checkpoints": model.gpt.checkpoints
                }

            # Use the fleet api to compile the distributed optimizer
            optimizer = fleet.distributed_optimizer(optimizer,
                                                    strategy=dist_strategy)

            optimizer.minimize(loss)
            logger.info(f'final strategy: {fleet._final_strategy()}')
            logger.info("The training meta optimizer is/are %s" %
                        fleet._get_applied_meta_list())

    program_desc_dir = os.path.join(args.output_dir, "program_desc")
    if not os.path.isdir(program_desc_dir):
        os.mkdir(program_desc_dir)

    with open(program_desc_dir + "/main_program.txt.%d" % worker_index,
              'w') as f:
        f.write(str(main_program))

    with open(program_desc_dir + "/startup_program.txt.%d" % worker_index,
              'w') as f:
        f.write(str(startup_program))

    # Define the Executor for running the static model
    exe = paddle.static.Executor(place)
    exe.run(startup_program)
    test_program = main_program.clone(for_test=True)

    if args.model_name_or_path not in pretrained_models_list:
        logger.info("Try to load checkpoint from %s " %
                    args.model_name_or_path)
        dygrah_path = os.path.join(args.model_name_or_path,
                                   "model_state.pdparams")
        static_path = os.path.join(args.model_name_or_path, "static_vars")

        flag_loaded = False
        if os.path.exists(static_path):
            if args.mp_degree > 1:
                logger.warning("MP should init with dygraph params")
            else:
                logger.info("Loading parameters from %s" % static_path)
                paddle.static.load(main_program, static_path, exe)
                flag_loaded = True

        if not flag_loaded and os.path.exists(dygrah_path):
            if args.sharding_degree > 1:
                logger.warning("Sharding should init with static vars")
            else:
                logger.info("Loading parameters from %s" % dygrah_path)
                init_static_with_params(
                    model, paddle.load(dygrah_path, return_numpy=True), topo,
                    main_program)
                flag_loaded = True

        if not flag_loaded:
            logger.error("No checkpoint load.")

    global_step = 0
    tic_train = time.time()
    epoch = 0
    learning_rate = main_program.global_block().vars["learning_rate_0"]
    while True:
        fetchs = []
        if topo.is_last:
            fetchs = [loss, learning_rate]

        # Bug fix, if not call valid_data_loader, the enumerate will call valid_data_loader
        # many times. and start a new random dataloader.
        valid_data_loader = valid_data_loader()
        test_data_loader = test_data_loader()

        for step, batch in enumerate(train_data_loader()):
            global_step += 1
            ret = exe.run(main_program,
                          feed=batch,
                          fetch_list=fetchs,
                          use_program_cache=True)
            # In the new 2.0 api, must call this function to change the learning_rate
            lr_scheduler.step()

            if global_step % args.logging_freq == 0:
                if topo.is_last:
                    loss_return, lr_return = ret
                    speed = args.logging_freq / (time.time() - tic_train)
                    logger.info(
                        "global step %d, epoch: %d, batch: %d, loss: %.9f, speed: %.2f steps/s, ips: %.0f tokens/s, learning rate: %.5e"
                        % (global_step, epoch, step, loss_return[0], speed,
                           speed * args.global_batch_size * args.max_seq_len,
                           lr_return[0]))
                    log_writer.add_scalar("loss", loss_return[0], global_step)
                    log_writer.add_scalar("learning_rate", lr_return[0],
                                          global_step)
                tic_train = time.time()

            if args.check_accuracy:
                if global_step >= args.max_steps:
                    return
                else:
                    continue

            if global_step % args.eval_freq == 0:
                # TODO, check the input data of validation
                eval_fetch = []
                if topo.is_last:
                    eval_fetch = [loss]

                run_evaluate(valid_data_loader, exe, test_program,
                             args.eval_iters, log_writer, global_step, args,
                             epoch, topo.is_last, eval_fetch, "valid")
                tic_train = time.time()

            if global_step % args.save_steps == 0 or global_step >= args.max_steps:
                output_dir = os.path.join(args.output_dir,
                                          "model_%d" % global_step)
                logger.debug("saving models to {}".format(output_dir))
                save_persistables(exe, os.path.join(output_dir, "static_vars"),
                                  main_program)
                if global_step == args.save_steps:
                    model.init_config["init_args"][0].init_config.pop(
                        "topo", None)
                model.save_pretrained(output_dir)
                tokenizer.save_pretrained(output_dir)
                tic_train = time.time()

            if global_step >= args.max_steps:
                eval_fetch = []
                if topo.is_last:
                    eval_fetch = [loss]

                run_evaluate(test_data_loader, exe, test_program,
                             args.test_iters, log_writer, global_step, args,
                             epoch, topo.is_last, eval_fetch, "test")
                del train_data_loader
                return
        epoch += 1
Example #3
0
def do_generation(args):
    # Initialize the paddle and paddle fleet execute environment
    paddle.enable_static()

    assert args.dp_degree == 1, "Data parallel is not supported in inference"
    assert args.sharding_degree == 1, "Sharding parallel is temporarily not supported in inference"
    assert args.pp_degree == 1, "Pipeline parallel will be supported later"

    if args.mp_degree == 1:
        args.mp_degree = paddle.distributed.get_world_size()
    else:
        assert args.mp_degree == paddle.distributed.get_world_size(), \
            "If mp_degree is specified, the size must be the same as world_size"

    strategy = fleet.DistributedStrategy()
    strategy.tensor_parallel = True
    strategy.tensor_parallel_configs = {
        "tensor_parallel_degree": args.mp_degree
    }

    fleet.init(is_collective=True, strategy=strategy)

    # temp use dynamic init, use HybridParallelInferenceHelper in future?
    paddle.distributed.init_parallel_env()

    # Create the random seed for the worker
    random.seed(args.seed)
    np.random.seed(args.seed)
    paddle.seed(args.seed)
    get_rng_state_tracker().add('global_seed', args.seed)
    get_rng_state_tracker().add('local_seed',
                                args.seed + fleet.worker_index() + 2021)

    if args.use_amp and args.amp_level == "O2":
        assert (args.mp_degree == 1 and args.pp_degree == 1
                ), "When amp level is O2, mp_degree and pp_degree should be 1."
        assert (args.use_sharding == False
                ), "When amp level is O2, use_sharding should be False."

    assert args.device in [
        "cpu", "gpu", "xpu"
    ], "Invalid device! Available device should be cpu, gpu, or xpu."
    place = paddle.set_device(args.device)

    worker_num = fleet.worker_num()
    worker_index = fleet.worker_index()
    local_rank = 0 if fleet.local_rank() is None else int(fleet.local_rank())

    topo = Topology(
        device_rank=worker_index,
        world_size=worker_num,
        dp_degree=args.dp_degree,
        pp_degree=args.pp_degree,
        sharding_degree=args.sharding_degree,
        mp_degree=args.mp_degree)

    logger.info("The topo of hybrid parallelism:\n{}".format(topo))

    model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
    pretrained_models_list = list(
        model_class.pretrained_init_configuration.keys())

    data_file = get_data_file(args)
    main_program = paddle.static.default_main_program()
    startup_program = paddle.static.default_startup_program()
    with paddle.static.program_guard(main_program, startup_program):
        with paddle.utils.unique_name.guard():
            with paddle.static.device_guard('gpu:0'):
                feeds = create_data_holder(args)
                tokenizer = tokenizer_class.from_pretrained(
                    args.model_name_or_path)
                eos_id = tokenizer.eos_token_id

                _, _, test_data_loader = create_pretrained_dataset(
                    args,
                    data_file,
                    local_rank=local_rank,
                    data_world_size=topo.data_info.size,
                    data_world_rank=topo.data_info.rank,
                    eos_id=eos_id,
                    max_seq_len=args.max_seq_len,
                    places=paddle.static.cuda_places(),
                    data_holders=feeds,
                    pipeline_mode=False)

                if args.model_name_or_path in pretrained_models_list:
                    model_config = model_class.pretrained_init_configuration[
                        args.model_name_or_path]
                    model_config[
                        "hidden_dropout_prob"] = args.hidden_dropout_prob
                    model_config[
                        "attention_probs_dropout_prob"] = args.attention_probs_dropout_prob
                    model_config["topo"] = topo
                    model_config["fuse"] = args.fuse
                    model = GPTForGeneration(
                        GPTModel(**model_config),
                        max_length=args.max_dec_len,
                        decoding_strategy=args.decoding_strategy,
                        temperature=args.temperature,
                        top_k=args.topk,
                        top_p=args.topp,
                        eos_id=eos_id,
                        fuse=args.fuse)
                else:
                    logger.error("No checkpoint load.")
                model.eval()
                ins = {v.name: v for v in feeds}
                preds = model(ins)

    # Define the Executor for running the static model
    exe = paddle.static.Executor(place)
    exe.run(startup_program)
    main_program = main_program.clone(for_test=True)

    model_urls = model.pretrained_resource_files_map['model_state']
    model_path = args.model_name_or_path
    if model_path in pretrained_models_list and model_path in model_urls:
        flag_loaded = False
        from paddle.utils.download import get_weights_path_from_url
        dygraph_path = get_weights_path_from_url(model_urls[model_path])
        if os.path.exists(dygraph_path):
            if args.sharding_degree > 1:
                logger.warning("Sharding should init with static vars")
            else:
                logger.info("Loading parameters from %s" % dygraph_path)
                init_static_with_params(
                    model,
                    paddle.load(
                        dygraph_path, return_numpy=True),
                    topo,
                    main_program)
                flag_loaded = True
        if not flag_loaded:
            logger.error("No checkpoint load.")

    global_step = 0
    epoch = 0
    fetchs = [preds]

    ### check resutls
    text = [
        "Question: Where is the capital of China? Answer:",
        "Question:Who is the CEO of Apple? Answer:"
    ]
    inputs = tokenizer(
        text,
        padding=True,
        return_attention_mask=True,
        return_position_ids=True)
    ids = np.array(inputs["input_ids"]).reshape(len(text), -1).astype('int64')
    position_ids = np.array(inputs["position_ids"]).reshape(len(text),
                                                            -1).astype('int64')
    attention_mask = np.array(inputs["attention_mask"]).reshape(
        len(text), -1).astype('float32')

    t_ids = paddle.fluid.core.Tensor()
    t_ids.set(ids, place)
    t_mask = paddle.fluid.core.Tensor()
    t_mask.set(attention_mask, place)
    t_pos = paddle.fluid.core.Tensor()
    t_pos.set(position_ids, place)
    feed_data = {'src_ids': t_ids, 'pos_ids': t_pos, 'input_mask': t_mask}
    ret = exe.run(main_program, feed=feed_data, fetch_list=fetchs)
    ret = np.array(ret[0])
    for i in range(ret.shape[0]):
        o = [int(x) for x in ret[i]]
        ret_str = tokenizer.convert_ids_to_string(o)
        ret_str = text[i] + ret_str
        logger.info(ret_str)
    ##################

    for step, batch in enumerate(test_data_loader()):
        ret = exe.run(main_program, feed=batch, fetch_list=fetchs)
        if step == 5:
            break

    if args.save_inference_model_then_exist:
        save_inference_model_dir = 'inference_model_pp{pp_degree}mp{mp_degree}'.format(
            pp_degree=args.pp_degree, mp_degree=args.mp_degree)
        inference_save_path = os.path.join(save_inference_model_dir,
                                           'rank_' + str(fleet.worker_index()),
                                           'step_' + str(0))
        print("saving inference models to {}".format(inference_save_path))
        feed_names = [v.name for v in feeds]
        fetchs_names = [v.name for v in fetchs]
        print('feeds: ', feed_names, 'fetches: ', fetchs_names)
        paddle.static.save_inference_model(
            inference_save_path, feeds, fetchs, exe, program=main_program)