def gen_figure_from_output(self, id_name, label, hidden, hparams):
        _, alphas = hidden
        labels_post = self.OutputGen.postprocess_sample(label)
        coded_sp, lf0, vuv, bap = WorldFeatLabelGen.convert_to_world_features(
            labels_post,
            contains_deltas=False,
            num_coded_sps=hparams.num_coded_sps)
        sp = WorldFeatLabelGen.mcep_to_amp_sp(coded_sp, hparams.synth_fs)
        lf0, _ = interpolate_lin(lf0)

        # Load original LF0.
        org_labels_post = WorldFeatLabelGen.load_sample(
            id_name,
            dir_out=self.OutputGen.dir_labels,
            add_deltas=self.OutputGen.add_deltas,
            num_coded_sps=hparams.num_coded_sps)
        original_mgc, original_lf0, original_vuv, *_ = WorldFeatLabelGen.convert_to_world_features(
            sample=org_labels_post,
            contains_deltas=self.OutputGen.add_deltas,
            num_coded_sps=hparams.num_coded_sps)
        original_lf0, _ = interpolate_lin(original_lf0)

        sp = sp[:, :150]  # Zoom into spectral features.

        # Get a data plotter.
        grid_idx = -1
        plotter = DataPlotter()
        net_name = os.path.basename(hparams.model_name)
        filename = str(os.path.join(hparams.out_dir, id_name + '.' + net_name))
        plotter.set_title(id_name + ' - ' + net_name)
        plotter.set_num_colors(3)
        # plotter.set_lim(grid_idx=0, ymin=math.log(60), ymax=math.log(250))

        # # Plot LF0
        # grid_idx += 1
        # graphs.append((original_lf0, 'Original LF0'))
        # graphs.append((lf0, 'NN LF0'))
        # plotter.set_data_list(grid_idx=grid_idx, data_list=graphs)
        # plotter.set_area_list(grid_idx=grid_idx, area_list=[(np.invert(vuv.astype(bool)), '0.8', 1.0),
        #                                                     (np.invert(original_vuv.astype(bool)), 'red', 0.2)])
        # plotter.set_label(grid_idx=grid_idx, xlabel='frames [{}] ms'.format(hparams.frame_length), ylabel='log(f0)')

        # Reverse the warping.
        wl = self._get_dummy_warping_layer(hparams)
        norm_params_no_deltas = (
            self.OutputGen.norm_params[0][:hparams.num_coded_sps],
            self.OutputGen.norm_params[1][:hparams.num_coded_sps])
        pre_net_output, _ = wl.forward_sample(label, -alphas)

        # Postprocess sample manually.
        pre_net_output = pre_net_output.detach().cpu().numpy()
        pre_net_mgc = pre_net_output[:, 0, :hparams.
                                     num_coded_sps] * norm_params_no_deltas[
                                         1] + norm_params_no_deltas[0]

        # Plot spectral features predicted by pre-network.
        grid_idx += 1
        plotter.set_label(grid_idx=grid_idx,
                          xlabel='frames [{}] ms'.format(
                              hparams.frame_size_ms),
                          ylabel='Pre-network')
        plotter.set_specshow(grid_idx=grid_idx,
                             spec=np.abs(
                                 WorldFeatLabelGen.mcep_to_amp_sp(
                                     pre_net_mgc,
                                     hparams.synth_fs)[:, :sp.shape[1]]))

        # Plot final predicted spectral features.
        grid_idx += 1
        plotter.set_label(grid_idx=grid_idx,
                          xlabel='frames [{}] ms'.format(
                              hparams.frame_size_ms),
                          ylabel='VTLN')
        plotter.set_specshow(grid_idx=grid_idx, spec=np.abs(sp))

        # Plot predicted alpha value and V/UV flag.
        grid_idx += 1
        plotter.set_label(grid_idx=grid_idx,
                          xlabel='frames [{}] ms'.format(
                              hparams.frame_size_ms),
                          ylabel='alpha')
        graphs = list()
        graphs.append((alphas, 'NN alpha'))
        plotter.set_data_list(grid_idx=grid_idx, data_list=graphs)
        plotter.set_area_list(grid_idx=grid_idx,
                              area_list=[(np.invert(vuv.astype(bool)), '0.8',
                                          1.0),
                                         (np.invert(original_vuv.astype(bool)),
                                          'red', 0.2)])

        # Add phoneme annotations if given.
        if hasattr(hparams, "phoneme_indices") and hparams.phoneme_indices is not None \
           and hasattr(hparams, "question_file") and hparams.question_file is not None:
            questions = QuestionLabelGen.load_sample(
                id_name,
                os.path.join("experiments", hparams.voice, "questions"),
                num_questions=hparams.num_questions)[:len(lf0)]
            np_phonemes = QuestionLabelGen.questions_to_phonemes(
                questions, hparams.phoneme_indices, hparams.question_file)
            plotter.set_annotations(grid_idx, np_phonemes)

        # Plot reference spectral features.
        grid_idx += 1
        plotter.set_label(grid_idx=grid_idx,
                          xlabel='frames [{}] ms'.format(
                              hparams.frame_size_ms),
                          ylabel='Original spectrogram')
        plotter.set_specshow(grid_idx=grid_idx,
                             spec=np.abs(
                                 WorldFeatLabelGen.mcep_to_amp_sp(
                                     original_mgc,
                                     hparams.synth_fs)[:, :sp.shape[1]]))

        plotter.gen_plot()
        plotter.save_to_file(filename + '.VTLN' + hparams.gen_figure_ext)
Esempio n. 2
0
    def gen_figure_from_output(self, id_name, label, hidden, hparams):

        labels_post = self.OutputGen.postprocess_sample(label)
        coded_sp, lf0, vuv, bap = WorldFeatLabelGen.convert_to_world_features(
            labels_post,
            contains_deltas=False,
            num_coded_sps=hparams.num_coded_sps)
        lf0, _ = interpolate_lin(lf0)

        # Load original lf0.
        org_labels_post = WorldFeatLabelGen.load_sample(
            id_name,
            self.OutputGen.dir_labels,
            add_deltas=self.OutputGen.add_deltas,
            num_coded_sps=hparams.num_coded_sps)
        original_mgc, original_lf0, original_vuv, *_ = WorldFeatLabelGen.convert_to_world_features(
            org_labels_post,
            contains_deltas=self.OutputGen.add_deltas,
            num_coded_sps=hparams.num_coded_sps)
        original_lf0, _ = interpolate_lin(original_lf0)

        # Get a data plotter.
        grid_idx = 0
        plotter = DataPlotter()
        net_name = os.path.basename(hparams.model_name)
        filename = str(os.path.join(hparams.out_dir, id_name + '.' + net_name))
        plotter.set_title(id_name + ' - ' + net_name)
        plotter.set_num_colors(3)
        # plotter.set_lim(grid_idx=0, ymin=math.log(60), ymax=math.log(250))
        plotter.set_label(grid_idx=grid_idx,
                          xlabel='frames [' + str(hparams.frame_size_ms) +
                          ' ms]',
                          ylabel='log(f0)')

        graphs = list()
        graphs.append((original_lf0, 'Original lf0'))
        graphs.append((lf0, 'PyTorch lf0'))
        plotter.set_data_list(grid_idx=grid_idx, data_list=graphs)
        plotter.set_area_list(grid_idx=grid_idx,
                              area_list=[(np.invert(vuv.astype(bool)), '0.8',
                                          1.0),
                                         (np.invert(original_vuv.astype(bool)),
                                          'red', 0.2)])

        grid_idx += 1
        import librosa
        plotter.set_label(grid_idx=grid_idx,
                          xlabel='frames [' + str(hparams.frame_size_ms) +
                          ' ms]',
                          ylabel='Original spectrogram')
        plotter.set_specshow(grid_idx=grid_idx,
                             spec=librosa.amplitude_to_db(np.absolute(
                                 WorldFeatLabelGen.mcep_to_amp_sp(
                                     original_mgc, hparams.synth_fs)),
                                                          top_db=None))

        grid_idx += 1
        plotter.set_label(grid_idx=grid_idx,
                          xlabel='frames [' + str(hparams.frame_size_ms) +
                          ' ms]',
                          ylabel='NN spectrogram')
        plotter.set_specshow(grid_idx=grid_idx,
                             spec=librosa.amplitude_to_db(np.absolute(
                                 WorldFeatLabelGen.mcep_to_amp_sp(
                                     coded_sp, hparams.synth_fs)),
                                                          top_db=None))

        if hasattr(hparams, "phoneme_indices") and hparams.phoneme_indices is not None \
           and hasattr(hparams, "question_file") and hparams.question_file is not None:
            questions = QuestionLabelGen.load_sample(
                id_name,
                "experiments/" + hparams.voice + "/questions/",
                num_questions=hparams.num_questions)[:len(lf0)]
            np_phonemes = QuestionLabelGen.questions_to_phonemes(
                questions, hparams.phoneme_indices, hparams.question_file)
            plotter.set_annotations(grid_idx, np_phonemes)

        plotter.gen_plot()
        plotter.save_to_file(filename + '.Org-PyTorch' +
                             hparams.gen_figure_ext)