示例#1
0
    def plot_phoneme_annotations(plotter: DataPlotter,
                                 id_name: str,
                                 hparams: ExtendedHParams,
                                 num_questions: int,
                                 phoneme_indices: np.ndarray,
                                 question_dir: os.PathLike,
                                 question_file: os.PathLike,
                                 grid_indices: List[int] = None):

        questions = QuestionLabelGen.load_sample(id_name=id_name,
                                                 dir_out=question_dir,
                                                 num_questions=num_questions)
        np_phonemes = QuestionLabelGen.questions_to_phonemes(
            questions, phoneme_indices, question_file)

        if grid_indices is None:
            grid_indices = plotter.get_all_grid_indices()
        for grid_idx in grid_indices:
            plotter.set_annotations(grid_idx, np_phonemes)
    def gen_figure_from_output(self, id_name, label, hidden, hparams):
        _, alphas = hidden
        labels_post = self.OutputGen.postprocess_sample(label)
        coded_sp, lf0, vuv, bap = WorldFeatLabelGen.convert_to_world_features(
            labels_post,
            contains_deltas=False,
            num_coded_sps=hparams.num_coded_sps)
        sp = WorldFeatLabelGen.mcep_to_amp_sp(coded_sp, hparams.synth_fs)
        lf0, _ = interpolate_lin(lf0)

        # Load original LF0.
        org_labels_post = WorldFeatLabelGen.load_sample(
            id_name,
            dir_out=self.OutputGen.dir_labels,
            add_deltas=self.OutputGen.add_deltas,
            num_coded_sps=hparams.num_coded_sps)
        original_mgc, original_lf0, original_vuv, *_ = WorldFeatLabelGen.convert_to_world_features(
            sample=org_labels_post,
            contains_deltas=self.OutputGen.add_deltas,
            num_coded_sps=hparams.num_coded_sps)
        original_lf0, _ = interpolate_lin(original_lf0)

        sp = sp[:, :150]  # Zoom into spectral features.

        # Get a data plotter.
        grid_idx = -1
        plotter = DataPlotter()
        net_name = os.path.basename(hparams.model_name)
        filename = str(os.path.join(hparams.out_dir, id_name + '.' + net_name))
        plotter.set_title(id_name + ' - ' + net_name)
        plotter.set_num_colors(3)
        # plotter.set_lim(grid_idx=0, ymin=math.log(60), ymax=math.log(250))

        # # Plot LF0
        # grid_idx += 1
        # graphs.append((original_lf0, 'Original LF0'))
        # graphs.append((lf0, 'NN LF0'))
        # plotter.set_data_list(grid_idx=grid_idx, data_list=graphs)
        # plotter.set_area_list(grid_idx=grid_idx, area_list=[(np.invert(vuv.astype(bool)), '0.8', 1.0),
        #                                                     (np.invert(original_vuv.astype(bool)), 'red', 0.2)])
        # plotter.set_label(grid_idx=grid_idx, xlabel='frames [{}] ms'.format(hparams.frame_length), ylabel='log(f0)')

        # Reverse the warping.
        wl = self._get_dummy_warping_layer(hparams)
        norm_params_no_deltas = (
            self.OutputGen.norm_params[0][:hparams.num_coded_sps],
            self.OutputGen.norm_params[1][:hparams.num_coded_sps])
        pre_net_output, _ = wl.forward_sample(label, -alphas)

        # Postprocess sample manually.
        pre_net_output = pre_net_output.detach().cpu().numpy()
        pre_net_mgc = pre_net_output[:, 0, :hparams.
                                     num_coded_sps] * norm_params_no_deltas[
                                         1] + norm_params_no_deltas[0]

        # Plot spectral features predicted by pre-network.
        grid_idx += 1
        plotter.set_label(grid_idx=grid_idx,
                          xlabel='frames [{}] ms'.format(
                              hparams.frame_size_ms),
                          ylabel='Pre-network')
        plotter.set_specshow(grid_idx=grid_idx,
                             spec=np.abs(
                                 WorldFeatLabelGen.mcep_to_amp_sp(
                                     pre_net_mgc,
                                     hparams.synth_fs)[:, :sp.shape[1]]))

        # Plot final predicted spectral features.
        grid_idx += 1
        plotter.set_label(grid_idx=grid_idx,
                          xlabel='frames [{}] ms'.format(
                              hparams.frame_size_ms),
                          ylabel='VTLN')
        plotter.set_specshow(grid_idx=grid_idx, spec=np.abs(sp))

        # Plot predicted alpha value and V/UV flag.
        grid_idx += 1
        plotter.set_label(grid_idx=grid_idx,
                          xlabel='frames [{}] ms'.format(
                              hparams.frame_size_ms),
                          ylabel='alpha')
        graphs = list()
        graphs.append((alphas, 'NN alpha'))
        plotter.set_data_list(grid_idx=grid_idx, data_list=graphs)
        plotter.set_area_list(grid_idx=grid_idx,
                              area_list=[(np.invert(vuv.astype(bool)), '0.8',
                                          1.0),
                                         (np.invert(original_vuv.astype(bool)),
                                          'red', 0.2)])

        # Add phoneme annotations if given.
        if hasattr(hparams, "phoneme_indices") and hparams.phoneme_indices is not None \
           and hasattr(hparams, "question_file") and hparams.question_file is not None:
            questions = QuestionLabelGen.load_sample(
                id_name,
                os.path.join("experiments", hparams.voice, "questions"),
                num_questions=hparams.num_questions)[:len(lf0)]
            np_phonemes = QuestionLabelGen.questions_to_phonemes(
                questions, hparams.phoneme_indices, hparams.question_file)
            plotter.set_annotations(grid_idx, np_phonemes)

        # Plot reference spectral features.
        grid_idx += 1
        plotter.set_label(grid_idx=grid_idx,
                          xlabel='frames [{}] ms'.format(
                              hparams.frame_size_ms),
                          ylabel='Original spectrogram')
        plotter.set_specshow(grid_idx=grid_idx,
                             spec=np.abs(
                                 WorldFeatLabelGen.mcep_to_amp_sp(
                                     original_mgc,
                                     hparams.synth_fs)[:, :sp.shape[1]]))

        plotter.gen_plot()
        plotter.save_to_file(filename + '.VTLN' + hparams.gen_figure_ext)
示例#3
0
    def gen_figure_from_output(self, id_name, label, hidden, hparams):
        _, alphas = hidden
        labels_post = self.OutputGen.postprocess_sample(label)
        coded_sp, lf0, vuv, bap = WorldFeatLabelGen.convert_to_world_features(
            labels_post,
            contains_deltas=False,
            num_coded_sps=hparams.num_coded_sps)
        sp = WorldFeatLabelGen.mcep_to_amp_sp(coded_sp, hparams.synth_fs)
        lf0, _ = interpolate_lin(lf0)

        # Load original lf0.
        org_labels_post = WorldFeatLabelGen.load_sample(
            id_name,
            self.OutputGen.dir_labels,
            add_deltas=self.OutputGen.add_deltas,
            num_coded_sps=hparams.num_coded_sps)
        original_mgc, original_lf0, original_vuv, *_ = WorldFeatLabelGen.convert_to_world_features(
            org_labels_post,
            contains_deltas=self.OutputGen.add_deltas,
            num_coded_sps=hparams.num_coded_sps)
        original_lf0, _ = interpolate_lin(original_lf0)

        questions = QuestionLabelGen.load_sample(
            id_name,
            os.path.join("experiments", hparams.voice, "questions"),
            num_questions=hparams.num_questions)[:len(alphas)]
        phoneme_indices = QuestionLabelGen.questions_to_phoneme_indices(
            questions, hparams.phoneme_indices)
        alpha_vec = self.phonemes_to_alpha_tensor[phoneme_indices % len(
            self.phonemes_to_alpha_tensor)]

        # Get a data plotter.
        grid_idx = 0
        plotter = DataPlotter()
        net_name = os.path.basename(hparams.model_name)
        filename = str(os.path.join(hparams.out_dir, id_name + '.' + net_name))
        plotter.set_title(id_name + ' - ' + net_name)
        plotter.set_num_colors(3)
        # plotter.set_lim(grid_idx=0, ymin=math.log(60), ymax=math.log(250))
        plotter.set_label(grid_idx=grid_idx,
                          xlabel='frames [' + str(hparams.frame_size_ms) +
                          ' ms]',
                          ylabel='log(f0)')

        graphs = list()
        graphs.append((original_lf0, 'Original LF0'))
        graphs.append((lf0, 'NN LF0'))
        plotter.set_data_list(grid_idx=grid_idx, data_list=graphs)
        plotter.set_area_list(grid_idx=grid_idx,
                              area_list=[(np.invert(vuv.astype(bool)), '0.8',
                                          1.0),
                                         (np.invert(original_vuv.astype(bool)),
                                          'red', 0.2)])

        # grid_idx += 1
        # plotter.set_label(grid_idx=grid_idx, xlabel='frames [' + str(hparams.frame_size_ms) + ' ms]', ylabel='Original spectrogram')
        # plotter.set_specshow(grid_idx=grid_idx, spec=WorldFeatLabelGen.mgc_to_sp(original_mgc, hparams.synth_fs))
        #
        # grid_idx += 1
        # plotter.set_label(grid_idx=grid_idx, xlabel='frames [' + str(hparams.frame_size_ms) + ' ms]', ylabel='NN spectrogram')
        # plotter.set_specshow(grid_idx=grid_idx, spec=sp)

        grid_idx += 1
        plotter.set_label(grid_idx=grid_idx,
                          xlabel='frames [' + str(hparams.frame_size_ms) +
                          ' ms]',
                          ylabel='alpha')
        graphs = list()
        graphs.append((alpha_vec, 'Original alpha'))
        graphs.append((alphas, 'NN alpha'))
        plotter.set_data_list(grid_idx=grid_idx, data_list=graphs)
        plotter.set_area_list(grid_idx=grid_idx,
                              area_list=[(np.invert(vuv.astype(bool)), '0.8',
                                          1.0),
                                         (np.invert(original_vuv.astype(bool)),
                                          'red', 0.2)])
        if hasattr(hparams, "phoneme_indices") and hparams.phoneme_indices is not None \
           and hasattr(hparams, "question_file") and hparams.question_file is not None:
            questions = QuestionLabelGen.load_sample(
                id_name,
                os.path.join("experiments", hparams.voice, "questions"),
                num_questions=hparams.num_questions)[:len(lf0)]
            np_phonemes = QuestionLabelGen.questions_to_phonemes(
                questions, hparams.phoneme_indices, hparams.question_file)
            plotter.set_annotations(grid_idx, np_phonemes)

        plotter.gen_plot()
        plotter.save_to_file(filename + '.VTLN' + hparams.gen_figure_ext)