Example #1
0
 def log_validation(self, reduced_loss_dict, reduced_bestval_loss_dict, model, y, y_pred, iteration, val_teacher_force_till, val_p_teacher_forcing):
     prepend = 'validation'
     
     # plot datapoints/graphs
     self.plot_loss_dict(reduced_loss_dict,         iteration, f'{prepend}')
     self.plot_loss_dict(reduced_bestval_loss_dict, iteration, f'{prepend}_best')
     
     # plot spects / imgs
     n_items = min(self.n_items, y['gt_mel'].shape[0])
     
     mel_L1_map = torch.nn.L1Loss(reduction='none')(y_pred['pred_mel_postnet'], y['gt_mel'])
     mel_L1_map[:, -1, -1] = 5.0 # because otherwise the color map scale is crap
     
     for idx in range(n_items):# plot target spectrogram of longest audio file(s)
         self.add_image(
             f"{prepend}_{idx}/alignment",
             plot_alignment_to_numpy(y_pred['alignments'][idx].data.cpu().numpy().T),
             iteration, dataformats='HWC')
         self.add_image(
             f"{prepend}_{idx}/mel_pred",
             plot_spectrogram_to_numpy(y_pred['pred_mel_postnet'][idx].data.cpu().numpy()),
             iteration, dataformats='HWC')
         self.add_image(
             f"{prepend}_{idx}/mel_SE",
             plot_spectrogram_to_numpy(mel_L1_map[idx].data.cpu().numpy()),
             iteration, dataformats='HWC')
         if self.plotted_targets_val < 2:
             self.add_image(
                 f"{prepend}_{idx}/mel_gt",
                 plot_spectrogram_to_numpy(y['gt_mel'][idx].data.cpu().numpy()),
                 iteration, dataformats='HWC')
     self.plotted_targets_val +=1 # target spect doesn't change so only needs to be plotted once.
Example #2
0
 def log_infer(self, reduced_loss_dict, reduced_bestval_loss_dict, model, y, y_pred, iteration, val_teacher_force_till, val_p_teacher_forcing):
     prepend = 'inference'
     
     # plot datapoints/graphs
     self.plot_loss_dict(reduced_loss_dict,         iteration, f'{prepend}')
     self.plot_loss_dict(reduced_bestval_loss_dict, iteration, f'{prepend}_best')
     
     # plot spects / imgs
     n_items = min(self.n_items, y['gt_mel'].shape[0])
     
     for idx in range(n_items):# plot target spectrogram of longest audio file(s)
         self.add_image(
             f"{prepend}_{idx}/alignment",
             plot_alignment_to_numpy(y_pred['alignments'][idx].data.cpu().numpy().T),
             iteration, dataformats='HWC')
         self.add_image(
             f"{prepend}_{idx}/mel_pred",
             plot_spectrogram_to_numpy(y_pred['pred_mel_postnet'][idx].data.cpu().numpy()),
             iteration, dataformats='HWC')
         if self.plotted_targets_inf < 10:
             self.add_image(
                 f"{prepend}_{idx}/mel_gt",
                 plot_spectrogram_to_numpy(y['gt_mel'][idx].data.cpu().numpy()),
                 iteration, dataformats='HWC')
     self.plotted_targets_inf +=1 # target spect doesn't change so only needs to be plotted ~~once~~ a couple times.
    def log_validation(self, reduced_loss, model, y, y_pred, iteration):
        # self.add_scalar("validation.loss", reduced_loss, iteration)
        _, mel_outputs, gate_outputs, alignments = y_pred
        mel_targets, gate_targets = y

        # plot distribution of parameters
        # for tag, value in model.named_parameters():
        #     tag = tag.replace('.', '/')
        #     self.add_histogram(tag, value.data.cpu().numpy(), iteration)

        # plot alignment, mel target and predicted, gate target and predicted
        idx = random.randint(0, alignments.size(0) - 1)

        align = plot_alignment_to_numpy(alignments[idx].data.cpu().numpy().T)
        spec = plot_spectrogram_to_numpy(mel_targets[idx].data.cpu().numpy())
        mel = plot_spectrogram_to_numpy(mel_outputs[idx].data.cpu().numpy())
        gate = plot_gate_outputs_to_numpy(
            gate_targets[idx].data.cpu().numpy(),
            torch.sigmoid(gate_outputs[idx]).data.cpu().numpy())

        wandb = self.wandb
        wandb.log({
            "validation loss": reduced_loss,
            "alignment": wandb.Image(align),
            "spectrogram": wandb.Image(spec),
            "mel_spec": wandb.Image(mel),
            "gate": wandb.Image(gate),
        })
Example #4
0
    def log_validation(self, reduced_loss, model, y, y_pred, gst_scores,
                       iteration):
        self.add_scalar("validation.loss", reduced_loss, iteration)
        _, mel_outputs, gate_outputs, alignments, _ = y_pred
        mel_targets, gate_targets = y

        # plot distribution of parameters
        for tag, value in model.named_parameters():
            tag = tag.replace('.', '/')
            self.add_histogram(tag, value.data.cpu().numpy(), iteration)

        # plot alignment, mel target and predicted, gate target and predicted
        idx = random.randint(0, alignments.size(0) - 1)

        align_idx = alignments[idx].data.cpu().numpy().T
        gst_scores = gst_scores.data.cpu().numpy().T
        # print("Validation GST scores before plotting to tensorboard: {}".format(gst_scores.shape))
        meltarg_idx = mel_targets[idx].data.cpu().numpy()
        melout_idx = mel_outputs[idx].data.cpu().numpy()

        self.add_image("alignment", plot_alignment_to_numpy(align_idx),
                       iteration)
        self.add_image("gst_scores", plot_gst_scores_to_numpy(gst_scores),
                       iteration)
        self.add_image("mel_target", plot_spectrogram_to_numpy(meltarg_idx),
                       iteration)
        self.add_image("mel_predicted", plot_spectrogram_to_numpy(melout_idx),
                       iteration)
        self.add_image(
            "gate",
            plot_gate_outputs_to_numpy(
                gate_targets[idx].data.cpu().numpy(),
                F.sigmoid(gate_outputs[idx]).data.cpu().numpy()), iteration)
Example #5
0
    def log_validation(self, reduced_loss, reduced_dur_loss, model, y, y_pred,
                       iteration):
        self.add_scalar("forwardtaco/validation.loss", reduced_loss, iteration)
        self.add_scalar("forwardtaco/validation.dur_loss", reduced_dur_loss,
                        iteration)
        #_, _, mel_outputs, gate_outputs, alignments = y_pred
        m1, m2, dur_hat = y_pred
        mel_targets, gate_targets = y

        # plot distribution of parameters
        for tag, value in model.named_parameters():
            tag = tag.replace('.', '/')
            self.add_histogram(tag, value.data.cpu().numpy(), iteration)

        # plot alignment, mel target and predicted, gate target and predicted
        #idx = random.randint(0, alignments.size(0) - 1)
        idx = random.randint(0, mel_targets.size(0) - 1)
        #self.add_image(
        #    "alignment",
        #    plot_alignment_to_numpy(alignments[idx].data.cpu().numpy().T),
        #    iteration, dataformats='HWC')
        self.add_image("forwardtaco/mel_target",
                       plot_spectrogram_to_numpy(
                           mel_targets[idx].data.cpu().numpy()),
                       iteration,
                       dataformats='HWC')
        self.add_image("forwardtaco/mel_predicted_pre",
                       plot_spectrogram_to_numpy(m1[idx].data.cpu().numpy()),
                       iteration,
                       dataformats='HWC')
        self.add_image("forwardtaco/mel_predicted_post",
                       plot_spectrogram_to_numpy(m2[idx].data.cpu().numpy()),
                       iteration,
                       dataformats='HWC')
Example #6
0
    def log_alignment(self, model, enc_slf_attn, dec_enc_attn, out_mel, target,
                      iteration):

        # plot distribution of parameters
        for tag, value in model.named_parameters():
            tag = tag.replace('.', '/')
            self.add_histogram(tag, value.data.cpu().numpy(), iteration)

        # plot alignment, mel target and predicted, gate target and predicted
        #idx = random.randint(0, enc_slf_attn.size(0) - 1)
        idx = 0
        self.add_image(
            "encoder_self_alignment",
            plot_alignment_to_numpy(enc_slf_attn[idx].data.cpu().numpy().T),
            iteration)
        self.add_image(
            "decoder_encoder_alignment",
            plot_alignment_to_numpy(dec_enc_attn[idx].data.cpu().numpy().T),
            iteration)
        self.add_image(
            "mel_target",
            plot_spectrogram_to_numpy(target[idx].data.cpu().numpy()),
            iteration)
        self.add_image(
            "mel_predicted",
            plot_spectrogram_to_numpy(out_mel[idx].data.cpu().numpy()),
            iteration)
Example #7
0
 def summary(self, outputs, epoch):
     self.logger.image_summary("mel-spectrogram",
                               plot_spectrogram_to_numpy("ground-truth", outputs['GT'][0].numpy().T),
                               epoch)
     self.logger.image_summary("mel-spectrogram",
                               plot_spectrogram_to_numpy("reconstruction", outputs['recon'][0].numpy().T),
                               epoch)
Example #8
0
    def log_validation(self, reduced_loss, model, y, y_pred, iteration):
        self.add_scalar("validation.loss", reduced_loss, iteration)
        _, mel_outputs, gate_outputs, alignments = y_pred
        mel_targets, gate_targets, alignment_targets = y

        # plot distribution of parameters
        for tag, value in model.named_parameters():
            tag = tag.replace('.', '/')
            self.add_histogram(tag, value.data.cpu().numpy(), iteration)

        # plot alignment, mel target and predicted, gate target and predicted
        idx = random.randint(0, alignments.size(0) - 1)
        self.add_image(
            "alignment",
            plot_alignment_to_numpy(alignments[idx].data.cpu().numpy().T),
            iteration, dataformats='HWC')
        self.add_image(
            "mel_target",
            plot_spectrogram_to_numpy(mel_targets[idx].data.cpu().numpy()),
            iteration, dataformats='HWC')
        self.add_image(
            "mel_predicted",
            plot_spectrogram_to_numpy(mel_outputs[idx].data.cpu().numpy()),
            iteration, dataformats='HWC')
        self.add_image(
            "gate",
            plot_gate_outputs_to_numpy(
                gate_targets[idx].data.cpu().numpy(),
                torch.sigmoid(gate_outputs[idx]).data.cpu().numpy()),
            iteration, dataformats='HWC')
Example #9
0
    def log_validation(self, reduced_loss_main, reduced_loss_join,
                       reduced_loss_class, reduced_loss, model, y, y_pred,
                       iteration):
        self.add_scalar("validation.loss_mian", reduced_loss_main, iteration)
        self.add_scalar("validation.loss_join", reduced_loss_join, iteration)
        self.add_scalar("validation.loss_class", reduced_loss_class, iteration)

        self.add_scalar("validation.loss", reduced_loss, iteration)
        _, mel_outputs, alignment_outputs, acoustics_of_phone, join_outs, text_alignment = y_pred
        mel_targets, alignment_targets, alignments_weights, text_alignment_padded = y

        # plot distribution of parameters
        for tag, value in model.named_parameters():
            tag = tag.replace('.', '/')
            self.add_histogram(tag, value.data.cpu().numpy(), iteration)

        # plot alignment, mel target and predicted, gate target and predicted
        idx = random.randint(0, mel_targets.size(0) - 1)
        self.add_image(
            "alignment_target",
            plot_alignment_to_numpy(
                alignment_targets[idx].data.cpu().numpy().T), iteration)
        self.add_image(
            "alignment_output",
            plot_alignment_to_numpy(
                alignment_outputs[idx].data.cpu().numpy().T), iteration)
        self.add_image(
            "mel_target",
            plot_spectrogram_to_numpy(mel_targets[idx].data.cpu().numpy()),
            iteration)
        self.add_image(
            "mel_predicted",
            plot_spectrogram_to_numpy(mel_outputs[idx].data.cpu().numpy()),
            iteration)
        self.add_image(
            "alignments_weights",
            plot_weight_outputs_to_numpy(
                alignments_weights[idx].data.cpu().numpy()), iteration)
        self.add_image(
            "acoustic_of_phone_reference",
            plot_alignment_to_numpy(
                acoustics_of_phone[idx].data.cpu().numpy().T,
                x_label='Encoder timestep',
                y_label='Phoneme Acoustic'), iteration)
        self.add_image(
            "acoustic_of_phone_predicted",
            plot_alignment_to_numpy(join_outs[idx].data.cpu().numpy().T,
                                    x_label='Encoder timestep',
                                    y_label='Phoneme Acoustic'), iteration)
        self.add_image(
            "phone_level_acoustic_text_alignment_output",
            plot_alignment_to_numpy(text_alignment[idx].data.cpu().numpy().T,
                                    figsize=(8, 6),
                                    x_label='Encoder timestep',
                                    y_label='Encoder timestep'), iteration)
Example #10
0
    def log_teacher_forced_validation(self, reduced_loss, model, y, y_pred, iteration, val_teacher_force_till, val_p_teacher_forcing, diagonality, avg_prob):
        self.add_scalar("teacher_forced_validation.loss", reduced_loss, iteration)
        self.add_scalar("teacher_forced_validation.attention_alignment_diagonality", diagonality, iteration)
        self.add_scalar("teacher_forced_validation.average_max_attention_weight", avg_prob, iteration)
        _, mel_outputs, gate_outputs, alignments = y_pred
        mel_targets, gate_targets, *_ = y

        # plot distribution of parameters
        for tag, value in model.named_parameters():
            tag = tag.replace('.', '/')
            self.add_histogram(tag, value.data.cpu().numpy(), iteration)
        
        # plot alignment, mel target and predicted, gate target and predicted
        idx = 0 # plot longest audio file
        self.add_image(
            "teacher_forced_alignment",
            plot_alignment_to_numpy(alignments[idx].data.cpu().numpy().T),
            iteration, dataformats='HWC')
        self.add_image(
            "mel_target",
            plot_spectrogram_to_numpy(mel_targets[idx].data.cpu().numpy()),
            iteration, dataformats='HWC')
        self.add_image(
            "mel_predicted",
            plot_spectrogram_to_numpy(mel_outputs[idx].data.cpu().numpy()),
            iteration, dataformats='HWC')
        self.add_image(
            "gate",
            plot_gate_outputs_to_numpy(
                gate_targets[idx].data.cpu().numpy(),
                torch.sigmoid(gate_outputs[idx]).data.cpu().numpy()),
            iteration, dataformats='HWC')
        
        idx = 1 # and plot 2nd longest audio file
        self.add_image(
            "teacher_forced_alignment2",
            plot_alignment_to_numpy(alignments[idx].data.cpu().numpy().T),
            iteration, dataformats='HWC')
        self.add_image(
            "mel_target2",
            plot_spectrogram_to_numpy(mel_targets[idx].data.cpu().numpy()),
            iteration, dataformats='HWC')
        self.add_image(
            "mel_predicted2",
            plot_spectrogram_to_numpy(mel_outputs[idx].data.cpu().numpy()),
            iteration, dataformats='HWC')
        self.add_image(
            "gate2",
            plot_gate_outputs_to_numpy(
                gate_targets[idx].data.cpu().numpy(),
                torch.sigmoid(gate_outputs[idx]).data.cpu().numpy()),
            iteration, dataformats='HWC')
Example #11
0
    def log_validation(self, reduced_loss, model, y, y_pred, iteration):
        log_dict = {
            "loss/val": reduced_loss,
        }

        _, mel_outputs, gate_outputs, alignments = y_pred
        mel_targets, gate_targets = y

        # plot alignment, mel target and predicted, gate target and predicted
        idx = random.randint(0, alignments.size(0) - 1)

        align = Image.fromarray(
            plot_alignment_to_numpy(alignments[idx].data.cpu().numpy().T))
        align.save(os.path.join(self.outdir, f'align_{iteration:08}.png'))

        target = Image.fromarray(
            plot_spectrogram_to_numpy(mel_targets[idx].data.cpu().numpy()))
        target.save(os.path.join(self.outdir, f'target_{iteration:08}.png'))

        output = Image.fromarray(
            plot_spectrogram_to_numpy(mel_outputs[idx].data.cpu().numpy()))
        output.save(os.path.join(self.outdir, f'output_{iteration:08}.png'))

        gate = Image.fromarray(
            plot_gate_outputs_to_numpy(
                gate_targets[idx].data.cpu().numpy(),
                torch.sigmoid(gate_outputs[idx]).data.cpu().numpy()))
        gate.save(os.path.join(self.outdir, f'gate_{iteration:08}.png'))

        log_dict.update({
            "alignment":
            wandb.Image(Image.fromarray(
                plot_alignment_to_numpy(alignments[idx].data.cpu().numpy().T)),
                        caption='att'),
            "mel_target":
            wandb.Image(Image.fromarray(
                plot_spectrogram_to_numpy(
                    mel_targets[idx].data.cpu().numpy())),
                        caption='att'),
            "mel_predicted":
            wandb.Image(Image.fromarray(
                plot_spectrogram_to_numpy(
                    mel_outputs[idx].data.cpu().numpy())),
                        caption='att'),
            "gate":
            wandb.Image(Image.fromarray(
                plot_gate_outputs_to_numpy(
                    gate_targets[idx].data.cpu().numpy(),
                    torch.sigmoid(gate_outputs[idx]).data.cpu().numpy())),
                        caption='att'),
        })
        wandb.log(log_dict, step=iteration)
Example #12
0
    def log_validation(self,
                       reduced_loss,
                       model,
                       y,
                       y_pred,
                       iteration,
                       model_name="",
                       log_embedding=False):
        self.add_scalar("validation.loss", reduced_loss, iteration)
        _, mel_outputs, gate_outputs, alignments = y_pred
        mel_targets, gate_targets = y

        # plot distribution of parameters
        for tag, value in model.named_parameters():
            tag = tag.replace('.', '/')
            self.add_histogram(tag, value.data.cpu().numpy(), iteration)

        # plot alignment, mel target and predicted, gate target and predicted
        idx = random.randint(0, alignments.size(0) - 1)
        self.add_image(
            "alignment",
            np.moveaxis(
                plot_alignment_to_numpy(alignments[idx].data.cpu().numpy().T),
                2, 0), iteration)
        self.add_image(
            "mel_target",
            np.moveaxis(
                plot_spectrogram_to_numpy(mel_targets[idx].data.cpu().numpy()),
                2, 0), iteration)
        self.add_image(
            "mel_predicted",
            np.moveaxis(
                plot_spectrogram_to_numpy(mel_outputs[idx].data.cpu().numpy()),
                2, 0), iteration)
        self.add_image(
            "gate",
            np.moveaxis(
                plot_gate_outputs_to_numpy(
                    gate_targets[idx].data.cpu().numpy(),
                    torch.sigmoid(gate_outputs[idx]).data.cpu().numpy()), 2,
                0), iteration)
        if log_embedding:
            self.add_embedding(
                model.speaker_embedding.weight.detach().cpu().numpy(), [
                    str(i)
                    for i in range(model.speaker_embedding.num_embeddings)
                ],
                global_step=iteration,
                tag='emb_{}'.format(model_name))
Example #13
0
def generate_mels_by_sytle_tokens(model,
                                  waveglow,
                                  hparams,
                                  sequence,
                                  denoiser,
                                  denoiser_strength=0.01,
                                  device=torch.device('cpu')):
    outputs_by_tokens = model.inference_by_style_tokens(sequence)

    for i, (mel_outputs, mel_outputs_postnet, _,
            alignments) in enumerate(outputs_by_tokens):
        # Plot results
        plot_data("mel_{}.png".format(i),
                  plot_spectrogram_to_numpy(mel_outputs.data.cpu().numpy()[0]))

        # Synthesize audio from spectrogram using WaveGlow
        with torch.no_grad():
            audio = waveglow.infer(mel_outputs_postnet,
                                   sigma=0.666,
                                   device=device)
        write("output_{}.wav".format(i), hparams.sampling_rate,
              audio[0].data.cpu().numpy())

        # (Optional) Remove WaveGlow bias
        if denoiser_strength > 0:
            audio_denoised = denoiser(audio, strength=denoiser_strength)[:, 0]
            audio_denoised = audio_denoised * hparams.max_wav_value
            write("denoised_output_{}.wav".format(i), hparams.sampling_rate,
                  audio_denoised.squeeze().cpu().numpy().astype('int16'))
Example #14
0
    def log_validation(self, reduced_loss_dict, reduced_bestval_loss_dict,
                       model, y, y_pred, iteration):
        # plot distribution of parameters
        if iteration % 5000 == 0:
            for tag, value in model.named_parameters():
                tag = tag.replace('.', '/')
                self.add_histogram(tag, value.data.cpu().numpy(), iteration)

        # plot datapoints/graphs
        for loss_name, reduced_loss in reduced_loss_dict.items():
            self.add_scalar(f"validation/{loss_name}", reduced_loss, iteration)

        for loss_name, reduced_loss in reduced_bestval_loss_dict.items():
            self.add_scalar(f"validation_best/{loss_name}", reduced_loss,
                            iteration)

        # pickup predicted model outputs
        melglow_package, durglow_package, varglow_package, *_ = y_pred
        mel_targets, *_ = y

        # plot spects / imgs
        n_items = min(self.n_items, mel_targets.shape[0])

        if not self.plotted_targets:
            for idx in range(
                    n_items
            ):  # plot target spectrogram of longest audio file(s)
                self.add_image(f"mel_target/{idx}",
                               plot_spectrogram_to_numpy(
                                   mel_targets[idx].data.cpu().numpy()),
                               iteration,
                               dataformats='HWC')
            self.plotted_targets = True  # target spect doesn't change so only needs to be plotted once.
Example #15
0
    def log_validation(self, mel_real, mel_real_noisy, mel_fake, iteration):
        mel_real = mel_real.data[:32].cpu().numpy()
        mel_real_noisy = mel_real_noisy[:32].data.cpu().numpy()
        mel_fake = mel_fake.data[:32].cpu().numpy()

        mel_real = reshape_to_matrix(mel_real, 4, 8)
        mel_real_noisy = reshape_to_matrix(mel_real_noisy, 4, 8)
        mel_fake = reshape_to_matrix(mel_fake, 4, 8)

        self.add_image("mel_real", plot_spectrogram_to_numpy(mel_real, 6, 5),
                       iteration)
        self.add_image("mel_real+noise",
                       plot_spectrogram_to_numpy(mel_real_noisy, 6, 5),
                       iteration)
        self.add_image("mel_fake", plot_spectrogram_to_numpy(mel_fake, 6, 5),
                       iteration)
Example #16
0
 def log_infer(self, reduced_loss, model, y, y_pred, iteration, val_teacher_force_till, val_p_teacher_forcing, diagonality, avg_prob):
     self.add_scalar("infer.loss", reduced_loss, iteration)
     self.add_scalar("infer.attention_alignment_diagonality", diagonality, iteration)
     self.add_scalar("infer.average_max_attention_weight", avg_prob, iteration)
     _, mel_outputs, gate_outputs, alignments = y_pred
     mel_targets, gate_targets, *_ = y
     
     # plot alignment, mel target and predicted, gate target and predicted
     idx = 0 # plot longest audio file
     self.add_image(
         "infer_alignment",
         plot_alignment_to_numpy(alignments[idx].data.cpu().numpy().T),
         iteration, dataformats='HWC')
     self.add_image(
         "infer_mel_target",
         plot_spectrogram_to_numpy(mel_targets[idx].data.cpu().numpy()),
         iteration, dataformats='HWC')
     self.add_image(
         "infer_mel_predicted",
         plot_spectrogram_to_numpy(mel_outputs[idx].data.cpu().numpy()),
         iteration, dataformats='HWC')
     self.add_image(
         "infer_gate",
         plot_gate_outputs_to_numpy(
             gate_targets[idx].data.cpu().numpy(),
             torch.sigmoid(gate_outputs[idx]).data.cpu().numpy()),
         iteration, dataformats='HWC')
     
     idx = 1 # and plot 2nd longest audio file
     self.add_image(
         "infer_alignment2",
         plot_alignment_to_numpy(alignments[idx].data.cpu().numpy().T),
         iteration, dataformats='HWC')
     self.add_image(
         "infer_mel_target2",
         plot_spectrogram_to_numpy(mel_targets[idx].data.cpu().numpy()),
         iteration, dataformats='HWC')
     self.add_image(
         "infer_mel_predicted2",
         plot_spectrogram_to_numpy(mel_outputs[idx].data.cpu().numpy()),
         iteration, dataformats='HWC')
     self.add_image(
         "infer_gate2",
         plot_gate_outputs_to_numpy(
             gate_targets[idx].data.cpu().numpy(),
             torch.sigmoid(gate_outputs[idx]).data.cpu().numpy()),
         iteration, dataformats='HWC')
Example #17
0
    def log_infer(self, reduced_loss, model, y, y_pred, iteration,
                  val_teacher_force_till, val_p_teacher_forcing, diagonality,
                  avg_prob):
        self.add_scalar("infer.loss", reduced_loss, iteration)
        self.add_scalar("infer.attention_alignment_diagonality", diagonality,
                        iteration)
        self.add_scalar("infer.average_max_attention_weight", avg_prob,
                        iteration)
        mel_outputs, mel_outputs_postnet, gate_outputs, alignments, *_ = y_pred
        if mel_outputs_postnet is not None:
            mel_outputs = mel_outputs_postnet
        mel_outputs_GAN = y_pred[8][0]
        mel_targets, gate_targets, *_ = y
        mel_outputs = mel_outputs[:, :mel_targets.shape[1], :]

        plot_n_files = 5
        # plot infer alignment, mel target and predicted, gate predicted
        for idx in range(plot_n_files):  # plot longest x audio files
            str_idx = '' if idx == 0 else idx
            self.add_image(f"infer_alignment{str_idx}",
                           plot_alignment_to_numpy(
                               alignments[idx].data.cpu().numpy().T),
                           iteration,
                           dataformats='HWC')
            self.add_image(f"infer_mel_target{str_idx}",
                           plot_spectrogram_to_numpy(
                               mel_targets[idx].data.cpu().numpy()),
                           iteration,
                           dataformats='HWC')
            self.add_image(f"infer_mel_predicted{str_idx}",
                           plot_spectrogram_to_numpy(
                               mel_outputs[idx].data.cpu().numpy()),
                           iteration,
                           dataformats='HWC')
            if mel_outputs_GAN is not None:
                self.add_image(f"mel_predicted_GAN{str_idx}",
                               plot_spectrogram_to_numpy(
                                   mel_outputs_GAN[idx].data.cpu().numpy()),
                               iteration,
                               dataformats='HWC')
            self.add_image(f"infer_gate{str_idx}",
                           plot_gate_outputs_to_numpy(
                               gate_targets[idx].data.cpu().numpy(),
                               torch.sigmoid(
                                   gate_outputs[idx]).data.cpu().numpy()),
                           iteration,
                           dataformats='HWC')
Example #18
0
    def log_validation(self,
                       reduced_loss,
                       model,
                       y,
                       y_pred,
                       iteration,
                       speaker_acc=0,
                       augment_acc=0):
        self.add_scalar("validation.loss", reduced_loss, iteration)
        self.add_scalar("Speaker_classifier_ACC", speaker_acc, iteration)
        self.add_scalar("Augment_classifier_ACC", augment_acc, iteration)
        _, mel_outputs, gate_outputs, alignments, speaker_output, augmentation_output, _, _ = y_pred
        mel_targets, gate_targets, speaker_id, labels = y

        # plot distribution of parameters
        for tag, value in model.named_parameters():
            tag = tag.replace('.', '/')
            self.add_histogram(tag, value.data.cpu().numpy(), iteration)

        # plot alignment, mel target and predicted, gate target and predicted
        idx = random.randint(0, alignments.size(0) - 1)
        self.add_image(
            "alignment",
            torch.from_numpy(
                plot_alignment_to_numpy(
                    alignments[idx].data.cpu().numpy().T)).permute(2, 0, 1),
            iteration)
        self.add_image(
            "mel_target",
            torch.from_numpy(
                plot_spectrogram_to_numpy(
                    mel_targets[idx].data.cpu().numpy())).permute(2, 0, 1),
            iteration)
        self.add_image(
            "mel_predicted",
            torch.from_numpy(
                plot_spectrogram_to_numpy(
                    mel_outputs[idx].data.cpu().numpy())).permute(2, 0, 1),
            iteration)
        self.add_image(
            "gate",
            torch.from_numpy(
                plot_gate_outputs_to_numpy(
                    gate_targets[idx].data.cpu().numpy(),
                    F.sigmoid(gate_outputs[idx]).data.cpu().numpy())).permute(
                        2, 0, 1), iteration)
    def log_validation(self, reduced_loss, model, y, y_pred, iteration):
        self.add_scalar("validation.loss", reduced_loss, iteration)
        _, mel_outputs, gate_outputs, alignments = y_pred
        mel_targets, gate_targets = y

        # plot distribution of parameters
        for tag, value in model.named_parameters():
            tag = tag.replace('.', '/')
            self.add_histogram(tag, value.data.cpu().numpy(), iteration)

        # plot alignment, mel target and predicted, gate target and predicted
        idx = random.randint(0, alignments.size(0) - 1)
        self.add_image("alignment",
                       plot_alignment_to_numpy(
                           alignments[idx].data.cpu().numpy().T),
                       iteration,
                       dataformats='HWC')
        self.add_image("mel_target",
                       plot_spectrogram_to_numpy(
                           mel_targets[idx].data.cpu().numpy()),
                       iteration,
                       dataformats='HWC')
        self.add_image("mel_predicted",
                       plot_spectrogram_to_numpy(
                           mel_outputs[idx].data.cpu().numpy()),
                       iteration,
                       dataformats='HWC')
        self.add_image("gate",
                       plot_gate_outputs_to_numpy(
                           gate_targets[idx].data.cpu().numpy(),
                           torch.sigmoid(
                               gate_outputs[idx]).data.cpu().numpy()),
                       iteration,
                       dataformats='HWC')

        mel = mel_outputs.cpu()[0]
        if len(mel.shape) == 2:
            mel = mel.unsqueeze(0)
        audio = self.melgan.inference(mel)
        self.add_audio('audio',
                       audio,
                       global_step=iteration,
                       sample_rate=self.sampling_rate,
                       walltime=None)
Example #20
0
    def log_validation(self, reduced_loss, model, y, y_pred, iteration):
        self.add_scalar("validation.loss", reduced_loss, iteration)
        wandb.log({'validation.loss': reduced_loss}, step=iteration)
        _, mel_outputs, gate_outputs, alignments = y_pred
        mel_targets, gate_targets = y

        # plot distribution of parameters
        for tag, value in model.named_parameters():
            tag = tag.replace('.', '/')
            self.add_histogram(tag, value.data.cpu().numpy(), iteration)

        # plot alignment, mel target and predicted, gate target and predicted
        idx = random.randint(0, alignments.size(0) - 1)

        alignment_arr = plot_alignment_to_numpy(
            alignments[idx].data.cpu().numpy().T)
        self.add_image("alignment", alignment_arr, iteration)
        wandb.log(
            {"alignment": [wandb.Image(alignment_arr, caption="Alignment")]},
            step=iteration)

        mel_target = plot_spectrogram_to_numpy(
            mel_targets[idx].data.cpu().numpy())
        self.add_image("mel_target", mel_target, iteration)
        wandb.log(
            {"mel_target": [wandb.Image(mel_target, caption="Mel target")]},
            step=iteration)

        mel_predicted = plot_spectrogram_to_numpy(
            mel_outputs[idx].data.cpu().numpy())
        self.add_image("mel_predicted", mel_predicted, iteration)
        wandb.log(
            {
                "mel_predicted":
                [wandb.Image(mel_predicted, caption="Mel predicted")]
            },
            step=iteration)

        self.add_image(
            "gate",
            plot_gate_outputs_to_numpy(
                gate_targets[idx].data.cpu().numpy(),
                torch.sigmoid(gate_outputs[idx]).data.cpu().numpy()),
            iteration)
Example #21
0
    def log_validation(self, reduced_loss, model, x, y, y_pred, iteration,
                       hparams):
        self.add_scalar("validation.loss", reduced_loss, iteration)
        _, mel_outputs, gate_outputs, alignments = y_pred
        mel_targets, gate_targets = y

        # plot distribution of parameters
        for tag, value in model.named_parameters():
            tag = tag.replace('.', '/')
            self.add_histogram(tag, value.data.cpu().numpy(), iteration)

        # plot alignment, mel target and predicted, gate target and predicted
        idx = random.randint(0, alignments.size(0) - 1)
        self.add_image(
            "alignment",
            plot_alignment_to_numpy(alignments[idx].data.cpu().numpy().T),
            iteration)
        self.add_image(
            "mel_target",
            plot_spectrogram_to_numpy(mel_targets[idx].data.cpu().numpy()),
            iteration)
        self.add_image(
            "mel_predicted",
            plot_spectrogram_to_numpy(mel_outputs[idx].data.cpu().numpy()),
            iteration)
        self.add_image(
            "gate",
            plot_gate_outputs_to_numpy(
                gate_targets[idx].data.cpu().numpy(),
                F.sigmoid(gate_outputs[idx]).data.cpu().numpy()), iteration)
        self.add_audio(
            "audio_from_target",
            synthesis_griffin_lim(mel_targets[idx].unsqueeze(0), hparams),
            iteration, hparams.sampling_rate)
        self.add_audio(
            "audio_from_predicted",
            synthesis_griffin_lim(mel_outputs[idx].unsqueeze(0), hparams),
            iteration, hparams.sampling_rate)
        self.add_text(
            "text", ''.join([
                _id_to_symbol[symbol_id]
                for symbol_id in x[0][idx].data.cpu().numpy()
            ]), iteration)
Example #22
0
    def log_validation(self, reduced_loss, model, y, y_pred, iteration):
        self.add_scalar("validation.loss", reduced_loss, iteration)
        if self.use_vae:
            _, mel_outputs, gate_outputs, alignments, mus, _, _, emotions = y_pred
        else:
            _, mel_outputs, gate_outputs, alignments = y_pred
        mel_targets, gate_targets = y
        #print('emotion:\n{}'.format(emotions))

        # plot distribution of parameters
        for tag, value in model.named_parameters():
            tag = tag.replace('.', '/')
            self.add_histogram(tag, value.data.cpu().numpy(), iteration)

        # plot alignment, mel target and predicted, gate target and predicted
        idx = random.randint(0, alignments.size(0) - 1)
        self.add_image(
            "alignment",
            plot_alignment_to_numpy(alignments[idx].data.cpu().numpy().T),
            iteration, dataformats=self.dataformat)
        self.add_image(
            "mel_target",
            plot_spectrogram_to_numpy(mel_targets[idx].data.cpu().numpy()),
            iteration, dataformats=self.dataformat)
        self.add_image(
            "mel_predicted",
            plot_spectrogram_to_numpy(mel_outputs[idx].data.cpu().numpy()),
            iteration, dataformats=self.dataformat)
        self.add_image(
            "gate",
            plot_gate_outputs_to_numpy(
                gate_targets[idx].data.cpu().numpy(),
                torch.sigmoid(gate_outputs[idx]).data.cpu().numpy()),
            iteration, dataformats=self.dataformat)
        if self.use_vae:
            self.add_image(
                "latent_dim (regular)",
                plot_scatter(mus, emotions),
                iteration, dataformats=self.dataformat)
            self.add_image(
                "latent_dim (t-sne)",
                plot_tsne(mus, emotions),
                iteration, dataformats=self.dataformat)
Example #23
0
 def log_validation(self, reduced_loss_dict, reduced_bestval_loss_dict, model, y, y_pred, iteration,):
     prepend = 'validation'
     
     # plot datapoints/graphs
     self.plot_loss_dict(reduced_loss_dict, iteration, f'{prepend}')
     
     # plot spects / imgs
     n_items = min(self.n_items, y['gt_mel'].shape[0])
     
     for idx in range(n_items):# plot target spectrogram of longest audio file(s)
         self.add_image(
             f"{prepend}_{idx}/mel_pred",
             plot_spectrogram_to_numpy(y_pred['pred_mel_postnet'][idx].data.cpu().numpy()),
             iteration, dataformats='HWC')
         if self.plotted_targets_val < 2:
             self.add_image(
                 f"{prepend}_{idx}/mel_gt",
                 plot_spectrogram_to_numpy(y['gt_mel'][idx].data.cpu().numpy()),
                 iteration, dataformats='HWC')
     self.plotted_targets_val +=1 # target spect doesn't change so only needs to be plotted once.
Example #24
0
    def log_validation(self, reduced_loss, model, y, y_pred, iteration, stft):
        self.add_scalar("validation.loss", reduced_loss, iteration)
        decoder_outputs, mel_outputs, gate_outputs, alignment = y_pred
        mel_targets, gate_targets = y

        # plot distribution of parameters
        for tag, value in model.named_parameters():
            tag = tag.replace('.', '/')
            self.add_histogram(tag, value.data.cpu().numpy(), iteration)

        # plot alignment, mel target and predicted, gate target and predicted
        idx = random.randint(0, mel_outputs.size(0) - 1)
        # self.add_image(
        #     "alignment",
        #     plot_alignment_to_numpy(alignments[idx].data.cpu().numpy().T),
        #     iteration)
        index = 0

        plot_spectrogram(mel_targets[index].data.cpu().numpy(),
                     decoder_outputs[index].data.cpu().numpy(),
                     mel_outputs[index].data.cpu().numpy(),
                     alignment[index].data.cpu().numpy(),
                     self.logdir, iteration,
                     append="eval")

        save_audio(mel_outputs[index], self.logdir, iteration, stft, False)

        self.add_image(
            "mel_target",
            plot_spectrogram_to_numpy(mel_targets[idx].data.cpu().numpy()),
            iteration)
        self.add_image(
            "mel_predicted",
            plot_spectrogram_to_numpy(mel_outputs[idx].data.cpu().numpy()),
            iteration)
        self.add_image(
            "gate",
            plot_gate_outputs_to_numpy(
                gate_targets[idx].data.cpu().numpy(),
                F.sigmoid(gate_outputs[idx]).data.cpu().numpy()),
            iteration)
Example #25
0
    def log_alignment(self, model, mel_predict, mel_tgt, iteration):

        # plot distribution of parameters
        for tag, value in model.named_parameters():
            tag = tag.replace('.', '/')
            self.add_histogram(tag, value.data.cpu().numpy(), iteration)

        # plot alignment, mel target and predicted, gate target and predicted
        #idx = random.randint(0, enc_slf_attn.size(0) - 1)
        idx = 0
        mel_tgt = mel_tgt.transpose(1, 2)
        mel_predict = mel_predict.transpose(1, 2)
        self.add_image(
            "mel_target",
            plot_spectrogram_to_numpy(mel_tgt[idx].data.cpu().numpy().T),
            iteration)
        self.add_image(
            "mel_predict",
            plot_spectrogram_to_numpy(mel_predict[idx].data.cpu().numpy().T),
            iteration)
        '''self.add_image(
    def log_validation(self, params, iteration):
        for key, key_params in params.items():
            self.add_scalar(f'{key}.loss', key_params['loss'], iteration)

            _, mel_outputs, gate_outputs, alignments = key_params['y_pred']
            mel_targets, gate_targets = key_params['y']

            # plot alignment, mel target and predicted, gate target and predicted
            idx = random.randint(0, mel_outputs.size(0) - 1)
            self.add_image(
                f'{key}.alignment',
                plot_alignment_to_numpy(alignments[idx].data.cpu().numpy().T),
                iteration, dataformats='HWC')
            self.add_image(
                f'{key}.mel_target',
                plot_spectrogram_to_numpy(mel_targets[idx].data.cpu().numpy()),
                iteration, dataformats='HWC')
            self.add_image(
                f'{key}.mel_predicted',
                plot_spectrogram_to_numpy(mel_outputs[idx].data.cpu().numpy()),
                iteration, dataformats='HWC')
Example #27
0
    def log_alignment(self, model, dec_enc_attn, alignment, mel_padded, mel_predict, test_attn, iteration):

        # plot distribution of parameters
        for tag, value in model.named_parameters():
            tag = tag.replace('.', '/')
            self.add_histogram(tag, value.data.cpu().numpy(), iteration)

        # plot alignment, mel target and predicted, gate target and predicted
        
        idx = random.randint(0, dec_enc_attn[0].size(0) - 1)
        mel_padded = mel_padded.permute(0, 2, 1)
        mel_predict = mel_predict.permute(0, 2, 1)
        '''self.add_image(
            "encoder_self_alignment",
            plot_alignment_to_numpy(enc_slf_attn[idx].data.cpu().numpy().T),
            iteration)'''
        for i in range(len(dec_enc_attn)):
            self.add_image(
                "decoder_encoder_alignment_{}".format(i),
                plot_alignment_to_numpy(dec_enc_attn[i][idx].data.cpu().numpy().T),
                iteration)
            self.add_image(
                "test_alignment_{}".format(i),
                plot_alignment_to_numpy(test_attn[len(test_attn)-i-1][idx].data.cpu().numpy().T),
                iteration)
        self.add_image(
            "target_alignment",
            plot_alignment_to_numpy(alignment[idx].data.cpu().numpy().T),
            iteration)
        self.add_image(
            "target_mel",
            plot_spectrogram_to_numpy(mel_padded[idx].data.cpu().numpy().T),
            iteration)
        self.add_image(
            "predict_mel",
            plot_spectrogram_to_numpy(mel_predict[idx].data.cpu().numpy().T),
            iteration)
Example #28
0
def inference(args):

    sentences = get_sentences(args)

    model = load_model(hparams)
    model.load_state_dict(torch.load(args.checkpoint)['state_dict'])
    model.cuda().eval()

    test_set = TextMelLoaderEval(sentences, hparams)
    test_collate_fn = TextMelCollateEval(hparams)
    test_sampler = DistributedSampler(
        valset) if hparams.distributed_run else None
    test_loader = DataLoader(test_set,
                             num_workers=0,
                             sampler=test_sampler,
                             batch_size=hparams.synth_batch_size,
                             pin_memory=False,
                             drop_last=True,
                             collate_fn=test_collate_fn)

    T2_output_range = (-hparams.max_abs_value,
                       hparams.max_abs_value) if hparams.symmetric_mels else (
                           0, hparams.max_abs_value)

    os.makedirs(args.out_filename, exist_ok=True)

    with torch.no_grad():
        for i, batch in enumerate(test_loader):
            mel_outputs, mel_outputs_postnet, _, alignments = model.inference(
                batch)
            align_img = Image.fromarray(
                plot_alignment_to_numpy(alignments[0].data.cpu().numpy().T))
            spec_img = Image.fromarray(
                plot_spectrogram_to_numpy(
                    mel_outputs_postnet[0].data.cpu().numpy()))
            align_img.save(
                os.path.join(args.out_filename,
                             'sentence_{}_alignment.jpg'.format(i)))
            spec_img.save(
                os.path.join(args.out_filename,
                             'sentence_{}_mel-spectrogram.jpg'.format(i)))
            mels = mel_outputs_postnet[0].cpu().numpy()

            mel_path = os.path.join(args.out_filename,
                                    'sentence_{}_mel-feats.npy'.format(i))
            mels = np.clip(mels, T2_output_range[0], T2_output_range[1])
            np.save(mel_path, mels.T, allow_pickle=False)

            print('CHECK MEL SHAPE:', mels.T.shape)
def train_(args, model, opt, latent_loss_weight, criterion, loader, epochs,
           inf_iterator_test, logger, iteration):

    for epoch in range(8000):
        mse_sum = 0
        mse_n = 0

        for i, (audio, name) in enumerate(loader):

            cluster_size = audio.size(1)
            audio = audio.cuda()
            audio = (audio * 25 + 50) / 50

            time_step = audio.size(2)
            factor = 32
            audio_shuffle = [[] for i in range(time_step // factor)]
            nums = [x for x in range(time_step // factor)]
            random.shuffle(nums)

            for i_n, n in enumerate(nums):
                sf = random.uniform(0.5, 2)
                audio_shuffle[n] = F.interpolate(audio[..., factor * n:factor *
                                                       (n + 1)],
                                                 scale_factor=sf,
                                                 mode='nearest')

            audio_shuffle = torch.cat(audio_shuffle, dim=2)

            audio = audio_shuffle  #F.interpolate(audio, scale_factor= audio_shuffle.size(2)/time_step)
            audio = audio[..., :audio.size(2) // 16 * 16]

            audio_middile = F.interpolate(audio, scale_factor=1 / 2)
            audio_middile = audio_middile[:, :audio_middile.size(1) // 2, :]

            audio_low = F.interpolate(audio_middile, scale_factor=1 / 2)
            audio_low = audio_low[:, :audio_low.size(1) // 2, :]

            audio_list = [audio_low, audio_middile, audio]

            out, out_conversion, enc_content, latent_loss = model(audio, name)

            recon_loss = 0

            for num in range(3):
                recon_loss += criterion(out[num], audio_list[num])

            latent_loss = latent_loss.mean()
            #print ("recon_loss:", recon_loss)
            OptimStep([(model, opt,
                        recon_loss + latent_loss_weight * latent_loss, False)],
                      3)  # True),

            if i % 50 == 0:

                logger.log_training(iteration=iteration,
                                    loss_recon=recon_loss,
                                    latent_loss=latent_loss)

                model.eval()

                audio, name = next(inf_iterator_test)
                audio = audio.cuda()
                audio = (audio * 25 + 50) / 50

                out, out_conversion, enc_content, latent_loss = model(
                    audio, name)

                a = torch.stack([audio[0], out[-1][0], out_conversion[-1][0]],
                                dim=0)

                a = (a * 50 - 50) / 25
                a = vocoder.inverse(a)
                a = a.detach().cpu().numpy()
                logger.log_validation(
                    iteration=iteration,
                    mel_ori=("image", plot_spectrogram_to_numpy(), audio[0]),
                    mel_recon=("image", plot_spectrogram_to_numpy(),
                               out[-1][0]),
                    mel_conversion=("image", plot_spectrogram_to_numpy(),
                                    out_conversion[-1][0]),
                    audio_ori=("audio", 22050, a[0]),
                    audio_recon=("audio", 22050, a[1]),
                    audio_conversion=("audio", 22050, a[2]),
                )
                logger.close()
                save_checkpoint(
                    model, opt, iteration,
                    f'checkpoint/{args.model}_n{args.n_embed}_ch{args.channel}_{args.trainer}/gen'
                )

                model.train()
            iteration += 1
Example #30
0
    def log_validation(self, reduced_loss, model, x, y, y_pred, iteration,
                       epoch, sample_rate):
        text_padded, input_lengths, mel_padded, max_len, output_lengths = x

        #self.add_scalar("validation.loss", reduced_loss, iteration) # Tensorboard log
        _, mel_outputs, gate_outputs, alignments = y_pred
        mel_targets, gate_targets = y

        # plot distribution of parameters
        for tag, value in model.named_parameters():
            tag = tag.replace('.', '/')
            # self.add_histogram(tag, value.data.cpu().numpy(), iteration) # Tensorboard log
            wandb.log(
                {
                    tag: wandb.Histogram(value.data.cpu().numpy()),
                    "epoch": epoch,
                    "iteration": iteration
                },
                step=iteration)

        # plot alignment, mel target and predicted, gate target and predicted
        idx = random.randint(0, alignments.size(0) - 1)

        text_len = input_lengths[idx].item()
        text_string = sequence_to_text(text_padded[idx].tolist())[:text_len]

        mel_len = get_mel_length(alignments, idx, text_len)
        mel = mel_outputs[idx:idx + 1, :, :mel_len]

        np_wav = self.mel2wav(mel.type('torch.cuda.HalfTensor'))

        np_alignment = plot_alignment_to_numpy(
            alignments[idx].data.cpu().numpy().T, decoding_len=mel_len)
        '''self.add_image(
            "alignment",
            np_alignment,
            iteration, dataformats='HWC')'''

        np_mel_target = plot_spectrogram_to_numpy(
            mel_targets[idx].data.cpu().numpy())
        '''self.add_image(
            "mel_target",
            np_mel_target,
            iteration, dataformats='HWC')'''

        np_mel_predicted = plot_spectrogram_to_numpy(
            mel_outputs[idx].data.cpu().numpy())
        '''self.add_image(
            "mel_predicted",
            np_mel_predicted,
            iteration, dataformats='HWC')'''

        np_gate = plot_gate_outputs_to_numpy(
            gate_targets[idx].data.cpu().numpy(),
            torch.sigmoid(gate_outputs[idx]).data.cpu().numpy())
        '''self.add_image(
            "gate",
            np_gate,
            iteration, dataformats='HWC')'''

        # wandb log
        wandb.log(
            {
                "val/loss":
                reduced_loss,
                "val/alignment":
                [wandb.Image(np_alignment, caption=text_string)],
                "val/audio": [
                    wandb.Audio(np_wav.astype(np.float32),
                                caption=text_string,
                                sample_rate=sample_rate)
                ],
                "val/mel_target": [wandb.Image(np_mel_target)],
                "val/mel_predicted": [wandb.Image(np_mel_predicted)],
                "val/gate": [wandb.Image(np_gate)],
                "epoch":
                epoch,
                "iteration":
                iteration
            },
            step=iteration)

        # foward attention ratio
        hop_list = [1]
        for hop_size in hop_list:
            mean_far, batch_far = forward_attention_ratio(
                alignments, input_lengths, hop_size)
            log_name = "mean_forward_attention_ratio.val/hop_size={}".format(
                hop_size)
            wandb.log(
                {
                    log_name: mean_far,
                    "epoch": epoch,
                    "iteration": iteration
                },
                step=iteration)
            log_name = "forward_attention_ratio.val/hop_size={}".format(
                hop_size)
            wandb.log(
                {
                    log_name: wandb.Histogram(batch_far.data.cpu().numpy()),
                    "epoch": epoch,
                    "iteration": iteration
                },
                step=iteration)