Exemplo n.º 1
0
def do_train(args):
    paddle.enable_static()
    if args.use_gpu:
        trainer_count = len(os.environ['CUDA_VISIBLE_DEVICES'].split(','))
        place = paddle.set_device("gpu:0")
    else:
        trainer_count = int(os.environ['CPU_NUM'])
        place = paddle.set_device("cpu")

    # Set seed for CE
    random_seed = eval(str(args.random_seed))
    if random_seed is not None:
        paddle.seed(random_seed)

    # Define data loader
    # NOTE: To guarantee all data is involved, use world_size=1 and rank=0.
    (train_loader,
     train_steps_fn), (eval_loader,
                       eval_steps_fn) = reader.create_data_loader(args)

    train_program = paddle.static.Program()
    startup_program = paddle.static.Program()
    with paddle.static.program_guard(train_program, startup_program):
        src_word = paddle.static.data(name="src_word",
                                      shape=[None, None],
                                      dtype="int64")
        trg_word = paddle.static.data(name="trg_word",
                                      shape=[None, None],
                                      dtype="int64")
        lbl_word = paddle.static.data(name="lbl_word",
                                      shape=[None, None, 1],
                                      dtype="int64")

        # Define model
        transformer = TransformerModel(src_vocab_size=args.src_vocab_size,
                                       trg_vocab_size=args.trg_vocab_size,
                                       max_length=args.max_length + 1,
                                       n_layer=args.n_layer,
                                       n_head=args.n_head,
                                       d_model=args.d_model,
                                       d_inner_hid=args.d_inner_hid,
                                       dropout=args.dropout,
                                       weight_sharing=args.weight_sharing,
                                       bos_id=args.bos_idx,
                                       eos_id=args.eos_idx)
        # Define loss
        criterion = CrossEntropyCriterion(args.label_smooth_eps, args.bos_idx)

        logits = transformer(src_word=src_word, trg_word=trg_word)

        sum_cost, avg_cost, token_num = criterion(logits, lbl_word)

        scheduler = paddle.optimizer.lr.NoamDecay(args.d_model,
                                                  args.warmup_steps,
                                                  args.learning_rate,
                                                  last_epoch=0)

        # Define optimizer
        optimizer = paddle.optimizer.Adam(learning_rate=scheduler,
                                          beta1=args.beta1,
                                          beta2=args.beta2,
                                          epsilon=float(args.eps),
                                          parameters=transformer.parameters())

        optimizer.minimize(avg_cost)

    exe = paddle.static.Executor(place)
    exe.run(startup_program)

    build_strategy = paddle.static.BuildStrategy()
    build_strategy.enable_inplace = True
    exec_strategy = paddle.static.ExecutionStrategy()

    compiled_train_program = paddle.static.CompiledProgram(
        train_program).with_data_parallel(loss_name=avg_cost.name,
                                          build_strategy=build_strategy,
                                          exec_strategy=exec_strategy)

    # the best cross-entropy value with label smoothing
    loss_normalizer = -(
        (1. - args.label_smooth_eps) * np.log((1. - args.label_smooth_eps)) +
        args.label_smooth_eps * np.log(args.label_smooth_eps /
                                       (args.trg_vocab_size - 1) + 1e-20))

    step_idx = 0

    # For benchmark
    reader_cost_avg = AverageStatistical()
    batch_cost_avg = AverageStatistical()
    batch_ips_avg = AverageStatistical()

    for pass_id in range(args.epoch):
        batch_id = 0
        batch_start = time.time()
        pass_start_time = batch_start
        for data in batch_creator(train_loader, trainer_count):
            # NOTE: used for benchmark and use None as default.
            if args.max_iter and step_idx == args.max_iter:
                return
            train_reader_cost = time.time() - batch_start

            outs = exe.run(compiled_train_program,
                           feed=[{
                               'src_word': data[i][0],
                               'trg_word': data[i][1],
                               'lbl_word': data[i][2],
                           } for i in range(trainer_count)],
                           fetch_list=[sum_cost.name, token_num.name])
            scheduler.step()

            train_batch_cost = time.time() - batch_start
            reader_cost_avg.record(train_reader_cost)
            batch_cost_avg.record(train_batch_cost)
            batch_ips_avg.record(train_batch_cost, np.asarray(outs[1]).sum())

            if step_idx % args.print_step == 0:
                sum_cost_val, token_num_val = np.array(outs[0]), np.array(
                    outs[1])
                # Sum the cost from multi-devices
                total_sum_cost = sum_cost_val.sum()
                total_token_num = token_num_val.sum()
                total_avg_cost = total_sum_cost / total_token_num

                if step_idx == 0:
                    logging.info(
                        "step_idx: %d, epoch: %d, batch: %d, avg loss: %f, "
                        "normalized loss: %f, ppl: %f" %
                        (step_idx, pass_id, batch_id, total_avg_cost,
                         total_avg_cost - loss_normalizer,
                         np.exp([min(total_avg_cost, 100)])))
                else:
                    train_avg_batch_cost = args.print_step / batch_cost_avg.get_total_time(
                    )
                    logging.info(
                        "step_idx: %d, epoch: %d, batch: %d, avg loss: %f, "
                        "normalized loss: %f, ppl: %f, avg_speed: %.2f step/s, "
                        "batch_cost: %.5f sec, reader_cost: %.5f sec, tokens: %d, "
                        "ips: %.5f words/sec" %
                        (step_idx, pass_id, batch_id, total_avg_cost,
                         total_avg_cost - loss_normalizer,
                         np.exp([min(total_avg_cost, 100)]),
                         train_avg_batch_cost, batch_cost_avg.get_average(),
                         reader_cost_avg.get_average(),
                         batch_ips_avg.get_total_cnt(),
                         batch_ips_avg.get_average_per_sec()))
                reader_cost_avg.reset()
                batch_cost_avg.reset()
                batch_ips_avg.reset()

            if step_idx % args.save_step == 0 and step_idx != 0:
                if args.save_model:
                    model_path = os.path.join(args.save_model,
                                              "step_" + str(step_idx),
                                              "transformer")
                    paddle.io.save(train_program, model_path)

            batch_id += 1
            step_idx += 1
            batch_start = time.time()

    paddle.disable_static()
Exemplo n.º 2
0
def do_train(args):
    if args.use_gpu:
        rank = dist.get_rank()
        trainer_count = dist.get_world_size()
    else:
        rank = 0
        trainer_count = 1

    if trainer_count > 1:
        dist.init_parallel_env()

    # Set seed for CE
    random_seed = eval(str(args.random_seed))
    if random_seed is not None:
        paddle.seed(random_seed)

    # Define data loader
    (train_loader), (eval_loader) = reader.create_data_loader(args)

    # Define model
    transformer = TransformerModel(src_vocab_size=args.src_vocab_size,
                                   trg_vocab_size=args.trg_vocab_size,
                                   max_length=args.max_length + 1,
                                   n_layer=args.n_layer,
                                   n_head=args.n_head,
                                   d_model=args.d_model,
                                   d_inner_hid=args.d_inner_hid,
                                   dropout=args.dropout,
                                   weight_sharing=args.weight_sharing,
                                   bos_id=args.bos_idx,
                                   eos_id=args.eos_idx)

    # Define loss
    criterion = CrossEntropyCriterion(args.label_smooth_eps, args.bos_idx)

    scheduler = paddle.optimizer.lr.NoamDecay(args.d_model,
                                              args.warmup_steps,
                                              args.learning_rate,
                                              last_epoch=0)

    # Define optimizer
    optimizer = paddle.optimizer.Adam(learning_rate=scheduler,
                                      beta1=args.beta1,
                                      beta2=args.beta2,
                                      epsilon=float(args.eps),
                                      parameters=transformer.parameters())

    # Init from some checkpoint, to resume the previous training
    if args.init_from_checkpoint:
        model_dict = paddle.load(
            os.path.join(args.init_from_checkpoint, "transformer.pdparams"))
        opt_dict = paddle.load(
            os.path.join(args.init_from_checkpoint, "transformer.pdopt"))
        transformer.set_state_dict(model_dict)
        optimizer.set_state_dict(opt_dict)
        print("loaded from checkpoint.")
    # Init from some pretrain models, to better solve the current task
    if args.init_from_pretrain_model:
        model_dict = paddle.load(
            os.path.join(args.init_from_pretrain_model,
                         "transformer.pdparams"))
        transformer.set_state_dict(model_dict)
        print("loaded from pre-trained model.")

    if trainer_count > 1:
        transformer = paddle.DataParallel(transformer)

    # The best cross-entropy value with label smoothing
    loss_normalizer = -(
        (1. - args.label_smooth_eps) * np.log((1. - args.label_smooth_eps)) +
        args.label_smooth_eps * np.log(args.label_smooth_eps /
                                       (args.trg_vocab_size - 1) + 1e-20))

    step_idx = 0

    # For benchmark
    reader_cost_avg = AverageStatistical()
    batch_cost_avg = AverageStatistical()
    batch_ips_avg = AverageStatistical()

    # Train loop
    for pass_id in range(args.epoch):
        epoch_start = time.time()

        batch_id = 0
        batch_start = time.time()
        for input_data in train_loader:
            #NOTE: Used for benchmark and use None as default.
            if args.max_iter and step_idx == args.max_iter:
                return
            train_reader_cost = time.time() - batch_start
            (src_word, trg_word, lbl_word) = input_data

            logits = transformer(src_word=src_word, trg_word=trg_word)

            sum_cost, avg_cost, token_num = criterion(logits, lbl_word)

            avg_cost.backward()

            optimizer.step()
            optimizer.clear_grad()

            tokens_per_cards = token_num.numpy()

            train_batch_cost = time.time() - batch_start
            reader_cost_avg.record(train_reader_cost)
            batch_cost_avg.record(train_batch_cost)
            batch_ips_avg.record(train_batch_cost, tokens_per_cards)

            # NOTE: For benchmark, loss infomation on all cards will be printed.
            if step_idx % args.print_step == 0:
                total_avg_cost = avg_cost.numpy()

                if step_idx == 0:
                    logger.info(
                        "step_idx: %d, epoch: %d, batch: %d, avg loss: %f, "
                        "normalized loss: %f, ppl: %f " %
                        (step_idx, pass_id, batch_id, total_avg_cost,
                         total_avg_cost - loss_normalizer,
                         np.exp([min(total_avg_cost, 100)])))
                else:
                    train_avg_batch_cost = args.print_step / batch_cost_avg.get_total_time(
                    )
                    logger.info(
                        "step_idx: %d, epoch: %d, batch: %d, avg loss: %f, "
                        "normalized loss: %f, ppl: %f, avg_speed: %.2f step/sec, "
                        "batch_cost: %.5f sec, reader_cost: %.5f sec, tokens: %d, "
                        "ips: %.5f words/sec" %
                        (step_idx, pass_id, batch_id, total_avg_cost,
                         total_avg_cost - loss_normalizer,
                         np.exp([min(total_avg_cost, 100)]),
                         train_avg_batch_cost, batch_cost_avg.get_average(),
                         reader_cost_avg.get_average(),
                         batch_ips_avg.get_total_cnt(),
                         batch_ips_avg.get_average_per_sec()))
                reader_cost_avg.reset()
                batch_cost_avg.reset()
                batch_ips_avg.reset()

            if step_idx % args.save_step == 0 and step_idx != 0:
                # Validation
                transformer.eval()
                total_sum_cost = 0
                total_token_num = 0
                with paddle.no_grad():
                    for input_data in eval_loader:
                        (src_word, trg_word, lbl_word) = input_data
                        logits = transformer(src_word=src_word,
                                             trg_word=trg_word)
                        sum_cost, avg_cost, token_num = criterion(
                            logits, lbl_word)
                        total_sum_cost += sum_cost.numpy()
                        total_token_num += token_num.numpy()
                        total_avg_cost = total_sum_cost / total_token_num
                    logger.info(
                        "validation, step_idx: %d, avg loss: %f, "
                        "normalized loss: %f, ppl: %f" %
                        (step_idx, total_avg_cost, total_avg_cost -
                         loss_normalizer, np.exp([min(total_avg_cost, 100)])))
                transformer.train()

                if args.save_model and rank == 0:
                    model_dir = os.path.join(args.save_model,
                                             "step_" + str(step_idx))
                    if not os.path.exists(model_dir):
                        os.makedirs(model_dir)
                    paddle.save(
                        transformer.state_dict(),
                        os.path.join(model_dir, "transformer.pdparams"))
                    paddle.save(optimizer.state_dict(),
                                os.path.join(model_dir, "transformer.pdopt"))

            batch_id += 1
            step_idx += 1
            scheduler.step()
            batch_start = time.time()

        train_epoch_cost = time.time() - epoch_start
        logger.info("train epoch: %d, epoch_cost: %.5f s" %
                    (pass_id, train_epoch_cost))

    if args.save_model and rank == 0:
        model_dir = os.path.join(args.save_model, "step_final")
        if not os.path.exists(model_dir):
            os.makedirs(model_dir)
        paddle.save(transformer.state_dict(),
                    os.path.join(model_dir, "transformer.pdparams"))
        paddle.save(optimizer.state_dict(),
                    os.path.join(model_dir, "transformer.pdopt"))
Exemplo n.º 3
0
def do_train(args):
    paddle.enable_static()
    if args.is_distributed:
        fleet.init(is_collective=True)
        assert args.device != "xpu", "xpu doesn't support distributed training"
        places = [paddle.set_device("gpu")] if \
                 args.device == "gpu" else paddle.static.cpu_places()
        trainer_count = len(places)
    else:
        if args.device == "gpu":
            places = paddle.static.cuda_places()
        elif args.device == "xpu":
            places = paddle.static.xpu_places()
            paddle.set_device("xpu")
        else:
            places = paddle.static.cpu_places()
            paddle.set_device("cpu")
        trainer_count = len(places)

    # Set seed for CE
    random_seed = eval(str(args.random_seed))
    if random_seed is not None:
        paddle.seed(random_seed)

    # Define data loader
    (train_loader), (eval_loader) = reader.create_data_loader(args,
                                                              places=places)

    train_program = paddle.static.Program()
    startup_program = paddle.static.Program()
    with paddle.static.program_guard(train_program, startup_program):
        src_word = paddle.static.data(name="src_word",
                                      shape=[None, None],
                                      dtype=args.input_dtype)
        trg_word = paddle.static.data(name="trg_word",
                                      shape=[None, None],
                                      dtype=args.input_dtype)
        lbl_word = paddle.static.data(name="lbl_word",
                                      shape=[None, None, 1],
                                      dtype=args.input_dtype)

        # Define model
        transformer = TransformerModel(src_vocab_size=args.src_vocab_size,
                                       trg_vocab_size=args.trg_vocab_size,
                                       max_length=args.max_length + 1,
                                       num_encoder_layers=args.n_layer,
                                       num_decoder_layers=args.n_layer,
                                       n_head=args.n_head,
                                       d_model=args.d_model,
                                       d_inner_hid=args.d_inner_hid,
                                       dropout=args.dropout,
                                       weight_sharing=args.weight_sharing,
                                       bos_id=args.bos_idx,
                                       eos_id=args.eos_idx)
        # Define loss
        criterion = CrossEntropyCriterion(args.label_smooth_eps, args.bos_idx)

        logits = transformer(src_word=src_word, trg_word=trg_word)

        sum_cost, avg_cost, token_num = criterion(logits, lbl_word)

        scheduler = paddle.optimizer.lr.NoamDecay(args.d_model,
                                                  args.warmup_steps,
                                                  args.learning_rate,
                                                  last_epoch=0)

        # Define optimizer
        optimizer = paddle.optimizer.Adam(learning_rate=scheduler,
                                          beta1=args.beta1,
                                          beta2=args.beta2,
                                          epsilon=float(args.eps),
                                          parameters=transformer.parameters())

        if args.is_distributed:
            build_strategy = paddle.static.BuildStrategy()
            exec_strategy = paddle.static.ExecutionStrategy()
            dist_strategy = fleet.DistributedStrategy()
            dist_strategy.build_strategy = build_strategy
            dist_strategy.execution_strategy = exec_strategy
            dist_strategy.fuse_grad_size_in_MB = 16

            if args.use_amp:
                dist_strategy.amp = True
                dist_strategy.amp_configs = {
                    'custom_white_list': ['softmax', 'layer_norm'],
                    'init_loss_scaling': args.scale_loss,
                    'custom_black_list': ['lookup_table_v2'],
                    'use_pure_fp16': args.use_pure_fp16
                }

            optimizer = fleet.distributed_optimizer(optimizer,
                                                    strategy=dist_strategy)
        else:
            if args.use_amp:
                amp_list = paddle.static.amp.AutoMixedPrecisionLists(
                    custom_white_list=['softmax', 'layer_norm'],
                    custom_black_list=['lookup_table_v2'])
                optimizer = paddle.static.amp.decorate(
                    optimizer,
                    amp_list,
                    init_loss_scaling=args.scale_loss,
                    use_dynamic_loss_scaling=True,
                    use_pure_fp16=args.use_pure_fp16)
        optimizer.minimize(avg_cost)

    if args.is_distributed:
        exe = paddle.static.Executor(places[0])
    else:
        exe = paddle.static.Executor()
        build_strategy = paddle.static.BuildStrategy()
        exec_strategy = paddle.static.ExecutionStrategy()

        compiled_train_program = paddle.static.CompiledProgram(
            train_program).with_data_parallel(loss_name=avg_cost.name,
                                              build_strategy=build_strategy,
                                              exec_strategy=exec_strategy)
    exe.run(startup_program)

    if args.use_amp:
        optimizer.amp_init(places[0])

    # the best cross-entropy value with label smoothing
    loss_normalizer = -(
        (1. - args.label_smooth_eps) * np.log((1. - args.label_smooth_eps)) +
        args.label_smooth_eps * np.log(args.label_smooth_eps /
                                       (args.trg_vocab_size - 1) + 1e-20))

    step_idx = 0

    # For benchmark
    reader_cost_avg = AverageStatistical()
    batch_cost_avg = AverageStatistical()
    batch_ips_avg = AverageStatistical()

    for pass_id in range(args.epoch):
        batch_id = 0
        batch_start = time.time()
        pass_start_time = batch_start
        for data in train_loader:
            # NOTE: used for benchmark and use None as default.
            if args.max_iter and step_idx == args.max_iter:
                break
            if trainer_count == 1:
                data = [data]
            train_reader_cost = time.time() - batch_start

            if args.is_distributed:
                outs = exe.run(train_program,
                               feed=[{
                                   'src_word': data[i][0],
                                   'trg_word': data[i][1],
                                   'lbl_word': data[i][2],
                               } for i in range(trainer_count)],
                               fetch_list=[sum_cost.name, token_num.name])
                train_batch_cost = time.time() - batch_start
                batch_ips_avg.record(train_batch_cost,
                                     np.asarray(outs[1]).sum())
            else:
                outs = exe.run(compiled_train_program,
                               feed=[{
                                   'src_word': data[i][0],
                                   'trg_word': data[i][1],
                                   'lbl_word': data[i][2],
                               } for i in range(trainer_count)],
                               fetch_list=[sum_cost.name, token_num.name])
                train_batch_cost = time.time() - batch_start
                batch_ips_avg.record(train_batch_cost,
                                     np.asarray(outs[1]).sum() / trainer_count)
            scheduler.step()

            reader_cost_avg.record(train_reader_cost)
            batch_cost_avg.record(train_batch_cost)

            # Profile for model benchmark
            if args.profiler_options is not None:
                profiler.add_profiler_step(args.profiler_options)

            if step_idx % args.print_step == 0 and (
                    args.benchmark or
                (args.is_distributed and dist.get_rank() == 0)
                    or not args.is_distributed):
                sum_cost_val, token_num_val = np.array(outs[0]), np.array(
                    outs[1])
                # Sum the cost from multi-devices
                total_sum_cost = sum_cost_val.sum()
                total_token_num = token_num_val.sum()
                total_avg_cost = total_sum_cost / total_token_num

                if step_idx == 0:
                    logging.info(
                        "step_idx: %d, epoch: %d, batch: %d, avg loss: %f, "
                        "normalized loss: %f, ppl: %f" %
                        (step_idx, pass_id, batch_id, total_avg_cost,
                         total_avg_cost - loss_normalizer,
                         np.exp([min(total_avg_cost, 100)])))
                else:
                    train_avg_batch_cost = args.print_step / batch_cost_avg.get_total_time(
                    )
                    logging.info(
                        "step_idx: %d, epoch: %d, batch: %d, avg loss: %f, "
                        "normalized loss: %f, ppl: %f, avg_speed: %.2f step/s, "
                        "batch_cost: %.5f sec, reader_cost: %.5f sec, tokens: %d, "
                        "ips: %.5f words/sec" %
                        (step_idx, pass_id, batch_id, total_avg_cost,
                         total_avg_cost - loss_normalizer,
                         np.exp([min(total_avg_cost, 100)]),
                         train_avg_batch_cost, batch_cost_avg.get_average(),
                         reader_cost_avg.get_average(),
                         batch_ips_avg.get_total_cnt(),
                         batch_ips_avg.get_average_per_sec()))
                reader_cost_avg.reset()
                batch_cost_avg.reset()
                batch_ips_avg.reset()

            if step_idx % args.save_step == 0 and step_idx != 0:
                if args.save_model and dist.get_rank() == 0:
                    model_path = os.path.join(args.save_model,
                                              "step_" + str(step_idx),
                                              "transformer")
                    paddle.static.save(train_program, model_path)

            batch_id += 1
            step_idx += 1
            batch_start = time.time()

        # NOTE: used for benchmark and use None as default.
        if args.max_iter and step_idx == args.max_iter:
            break

    if args.save_model and dist.get_rank() == 0:
        model_path = os.path.join(args.save_model, "step_final", "transformer")
        paddle.static.save(train_program, model_path)

    paddle.disable_static()
Exemplo n.º 4
0
def do_train(args):
    if args.device == "gpu":
        rank = dist.get_rank()
        trainer_count = dist.get_world_size()
    else:
        rank = 0
        trainer_count = 1
        paddle.set_device("cpu")

    if trainer_count > 1:
        dist.init_parallel_env()

    # Set seed for CE
    random_seed = eval(str(args.random_seed))
    if random_seed is not None:
        paddle.seed(random_seed)

    # Define data loader
    (train_loader), (eval_loader) = reader.create_data_loader(args)

    # Define model
    transformer = TransformerModel(src_vocab_size=args.src_vocab_size,
                                   trg_vocab_size=args.trg_vocab_size,
                                   max_length=args.max_length + 1,
                                   num_encoder_layers=args.n_layer,
                                   num_decoder_layers=args.n_layer,
                                   n_head=args.n_head,
                                   d_model=args.d_model,
                                   d_inner_hid=args.d_inner_hid,
                                   dropout=args.dropout,
                                   weight_sharing=args.weight_sharing,
                                   bos_id=args.bos_idx,
                                   eos_id=args.eos_idx)

    transformer = apply_to_static(args, transformer)

    # Define loss
    criterion = CrossEntropyCriterion(args.label_smooth_eps, args.bos_idx)

    scheduler = paddle.optimizer.lr.NoamDecay(args.d_model,
                                              args.warmup_steps,
                                              args.learning_rate,
                                              last_epoch=0)

    # Define optimizer
    if 'use_multi_tensor' not in inspect.getfullargspec(
            paddle.optimizer.Adam.__init__).args:
        optimizer = paddle.optimizer.Adam(learning_rate=scheduler,
                                          beta1=args.beta1,
                                          beta2=args.beta2,
                                          epsilon=float(args.eps),
                                          parameters=transformer.parameters())
    else:
        optimizer = paddle.optimizer.Adam(learning_rate=scheduler,
                                          beta1=args.beta1,
                                          beta2=args.beta2,
                                          epsilon=float(args.eps),
                                          parameters=transformer.parameters(),
                                          use_multi_tensor=True)

    # Init from some checkpoint, to resume the previous training
    if args.init_from_checkpoint:
        model_dict = paddle.load(
            os.path.join(args.init_from_checkpoint, "transformer.pdparams"))
        opt_dict = paddle.load(
            os.path.join(args.init_from_checkpoint, "transformer.pdopt"))
        transformer.set_state_dict(model_dict)
        optimizer.set_state_dict(opt_dict)
        print("loaded from checkpoint.")
    # Init from some pretrain models, to better solve the current task
    if args.init_from_pretrain_model:
        model_dict = paddle.load(
            os.path.join(args.init_from_pretrain_model,
                         "transformer.pdparams"))
        transformer.set_state_dict(model_dict)
        print("loaded from pre-trained model.")

    # for amp training
    if args.use_amp:
        amp_level = 'O2' if args.use_pure_fp16 else 'O1'
        scaler = paddle.amp.GradScaler(enable=True,
                                       init_loss_scaling=args.scale_loss)
        transformer = paddle.amp.decorate(models=transformer,
                                          level=amp_level,
                                          save_dtype='float32')

    # for distributed training
    if trainer_count > 1:
        transformer = paddle.DataParallel(transformer)

    # The best cross-entropy value with label smoothing
    loss_normalizer = -(
        (1. - args.label_smooth_eps) * np.log((1. - args.label_smooth_eps)) +
        args.label_smooth_eps * np.log(args.label_smooth_eps /
                                       (args.trg_vocab_size - 1) + 1e-20))

    step_idx = 0

    # For benchmark
    reader_cost_avg = AverageStatistical()
    batch_cost_avg = AverageStatistical()
    batch_ips_avg = AverageStatistical()

    # Train loop
    for pass_id in range(args.epoch):
        epoch_start = time.time()

        batch_id = 0
        batch_start = time.time()
        for input_data in train_loader:
            train_reader_cost = time.time() - batch_start
            (src_word, trg_word, lbl_word) = input_data

            if args.use_amp:
                with paddle.amp.auto_cast(custom_black_list={
                        'scale', 'reduce_sum', 'elementwise_div'
                } if amp_level == 'O2' else {},
                                          level=amp_level):
                    logits = transformer(src_word=src_word, trg_word=trg_word)
                    sum_cost, avg_cost, token_num = criterion(logits, lbl_word)

                tokens_per_cards = token_num.numpy()
                scaled = scaler.scale(avg_cost)  # scale the loss
                scaled.backward()  # do backward

                scaler.minimize(optimizer, scaled)  # update parameters
                if 'set_to_zero' in inspect.getfullargspec(
                        optimizer.clear_grad).args:
                    optimizer.clear_grad(set_to_zero=False)
                else:
                    optimizer.clear_grad()
            else:
                logits = transformer(src_word=src_word, trg_word=trg_word)
                sum_cost, avg_cost, token_num = criterion(logits, lbl_word)
                tokens_per_cards = token_num.numpy()

                avg_cost.backward()

                optimizer.step()
                optimizer.clear_grad()

            train_batch_cost = time.time() - batch_start
            reader_cost_avg.record(train_reader_cost)
            batch_cost_avg.record(train_batch_cost)
            batch_ips_avg.record(train_batch_cost, tokens_per_cards)

            # NOTE: For benchmark, loss infomation on all cards will be printed.
            if step_idx % args.print_step == 0 and (args.benchmark
                                                    or rank == 0):
                total_avg_cost = avg_cost.numpy()

                if step_idx == 0:
                    logger.info(
                        "step_idx: %d, epoch: %d, batch: %d, avg loss: %f, "
                        "normalized loss: %f, ppl: %f " %
                        (step_idx, pass_id, batch_id, total_avg_cost,
                         total_avg_cost - loss_normalizer,
                         np.exp([min(total_avg_cost, 100)])))
                else:
                    train_avg_batch_cost = args.print_step / batch_cost_avg.get_total_time(
                    )
                    logger.info(
                        "step_idx: %d, epoch: %d, batch: %d, avg loss: %f, "
                        "normalized loss: %f, ppl: %f, avg_speed: %.2f step/sec, "
                        "batch_cost: %.5f sec, reader_cost: %.5f sec, tokens: %d, "
                        "ips: %.5f words/sec" %
                        (step_idx, pass_id, batch_id, total_avg_cost,
                         total_avg_cost - loss_normalizer,
                         np.exp([min(total_avg_cost, 100)]),
                         train_avg_batch_cost, batch_cost_avg.get_average(),
                         reader_cost_avg.get_average(),
                         batch_ips_avg.get_total_cnt(),
                         batch_ips_avg.get_average_per_sec()))
                reader_cost_avg.reset()
                batch_cost_avg.reset()
                batch_ips_avg.reset()

            if step_idx % args.save_step == 0 and step_idx != 0:
                # Validation
                transformer.eval()
                total_sum_cost = 0
                total_token_num = 0
                with paddle.no_grad():
                    for input_data in eval_loader:
                        (src_word, trg_word, lbl_word) = input_data
                        if args.use_amp:
                            with paddle.amp.auto_cast(custom_black_list={
                                    'scale', 'reduce_sum', 'elementwise_div'
                            } if amp_level == 'O2' else {},
                                                      level=amp_level):
                                logits = transformer(src_word=src_word,
                                                     trg_word=trg_word)
                                sum_cost, avg_cost, token_num = criterion(
                                    logits, lbl_word)

                        else:
                            logits = transformer(src_word=src_word,
                                                 trg_word=trg_word)
                            sum_cost, avg_cost, token_num = criterion(
                                logits, lbl_word)

                        total_sum_cost += sum_cost.numpy()
                        total_token_num += token_num.numpy()
                        total_avg_cost = total_sum_cost / total_token_num
                    logger.info(
                        "validation, step_idx: %d, avg loss: %f, "
                        "normalized loss: %f, ppl: %f" %
                        (step_idx, total_avg_cost, total_avg_cost -
                         loss_normalizer, np.exp([min(total_avg_cost, 100)])))
                transformer.train()

                if args.save_model and rank == 0:
                    model_dir = os.path.join(args.save_model,
                                             "step_" + str(step_idx))
                    if not os.path.exists(model_dir):
                        os.makedirs(model_dir)
                    paddle.save(
                        transformer.state_dict(),
                        os.path.join(model_dir, "transformer.pdparams"))
                    paddle.save(optimizer.state_dict(),
                                os.path.join(model_dir, "transformer.pdopt"))

            #NOTE: Used for benchmark and use None as default.
            if args.max_iter and step_idx == args.max_iter:
                break
            batch_id += 1
            step_idx += 1
            scheduler.step()
            batch_start = time.time()

        #NOTE: Used for benchmark and use None as default.
        if args.max_iter and step_idx == args.max_iter:
            break

        train_epoch_cost = time.time() - epoch_start
        logger.info("train epoch: %d, epoch_cost: %.5f s" %
                    (pass_id, train_epoch_cost))

    if args.save_model and rank == 0:
        model_dir = os.path.join(args.save_model, "step_final")
        if not os.path.exists(model_dir):
            os.makedirs(model_dir)
        paddle.save(transformer.state_dict(),
                    os.path.join(model_dir, "transformer.pdparams"))
        paddle.save(optimizer.state_dict(),
                    os.path.join(model_dir, "transformer.pdopt"))