f'{n_steps}-steps average loss: {sum(losses[-n_steps:]) / n_steps}',
                pos=pos + 2)

    summary_manager.display_loss(output, tag='Train')
    summary_manager.display_scalar(tag='Meta/decoder_prenet_dropout',
                                   scalar_value=model.decoder_prenet.rate)
    summary_manager.display_scalar(tag='Meta/learning_rate',
                                   scalar_value=model.optimizer.lr)
    summary_manager.display_scalar(tag='Meta/reduction_factor',
                                   scalar_value=model.r)
    summary_manager.display_scalar(tag='Meta/drop_n_heads',
                                   scalar_value=model.drop_n_heads)
    if model.step % config['train_images_plotting_frequency'] == 0:
        summary_manager.display_attention_heads(output,
                                                tag='TrainAttentionHeads')
        summary_manager.display_mel(mel=output['mel_linear'][0],
                                    tag=f'Train/linear_mel_out')
        summary_manager.display_mel(mel=output['final_output'][0],
                                    tag=f'Train/predicted_mel')
        residual = abs(output['mel_linear'] - output['final_output'])
        summary_manager.display_mel(mel=residual[0],
                                    tag=f'Train/conv-linear_residual')
        summary_manager.display_mel(mel=mel[0], tag=f'Train/target_mel')

    if model.step % config['weights_save_frequency'] == 0:
        save_path = manager.save()
        t.display(f'checkpoint at step {model.step}: {save_path}',
                  pos=len(config['n_steps_avg_losses']) + 2)

    if model.step % config['validation_frequency'] == 0:
        val_loss, time_taken = validate(model=model,
                                        val_dataset=val_dataset,
Esempio n. 2
0
        if len(losses) > n_steps:
            t.display(
                f'{n_steps}-steps average loss: {sum(losses[-n_steps:]) / n_steps}',
                pos=pos + 2)

    summary_manager.display_loss(output, tag='Train')
    summary_manager.display_scalar(tag='Meta/learning_rate',
                                   scalar_value=model.optimizer.lr)
    summary_manager.display_scalar(tag='Meta/decoder_prenet_dropout',
                                   scalar_value=model.decoder_prenet.rate)
    summary_manager.display_scalar(tag='Meta/drop_n_heads',
                                   scalar_value=model.drop_n_heads)
    if model.step % config['train_images_plotting_frequency'] == 0:
        summary_manager.display_attention_heads(output,
                                                tag='TrainAttentionHeads')
        summary_manager.display_mel(mel=output['mel'][0],
                                    tag=f'Train/predicted_mel')
        summary_manager.display_mel(mel=mel[0], tag=f'Train/target_mel')
        summary_manager.add_histogram(tag=f'Train/Predicted durations',
                                      values=output['duration'])
        summary_manager.add_histogram(tag=f'Train/Target durations',
                                      values=durations)

    if model.step % config['weights_save_frequency'] == 0:
        save_path = manager.save()
        t.display(f'checkpoint at step {model.step}: {save_path}',
                  pos=len(config['n_steps_avg_losses']) + 2)

    if model.step % config['validation_frequency'] == 0:
        t.display(f'Validating', pos=len(config['n_steps_avg_losses']) + 3)
        val_loss, time_taken = validate(model=model,
                                        val_dataset=val_dataset,