Example #1
0
def save_parameters(checkpoint_dir, iteration, model, optimizer=None):
    """Checkpoint the latest trained model parameters.

    Args:
        checkpoint_dir (str): the directory where checkpoint is saved.
        iteration (int): the latest iteration number.
        model (obj): model to be checkpointed.
        optimizer (obj, optional): optimizer to be checkpointed.
            Defaults to None.

    Returns:
        None
    """
    checkpoint_path = os.path.join(checkpoint_dir, "step-{}".format(iteration))
    model_dict = model.state_dict()
    dg.save_dygraph(model_dict, checkpoint_path)
    print("[checkpoint] Saved model to {}.pdparams".format(checkpoint_path))

    if optimizer:
        opt_dict = optimizer.state_dict()
        dg.save_dygraph(opt_dict, checkpoint_path)
        print("[checkpoint] Saved optimzier state to {}.pdopt".format(
            checkpoint_path))

    _save_checkpoint(checkpoint_dir, iteration)
    elif 'body_conv' in k:
        is_bias = k.split('_')[-1]
        new_k = '.'.join(['roi_body_uv_heads', k[:-2], rename_dict[is_bias]])
        new_weight_dict[new_k] = v
    else:
        print(k)

import sys

sys.path.append('/home/aistudio/')
import numpy as np

import paddle.fluid.dygraph as dg

from densepose.config.config import get_cfg
from densepose.modeling.build import build_model

cfg = get_cfg()

densepose = build_model(cfg)

densepose.load_dict(new_weight_dict)

import paddle.fluid.dygraph as dg

dg.save_dygraph(
    densepose.state_dict(),
    '/home/aistudio/densepose/pretrained_models/DensePose_ResNet101_FPN_32x8d_s1x-e2e'
)
Example #3
0
                weight_decay=0.01)
    g_clip = F.dygraph_grad_clip.GradClipByGlobalNorm(1.0)
    for epoch in range(EPOCH):
        for step, (ids_student, ids, sids,
                   labels) in enumerate(train_ds.start(place)):
            loss, logits = teacher_model(ids, labels=labels)
            loss.backward()
            if step % 10 == 0:
                print('[step %03d] teacher train loss %.5f lr %.3e' %
                      (step, loss.numpy(), opt.current_step_lr()))
            opt.minimize(loss, grad_clip=g_clip)
            teacher_model.clear_gradients()
            if step % 100 == 0:
                f1 = evaluate_teacher(teacher_model, dev_ds)
                print('teacher f1: %.5f' % f1)
    D.save_dygraph(teacher_model.state_dict(), './teacher_model')
else:
    state_dict, _ = D.load_dygraph('./teacher_model')
    teacher_model.set_dict(state_dict)
    f1 = evaluate_teacher(teacher_model, dev_ds)
    print('teacher f1: %.5f' % f1)

# 定义finetune student 模型所需要的超参数
SEQLEN = 256
BATCH = 100
EPOCH = 10
LR = 1e-4


def evaluate_student(model, dataset):
    all_pred, all_label = [], []
Example #4
0
def train_model(model, loader, criterion, optimizer, clipper, writer, args,
                hparams):
    assert fluid.framework.in_dygraph_mode(
    ), "this function must be run within dygraph guard"

    n_trainers = dg.parallel.Env().nranks
    local_rank = dg.parallel.Env().local_rank

    # amount of shifting when compute losses
    linear_shift = hparams.outputs_per_step
    mel_shift = hparams.outputs_per_step

    global_step = 0
    global_epoch = 0
    ismultispeaker = model.n_speakers > 1
    checkpoint_dir = os.path.join(args.output, "checkpoints")
    tensorboard_dir = os.path.join(args.output, "log")

    ce_loss = 0
    start_time = time.time()

    for epoch in range(hparams.nepochs):
        epoch_loss = 0.
        for step, inputs in tqdm(enumerate(loader())):

            if len(inputs) == 9:
                (text, input_lengths, mel, linear, text_positions,
                 frame_positions, done, target_lengths, speaker_ids) = inputs
            else:
                (text, input_lengths, mel, linear, text_positions,
                 frame_positions, done, target_lengths) = inputs
                speaker_ids = None

            model.train()
            if not (args.train_seq2seq_only or args.train_postnet_only):
                results = model(text, input_lengths, mel, speaker_ids,
                                text_positions, frame_positions)
                mel_outputs, linear_outputs, alignments, done_hat = results
            elif args.train_seq2seq_only:

                if speaker_ids is not None:
                    speaker_embed = model.speaker_embedding(speaker_ids)
                else:
                    speaker_embed = None
                results = model.seq2seq(text, input_lengths, mel,
                                        speaker_embed, text_positions,
                                        frame_positions)
                mel_outputs, alignments, done_hat, decoder_states = results
                if model.r > 1:
                    mel_outputs = fluid.layers.transpose(
                        mel_outputs, [0, 3, 2, 1])
                    mel_outputs = fluid.layers.reshape(
                        mel_outputs,
                        [mel_outputs.shape[0], -1, 1, model.mel_dim])
                    mel_outputs = fluid.layers.transpose(
                        mel_outputs, [0, 3, 2, 1])

                linear_outputs = None
            else:
                assert (
                    model.use_decoder_state_for_postnet_input is False
                ), "when train only the converter, you have no decoder states"

                if speaker_ids is not None:
                    speaker_embed = model.speaker_embedding(speaker_ids)
                else:
                    speaker_embed = None
                linear_outputs = model.converter(mel, speaker_embed)
                alignments = None
                mel_outputs = None
                done_hat = None

            if not args.train_seq2seq_only:
                n_priority_freq = int(hparams.priority_freq /
                                      (hparams.sample_rate * 0.5) *
                                      model.linear_dim)
                linear_mask = fluid.layers.sequence_mask(
                    target_lengths, maxlen=linear.shape[-1], dtype="float32")
                linear_mask = linear_mask[:, linear_shift:]
                linear_predicted = linear_outputs[:, :, :, :-linear_shift]
                linear_target = linear[:, :, :, linear_shift:]
                lin_l1_loss = criterion.l1_loss(linear_predicted,
                                                linear_target,
                                                linear_mask,
                                                priority_bin=n_priority_freq)
                lin_div = criterion.binary_divergence(linear_predicted,
                                                      linear_target,
                                                      linear_mask)
                lin_loss = criterion.binary_divergence_weight * lin_div \
                    + (1 - criterion.binary_divergence_weight) * lin_l1_loss
                if writer is not None and local_rank == 0:
                    writer.add_scalar("linear_loss", float(lin_loss.numpy()),
                                      global_step)
                    writer.add_scalar("linear_l1_loss",
                                      float(lin_l1_loss.numpy()), global_step)
                    writer.add_scalar("linear_binary_div_loss",
                                      float(lin_div.numpy()), global_step)

            if not args.train_postnet_only:
                mel_lengths = target_lengths // hparams.downsample_step
                mel_mask = fluid.layers.sequence_mask(mel_lengths,
                                                      maxlen=mel.shape[-1],
                                                      dtype="float32")
                mel_mask = mel_mask[:, mel_shift:]
                mel_predicted = mel_outputs[:, :, :, :-mel_shift]
                mel_target = mel[:, :, :, mel_shift:]
                mel_l1_loss = criterion.l1_loss(mel_predicted, mel_target,
                                                mel_mask)
                mel_div = criterion.binary_divergence(mel_predicted,
                                                      mel_target, mel_mask)
                mel_loss = criterion.binary_divergence_weight * mel_div \
                    + (1 - criterion.binary_divergence_weight) * mel_l1_loss
                if writer is not None and local_rank == 0:
                    writer.add_scalar("mel_loss", float(mel_loss.numpy()),
                                      global_step)
                    writer.add_scalar("mel_l1_loss",
                                      float(mel_l1_loss.numpy()), global_step)
                    writer.add_scalar("mel_binary_div_loss",
                                      float(mel_div.numpy()), global_step)

                done_loss = criterion.done_loss(done_hat, done)
                if writer is not None and local_rank == 0:
                    writer.add_scalar("done_loss", float(done_loss.numpy()),
                                      global_step)

                if hparams.use_guided_attention:
                    decoder_length = target_lengths.numpy() / (
                        hparams.outputs_per_step * hparams.downsample_step)
                    attn_loss = criterion.attention_loss(
                        alignments, input_lengths.numpy(), decoder_length)
                    if writer is not None and local_rank == 0:
                        writer.add_scalar("attention_loss",
                                          float(attn_loss.numpy()),
                                          global_step)

            if not (args.train_seq2seq_only or args.train_postnet_only):
                if hparams.use_guided_attention:
                    loss = lin_loss + mel_loss + done_loss + attn_loss
                else:
                    loss = lin_loss + mel_loss + done_loss
            elif args.train_seq2seq_only:
                if hparams.use_guided_attention:
                    loss = mel_loss + done_loss + attn_loss
                else:
                    loss = mel_loss + done_loss
            else:
                loss = lin_loss

            if writer is not None and local_rank == 0:
                writer.add_scalar("loss", float(loss.numpy()), global_step)

            if isinstance(optimizer._learning_rate,
                          fluid.optimizer.LearningRateDecay):
                current_lr = optimizer._learning_rate.step().numpy()
            else:
                current_lr = optimizer._learning_rate
            if writer is not None and local_rank == 0:
                writer.add_scalar("learning_rate", current_lr, global_step)

            epoch_loss += loss.numpy()[0]

            if (local_rank == 0 and global_step > 0
                    and global_step % hparams.checkpoint_interval == 0):

                save_states(global_step, writer, mel_outputs,
                            linear_outputs, alignments, mel, linear,
                            input_lengths.numpy(), checkpoint_dir)
                step_path = os.path.join(
                    checkpoint_dir, "checkpoint_{:09d}".format(global_step))
                dg.save_dygraph(model.state_dict(), step_path)
                dg.save_dygraph(optimizer.state_dict(), step_path)

            if (local_rank == 0 and global_step > 0
                    and global_step % hparams.eval_interval == 0):
                eval_model(global_step, writer, model, checkpoint_dir,
                           ismultispeaker)

            if args.use_data_parallel:
                loss = model.scale_loss(loss)
                loss.backward()
                model.apply_collective_grads()
            else:
                loss.backward()

            if not (args.train_seq2seq_only or args.train_postnet_only):
                param_list = model.parameters()
            elif args.train_seq2seq_only:
                if ismultispeaker:
                    param_list = chain(model.speaker_embedding.parameters(),
                                       model.seq2seq.parameters())
                else:
                    param_list = model.seq2seq.parameters()
            else:
                if ismultispeaker:
                    param_list = chain(model.speaker_embedding.parameters(),
                                       model.seq2seq.parameters())
                else:
                    param_list = model.converter.parameters()

            optimizer.minimize(loss,
                               grad_clip=clipper,
                               parameter_list=param_list)

            if not (args.train_seq2seq_only or args.train_postnet_only):
                model.clear_gradients()
            elif args.train_seq2seq_only:
                if ismultispeaker:
                    model.speaker_embedding.clear_gradients()
                model.seq2seq.clear_gradients()
            else:
                if ismultispeaker:
                    model.speaker_embedding.clear_gradients()
                model.converter.clear_gradients()

            global_step += 1

        average_loss_in_epoch = epoch_loss / (step + 1)
        print("Epoch loss: {}".format(average_loss_in_epoch))
        if writer is not None and local_rank == 0:
            writer.add_scalar("average_loss_in_epoch", average_loss_in_epoch,
                              global_epoch)
        ce_loss = average_loss_in_epoch
        global_epoch += 1

    end_time = time.time()
    epoch_time = (end_time - start_time) / global_epoch
    print("kpis\teach_epoch_duration_frame%s_card%s\t%s" %
          (hparams.outputs_per_step, n_trainers, epoch_time))
    print("kpis\ttrain_cost_frame%s_card%s\t%f" %
          (hparams.outputs_per_step, n_trainers, ce_loss))
Example #5
0
def main(args):
    local_rank = dg.parallel.Env().local_rank if args.use_data_parallel else 0
    nranks = dg.parallel.Env().nranks if args.use_data_parallel else 1

    with open(args.config_path) as f:
        cfg = yaml.load(f, Loader=yaml.Loader)

    global_step = 0
    place = (fluid.CUDAPlace(dg.parallel.Env().dev_id)
             if args.use_data_parallel else fluid.CUDAPlace(0)
             if args.use_gpu else fluid.CPUPlace())

    if not os.path.exists(args.log_dir):
        os.mkdir(args.log_dir)
    path = os.path.join(args.log_dir, 'fastspeech')

    writer = SummaryWriter(path) if local_rank == 0 else None

    with dg.guard(place):
        with fluid.unique_name.guard():
            transformer_tts = TransformerTTS(cfg)
            model_dict, _ = load_checkpoint(
                str(args.transformer_step),
                os.path.join(args.transtts_path, "transformer"))
            transformer_tts.set_dict(model_dict)
            transformer_tts.eval()

        model = FastSpeech(cfg)
        model.train()
        optimizer = fluid.optimizer.AdamOptimizer(
            learning_rate=dg.NoamDecay(1 / (
                cfg['warm_up_step'] * (args.lr**2)), cfg['warm_up_step']),
            parameter_list=model.parameters())
        reader = LJSpeechLoader(
            cfg, args, nranks, local_rank, shuffle=True).reader()

        if args.checkpoint_path is not None:
            model_dict, opti_dict = load_checkpoint(
                str(args.fastspeech_step),
                os.path.join(args.checkpoint_path, "fastspeech"))
            model.set_dict(model_dict)
            optimizer.set_dict(opti_dict)
            global_step = args.fastspeech_step
            print("load checkpoint!!!")

        if args.use_data_parallel:
            strategy = dg.parallel.prepare_context()
            model = fluid.dygraph.parallel.DataParallel(model, strategy)

        for epoch in range(args.epochs):
            pbar = tqdm(reader)

            for i, data in enumerate(pbar):
                pbar.set_description('Processing at epoch %d' % epoch)
                (character, mel, mel_input, pos_text, pos_mel, text_length,
                 mel_lens, enc_slf_mask, enc_query_mask, dec_slf_mask,
                 enc_dec_mask, dec_query_slf_mask, dec_query_mask) = data

                _, _, attn_probs, _, _, _ = transformer_tts(
                    character,
                    mel_input,
                    pos_text,
                    pos_mel,
                    dec_slf_mask=dec_slf_mask,
                    enc_slf_mask=enc_slf_mask,
                    enc_query_mask=enc_query_mask,
                    enc_dec_mask=enc_dec_mask,
                    dec_query_slf_mask=dec_query_slf_mask,
                    dec_query_mask=dec_query_mask)
                alignment, max_attn = get_alignment(attn_probs, mel_lens,
                                                    cfg['transformer_head'])
                alignment = dg.to_variable(alignment).astype(np.float32)

                if local_rank == 0 and global_step % 5 == 1:
                    x = np.uint8(
                        cm.viridis(max_attn[8, :mel_lens.numpy()[8]]) * 255)
                    writer.add_image(
                        'Attention_%d_0' % global_step,
                        x,
                        0,
                        dataformats="HWC")

                global_step += 1

                #Forward
                result = model(
                    character,
                    pos_text,
                    mel_pos=pos_mel,
                    length_target=alignment,
                    enc_non_pad_mask=enc_query_mask,
                    enc_slf_attn_mask=enc_slf_mask,
                    dec_non_pad_mask=dec_query_slf_mask,
                    dec_slf_attn_mask=dec_slf_mask)
                mel_output, mel_output_postnet, duration_predictor_output, _, _ = result
                mel_loss = layers.mse_loss(mel_output, mel)
                mel_postnet_loss = layers.mse_loss(mel_output_postnet, mel)
                duration_loss = layers.mean(
                    layers.abs(
                        layers.elementwise_sub(duration_predictor_output,
                                               alignment)))
                total_loss = mel_loss + mel_postnet_loss + duration_loss

                if local_rank == 0:
                    writer.add_scalar('mel_loss',
                                      mel_loss.numpy(), global_step)
                    writer.add_scalar('post_mel_loss',
                                      mel_postnet_loss.numpy(), global_step)
                    writer.add_scalar('duration_loss',
                                      duration_loss.numpy(), global_step)
                    writer.add_scalar('learning_rate',
                                      optimizer._learning_rate.step().numpy(),
                                      global_step)

                if args.use_data_parallel:
                    total_loss = model.scale_loss(total_loss)
                    total_loss.backward()
                    model.apply_collective_grads()
                else:
                    total_loss.backward()
                optimizer.minimize(
                    total_loss,
                    grad_clip=fluid.dygraph_grad_clip.GradClipByGlobalNorm(cfg[
                        'grad_clip_thresh']))
                model.clear_gradients()

                # save checkpoint
                if local_rank == 0 and global_step % args.save_step == 0:
                    if not os.path.exists(args.save_path):
                        os.mkdir(args.save_path)
                    save_path = os.path.join(args.save_path,
                                             'fastspeech/%d' % global_step)
                    dg.save_dygraph(model.state_dict(), save_path)
                    dg.save_dygraph(optimizer.state_dict(), save_path)
        if local_rank == 0:
            writer.close()
Example #6
0
def main(args):

    local_rank = dg.parallel.Env().local_rank if args.use_data_parallel else 0
    nranks = dg.parallel.Env().nranks if args.use_data_parallel else 1

    with open(args.config_path) as f:
        cfg = yaml.load(f, Loader=yaml.Loader)

    global_step = 0
    place = (fluid.CUDAPlace(dg.parallel.Env().dev_id)
             if args.use_data_parallel else
             fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace())

    if not os.path.exists(args.log_dir):
        os.mkdir(args.log_dir)
    path = os.path.join(args.log_dir, 'vocoder')

    writer = SummaryWriter(path) if local_rank == 0 else None

    with dg.guard(place):
        model = Vocoder(cfg, args.batch_size)

        model.train()
        optimizer = fluid.optimizer.AdamOptimizer(
            learning_rate=dg.NoamDecay(
                1 / (cfg['warm_up_step'] * (args.lr**2)), cfg['warm_up_step']),
            parameter_list=model.parameters())

        if args.checkpoint_path is not None:
            model_dict, opti_dict = load_checkpoint(
                str(args.vocoder_step),
                os.path.join(args.checkpoint_path, "vocoder"))
            model.set_dict(model_dict)
            optimizer.set_dict(opti_dict)
            global_step = args.vocoder_step
            print("load checkpoint!!!")

        if args.use_data_parallel:
            strategy = dg.parallel.prepare_context()
            model = fluid.dygraph.parallel.DataParallel(model, strategy)

        reader = LJSpeechLoader(cfg, args, nranks, local_rank,
                                is_vocoder=True).reader()

        for epoch in range(args.epochs):
            pbar = tqdm(reader)
            for i, data in enumerate(pbar):
                pbar.set_description('Processing at epoch %d' % epoch)
                mel, mag = data
                mag = dg.to_variable(mag.numpy())
                mel = dg.to_variable(mel.numpy())
                global_step += 1

                mag_pred = model(mel)
                loss = layers.mean(
                    layers.abs(layers.elementwise_sub(mag_pred, mag)))

                if args.use_data_parallel:
                    loss = model.scale_loss(loss)
                    loss.backward()
                    model.apply_collective_grads()
                else:
                    loss.backward()
                optimizer.minimize(
                    loss,
                    grad_clip=fluid.dygraph_grad_clip.GradClipByGlobalNorm(
                        cfg['grad_clip_thresh']))
                model.clear_gradients()

                if local_rank == 0:
                    writer.add_scalars('training_loss', {
                        'loss': loss.numpy(),
                    }, global_step)

                    if global_step % args.save_step == 0:
                        if not os.path.exists(args.save_path):
                            os.mkdir(args.save_path)
                        save_path = os.path.join(args.save_path,
                                                 'vocoder/%d' % global_step)
                        dg.save_dygraph(model.state_dict(), save_path)
                        dg.save_dygraph(optimizer.state_dict(), save_path)

        if local_rank == 0:
            writer.close()
Example #7
0
def main(args):
    local_rank = dg.parallel.Env().local_rank if args.use_data_parallel else 0
    nranks = dg.parallel.Env().nranks if args.use_data_parallel else 1

    with open(args.config_path) as f:
        cfg = yaml.load(f, Loader=yaml.Loader)

    global_step = 0
    place = (fluid.CUDAPlace(dg.parallel.Env().dev_id)
             if args.use_data_parallel else
             fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace())

    if not os.path.exists(args.log_dir):
        os.mkdir(args.log_dir)
    path = os.path.join(args.log_dir, 'transformer')

    writer = SummaryWriter(path) if local_rank == 0 else None

    with dg.guard(place):
        model = TransformerTTS(cfg)

        model.train()
        optimizer = fluid.optimizer.AdamOptimizer(
            learning_rate=dg.NoamDecay(
                1 / (cfg['warm_up_step'] * (args.lr**2)), cfg['warm_up_step']),
            parameter_list=model.parameters())

        if args.checkpoint_path is not None:
            model_dict, opti_dict = load_checkpoint(
                str(args.transformer_step),
                os.path.join(args.checkpoint_path, "transformer"))
            model.set_dict(model_dict)
            optimizer.set_dict(opti_dict)
            global_step = args.transformer_step
            print("load checkpoint!!!")

        if args.use_data_parallel:
            strategy = dg.parallel.prepare_context()
            model = fluid.dygraph.parallel.DataParallel(model, strategy)

        reader = LJSpeechLoader(cfg, args, nranks, local_rank,
                                shuffle=True).reader()

        for epoch in range(args.epochs):
            pbar = tqdm(reader)
            for i, data in enumerate(pbar):
                pbar.set_description('Processing at epoch %d' % epoch)
                character, mel, mel_input, pos_text, pos_mel, text_length, _, enc_slf_mask, enc_query_mask, dec_slf_mask, enc_dec_mask, dec_query_slf_mask, dec_query_mask = data

                global_step += 1

                mel_pred, postnet_pred, attn_probs, stop_preds, attn_enc, attn_dec = model(
                    character,
                    mel_input,
                    pos_text,
                    pos_mel,
                    dec_slf_mask=dec_slf_mask,
                    enc_slf_mask=enc_slf_mask,
                    enc_query_mask=enc_query_mask,
                    enc_dec_mask=enc_dec_mask,
                    dec_query_slf_mask=dec_query_slf_mask,
                    dec_query_mask=dec_query_mask)

                mel_loss = layers.mean(
                    layers.abs(layers.elementwise_sub(mel_pred, mel)))
                post_mel_loss = layers.mean(
                    layers.abs(layers.elementwise_sub(postnet_pred, mel)))
                loss = mel_loss + post_mel_loss

                # Note: When used stop token loss the learning did not work.
                if args.stop_token:
                    label = (pos_mel == 0).astype(np.float32)
                    stop_loss = cross_entropy(stop_preds, label)
                    loss = loss + stop_loss

                if local_rank == 0:
                    writer.add_scalars(
                        'training_loss', {
                            'mel_loss': mel_loss.numpy(),
                            'post_mel_loss': post_mel_loss.numpy()
                        }, global_step)

                    if args.stop_token:
                        writer.add_scalar('stop_loss', stop_loss.numpy(),
                                          global_step)

                    if args.use_data_parallel:
                        writer.add_scalars(
                            'alphas', {
                                'encoder_alpha':
                                model._layers.encoder.alpha.numpy(),
                                'decoder_alpha':
                                model._layers.decoder.alpha.numpy(),
                            }, global_step)
                    else:
                        writer.add_scalars(
                            'alphas', {
                                'encoder_alpha': model.encoder.alpha.numpy(),
                                'decoder_alpha': model.decoder.alpha.numpy(),
                            }, global_step)

                    writer.add_scalar('learning_rate',
                                      optimizer._learning_rate.step().numpy(),
                                      global_step)

                    if global_step % args.image_step == 1:
                        for i, prob in enumerate(attn_probs):
                            for j in range(4):
                                x = np.uint8(
                                    cm.viridis(prob.numpy()[j * args.batch_size
                                                            // 2]) * 255)
                                writer.add_image('Attention_%d_0' %
                                                 global_step,
                                                 x,
                                                 i * 4 + j,
                                                 dataformats="HWC")

                        for i, prob in enumerate(attn_enc):
                            for j in range(4):
                                x = np.uint8(
                                    cm.viridis(prob.numpy()[j * args.batch_size
                                                            // 2]) * 255)
                                writer.add_image('Attention_enc_%d_0' %
                                                 global_step,
                                                 x,
                                                 i * 4 + j,
                                                 dataformats="HWC")

                        for i, prob in enumerate(attn_dec):
                            for j in range(4):
                                x = np.uint8(
                                    cm.viridis(prob.numpy()[j * args.batch_size
                                                            // 2]) * 255)
                                writer.add_image('Attention_dec_%d_0' %
                                                 global_step,
                                                 x,
                                                 i * 4 + j,
                                                 dataformats="HWC")

                if args.use_data_parallel:
                    loss = model.scale_loss(loss)
                    loss.backward()
                    model.apply_collective_grads()
                else:
                    loss.backward()
                optimizer.minimize(
                    loss,
                    grad_clip=fluid.dygraph_grad_clip.GradClipByGlobalNorm(
                        cfg['grad_clip_thresh']))
                model.clear_gradients()

                # save checkpoint
                if local_rank == 0 and global_step % args.save_step == 0:
                    if not os.path.exists(args.save_path):
                        os.mkdir(args.save_path)
                    save_path = os.path.join(args.save_path,
                                             'transformer/%d' % global_step)
                    dg.save_dygraph(model.state_dict(), save_path)
                    dg.save_dygraph(optimizer.state_dict(), save_path)
        if local_rank == 0:
            writer.close()