Beispiel #1
0
    def setup_model(self):
        config = self.config
        frontend = English()
        model = TransformerTTS(
            frontend,
            d_encoder=config.model.d_encoder,
            d_decoder=config.model.d_decoder,
            d_mel=config.data.d_mel,
            n_heads=config.model.n_heads,
            d_ffn=config.model.d_ffn,
            encoder_layers=config.model.encoder_layers,
            decoder_layers=config.model.decoder_layers,
            d_prenet=config.model.d_prenet,
            d_postnet=config.model.d_postnet,
            postnet_layers=config.model.postnet_layers,
            postnet_kernel_size=config.model.postnet_kernel_size,
            max_reduction_factor=config.model.max_reduction_factor,
            decoder_prenet_dropout=config.model.decoder_prenet_dropout,
            dropout=config.model.dropout)
        if self.parallel:
            model = paddle.DataParallel(model)
        optimizer = paddle.optimizer.Adam(learning_rate=config.training.lr,
                                          beta1=0.9,
                                          beta2=0.98,
                                          epsilon=1e-9,
                                          parameters=model.parameters())
        criterion = TransformerTTSLoss(config.model.stop_loss_scale)
        drop_n_heads = scheduler.StepWise(config.training.drop_n_heads)
        reduction_factor = scheduler.StepWise(config.training.reduction_factor)

        self.model = model
        self.optimizer = optimizer
        self.criterion = criterion
        self.drop_n_heads = drop_n_heads
        self.reduction_factor = reduction_factor
Beispiel #2
0
def alignments(args):
    local_rank = dg.parallel.Env().local_rank
    place = (fluid.CUDAPlace(local_rank) if args.use_gpu else fluid.CPUPlace())

    with open(args.config) as f:
        cfg = yaml.load(f, Loader=yaml.Loader)

    with dg.guard(place):
        network_cfg = cfg['network']
        model = TransformerTTS(
            network_cfg['embedding_size'], network_cfg['hidden_size'],
            network_cfg['encoder_num_head'], network_cfg['encoder_n_layers'],
            cfg['audio']['num_mels'], network_cfg['outputs_per_step'],
            network_cfg['decoder_num_head'], network_cfg['decoder_n_layers'])
        # Load parameters.
        global_step = io.load_parameters(
            model=model, checkpoint_path=args.checkpoint_transformer)
        model.eval()

        # get text data
        root = Path(args.data)
        csv_path = root.joinpath("metadata.csv")
        table = pd.read_csv(csv_path,
                            sep="|",
                            header=None,
                            quoting=csv.QUOTE_NONE,
                            names=["fname", "raw_text", "normalized_text"])
        ljspeech_processor = audio.AudioProcessor(
            sample_rate=cfg['audio']['sr'],
            num_mels=cfg['audio']['num_mels'],
            min_level_db=cfg['audio']['min_level_db'],
            ref_level_db=cfg['audio']['ref_level_db'],
            n_fft=cfg['audio']['n_fft'],
            win_length=cfg['audio']['win_length'],
            hop_length=cfg['audio']['hop_length'],
            power=cfg['audio']['power'],
            preemphasis=cfg['audio']['preemphasis'],
            signal_norm=True,
            symmetric_norm=False,
            max_norm=1.,
            mel_fmin=0,
            mel_fmax=None,
            clip_norm=True,
            griffin_lim_iters=60,
            do_trim_silence=False,
            sound_norm=False)

        pbar = tqdm(range(len(table)))
        alignments = OrderedDict()
        for i in pbar:
            fname, raw_text, normalized_text = table.iloc[i]
            # init input
            text = np.asarray(text_to_sequence(normalized_text))
            text = fluid.layers.unsqueeze(dg.to_variable(text), [0])
            pos_text = np.arange(1, text.shape[1] + 1)
            pos_text = fluid.layers.unsqueeze(dg.to_variable(pos_text), [0])
            wav = ljspeech_processor.load_wav(
                os.path.join(args.data, 'wavs', fname + ".wav"))
            mel_input = ljspeech_processor.melspectrogram(wav).astype(
                np.float32)
            mel_input = np.transpose(mel_input, axes=(1, 0))
            mel_input = fluid.layers.unsqueeze(dg.to_variable(mel_input), [0])
            mel_lens = mel_input.shape[1]

            dec_slf_mask = get_triu_tensor(mel_input,
                                           mel_input).astype(np.float32)
            dec_slf_mask = np.expand_dims(dec_slf_mask, axis=0)
            dec_slf_mask = fluid.layers.cast(dg.to_variable(dec_slf_mask != 0),
                                             np.float32) * (-2**32 + 1)
            pos_mel = np.arange(1, mel_input.shape[1] + 1)
            pos_mel = fluid.layers.unsqueeze(dg.to_variable(pos_mel), [0])
            mel_pred, postnet_pred, attn_probs, stop_preds, attn_enc, attn_dec = model(
                text, mel_input, pos_text, pos_mel, dec_slf_mask)
            mel_input = fluid.layers.concat(
                [mel_input, postnet_pred[:, -1:, :]], axis=1)

            alignment, _ = get_alignment(attn_probs, mel_lens,
                                         network_cfg['decoder_num_head'])
            alignments[fname] = alignment
        with open(args.output + '.txt', "wb") as f:
            pickle.dump(alignments, f)
Beispiel #3
0
def synthesis(text_input, args):
    local_rank = dg.parallel.Env().local_rank
    place = (fluid.CUDAPlace(local_rank) if args.use_gpu else fluid.CPUPlace())

    with open(args.config) as f:
        cfg = yaml.load(f, Loader=yaml.Loader)

    # tensorboard
    if not os.path.exists(args.output):
        os.mkdir(args.output)

    writer = SummaryWriter(os.path.join(args.output, 'log'))

    fluid.enable_dygraph(place)
    with fluid.unique_name.guard():
        network_cfg = cfg['network']
        model = TransformerTTS(
            network_cfg['embedding_size'], network_cfg['hidden_size'],
            network_cfg['encoder_num_head'], network_cfg['encoder_n_layers'],
            cfg['audio']['num_mels'], network_cfg['outputs_per_step'],
            network_cfg['decoder_num_head'], network_cfg['decoder_n_layers'])
        # Load parameters.
        global_step = io.load_parameters(
            model=model, checkpoint_path=args.checkpoint_transformer)
        model.eval()

    with fluid.unique_name.guard():
        model_vocoder = Vocoder(cfg['train']['batch_size'],
                                cfg['vocoder']['hidden_size'],
                                cfg['audio']['num_mels'],
                                cfg['audio']['n_fft'])
        # Load parameters.
        global_step = io.load_parameters(
            model=model_vocoder, checkpoint_path=args.checkpoint_vocoder)
        model_vocoder.eval()
    # init input
    text = np.asarray(text_to_sequence(text_input))
    text = fluid.layers.unsqueeze(dg.to_variable(text), [0])
    mel_input = dg.to_variable(np.zeros([1, 1, 80])).astype(np.float32)
    pos_text = np.arange(1, text.shape[1] + 1)
    pos_text = fluid.layers.unsqueeze(dg.to_variable(pos_text), [0])

    pbar = tqdm(range(args.max_len))
    for i in pbar:
        pos_mel = np.arange(1, mel_input.shape[1] + 1)
        pos_mel = fluid.layers.unsqueeze(dg.to_variable(pos_mel), [0])
        mel_pred, postnet_pred, attn_probs, stop_preds, attn_enc, attn_dec = model(
            text, mel_input, pos_text, pos_mel)
        mel_input = fluid.layers.concat([mel_input, postnet_pred[:, -1:, :]],
                                        axis=1)

    mag_pred = model_vocoder(postnet_pred)

    _ljspeech_processor = audio.AudioProcessor(
        sample_rate=cfg['audio']['sr'],
        num_mels=cfg['audio']['num_mels'],
        min_level_db=cfg['audio']['min_level_db'],
        ref_level_db=cfg['audio']['ref_level_db'],
        n_fft=cfg['audio']['n_fft'],
        win_length=cfg['audio']['win_length'],
        hop_length=cfg['audio']['hop_length'],
        power=cfg['audio']['power'],
        preemphasis=cfg['audio']['preemphasis'],
        signal_norm=True,
        symmetric_norm=False,
        max_norm=1.,
        mel_fmin=0,
        mel_fmax=None,
        clip_norm=True,
        griffin_lim_iters=60,
        do_trim_silence=False,
        sound_norm=False)

    # synthesis with cbhg
    wav = _ljspeech_processor.inv_spectrogram(
        fluid.layers.transpose(fluid.layers.squeeze(mag_pred, [0]),
                               [1, 0]).numpy())
    global_step = 0
    for i, prob in enumerate(attn_probs):
        for j in range(4):
            x = np.uint8(cm.viridis(prob.numpy()[j]) * 255)
            writer.add_image('Attention_%d_0' % global_step,
                             x,
                             i * 4 + j,
                             dataformats="HWC")

    writer.add_audio(text_input + '(cbhg)', wav, 0, cfg['audio']['sr'])

    if not os.path.exists(os.path.join(args.output, 'samples')):
        os.mkdir(os.path.join(args.output, 'samples'))
    write(os.path.join(os.path.join(args.output, 'samples'), 'cbhg.wav'),
          cfg['audio']['sr'], wav)

    # synthesis with griffin-lim
    wav = _ljspeech_processor.inv_melspectrogram(
        fluid.layers.transpose(fluid.layers.squeeze(postnet_pred, [0]),
                               [1, 0]).numpy())
    writer.add_audio(text_input + '(griffin)', wav, 0, cfg['audio']['sr'])

    write(os.path.join(os.path.join(args.output, 'samples'), 'griffin.wav'),
          cfg['audio']['sr'], wav)
    print("Synthesis completed !!!")
    writer.close()
Beispiel #4
0
def synthesis(text_input, args):
    local_rank = dg.parallel.Env().local_rank
    place = (fluid.CUDAPlace(local_rank) if args.use_gpu else fluid.CPUPlace())

    with open(args.config) as f:
        cfg = yaml.load(f, Loader=yaml.Loader)

    # tensorboard
    if not os.path.exists(args.output):
        os.mkdir(args.output)

    writer = SummaryWriter(os.path.join(args.output, 'log'))

    fluid.enable_dygraph(place)
    with fluid.unique_name.guard():
        network_cfg = cfg['network']
        model = TransformerTTS(
            network_cfg['embedding_size'], network_cfg['hidden_size'],
            network_cfg['encoder_num_head'], network_cfg['encoder_n_layers'],
            cfg['audio']['num_mels'], network_cfg['outputs_per_step'],
            network_cfg['decoder_num_head'], network_cfg['decoder_n_layers'])
        # Load parameters.
        global_step = io.load_parameters(
            model=model, checkpoint_path=args.checkpoint_transformer)
        model.eval()

    # init input
    text = np.asarray(text_to_sequence(text_input))
    text = fluid.layers.unsqueeze(dg.to_variable(text).astype(np.int64), [0])
    mel_input = dg.to_variable(np.zeros([1, 1, 80])).astype(np.float32)
    pos_text = np.arange(1, text.shape[1] + 1)
    pos_text = fluid.layers.unsqueeze(
        dg.to_variable(pos_text).astype(np.int64), [0])

    for i in range(args.max_len):
        pos_mel = np.arange(1, mel_input.shape[1] + 1)
        pos_mel = fluid.layers.unsqueeze(
            dg.to_variable(pos_mel).astype(np.int64), [0])
        mel_pred, postnet_pred, attn_probs, stop_preds, attn_enc, attn_dec = model(
            text, mel_input, pos_text, pos_mel)
        if stop_preds.numpy()[0, -1] > args.stop_threshold:
            break
        mel_input = fluid.layers.concat([mel_input, postnet_pred[:, -1:, :]],
                                        axis=1)
    global_step = 0
    for i, prob in enumerate(attn_probs):
        for j in range(4):
            x = np.uint8(cm.viridis(prob.numpy()[j]) * 255)
            writer.add_image('Attention_%d_0' % global_step,
                             x,
                             i * 4 + j,
                             dataformats="HWC")

    if args.vocoder == 'griffin-lim':
        #synthesis use griffin-lim
        wav = synthesis_with_griffinlim(postnet_pred, cfg['audio'])
    elif args.vocoder == 'waveflow':
        # synthesis use waveflow
        wav = synthesis_with_waveflow(postnet_pred, args,
                                      args.checkpoint_vocoder, place)
    else:
        print(
            'vocoder error, we only support griffinlim and waveflow, but recevied %s.'
            % args.vocoder)

    writer.add_audio(text_input + '(' + args.vocoder + ')', wav, 0,
                     cfg['audio']['sr'])
    if not os.path.exists(os.path.join(args.output, 'samples')):
        os.mkdir(os.path.join(args.output, 'samples'))
    write(
        os.path.join(os.path.join(args.output, 'samples'),
                     args.vocoder + '.wav'), cfg['audio']['sr'], wav)
    print("Synthesis completed !!!")
    writer.close()
Beispiel #5
0
def alignments(args):
    local_rank = dg.parallel.Env().local_rank
    place = (fluid.CUDAPlace(local_rank) if args.use_gpu else fluid.CPUPlace())

    with open(args.config) as f:
        cfg = yaml.load(f, Loader=yaml.Loader)

    with dg.guard(place):
        network_cfg = cfg['network']
        model = TransformerTTS(
            network_cfg['embedding_size'], network_cfg['hidden_size'],
            network_cfg['encoder_num_head'], network_cfg['encoder_n_layers'],
            cfg['audio']['num_mels'], network_cfg['outputs_per_step'],
            network_cfg['decoder_num_head'], network_cfg['decoder_n_layers'])
        # Load parameters.
        global_step = io.load_parameters(
            model=model, checkpoint_path=args.checkpoint_transformer)
        model.eval()

        # get text data
        root = Path(args.data)
        csv_path = root.joinpath("metadata.csv")
        table = pd.read_csv(
            csv_path,
            sep="|",
            header=None,
            quoting=csv.QUOTE_NONE,
            names=["fname", "raw_text", "normalized_text"])

        pbar = tqdm(range(len(table)))
        alignments = OrderedDict()
        for i in pbar:
            fname, raw_text, normalized_text = table.iloc[i]
            # init input
            text = np.asarray(text_to_sequence(normalized_text))
            text = fluid.layers.unsqueeze(dg.to_variable(text), [0])
            pos_text = np.arange(1, text.shape[1] + 1)
            pos_text = fluid.layers.unsqueeze(dg.to_variable(pos_text), [0])

            # load
            wav, _ = librosa.load(
                str(os.path.join(args.data, 'wavs', fname + ".wav")))

            spec = librosa.stft(
                y=wav,
                n_fft=cfg['audio']['n_fft'],
                win_length=cfg['audio']['win_length'],
                hop_length=cfg['audio']['hop_length'])
            mag = np.abs(spec)
            mel = librosa.filters.mel(sr=cfg['audio']['sr'],
                                      n_fft=cfg['audio']['n_fft'],
                                      n_mels=cfg['audio']['num_mels'],
                                      fmin=cfg['audio']['fmin'],
                                      fmax=cfg['audio']['fmax'])
            mel = np.matmul(mel, mag)
            mel = np.log(np.maximum(mel, 1e-5))

            mel_input = np.transpose(mel, axes=(1, 0))
            mel_input = fluid.layers.unsqueeze(dg.to_variable(mel_input), [0])
            mel_lens = mel_input.shape[1]

            pos_mel = np.arange(1, mel_input.shape[1] + 1)
            pos_mel = fluid.layers.unsqueeze(dg.to_variable(pos_mel), [0])
            mel_pred, postnet_pred, attn_probs, stop_preds, attn_enc, attn_dec = model(
                text, mel_input, pos_text, pos_mel)
            mel_input = fluid.layers.concat(
                [mel_input, postnet_pred[:, -1:, :]], axis=1)

            alignment, _ = get_alignment(attn_probs, mel_lens,
                                         network_cfg['decoder_num_head'])
            alignments[fname] = alignment
        with open(args.output + '.pkl', "wb") as f:
            pickle.dump(alignments, f)
Beispiel #6
0
def main(args):
    local_rank = dg.parallel.Env().local_rank
    nranks = dg.parallel.Env().nranks
    parallel = nranks > 1

    with open(args.config) as f:
        cfg = yaml.load(f, Loader=yaml.Loader)

    global_step = 0
    place = fluid.CUDAPlace(local_rank) if args.use_gpu else fluid.CPUPlace()

    if not os.path.exists(args.output):
        os.mkdir(args.output)

    writer = LogWriter(os.path.join(args.output,
                                    'log')) if local_rank == 0 else None

    fluid.enable_dygraph(place)
    network_cfg = cfg['network']
    model = TransformerTTS(
        network_cfg['embedding_size'], network_cfg['hidden_size'],
        network_cfg['encoder_num_head'], network_cfg['encoder_n_layers'],
        cfg['audio']['num_mels'], network_cfg['outputs_per_step'],
        network_cfg['decoder_num_head'], network_cfg['decoder_n_layers'])

    model.train()
    optimizer = fluid.optimizer.AdamOptimizer(
        learning_rate=dg.NoamDecay(1 / (cfg['train']['warm_up_step'] *
                                        (cfg['train']['learning_rate']**2)),
                                   cfg['train']['warm_up_step']),
        parameter_list=model.parameters(),
        grad_clip=fluid.clip.GradientClipByGlobalNorm(cfg['train'][
            'grad_clip_thresh']))

    # Load parameters.
    global_step = io.load_parameters(
        model=model,
        optimizer=optimizer,
        checkpoint_dir=os.path.join(args.output, 'checkpoints'),
        iteration=args.iteration,
        checkpoint_path=args.checkpoint)
    print("Rank {}: checkpoint loaded.".format(local_rank))

    if parallel:
        strategy = dg.parallel.prepare_context()
        model = fluid.dygraph.parallel.DataParallel(model, strategy)

    reader = LJSpeechLoader(
        cfg['audio'],
        place,
        args.data,
        cfg['train']['batch_size'],
        nranks,
        local_rank,
        shuffle=True).reader

    iterator = iter(tqdm(reader))

    global_step += 1

    while global_step <= cfg['train']['max_iteration']:
        try:
            batch = next(iterator)
        except StopIteration as e:
            iterator = iter(tqdm(reader))
            batch = next(iterator)

        character, mel, mel_input, pos_text, pos_mel, stop_tokens = batch

        mel_pred, postnet_pred, attn_probs, stop_preds, attn_enc, attn_dec = model(
            character, mel_input, pos_text, pos_mel)

        mel_loss = layers.mean(
            layers.abs(layers.elementwise_sub(mel_pred, mel)))
        post_mel_loss = layers.mean(
            layers.abs(layers.elementwise_sub(postnet_pred, mel)))
        loss = mel_loss + post_mel_loss

        stop_loss = cross_entropy(
            stop_preds, stop_tokens, weight=cfg['network']['stop_loss_weight'])
        loss = loss + stop_loss

        if local_rank == 0:
            writer.add_scalar('training_loss/mel_loss',
                              mel_loss.numpy(),
                              global_step)
            writer.add_scalar('training_loss/post_mel_loss',
                              post_mel_loss.numpy(),
                              global_step)
            writer.add_scalar('stop_loss', stop_loss.numpy(), global_step)

            if parallel:
                writer.add_scalar('alphas/encoder_alpha',
                                   model._layers.encoder.alpha.numpy(),
                                   global_step)
                writer.add_scalar('alphas/decoder_alpha',
                                   model._layers.decoder.alpha.numpy(),
                                   global_step)
            else:
                writer.add_scalar('alphas/encoder_alpha',
                                   model.encoder.alpha.numpy(),
                                   global_step)
                writer.add_scalar('alphas/decoder_alpha',
                                   model.decoder.alpha.numpy(),
                                   global_step)

            writer.add_scalar('learning_rate',
                              optimizer._learning_rate.step().numpy(),
                              global_step)

            if global_step % cfg['train']['image_interval'] == 1:
                for i, prob in enumerate(attn_probs):
                    for j in range(cfg['network']['decoder_num_head']):
                        x = np.uint8(
                            cm.viridis(prob.numpy()[j * cfg['train'][
                                'batch_size'] // nranks]) * 255)
                        writer.add_image(
                            'Attention_%d_0' % global_step,
                            x,
                            i * 4 + j)

                for i, prob in enumerate(attn_enc):
                    for j in range(cfg['network']['encoder_num_head']):
                        x = np.uint8(
                            cm.viridis(prob.numpy()[j * cfg['train'][
                                'batch_size'] // nranks]) * 255)
                        writer.add_image(
                            'Attention_enc_%d_0' % global_step,
                            x,
                            i * 4 + j)

                for i, prob in enumerate(attn_dec):
                    for j in range(cfg['network']['decoder_num_head']):
                        x = np.uint8(
                            cm.viridis(prob.numpy()[j * cfg['train'][
                                'batch_size'] // nranks]) * 255)
                        writer.add_image(
                            'Attention_dec_%d_0' % global_step,
                            x,
                            i * 4 + j)

        if parallel:
            loss = model.scale_loss(loss)
            loss.backward()
            model.apply_collective_grads()
        else:
            loss.backward()
        optimizer.minimize(loss)
        model.clear_gradients()

        # save checkpoint
        if local_rank == 0 and global_step % cfg['train'][
                'checkpoint_interval'] == 0:
            io.save_parameters(
                os.path.join(args.output, 'checkpoints'), global_step, model,
                optimizer)
        global_step += 1

    if local_rank == 0:
        writer.close()
def main(args):
    local_rank = dg.parallel.Env().local_rank
    nranks = dg.parallel.Env().nranks
    parallel = nranks > 1

    with open(args.config) as f:
        cfg = yaml.load(f, Loader=yaml.Loader)

    global_step = 0
    place = fluid.CUDAPlace(local_rank) if args.use_gpu else fluid.CPUPlace()

    if not os.path.exists(args.output):
        os.mkdir(args.output)

    writer = SummaryWriter(os.path.join(args.output,
                                        'log')) if local_rank == 0 else None

    fluid.enable_dygraph(place)
    network_cfg = cfg['network']
    model = TransformerTTS(
        network_cfg['embedding_size'], network_cfg['hidden_size'],
        network_cfg['encoder_num_head'], network_cfg['encoder_n_layers'],
        cfg['audio']['num_mels'], network_cfg['outputs_per_step'],
        network_cfg['decoder_num_head'], network_cfg['decoder_n_layers'])

    model.train()
    optimizer = fluid.optimizer.AdamOptimizer(
        learning_rate=dg.NoamDecay(
            1 / (cfg['train']['warm_up_step'] *
                 (cfg['train']['learning_rate']**2)),
            cfg['train']['warm_up_step']),
        parameter_list=model.parameters(),
        grad_clip=fluid.clip.GradientClipByGlobalNorm(
            cfg['train']['grad_clip_thresh']))

    # Load parameters.
    global_step = io.load_parameters(model=model,
                                     optimizer=optimizer,
                                     checkpoint_dir=os.path.join(
                                         args.output, 'checkpoints'),
                                     iteration=args.iteration,
                                     checkpoint_path=args.checkpoint)
    print("Rank {}: checkpoint loaded.".format(local_rank))

    if parallel:
        strategy = dg.parallel.prepare_context()
        model = fluid.dygraph.parallel.DataParallel(model, strategy)

    reader = LJSpeechLoader(cfg['audio'],
                            place,
                            args.data,
                            cfg['train']['batch_size'],
                            nranks,
                            local_rank,
                            shuffle=True).reader()

    for epoch in range(cfg['train']['max_epochs']):
        pbar = tqdm(reader)
        for i, data in enumerate(pbar):
            pbar.set_description('Processing at epoch %d' % epoch)
            character, mel, mel_input, pos_text, pos_mel = data

            global_step += 1

            mel_pred, postnet_pred, attn_probs, stop_preds, attn_enc, attn_dec = model(
                character, mel_input, pos_text, pos_mel)

            mel_loss = layers.mean(
                layers.abs(layers.elementwise_sub(mel_pred, mel)))
            post_mel_loss = layers.mean(
                layers.abs(layers.elementwise_sub(postnet_pred, mel)))
            loss = mel_loss + post_mel_loss

            # Note: When used stop token loss the learning did not work.
            if cfg['network']['stop_token']:
                label = (pos_mel == 0).astype(np.float32)
                stop_loss = cross_entropy(stop_preds, label)
                loss = loss + stop_loss

            if local_rank == 0:
                writer.add_scalars(
                    'training_loss', {
                        'mel_loss': mel_loss.numpy(),
                        'post_mel_loss': post_mel_loss.numpy()
                    }, global_step)

                if cfg['network']['stop_token']:
                    writer.add_scalar('stop_loss', stop_loss.numpy(),
                                      global_step)

                if parallel:
                    writer.add_scalars(
                        'alphas', {
                            'encoder_alpha':
                            model._layers.encoder.alpha.numpy(),
                            'decoder_alpha':
                            model._layers.decoder.alpha.numpy(),
                        }, global_step)
                else:
                    writer.add_scalars(
                        'alphas', {
                            'encoder_alpha': model.encoder.alpha.numpy(),
                            'decoder_alpha': model.decoder.alpha.numpy(),
                        }, global_step)

                writer.add_scalar('learning_rate',
                                  optimizer._learning_rate.step().numpy(),
                                  global_step)

                if global_step % cfg['train']['image_interval'] == 1:
                    for i, prob in enumerate(attn_probs):
                        for j in range(cfg['network']['decoder_num_head']):
                            x = np.uint8(
                                cm.viridis(prob.numpy()[
                                    j * cfg['train']['batch_size'] // 2]) *
                                255)
                            writer.add_image('Attention_%d_0' % global_step,
                                             x,
                                             i * 4 + j,
                                             dataformats="HWC")

                    for i, prob in enumerate(attn_enc):
                        for j in range(cfg['network']['encoder_num_head']):
                            x = np.uint8(
                                cm.viridis(prob.numpy()[
                                    j * cfg['train']['batch_size'] // 2]) *
                                255)
                            writer.add_image('Attention_enc_%d_0' %
                                             global_step,
                                             x,
                                             i * 4 + j,
                                             dataformats="HWC")

                    for i, prob in enumerate(attn_dec):
                        for j in range(cfg['network']['decoder_num_head']):
                            x = np.uint8(
                                cm.viridis(prob.numpy()[
                                    j * cfg['train']['batch_size'] // 2]) *
                                255)
                            writer.add_image('Attention_dec_%d_0' %
                                             global_step,
                                             x,
                                             i * 4 + j,
                                             dataformats="HWC")

            if parallel:
                loss = model.scale_loss(loss)
                loss.backward()
                model.apply_collective_grads()
            else:
                loss.backward()
            optimizer.minimize(loss)
            model.clear_gradients()

            # save checkpoint
            if local_rank == 0 and global_step % cfg['train'][
                    'checkpoint_interval'] == 0:
                io.save_parameters(os.path.join(args.output, 'checkpoints'),
                                   global_step, model, optimizer)

    if local_rank == 0:
        writer.close()