Example #1
0
    def synth_ref(self, hparams, file_id_list):
        # Create reference audio files containing only the vocoder degradation.
        self.logger.info("Synthesise references for [{0}].".format(", ".join([id_name for id_name in file_id_list])))

        synth_dict = dict()
        for id_name in file_id_list:
            world_dir = os.path.join(self.OutputGen.dir_labels, self.dir_extracted_acoustic_features)
            # Load reference audio features.
            try:
                output = WorldFeatLabelGen.load_sample(id_name, world_dir, num_coded_sps=hparams.num_coded_sps)
            except FileNotFoundError as e1:
                try:
                    output = WorldFeatLabelGen.load_sample(id_name, world_dir, add_deltas=True, num_coded_sps=hparams.num_coded_sps)
                    coded_sp, lf0, vuv, bap = WorldFeatLabelGen.convert_to_world_features(output, True, hparams.num_coded_sps)
                    length = len(output)
                    lf0 = lf0.reshape(length, 1)
                    vuv = vuv.reshape(length, 1)
                    bap = bap.reshape(length, 1)
                    output = np.concatenate((coded_sp, lf0, vuv, bap), axis=1)
                except FileNotFoundError as e2:
                    self.logger.error("Cannot find extracted acoustic feature with or without deltas labels in {}.".format(world_dir))
                    raise Exception([e1, e2])
            synth_dict[id_name] = output

        # Add identifier to suffix.
        old_synth_file_suffix = hparams.synth_file_suffix
        hparams.synth_file_suffix = '_ref' + str(hparams.num_coded_sps) + 'sp'

        # Run the vocoder.
        ModelTrainer.synthesize(self, file_id_list, synth_dict, hparams)

        # Restore identifier.
        hparams.synth_file_suffix = old_synth_file_suffix
    def synthesize(self, id_list, synth_output, hparams):
        """Depending on hparams override the network output with the extracted features, then continue with normal synthesis pipeline."""

        if hparams.synth_load_org_sp or hparams.synth_load_org_lf0 or hparams.synth_load_org_vuv or hparams.synth_load_org_bap:
            for id_name in id_list:

                labels = WorldFeatLabelGen.load_sample(id_name, os.path.join(self.OutputGen.dir_labels, self.dir_extracted_acoustic_features), num_coded_sps=hparams.num_coded_sps)
                len_diff = len(labels) - len(synth_output[id_name])
                if len_diff > 0:
                    labels = WorldFeatLabelGen.trim_end_sample(labels, int(len_diff / 2), reverse=True)
                    labels = WorldFeatLabelGen.trim_end_sample(labels, len_diff - int(len_diff / 2))

                if hparams.synth_load_org_sp:
                    synth_output[id_name][:len(labels), :self.OutputGen.num_coded_sps] = labels[:, :self.OutputGen.num_coded_sps]

                if hparams.synth_load_org_lf0:
                    synth_output[id_name][:len(labels), -3] = labels[:, -3]

                if hparams.synth_load_org_vuv:
                    synth_output[id_name][:len(labels), -2] = labels[:, -2]

                if hparams.synth_load_org_bap:
                    synth_output[id_name][:len(labels), -1] = labels[:, -1]

        # Run the vocoder.
        ModelTrainer.synthesize(self, id_list, synth_output, hparams)
    def gen_figure_from_output(self, id_name, label, hidden, hparams):

        labels_post = self.OutputGen.postprocess_sample(label)
        coded_sp, lf0, vuv, bap = WorldFeatLabelGen.convert_to_world_features(labels_post, contains_deltas=False, num_coded_sps=hparams.num_coded_sps)
        lf0, _ = interpolate_lin(lf0)

        # Load original lf0.
        org_labels_post = WorldFeatLabelGen.load_sample(id_name, self.OutputGen.dir_labels, add_deltas=self.OutputGen.add_deltas, num_coded_sps=hparams.num_coded_sps)
        original_mgc, original_lf0, original_vuv, *_ = WorldFeatLabelGen.convert_to_world_features(org_labels_post, contains_deltas=self.OutputGen.add_deltas, num_coded_sps=hparams.num_coded_sps)
        original_lf0, _ = interpolate_lin(original_lf0)

        # # Load merlin lf0.
        # merlin_data_dir = os.path.join(os.path.dirname(self.InputGen.label_dir), "wav_merlin/")
        # with open(os.path.join(merlin_data_dir, id_name + ".lf0"), 'rb') as f:
        #     merlin_lf0 = np.fromfile(f, dtype=np.float32)
        #     merlin_lf0 = np.reshape(merlin_lf0, merlin_lf0.shape[0] / original_lf0.shape[1], original_lf0.shape[1])
        #     merlin_lf0[merlin_lf0 < 0] = 0.0
        #     merlin_lf0, _ = interpolate_lin(merlin_lf0)

        # Get a data plotter.
        grid_idx = 0
        plotter = DataPlotter()
        net_name = os.path.basename(hparams.model_name)
        filename = str(os.path.join(hparams.out_dir, id_name + '.' + net_name))
        plotter.set_title(id_name + ' - ' + net_name)
        plotter.set_num_colors(3)
        # plotter.set_lim(grid_idx=0, ymin=math.log(60), ymax=math.log(250))
        plotter.set_label(grid_idx=grid_idx, xlabel='frames [' + str(hparams.frame_size_ms) + ' ms]', ylabel='log(f0)')

        graphs = list()
        # graphs.append((merlin_lf0, 'Merlin lf0'))
        graphs.append((original_lf0, 'Original lf0'))
        graphs.append((lf0, 'PyTorch lf0'))
        plotter.set_data_list(grid_idx=grid_idx, data_list=graphs)
        plotter.set_area_list(grid_idx=grid_idx, area_list=[(np.invert(vuv.astype(bool)), '0.8', 1.0),(np.invert(original_vuv.astype(bool)), 'red', 0.2)])

        grid_idx += 1
        plotter.set_label(grid_idx=grid_idx, xlabel='frames [' + str(hparams.frame_size_ms) + ' ms]', ylabel='Original spectrogram')
        plotter.set_specshow(grid_idx=grid_idx, spec=WorldFeatLabelGen.mgc_to_sp(original_mgc, hparams.synth_fs))

        grid_idx += 1
        plotter.set_label(grid_idx=grid_idx, xlabel='frames [' + str(hparams.frame_size_ms) + ' ms]', ylabel='NN spectrogram')
        plotter.set_specshow(grid_idx=grid_idx, spec=WorldFeatLabelGen.mgc_to_sp(coded_sp, hparams.synth_fs))

        if hasattr(hparams, "phoneme_indices") and hparams.phoneme_indices is not None \
           and hasattr(hparams, "question_file") and hparams.question_file is not None:
            questions = QuestionLabelGen.load_sample(id_name,
                                                     "experiments/" + hparams.voice + "/questions/",
                                                     num_questions=hparams.num_questions)[:len(lf0)]
            np_phonemes = QuestionLabelGen.questions_to_phonemes(questions, hparams.phoneme_indices, hparams.question_file)
            plotter.set_annotations(grid_idx, np_phonemes)

        plotter.gen_plot()
        plotter.save_to_file(filename + '.Org-PyTorch' + hparams.gen_figure_ext)
Example #4
0
    def synthesize(self, id_list, synth_output, hparams):
        """Save output of model to .lf0 and (.vuv) files and call Merlin synth which reads those files."""

        # Reconstruct lf0 from generated atoms and write it to synth output.
        # recon_dict = self.get_recon_from_synth_output(synth_output)
        full_output = dict()
        for id_name, labels in synth_output.items():
            # Take lf0 and vuv from network output.
            lf0 = labels[:, 0]
            vuv = labels[:, 1]

            phrase_curve = self.OutputGen.get_phrase_curve(id_name)
            lf0 = lf0 + phrase_curve[:len(lf0)].squeeze()

            vuv[vuv < 0.5] = 0.0
            vuv[vuv >= 0.5] = 1.0

            # Get mgc, vuv and bap data either through a trained acoustic model or from data extracted from the audio.
            if hparams.synth_acoustic_model_path is None:
                path = os.path.realpath(
                    os.path.join(hparams.out_dir,
                                 self.dir_extracted_acoustic_features))
                full_sample: np.ndarray = WorldFeatLabelGen.load_sample(
                    id_name,
                    path,
                    add_deltas=False,
                    num_coded_sps=hparams.num_coded_sps
                )  # Load extracted data.
                len_diff = len(full_sample) - len(lf0)
                trim_front = len_diff // 2
                trim_end = len_diff - trim_front
                full_sample = WorldFeatLabelGen.trim_end_sample(
                    full_sample, trim_end)
                full_sample = WorldFeatLabelGen.trim_end_sample(full_sample,
                                                                trim_front,
                                                                reverse=True)
            else:
                raise NotImplementedError()

            # Overwrite lf0 and vuv by network output.
            full_sample[:, hparams.num_coded_sps] = lf0
            full_sample[:, hparams.num_coded_sps + 1] = vuv
            # Fill a dictionary with the samples.
            full_output[id_name + "_E2E"] = full_sample

        # Run the merlin synthesizer
        self.run_world_synth(full_output, hparams)
Example #5
0
    def synthesize(self, id_list, synth_output, hparams):
        """This method should be overwritten by sub classes."""
        full_output = self.run_atom_synth(id_list, synth_output, hparams)

        for id_name, labels in full_output.items():
            lf0 = labels[:, -3]
            lf0 = interpolate_lin(lf0)
            vuv = synth_output[id_name][:, 0, 1]
            len_diff = len(labels) - len(vuv)
            labels = WorldFeatLabelGen.trim_end_sample(labels,
                                                       int(len_diff / 2),
                                                       reverse=True)
            labels = WorldFeatLabelGen.trim_end_sample(
                labels, len_diff - int(len_diff / 2))
            labels[:, -2] = vuv

        # Run the vocoder.
        ModelTrainer.synthesize(self, id_list, full_output, hparams)
Example #6
0
    def load_extracted_audio_features(self, synth_output, hparams):
        """Load the audio features extracted from audio."""
        self.logger.info("Load extracted mgc, lf0, vuv, bap data.")

        org_output = dict()
        for id_name in synth_output.keys():
            path = os.path.realpath(os.path.join(hparams.out_dir, self.dir_extracted_acoustic_features))
            org_output[id_name] = WorldFeatLabelGen.load_sample(id_name, path, add_deltas=False, num_coded_sps=hparams.num_coded_sps)  # Load extracted data.

        return org_output
Example #7
0
    def gen_figure_from_output(self, id_name, labels, hidden, hparams):

        labels_post = self.OutputGen.postprocess_sample(labels)
        mfcc, lf0, vuv, bap = WorldFeatLabelGen.convert_to_world_features(labels_post, num_coded_sps=self.OutputGen.num_coded_sps)
        lf0, _ = interpolate_lin(lf0)

        # Load original lf0.
        org_labels_post = WorldFeatLabelGen.load_sample(id_name, self.OutputGen.dir_labels, num_coded_sps=hparams.num_coded_sps)
        _, original_lf0, *_ = WorldFeatLabelGen.convert_to_world_features(org_labels_post, num_coded_sps=hparams.num_coded_sps)
        original_lf0, _ = interpolate_lin(original_lf0)

        # Get a data plotter.
        grid_idx = 0
        plotter = DataPlotter()
        net_name = os.path.basename(hparams.model_name)
        filename = str(os.path.join(hparams.out_dir, id_name + '.' + net_name))
        plotter.set_title(id_name + ' - ' + net_name)
        plotter.set_num_colors(3)
        # plotter.set_lim(grid_idx=grid_idx, ymin=math.log(60), ymax=math.log(250))
        plotter.set_label(grid_idx=grid_idx, xlabel='frames [' + str(hparams.frame_size_ms) + ' ms]', ylabel='log(f0)')
        grid_idx += 1

        graphs = list()
        # graphs.append((merlin_lf0, 'Merlin lf0'))
        graphs.append((original_lf0, 'Original lf0'))
        graphs.append((lf0, 'PyTorch lf0'))
        plotter.set_data_list(grid_idx=grid_idx, data_list=graphs)
        plotter.set_area_list(grid_idx=grid_idx, area_list=[(np.invert(vuv.astype(bool)), '0.8', 1.0)])

        if hasattr(hparams, "phoneme_indices") and hparams.phoneme_indices is not None \
           and hasattr(hparams, "question_file") and hparams.question_file is not None:
            questions = QuestionLabelGen.load_sample(id_name,
                                                     "experiments/" + hparams.voice + "/questions/",
                                                     num_questions=hparams.num_questions)[:len(lf0)]
            np_phonemes = QuestionLabelGen.questions_to_phonemes(questions, hparams.phoneme_indices, hparams.question_file)
            plotter.set_annotations(grid_idx, np_phonemes)

        plotter.gen_plot()
        plotter.save_to_file(filename + '.Org-PyTorch' + hparams.gen_figure_ext)
Example #8
0
    def __init__(self, dir_world_features, dir_question_labels, id_list, num_questions, hparams=None):
        """Default constructor.

        :param dir_world_features:      Path to the directory containing the world features.
        :param dir_question_labels:     Path to the directory containing the question labels.
        :param id_list:                 List containing all ids. Subset is taken as test set.
        :param num_questions:           Expected number of questions in question labels.
        :param hparams:                 Hyper-parameter container.
        """
        if hparams is None:
            hparams = self.create_hparams()
            hparams.out_dir = os.path.curdir

        # Write missing default parameters.
        if hparams.variable_sequence_length_train is None:
            hparams.variable_sequence_length_train = hparams.batch_size_train > 1
        if hparams.variable_sequence_length_test is None:
            hparams.variable_sequence_length_test = hparams.batch_size_test > 1
        if hparams.synth_dir is None:
            hparams.synth_dir = os.path.join(hparams.out_dir, "synth")

        super().__init__(id_list, hparams)

        self.InputGen = QuestionLabelGen(dir_question_labels, num_questions)
        self.InputGen.get_normalisation_params(dir_question_labels)

        self.OutputGen = WorldFeatLabelGen(dir_world_features, add_deltas=False, num_coded_sps=hparams.num_coded_sps)
        self.OutputGen.get_normalisation_params(dir_world_features)

        self.dataset_train = PyTorchLabelGensDataset(self.id_list_train, self.InputGen, self.OutputGen, hparams)
        self.dataset_val = PyTorchLabelGensDataset(self.id_list_val, self.InputGen, self.OutputGen, hparams)

        if self.loss_function is None:
            self.loss_function = torch.nn.MSELoss(reduction='none')

        if hparams.scheduler_type == "default":
            hparams.scheduler_type = "Plateau"
            hparams.plateau_verbose = True
Example #9
0
    def run_atom_synth(self, file_id_list, synth_output, hparams):
        """
        Reconstruct lf0, get mgc and bap data, and store all in files in self.synth_dir.
        """

        # Get mgc, vuv and bap data either through a trained acoustic model or from data extracted from the audio.
        if hparams.synth_acoustic_model_path is None:
            full_output = self.load_extracted_audio_features(synth_output, hparams)
        else:
            full_output = self.generate_audio_features(file_id_list, hparams)

        # Reconstruct lf0 from generated atoms and write it to synth output.
        recon_dict = self.get_recon_from_synth_output(synth_output, hparams)
        for id_name, lf0 in recon_dict.items():
            full_sample = full_output[id_name]
            len_diff = len(full_sample) - len(lf0)
            full_sample = WorldFeatLabelGen.trim_end_sample(full_sample, int(len_diff / 2), reverse=True)
            full_sample = WorldFeatLabelGen.trim_end_sample(full_sample, len_diff - int(len_diff / 2))
            vuv = np.ones(lf0.shape)
            vuv[lf0 <= math.log(WorldFeatLabelGen.f0_silence_threshold)] = 0.0
            full_sample[:, hparams.num_coded_sps] = lf0
            full_sample[:, hparams.num_coded_sps + 1] = vuv

        return full_output
Example #10
0
    def run_world_synth(self, synth_output, hparams):
        """Run the WORLD synthesize method."""
        fft_size = pyworld.get_cheaptrick_fft_size(hparams.synth_fs)

        save_dir = hparams.synth_dir if hparams.synth_dir is not None else hparams.out_dir if hparams.out_dir is not None else os.path.curdir
        for id_name, output in synth_output.items():
            logging.info("Synthesise {} with the WORLD vocoder.".format(id_name))

            coded_sp, lf0, vuv, bap = WorldFeatLabelGen.convert_to_world_features(output, contains_deltas=False, num_coded_sps=hparams.num_coded_sps)
            ln_sp = pysptk.mgc2sp(np.ascontiguousarray(coded_sp, dtype=np.float64), alpha=WorldFeatLabelGen.mgc_alpha, gamma=0.0, fftlen=fft_size)
            # sp = np.exp(sp.real * 2.0)
            # sp.imag = sp.imag * 180.0 / np.pi
            sp = np.exp(ln_sp.real)
            sp = np.power(sp.real / 32768.0, 2)
            # sp = np.power(sp.real / 32768.0, 2)
            # sp = np.exp(np.power(sp.real, 2))
            # sp = pyworld.decode_spectral_envelope(np.ascontiguousarray(coded_sp, np.float64), self.synth_fs, fft_size)  # Cepstral version.
            f0 = np.exp(lf0, dtype=np.float64)
            vuv[f0 < WorldFeatLabelGen.f0_silence_threshold] = 0  # WORLD throws an error for too small f0 values.
            f0[vuv == 0] = 0.0
            ap = pyworld.decode_aperiodicity(np.ascontiguousarray(bap.reshape(-1, 1), np.float64), hparams.synth_fs, fft_size)

            waveform = pyworld.synthesize(f0, sp, ap, hparams.synth_fs)
            waveform = waveform.astype(np.float32, copy=False)  # Does inplace conversion, if possible.

            # Always save as wav file first and convert afterwards if necessary.
            wav_file_path = os.path.join(save_dir, "{}{}{}.wav".format(os.path.basename(id_name),
                                                                       "_" + hparams.model_name if hparams.model_name is not None else "",
                                                                       hparams.synth_file_suffix, "_WORLD", ".wav"))
            makedirs_safe(hparams.synth_dir)
            soundfile.write(wav_file_path, waveform, hparams.synth_fs)

            # Use PyDub for special audio formats.
            if hparams.synth_ext.lower() != 'wav':
                as_wave = pydub.AudioSegment.from_wav(wav_file_path)
                as_wave.export(os.path.join(hparams.synth_dir, id_name + "." + hparams.synth_ext), format=hparams.synth_ext)
                os.remove(wav_file_path)
Example #11
0
class AcousticModelTrainer(ModelTrainer):
    """
    Implementation of a ModelTrainer for the generation of acoustic features.

    Use question labels as input and WORLD features as output. Synthesize audio from model output.
    """
    logger = logging.getLogger(__name__)

    #########################
    # Default constructor
    #
    def __init__(self, dir_world_features, dir_question_labels, id_list, num_questions, hparams=None):
        """Default constructor.

        :param dir_world_features:      Path to the directory containing the world features.
        :param dir_question_labels:     Path to the directory containing the question labels.
        :param id_list:                 List containing all ids. Subset is taken as test set.
        :param num_questions:           Expected number of questions in question labels.
        :param hparams:                 Hyper-parameter container.
        """
        if hparams is None:
            hparams = self.create_hparams()
            hparams.out_dir = os.path.curdir

        # Write missing default parameters.
        if hparams.variable_sequence_length_train is None:
            hparams.variable_sequence_length_train = hparams.batch_size_train > 1
        if hparams.variable_sequence_length_test is None:
            hparams.variable_sequence_length_test = hparams.batch_size_test > 1
        if hparams.synth_dir is None:
            hparams.synth_dir = os.path.join(hparams.out_dir, "synth")

        super().__init__(id_list, hparams)

        self.InputGen = QuestionLabelGen(dir_question_labels, num_questions)
        self.InputGen.get_normalisation_params(dir_question_labels)

        self.OutputGen = WorldFeatLabelGen(dir_world_features, add_deltas=False, num_coded_sps=hparams.num_coded_sps)
        self.OutputGen.get_normalisation_params(dir_world_features)

        self.dataset_train = PyTorchLabelGensDataset(self.id_list_train, self.InputGen, self.OutputGen, hparams)
        self.dataset_val = PyTorchLabelGensDataset(self.id_list_val, self.InputGen, self.OutputGen, hparams)

        if self.loss_function is None:
            self.loss_function = torch.nn.MSELoss(reduction='none')

        if hparams.scheduler_type == "default":
            hparams.scheduler_type = "Plateau"
            hparams.plateau_verbose = True

    @staticmethod
    def create_hparams(hparams_string=None, verbose=False):
        hparams = ModelTrainer.create_hparams(hparams_string, verbose=False)
        hparams.num_coded_sps = 60

        if verbose:
            logging.info('Final parsed hparams: %s', hparams.values())

        return hparams

    def gen_figure_from_output(self, id_name, labels, hidden, hparams):

        labels_post = self.OutputGen.postprocess_sample(labels)
        mfcc, lf0, vuv, bap = WorldFeatLabelGen.convert_to_world_features(labels_post, num_coded_sps=self.OutputGen.num_coded_sps)
        lf0, _ = interpolate_lin(lf0)

        # Load original lf0.
        org_labels_post = WorldFeatLabelGen.load_sample(id_name, self.OutputGen.dir_labels, num_coded_sps=hparams.num_coded_sps)
        _, original_lf0, *_ = WorldFeatLabelGen.convert_to_world_features(org_labels_post, num_coded_sps=hparams.num_coded_sps)
        original_lf0, _ = interpolate_lin(original_lf0)

        # Get a data plotter.
        grid_idx = 0
        plotter = DataPlotter()
        net_name = os.path.basename(hparams.model_name)
        filename = str(os.path.join(hparams.out_dir, id_name + '.' + net_name))
        plotter.set_title(id_name + ' - ' + net_name)
        plotter.set_num_colors(3)
        # plotter.set_lim(grid_idx=grid_idx, ymin=math.log(60), ymax=math.log(250))
        plotter.set_label(grid_idx=grid_idx, xlabel='frames [' + str(hparams.frame_size_ms) + ' ms]', ylabel='log(f0)')
        grid_idx += 1

        graphs = list()
        # graphs.append((merlin_lf0, 'Merlin lf0'))
        graphs.append((original_lf0, 'Original lf0'))
        graphs.append((lf0, 'PyTorch lf0'))
        plotter.set_data_list(grid_idx=grid_idx, data_list=graphs)
        plotter.set_area_list(grid_idx=grid_idx, area_list=[(np.invert(vuv.astype(bool)), '0.8', 1.0)])

        if hasattr(hparams, "phoneme_indices") and hparams.phoneme_indices is not None \
           and hasattr(hparams, "question_file") and hparams.question_file is not None:
            questions = QuestionLabelGen.load_sample(id_name,
                                                     "experiments/" + hparams.voice + "/questions/",
                                                     num_questions=hparams.num_questions)[:len(lf0)]
            np_phonemes = QuestionLabelGen.questions_to_phonemes(questions, hparams.phoneme_indices, hparams.question_file)
            plotter.set_annotations(grid_idx, np_phonemes)

        plotter.gen_plot()
        plotter.save_to_file(filename + '.Org-PyTorch' + hparams.gen_figure_ext)
class AcousticDeltasModelTrainer(ModelTrainer):
    """
    Implementation of a ModelTrainer for the generation of acoustic data.

    Use question labels as input and WORLD features with deltas/double deltas as output.
    Synthesize audio from model output with MLPG smoothing.
    """
    logger = logging.getLogger(__name__)

    #########################
    # Default constructor
    #
    def __init__(self, dir_world_features, dir_question_labels, id_list, num_questions, hparams=None):
        """Default constructor.

        :param dir_world_features:      Path to the directory containing the world features.
        :param dir_question_labels:     Path to the directory containing the question labels.
        :param id_list:                 List of ids, can contain a speaker directory.
        :param num_questions:           Number of questions in question file.
        :param hparams:                 Set of hyper parameters.
        """
        if hparams is None:
            hparams = self.create_hparams()
            hparams.out_dir = os.path.curdir

        # Write missing default parameters.
        if hparams.variable_sequence_length_train is None:
            hparams.variable_sequence_length_train = hparams.batch_size_train > 1
        if hparams.variable_sequence_length_test is None:
            hparams.variable_sequence_length_test = hparams.batch_size_test > 1
        if hparams.synth_dir is None:
            hparams.synth_dir = os.path.join(hparams.out_dir, "synth")

        super(AcousticDeltasModelTrainer, self).__init__(id_list, hparams)

        self.InputGen = QuestionLabelGen(dir_question_labels, num_questions)
        self.InputGen.get_normalisation_params(dir_question_labels)

        self.OutputGen = WorldFeatLabelGen(dir_world_features, add_deltas=True, num_coded_sps=hparams.num_coded_sps)
        self.OutputGen.get_normalisation_params(dir_world_features)

        self.dataset_train = LabelGensDataset(self.id_list_train, self.InputGen, self.OutputGen, hparams)
        self.dataset_val = LabelGensDataset(self.id_list_val, self.InputGen, self.OutputGen, hparams)

        if self.loss_function is None:
            self.loss_function = torch.nn.MSELoss(reduction='none')

        if hparams.scheduler_type == "default":
            hparams.scheduler_type = "Plateau"
            hparams.plateau_verbose = True

    @staticmethod
    def create_hparams(hparams_string=None, verbose=False):
        """Create model hyperparameters. Parse nondefault from given string."""
        hparams = ModelTrainer.create_hparams(hparams_string, verbose=False)

        hparams.synth_load_org_sp = False
        hparams.synth_load_org_lf0 = False
        hparams.synth_load_org_vuv = False
        hparams.synth_load_org_bap = False

        if verbose:
            logging.info('Final parsed hparams: %s', hparams.values())

        return hparams

    # def set_train_params(self, learning_rate):
    #     """Overwrite baseclass method to change non_zero_occurrence value."""
    #
    #     # Defaults for criterion, optimiser_type and scheduler_type.
    #     # If not None it means that the model was loaded from a file.
    #     if self.model_handler.loss_function is not None:
    #         loss_function = self.model_handler.loss_function
    #     else:
    #         loss_function = torch.nn.MSELoss(reduction='none')
    #
    #     if self.model_handler.optimiser is not None:
    #         optimiser = self.model_handler.optimiser
    #         if learning_rate is not None:
    #             for g in optimiser.param_groups:
    #                 g['lr'] = learning_rate
    #     else:
    #         optimiser = torch.optim.Adam(self.model_handler.model.parameters(), lr=learning_rate)
    #
    #     if self.model_handler.scheduler is not None:
    #         scheduler = self.model_handler.scheduler
    #     else:
    #         scheduler = ReduceLROnPlateau(optimiser, verbose=True)
    #
    #     self.model_handler.set_train_params(loss_function, optimiser, scheduler, batch_loss=True)

    def gen_figure_from_output(self, id_name, label, hidden, hparams):

        labels_post = self.OutputGen.postprocess_sample(label)
        coded_sp, lf0, vuv, bap = WorldFeatLabelGen.convert_to_world_features(labels_post, contains_deltas=False, num_coded_sps=hparams.num_coded_sps)
        lf0, _ = interpolate_lin(lf0)

        # Load original lf0.
        org_labels_post = WorldFeatLabelGen.load_sample(id_name, self.OutputGen.dir_labels, add_deltas=self.OutputGen.add_deltas, num_coded_sps=hparams.num_coded_sps)
        original_mgc, original_lf0, original_vuv, *_ = WorldFeatLabelGen.convert_to_world_features(org_labels_post, contains_deltas=self.OutputGen.add_deltas, num_coded_sps=hparams.num_coded_sps)
        original_lf0, _ = interpolate_lin(original_lf0)

        # # Load merlin lf0.
        # merlin_data_dir = os.path.join(os.path.dirname(self.InputGen.label_dir), "wav_merlin/")
        # with open(os.path.join(merlin_data_dir, id_name + ".lf0"), 'rb') as f:
        #     merlin_lf0 = np.fromfile(f, dtype=np.float32)
        #     merlin_lf0 = np.reshape(merlin_lf0, merlin_lf0.shape[0] / original_lf0.shape[1], original_lf0.shape[1])
        #     merlin_lf0[merlin_lf0 < 0] = 0.0
        #     merlin_lf0, _ = interpolate_lin(merlin_lf0)

        # Get a data plotter.
        grid_idx = 0
        plotter = DataPlotter()
        net_name = os.path.basename(hparams.model_name)
        filename = str(os.path.join(hparams.out_dir, id_name + '.' + net_name))
        plotter.set_title(id_name + ' - ' + net_name)
        plotter.set_num_colors(3)
        # plotter.set_lim(grid_idx=0, ymin=math.log(60), ymax=math.log(250))
        plotter.set_label(grid_idx=grid_idx, xlabel='frames [' + str(hparams.frame_size_ms) + ' ms]', ylabel='log(f0)')

        graphs = list()
        # graphs.append((merlin_lf0, 'Merlin lf0'))
        graphs.append((original_lf0, 'Original lf0'))
        graphs.append((lf0, 'PyTorch lf0'))
        plotter.set_data_list(grid_idx=grid_idx, data_list=graphs)
        plotter.set_area_list(grid_idx=grid_idx, area_list=[(np.invert(vuv.astype(bool)), '0.8', 1.0),(np.invert(original_vuv.astype(bool)), 'red', 0.2)])

        grid_idx += 1
        plotter.set_label(grid_idx=grid_idx, xlabel='frames [' + str(hparams.frame_size_ms) + ' ms]', ylabel='Original spectrogram')
        plotter.set_specshow(grid_idx=grid_idx, spec=WorldFeatLabelGen.mgc_to_sp(original_mgc, hparams.synth_fs))

        grid_idx += 1
        plotter.set_label(grid_idx=grid_idx, xlabel='frames [' + str(hparams.frame_size_ms) + ' ms]', ylabel='NN spectrogram')
        plotter.set_specshow(grid_idx=grid_idx, spec=WorldFeatLabelGen.mgc_to_sp(coded_sp, hparams.synth_fs))

        if hasattr(hparams, "phoneme_indices") and hparams.phoneme_indices is not None \
           and hasattr(hparams, "question_file") and hparams.question_file is not None:
            questions = QuestionLabelGen.load_sample(id_name,
                                                     "experiments/" + hparams.voice + "/questions/",
                                                     num_questions=hparams.num_questions)[:len(lf0)]
            np_phonemes = QuestionLabelGen.questions_to_phonemes(questions, hparams.phoneme_indices, hparams.question_file)
            plotter.set_annotations(grid_idx, np_phonemes)

        plotter.gen_plot()
        plotter.save_to_file(filename + '.Org-PyTorch' + hparams.gen_figure_ext)

    def compute_score(self, dict_outputs_post, dict_hiddens, hparams):

        # Get data for comparision.
        dict_original_post = dict()
        for id_name in dict_outputs_post.keys():
            dict_original_post[id_name] = WorldFeatLabelGen.load_sample(id_name, self.OutputGen.dir_labels, True, num_coded_sps=hparams.num_coded_sps)

        f0_rmse = 0.0
        f0_rmse_max_id = "None"
        f0_rmse_max = 0.0
        all_rmse = []
        vuv_error_rate = 0.0
        vuv_error_max_id = "None"
        vuv_error_max = 0.0
        all_vuv = []
        mcd = 0.0
        mcd_max_id = "None"
        mcd_max = 0.0
        all_mcd = []
        bap_error = 0.0
        bap_error_max_id = "None"
        bap_error_max = 0.0
        all_bap_error = []

        for id_name, labels in dict_outputs_post.items():
            output_coded_sp, output_lf0, output_vuv, output_bap = self.OutputGen.convert_to_world_features(labels, False, num_coded_sps=hparams.num_coded_sps)
            output_vuv = output_vuv.astype(bool)

            # Get data for comparision.
            org_coded_sp, org_lf0, org_vuv, org_bap = self.OutputGen.convert_to_world_features(dict_original_post[id_name],
                                                                                               self.OutputGen.add_deltas,
                                                                                               num_coded_sps=hparams.num_coded_sps)

            # Compute f0 from lf0.
            org_f0 = np.exp(org_lf0.squeeze())[:len(output_lf0)]  # Fix minor negligible length mismatch.
            output_f0 = np.exp(output_lf0)

            # Compute MCD.
            org_coded_sp = org_coded_sp[:len(output_coded_sp)]
            current_mcd = metrics.melcd(output_coded_sp, org_coded_sp)  # TODO: Use aligned mcd.
            if current_mcd > mcd_max:
                mcd_max_id = id_name
                mcd_max = current_mcd
            mcd += current_mcd
            all_mcd.append(current_mcd)

            # Compute RMSE.
            f0_mse = (org_f0 - output_f0) ** 2
            current_f0_rmse = math.sqrt((f0_mse * org_vuv[:len(output_lf0)]).sum() / org_vuv[:len(output_lf0)].sum())
            if current_f0_rmse != current_f0_rmse:
                logging.error("Computed NaN for F0 RMSE for {}.".format(id_name))
            else:
                if current_f0_rmse > f0_rmse_max:
                    f0_rmse_max_id = id_name
                    f0_rmse_max = current_f0_rmse
                f0_rmse += current_f0_rmse
                all_rmse.append(current_f0_rmse)

            # Compute error of VUV in percentage.
            num_errors = (org_vuv[:len(output_lf0)] != output_vuv)
            vuv_error_rate_tmp = float(num_errors.sum()) / len(output_lf0)
            if vuv_error_rate_tmp > vuv_error_max:
                vuv_error_max_id = id_name
                vuv_error_max = vuv_error_rate_tmp
            vuv_error_rate += vuv_error_rate_tmp
            all_vuv.append(vuv_error_rate_tmp)

            # Compute aperiodicity distortion.
            org_bap = org_bap[:len(output_bap)]
            if len(output_bap.shape) > 1 and output_bap.shape[1] > 1:
                current_bap_error = metrics.melcd(output_bap, org_bap)  # TODO: Use aligned mcd?
            else:
                current_bap_error = math.sqrt(((org_bap - output_bap) ** 2).mean()) * (10.0 / np.log(10) * np.sqrt(2.0))
            if current_bap_error > bap_error_max:
                bap_error_max_id = id_name
                bap_error_max = current_bap_error
            bap_error += current_bap_error
            all_bap_error.append(current_bap_error)

        f0_rmse /= len(dict_outputs_post)
        vuv_error_rate /= len(dict_outputs_post)
        mcd /= len(dict_original_post)
        bap_error /= len(dict_original_post)

        self.logger.info("Worst MCD: {} {:4.2f}dB".format(mcd_max_id, mcd_max))
        self.logger.info("Worst F0 RMSE: {} {:4.2f}Hz".format(f0_rmse_max_id, f0_rmse_max))
        self.logger.info("Worst VUV error: {} {:2.2f}%".format(vuv_error_max_id, vuv_error_max * 100))
        self.logger.info("Worst BAP error: {} {:4.2f}db".format(bap_error_max_id, bap_error_max))
        self.logger.info("Benchmark score: MCD {:4.2f}dB, F0 RMSE {:4.2f}Hz, VUV {:2.2f}%, BAP error {:4.2f}db".format(mcd, f0_rmse, vuv_error_rate * 100, bap_error))

        return mcd, f0_rmse, vuv_error_rate, bap_error

    def synthesize(self, id_list, synth_output, hparams):
        """Depending on hparams override the network output with the extracted features, then continue with normal synthesis pipeline."""

        if hparams.synth_load_org_sp or hparams.synth_load_org_lf0 or hparams.synth_load_org_vuv or hparams.synth_load_org_bap:
            for id_name in id_list:

                labels = WorldFeatLabelGen.load_sample(id_name, os.path.join(self.OutputGen.dir_labels, self.dir_extracted_acoustic_features), num_coded_sps=hparams.num_coded_sps)
                len_diff = len(labels) - len(synth_output[id_name])
                if len_diff > 0:
                    labels = WorldFeatLabelGen.trim_end_sample(labels, int(len_diff / 2), reverse=True)
                    labels = WorldFeatLabelGen.trim_end_sample(labels, len_diff - int(len_diff / 2))

                if hparams.synth_load_org_sp:
                    synth_output[id_name][:len(labels), :self.OutputGen.num_coded_sps] = labels[:, :self.OutputGen.num_coded_sps]

                if hparams.synth_load_org_lf0:
                    synth_output[id_name][:len(labels), -3] = labels[:, -3]

                if hparams.synth_load_org_vuv:
                    synth_output[id_name][:len(labels), -2] = labels[:, -2]

                if hparams.synth_load_org_bap:
                    synth_output[id_name][:len(labels), -1] = labels[:, -1]

        # Run the vocoder.
        ModelTrainer.synthesize(self, id_list, synth_output, hparams)
    def compute_score(self, dict_outputs_post, dict_hiddens, hparams):

        # Get data for comparision.
        dict_original_post = dict()
        for id_name in dict_outputs_post.keys():
            dict_original_post[id_name] = WorldFeatLabelGen.load_sample(id_name, self.OutputGen.dir_labels, True, num_coded_sps=hparams.num_coded_sps)

        f0_rmse = 0.0
        f0_rmse_max_id = "None"
        f0_rmse_max = 0.0
        all_rmse = []
        vuv_error_rate = 0.0
        vuv_error_max_id = "None"
        vuv_error_max = 0.0
        all_vuv = []
        mcd = 0.0
        mcd_max_id = "None"
        mcd_max = 0.0
        all_mcd = []
        bap_error = 0.0
        bap_error_max_id = "None"
        bap_error_max = 0.0
        all_bap_error = []

        for id_name, labels in dict_outputs_post.items():
            output_coded_sp, output_lf0, output_vuv, output_bap = self.OutputGen.convert_to_world_features(labels, False, num_coded_sps=hparams.num_coded_sps)
            output_vuv = output_vuv.astype(bool)

            # Get data for comparision.
            org_coded_sp, org_lf0, org_vuv, org_bap = self.OutputGen.convert_to_world_features(dict_original_post[id_name],
                                                                                               self.OutputGen.add_deltas,
                                                                                               num_coded_sps=hparams.num_coded_sps)

            # Compute f0 from lf0.
            org_f0 = np.exp(org_lf0.squeeze())[:len(output_lf0)]  # Fix minor negligible length mismatch.
            output_f0 = np.exp(output_lf0)

            # Compute MCD.
            org_coded_sp = org_coded_sp[:len(output_coded_sp)]
            current_mcd = metrics.melcd(output_coded_sp, org_coded_sp)  # TODO: Use aligned mcd.
            if current_mcd > mcd_max:
                mcd_max_id = id_name
                mcd_max = current_mcd
            mcd += current_mcd
            all_mcd.append(current_mcd)

            # Compute RMSE.
            f0_mse = (org_f0 - output_f0) ** 2
            current_f0_rmse = math.sqrt((f0_mse * org_vuv[:len(output_lf0)]).sum() / org_vuv[:len(output_lf0)].sum())
            if current_f0_rmse != current_f0_rmse:
                logging.error("Computed NaN for F0 RMSE for {}.".format(id_name))
            else:
                if current_f0_rmse > f0_rmse_max:
                    f0_rmse_max_id = id_name
                    f0_rmse_max = current_f0_rmse
                f0_rmse += current_f0_rmse
                all_rmse.append(current_f0_rmse)

            # Compute error of VUV in percentage.
            num_errors = (org_vuv[:len(output_lf0)] != output_vuv)
            vuv_error_rate_tmp = float(num_errors.sum()) / len(output_lf0)
            if vuv_error_rate_tmp > vuv_error_max:
                vuv_error_max_id = id_name
                vuv_error_max = vuv_error_rate_tmp
            vuv_error_rate += vuv_error_rate_tmp
            all_vuv.append(vuv_error_rate_tmp)

            # Compute aperiodicity distortion.
            org_bap = org_bap[:len(output_bap)]
            if len(output_bap.shape) > 1 and output_bap.shape[1] > 1:
                current_bap_error = metrics.melcd(output_bap, org_bap)  # TODO: Use aligned mcd?
            else:
                current_bap_error = math.sqrt(((org_bap - output_bap) ** 2).mean()) * (10.0 / np.log(10) * np.sqrt(2.0))
            if current_bap_error > bap_error_max:
                bap_error_max_id = id_name
                bap_error_max = current_bap_error
            bap_error += current_bap_error
            all_bap_error.append(current_bap_error)

        f0_rmse /= len(dict_outputs_post)
        vuv_error_rate /= len(dict_outputs_post)
        mcd /= len(dict_original_post)
        bap_error /= len(dict_original_post)

        self.logger.info("Worst MCD: {} {:4.2f}dB".format(mcd_max_id, mcd_max))
        self.logger.info("Worst F0 RMSE: {} {:4.2f}Hz".format(f0_rmse_max_id, f0_rmse_max))
        self.logger.info("Worst VUV error: {} {:2.2f}%".format(vuv_error_max_id, vuv_error_max * 100))
        self.logger.info("Worst BAP error: {} {:4.2f}db".format(bap_error_max_id, bap_error_max))
        self.logger.info("Benchmark score: MCD {:4.2f}dB, F0 RMSE {:4.2f}Hz, VUV {:2.2f}%, BAP error {:4.2f}db".format(mcd, f0_rmse, vuv_error_rate * 100, bap_error))

        return mcd, f0_rmse, vuv_error_rate, bap_error
Example #14
0
    def gen_figure_from_output(self, id_name, labels, hidden, hparams):

        if labels.ndim < 2:
            labels = np.expand_dims(labels, axis=1)
        labels_post = self.OutputGen.postprocess_sample(labels)
        lf0 = self.OutputGen.labels_to_lf0(labels_post, hparams.k)
        lf0, vuv = interpolate_lin(lf0)
        vuv = vuv.astype(np.bool)

        # Load original lf0 and vuv.
        org_labels = WorldFeatLabelGen.load_sample(id_name, os.path.join(hparams.out_dir, self.dir_extracted_acoustic_features), num_coded_sps=hparams.num_coded_sps)
        _, original_lf0, original_vuv, _ = WorldFeatLabelGen.convert_to_world_features(org_labels, num_coded_sps=hparams.num_coded_sps)
        original_lf0, _ = interpolate_lin(original_lf0)
        original_vuv = original_vuv.astype(np.bool)

        phrase_curve = np.fromfile(os.path.join(self.OutputGen.dir_labels, id_name + self.OutputGen.ext_phrase),
                                   dtype=np.float32).reshape(-1, 1)
        original_lf0 -= phrase_curve
        len_diff = len(original_lf0) - len(lf0)
        original_lf0 = WorldFeatLabelGen.trim_end_sample(original_lf0, int(len_diff / 2.0))
        original_lf0 = WorldFeatLabelGen.trim_end_sample(original_lf0, int(len_diff / 2.0) + 1, reverse=True)

        org_labels = self.OutputGen.load_sample(id_name, self.OutputGen.dir_labels, len(hparams.thetas))
        org_labels = self.OutputGen.trim_end_sample(org_labels, int(len_diff / 2.0))
        org_labels = self.OutputGen.trim_end_sample(org_labels, int(len_diff / 2.0) + 1, reverse=True)
        org_atoms = self.OutputGen.labels_to_atoms(org_labels, k=hparams.k, frame_size=hparams.frame_size_ms)

        # Get a data plotter.
        net_name = os.path.basename(hparams.model_name)
        filename = str(os.path.join(hparams.out_dir, id_name + '.' + net_name))
        plotter = DataPlotter()
        plotter.set_title(id_name + " - " + net_name)

        graphs_output = list()
        grid_idx = 0
        for idx in reversed(range(labels.shape[1])):
            graphs_output.append((labels[:, idx], r'$\theta$=' + "{0:.3f}".format(hparams.thetas[idx])))
        plotter.set_label(grid_idx=grid_idx, xlabel='frames [' + str(hparams.frame_size_ms) + ' ms]', ylabel='NN output')
        plotter.set_data_list(grid_idx=grid_idx, data_list=graphs_output)
        # plotter.set_lim(grid_idx=0, ymin=-1.8, ymax=1.8)

        grid_idx += 1
        graphs_peaks = list()
        for idx in reversed(range(labels_post.shape[1])):
            graphs_peaks.append((labels_post[:, idx, 0],))
        plotter.set_label(grid_idx=grid_idx, xlabel='frames [' + str(hparams.frame_size_ms) + ' ms]', ylabel='NN post-processed')
        plotter.set_data_list(grid_idx=grid_idx, data_list=graphs_peaks)
        plotter.set_area_list(grid_idx=grid_idx, area_list=[(np.invert(vuv), '0.8', 1.0)])
        plotter.set_lim(grid_idx=grid_idx, ymin=-1.8, ymax=1.8)

        grid_idx += 1
        graphs_target = list()
        for idx in reversed(range(org_labels.shape[1])):
            graphs_target.append((org_labels[:, idx, 0],))
        plotter.set_label(grid_idx=grid_idx, xlabel='frames [' + str(hparams.frame_size_ms) + ' ms]', ylabel='target')
        plotter.set_data_list(grid_idx=grid_idx, data_list=graphs_target)
        plotter.set_area_list(grid_idx=grid_idx, area_list=[(np.invert(original_vuv), '0.8', 1.0)])
        plotter.set_lim(grid_idx=grid_idx, ymin=-1.8, ymax=1.8)

        grid_idx += 1
        output_atoms = AtomLabelGen.labels_to_atoms(labels_post, hparams.k, hparams.frame_size_ms, amp_threshold=hparams.min_atom_amp)
        wcad_lf0 = AtomLabelGen.atoms_to_lf0(org_atoms, len(labels))
        output_lf0 = AtomLabelGen.atoms_to_lf0(output_atoms, len(labels))
        graphs_lf0 = list()
        graphs_lf0.append((wcad_lf0, "wcad lf0"))
        graphs_lf0.append((original_lf0, "org lf0"))
        graphs_lf0.append((output_lf0, "predicted lf0"))
        plotter.set_data_list(grid_idx=grid_idx, data_list=graphs_lf0)
        plotter.set_area_list(grid_idx=grid_idx, area_list=[(np.invert(original_vuv), '0.8', 1.0)])
        plotter.set_label(grid_idx=grid_idx, xlabel='frames [' + str(hparams.frame_size_ms) + ' ms]', ylabel='lf0')
        amp_lim = max(np.max(np.abs(wcad_lf0)), np.max(np.abs(output_lf0))) * 1.1
        plotter.set_lim(grid_idx=grid_idx, ymin=-amp_lim, ymax=amp_lim)
        plotter.set_linestyles(grid_idx=grid_idx, linestyles=[':', '--', '-'])

        # plotter.set_lim(xmin=300, xmax=1100)
        plotter.gen_plot()
        plotter.save_to_file(filename + ".BASE.png")
Example #15
0
    def gen_figure_phrase(self, hparams, ids_input):
        id_list = ModelTrainer._input_to_str_list(ids_input)
        model_output, model_output_post = self._forward_batched(
            hparams,
            id_list,
            hparams.batch_size_gen_figure,
            synth=False,
            benchmark=False,
            gen_figure=False)

        for id_name, outputs_post in model_output_post.items():

            if outputs_post.ndim < 2:
                outputs_post = np.expand_dims(outputs_post, axis=1)

            lf0 = outputs_post[:, 0]
            output_lf0, _ = interpolate_lin(lf0)
            output_vuv = outputs_post[:, 1]
            output_vuv[output_vuv < 0.5] = 0.0
            output_vuv[output_vuv >= 0.5] = 1.0
            output_vuv = output_vuv.astype(np.bool)

            # Load original lf0 and vuv.
            world_dir = os.path.join(hparams.out_dir,
                                     self.dir_extracted_acoustic_features)
            org_labels = WorldFeatLabelGen.load_sample(
                id_name, world_dir,
                num_coded_sps=hparams.num_coded_sps)[:len(output_lf0)]
            _, original_lf0, original_vuv, _ = WorldFeatLabelGen.convert_to_world_features(
                org_labels, num_coded_sps=hparams.num_coded_sps)
            original_lf0, _ = interpolate_lin(original_lf0)
            original_vuv = original_vuv.astype(np.bool)

            phrase_curve = np.fromfile(os.path.join(
                self.flat_trainer.atom_trainer.OutputGen.dir_labels,
                id_name + self.OutputGen.ext_phrase),
                                       dtype=np.float32).reshape(
                                           -1, 1)[:len(original_lf0)]

            f0_mse = (np.exp(original_lf0.squeeze(-1)) -
                      np.exp(phrase_curve.squeeze(-1)))**2
            f0_rmse = math.sqrt(
                (f0_mse * original_vuv[:len(output_lf0)]).sum() /
                original_vuv[:len(output_lf0)].sum())
            self.logger.info("RMSE of {} phrase curve: {} Hz.".format(
                id_name, f0_rmse))

            len_diff = len(original_lf0) - len(lf0)
            original_lf0 = WorldFeatLabelGen.trim_end_sample(
                original_lf0, int(len_diff / 2.0))
            original_lf0 = WorldFeatLabelGen.trim_end_sample(
                original_lf0, int(len_diff / 2.0) + 1, reverse=True)

            # Get a data plotter.
            net_name = os.path.basename(hparams.model_name)
            filename = str(
                os.path.join(hparams.out_dir, id_name + '.' + net_name))
            plotter = DataPlotter()
            # plotter.set_title(id_name + " - " + net_name)

            grid_idx = 0
            graphs_lf0 = list()
            graphs_lf0.append((original_lf0, "Original"))
            graphs_lf0.append((phrase_curve, "Predicted"))
            plotter.set_data_list(grid_idx=grid_idx, data_list=graphs_lf0)
            plotter.set_area_list(grid_idx=grid_idx,
                                  area_list=[(np.invert(original_vuv), '0.8',
                                              1.0, 'Reference unvoiced')])
            plotter.set_label(grid_idx=grid_idx,
                              xlabel='frames [' + str(hparams.frame_size_ms) +
                              ' ms]',
                              ylabel='LF0')
            # amp_lim = max(np.max(np.abs(wcad_lf0)), np.max(np.abs(output_lf0))) * 1.1
            # plotter.set_lim(grid_idx=grid_idx, ymin=-amp_lim, ymax=amp_lim)
            plotter.set_lim(grid_idx=grid_idx, ymin=4.2, ymax=5.4)
            # plotter.set_linestyles(grid_idx=grid_idx, linestyles=[':', '--', '-'])

            # plotter.set_lim(xmin=300, xmax=1100)
            plotter.gen_plot()
            plotter.save_to_file(filename + ".PHRASE" + hparams.gen_figure_ext)