Beispiel #1
0
    def compute_score(self, data, output, hparams):

        dict_original_post = self.get_output_dict(data.keys(),
                                                  hparams,
                                                  chunk_size=hparams.get_value(
                                                      "n_frames_per_step",
                                                      default=1))

        metric_dict = {}
        for label_name in next(iter(data.values())).keys():
            metric = Metrics(hparams.metrics)
            for id_name, labels in data.items():
                labels = labels[label_name]
                output = WorldFeatLabelGen.convert_to_world_features(
                    sample=labels,
                    contains_deltas=False,
                    num_coded_sps=hparams.num_coded_sps,
                    num_bap=hparams.num_bap)

                org = WorldFeatLabelGen.convert_to_world_features(
                    sample=dict_original_post[id_name],
                    contains_deltas=hparams.add_deltas,
                    num_coded_sps=hparams.num_coded_sps,
                    num_bap=hparams.num_bap)

                current_metrics = metric.get_metrics(hparams.metrics, *org,
                                                     *output)
                metric.accumulate(id_name, current_metrics)

            metric.log()
            metric_dict[label_name] = metric.get_cum_values()

        return metric_dict
Beispiel #2
0
    def run_atom_synth(self, file_id_list, synth_output, hparams):
        """
        Reconstruct lf0, get mgc and bap data, and store all in files in self.synth_dir.
        """

        # Get mgc, vuv and bap data either through a trained acoustic model or from data extracted from the audio.
        if hparams.synth_acoustic_model_path is None:
            full_output = self.load_extracted_audio_features(
                synth_output, hparams)
        else:
            self.logger.warning("This method is untested.")
            full_output = self.generate_audio_features(file_id_list, hparams)

        # Reconstruct lf0 from generated atoms and write it to synth output.
        recon_dict = self.get_recon_from_synth_output(synth_output, hparams)
        for id_name, lf0 in recon_dict.items():
            full_sample = full_output[id_name]
            len_diff = len(full_sample) - len(lf0)
            full_sample = WorldFeatLabelGen.trim_end_sample(full_sample,
                                                            int(len_diff / 2),
                                                            reverse=True)
            full_sample = WorldFeatLabelGen.trim_end_sample(
                full_sample, len_diff - int(len_diff / 2))
            vuv = np.ones(lf0.shape)
            vuv[lf0 <= math.log(WorldFeatLabelGen.f0_silence_threshold)] = 0.0
            full_sample[:, hparams.num_coded_sps] = lf0
            full_sample[:, hparams.num_coded_sps + 1] = vuv

        return full_output
Beispiel #3
0
    def run_world_synth(synth_output: Dict[str, np.ndarray],
                        hparams: ExtendedHParams,
                        epoch: int = None,
                        step: int = None,
                        use_model_name: bool = True,
                        has_deltas: bool = False) -> None:
        """Run the WORLD synthesize method."""

        fft_size = pyworld.get_cheaptrick_fft_size(hparams.synth_fs)

        save_dir = Synthesiser._get_synth_dir(hparams,
                                              use_model_name,
                                              epoch=epoch,
                                              step=step)

        for id_name, output in synth_output.items():
            logging.info(
                "Synthesise {} with the WORLD vocoder.".format(id_name))

            coded_sp, lf0, vuv, bap = WorldFeatLabelGen.convert_to_world_features(
                output,
                contains_deltas=has_deltas,
                num_coded_sps=hparams.num_coded_sps,
                num_bap=hparams.num_bap)
            amp_sp = AudioProcessing.decode_sp(
                coded_sp,
                hparams.sp_type,
                hparams.synth_fs,
                post_filtering=hparams.do_post_filtering).astype(np.double,
                                                                 copy=False)
            args = dict()
            for attr in "preemphasis", "f0_silence_threshold", "lf0_zero":
                if hasattr(hparams, attr):
                    args[attr] = getattr(hparams, attr)
            waveform = WorldFeatLabelGen.world_features_to_raw(
                amp_sp,
                lf0,
                vuv,
                bap,
                fs=hparams.synth_fs,
                n_fft=fft_size,
                **args)

            # Always save as wav file first and convert afterwards if necessary.
            file_name = (os.path.basename(id_name) +
                         hparams.synth_file_suffix + '_' +
                         str(hparams.num_coded_sps) + hparams.sp_type +
                         "_WORLD")
            file_path = os.path.join(save_dir, file_name)
            soundfile.write(file_path + ".wav", waveform, hparams.synth_fs)

            # Use PyDub for special audio formats.
            if hparams.synth_ext.lower() != 'wav':
                as_wave = pydub.AudioSegment.from_wav(file_path + ".wav")
                file = as_wave.export(file_path + "." + hparams.synth_ext,
                                      format=hparams.synth_ext)
                file.close()
                os.remove(file_path + ".wav")
Beispiel #4
0
    def synth_ref(hparams, file_id_list, feature_dir=None):
        # Create reference audio files containing only the vocoder degradation.
        logging.info("Synthesise references with {} for [{}].".format(
            hparams.synth_vocoder,
            ", ".join([id_name for id_name in file_id_list])))

        synth_dict = dict()
        old_synth_file_suffix = hparams.synth_file_suffix
        hparams.synth_file_suffix = '_ref'
        if hparams.synth_vocoder == "WORLD":
            for id_name in file_id_list:
                # Load reference audio features.
                try:
                    output = WorldFeatLabelGen.load_sample(
                        id_name,
                        feature_dir,
                        num_coded_sps=hparams.num_coded_sps)
                except FileNotFoundError as e1:
                    try:
                        output = WorldFeatLabelGen.load_sample(
                            id_name,
                            feature_dir,
                            add_deltas=True,
                            num_coded_sps=hparams.num_coded_sps)
                        coded_sp, lf0, vuv, bap = WorldFeatLabelGen.convert_to_world_features(
                            output,
                            contains_deltas=True,
                            num_coded_sps=hparams.num_coded_sps)
                        length = len(output)
                        lf0 = lf0.reshape(length, 1)
                        vuv = vuv.reshape(length, 1)
                        bap = bap.reshape(length, 1)
                        output = np.concatenate((coded_sp, lf0, vuv, bap),
                                                axis=1)
                    except FileNotFoundError as e2:
                        logging.error(
                            "Cannot find extracted WORLD features with or without deltas in {}."
                            .format(feature_dir))
                        raise Exception([e1, e2])
                synth_dict[id_name] = output

            # Add identifier to suffix.
            old_synth_file_suffix = hparams.synth_file_suffix
            hparams.synth_file_suffix += str(hparams.num_coded_sps) + 'sp'
            Synthesiser.run_world_synth(synth_dict, hparams)
        elif hparams.synth_vocoder == "raw":
            for id_name in file_id_list:
                # Use extracted data. Useful to create a reference.
                raw = RawWaveformLabelGen.load_sample(
                    id_name, hparams.frame_rate_output_Hz)
                synth_dict[id_name] = raw
            Synthesiser.run_raw_synth(synth_dict, hparams)
        else:
            raise NotImplementedError("Unknown vocoder type {}.".format(
                hparams.synth_vocoder))

        # Restore identifier.
        hparams.synth_file_suffix = old_synth_file_suffix
    def __init__(self,
                 dir_world_features,
                 dir_question_labels,
                 id_list,
                 num_questions,
                 hparams=None):
        """Default constructor.

        :param dir_world_features:      Path to the directory containing the world features.
        :param dir_question_labels:     Path to the directory containing the question labels.
        :param id_list:                 List of ids, can contain a speaker directory.
        :param num_questions:           Number of questions in question file.
        :param hparams:                 Set of hyper parameters.
        """
        if hparams is None:
            hparams = self.create_hparams()
            hparams.out_dir = os.path.curdir

        # Write missing default parameters.
        if hparams.variable_sequence_length_train is None:
            hparams.variable_sequence_length_train = hparams.batch_size_train > 1
        if hparams.variable_sequence_length_test is None:
            hparams.variable_sequence_length_test = hparams.batch_size_test > 1
        if hparams.synth_dir is None:
            hparams.synth_dir = os.path.join(hparams.out_dir, "synth")

        super(AcousticModelTrainer, self).__init__(id_list, hparams)

        self.InputGen = QuestionLabelGen(dir_question_labels, num_questions)
        self.InputGen.get_normalisation_params(
            dir_question_labels, hparams.input_norm_params_file_prefix)

        self.OutputGen = WorldFeatLabelGen(dir_world_features,
                                           add_deltas=hparams.add_deltas,
                                           num_coded_sps=hparams.num_coded_sps,
                                           sp_type=hparams.sp_type)
        self.OutputGen.get_normalisation_params(
            dir_world_features, hparams.output_norm_params_file_prefix)

        self.dataset_train = LabelGensDataset(self.id_list_train,
                                              self.InputGen,
                                              self.OutputGen,
                                              hparams,
                                              match_lengths=True)
        self.dataset_val = LabelGensDataset(self.id_list_val,
                                            self.InputGen,
                                            self.OutputGen,
                                            hparams,
                                            match_lengths=True)

        if self.loss_function is None:
            self.loss_function = torch.nn.MSELoss(reduction='none')

        if hparams.scheduler_type == "default":
            hparams.scheduler_type = "Plateau"
            hparams.add_hparams(plateau_verbose=True)
Beispiel #6
0
    def run_r9y9wavenet_mulaw_world_feats_synth(synth_output, hparams):

        # If no path is given, use pre-trained model.
        if not hasattr(
                hparams,
                "synth_vocoder_path") or hparams.synth_vocoder_path is None:
            parent_dirs = os.path.realpath(__file__).split(os.sep)
            dir_root = str.join(
                os.sep, parent_dirs[:parent_dirs.index("IdiapTTS") + 1])
            hparams.synth_vocoder_path = os.path.join(
                dir_root, "idiaptts", "misc", "pretrained",
                "r9y9wavenet_quantized_16k_world_feats_English.nn")

        # Default quantization is with mu=255.
        if not hasattr(hparams, "mu") or hparams.mu is None:
            hparams.add_hparam("mu", 255)

        if hasattr(hparams, 'frame_rate_output_Hz'):
            org_frame_rate_output_Hz = hparams.frame_rate_output_Hz
            hparams.frame_rate_output_Hz = 16000
        else:
            org_frame_rate_output_Hz = None
            hparams.add_hparam("frame_rate_output_Hz", 16000)

        synth_output = copy.copy(synth_output)

        if hparams.do_post_filtering:
            for id_name, output in synth_output.items():
                coded_sp, lf0, vuv, bap = WorldFeatLabelGen.convert_to_world_features(
                    output,
                    contains_deltas=False,
                    num_coded_sps=hparams.num_coded_sps)
                coded_sp = merlin_post_filter(
                    coded_sp,
                    WorldFeatLabelGen.fs_to_mgc_alpha(hparams.synth_fs))
                synth_output[
                    id_name] = WorldFeatLabelGen.convert_from_world_features(
                        coded_sp, lf0, vuv, bap)

        if hasattr(hparams, 'bit_depth'):
            org_bit_depth = hparams.bit_depth
            hparams.bit_depth = 16
        else:
            org_bit_depth = None
            hparams.add_hparam("bit_depth", 16)

        Synthesiser.run_wavenet_vocoder(synth_output, hparams)

        # Restore identifier.
        hparams.setattr_no_type_check(
            "bit_depth", org_bit_depth)  # Can be None, thus no type check.
        hparams.setattr_no_type_check("frame_rate_output_Hz",
                                      org_frame_rate_output_Hz)  # Can be None.
Beispiel #7
0
    def legacy_support_init(dir_world_features: os.PathLike,
                            dir_question_labels: os.PathLike,
                            id_list: List[str], num_questions: int,
                            hparams: ExtendedHParams):
        """Get arguments for new init.

        :param dir_world_features:   Path to the directory containing
                                     the world features.
        :param dir_question_labels:  Path to the directory containing
                                     the question labels.
        :param id_list:              List of ids, can contain a speaker
                                     directory.
        :param num_questions:        Number of questions in question
                                     file (only needed for legacy code).
        :param hparams:              Set of hyper parameters.
        """
        data_reader_configs = []
        from idiaptts.src.data_preparation.DataReaderConfig import DataReaderConfig
        data_reader_configs.append(
            DataReaderConfig(name="questions",
                             feature_type="QuestionLabelGen",
                             directory=dir_question_labels,
                             features="questions",
                             num_questions=num_questions,
                             match_length=["cmp_features"]))
        # if hasattr(hparams, "add_deltas") and hparams.add_deltas:
        data_reader_configs.append(
            WorldFeatLabelGen.Config(
                name="cmp_features",
                # feature_type="WorldFeatLabelGen",
                directory=dir_world_features,
                features=["cmp_mcep" + str(hparams.num_coded_sps)],
                output_names=["acoustic_features"],
                add_deltas=hparams.add_deltas,
                num_coded_sps=hparams.num_coded_sps,
                num_bap=hparams.num_bap,
                sp_type=hparams.sp_type,
                requires_seq_mask=True,
                match_length=["questions"]))
        hparams.world_dir = dir_world_features
        # else:
        # # TODO: How to load them separately?
        # datareader_configs.append(
        #     DataReader.Config(
        #         name="cmp_features",
        #         feature_type="WorldFeatLabelGen",
        #         directory=dir_world_features,
        #         features=["cmp_mcep" + str(hparams.num_coded_sps)],
        #         output_names=["acoustic_features"],
        #         add_deltas=hparams.add_deltas,
        #         num_coded_sps=hparams.num_coded_sps,
        #         num_bap=hparams.num_bap,
        #         sp_type=hparams.sp_type,
        #         requires_seq_mask=True
        #     )
        # )

        return dict(data_reader_configs=data_reader_configs,
                    hparams=hparams,
                    id_list=id_list)
Beispiel #8
0
    def _get_trainer(self, hparams):
        dir_world_features = "integration/fixtures/WORLD"
        dir_question_labels = "integration/fixtures/questions"

        trainer = ModelTrainer(self.id_list, hparams)

        # Create datasets to work on.
        trainer.InputGen = QuestionLabelGen(dir_question_labels,
                                            hparams.num_questions)
        trainer.InputGen.get_normalisation_params(dir_question_labels)

        trainer.OutputGen = WorldFeatLabelGen(
            dir_world_features,
            num_coded_sps=hparams.num_coded_sps,
            add_deltas=True)
        trainer.OutputGen.get_normalisation_params(dir_world_features)

        trainer.dataset_train = LabelGensDataset(trainer.id_list_train,
                                                 trainer.InputGen,
                                                 trainer.OutputGen,
                                                 hparams,
                                                 match_lengths=True)
        trainer.dataset_val = LabelGensDataset(trainer.id_list_val,
                                               trainer.InputGen,
                                               trainer.OutputGen,
                                               hparams,
                                               match_lengths=True)

        trainer.loss_function = torch.nn.MSELoss(reduction='none')

        return trainer
Beispiel #9
0
    def synthesize(self, id_list, synth_output, hparams):

        # Reconstruct lf0 from generated atoms and write it to synth output.
        # recon_dict = self.get_recon_from_synth_output(synth_output)
        full_output = dict()
        for id_name, labels in synth_output.items():
            # Take lf0 and vuv from network output.
            lf0 = labels[:, 0]
            vuv = labels[:, 1]

            phrase_curve = self.OutputGen.get_phrase_curve(id_name)
            lf0 = lf0 + phrase_curve[:len(lf0)].squeeze()

            vuv[vuv < 0.5] = 0.0
            vuv[vuv >= 0.5] = 1.0

            # Get mgc, vuv and bap data either through a trained acoustic model or from data extracted from the audio.
            if hparams.synth_acoustic_model_path is None:

                world_dir = hparams.world_dir if hasattr(hparams, "world_dir") and hparams.world_dir is not None \
                                              else os.path.join(self.OutputGen.dir_labels, self.dir_extracted_acoustic_features)
                full_sample: np.ndarray = WorldFeatLabelGen.load_sample(
                    id_name,
                    world_dir,
                    add_deltas=False,
                    num_coded_sps=hparams.num_coded_sps,
                    num_bap=hparams.num_bap)  # Load extracted data.
                len_diff = len(full_sample) - len(lf0)
                trim_front = len_diff // 2
                trim_end = len_diff - trim_front
                full_sample = WorldFeatLabelGen.trim_end_sample(
                    full_sample, trim_end)
                full_sample = WorldFeatLabelGen.trim_end_sample(full_sample,
                                                                trim_front,
                                                                reverse=True)
            else:
                raise NotImplementedError()

            # Overwrite lf0 and vuv by network output.
            full_sample[:, hparams.num_coded_sps] = lf0
            full_sample[:, hparams.num_coded_sps + 1] = vuv
            # Fill a dictionary with the samples.
            full_output[id_name + "_E2E"] = full_sample

        # Run the merlin synthesizer
        Synthesiser.run_world_synth(full_output, hparams)
    def synthesize(self, id_list, synth_output, hparams):
        """
        Synthesise LF0 from atoms. The run_atom_synth function either loads the original acoustic features or uses an
        acoustic model to predict them.
        """
        full_output = self.run_atom_synth(id_list, synth_output, hparams)

        for id_name, labels in full_output.items():
            lf0 = labels[:, -3]
            lf0, _ = interpolate_lin(lf0)
            vuv = synth_output[id_name][:, 0, 1]
            len_diff = len(labels) - len(vuv)
            labels = WorldFeatLabelGen.trim_end_sample(labels, int(len_diff / 2), reverse=True)
            labels = WorldFeatLabelGen.trim_end_sample(labels, len_diff - int(len_diff / 2))
            labels[:, -2] = vuv

        # Run the vocoder.
        ModelTrainer.synthesize(self, id_list, full_output, hparams)
Beispiel #11
0
    def synthesize(self, id_list, synth_output, hparams):
        """Save output of model to .lf0 and (.vuv) files and call Merlin synth which reads those files."""

        # Reconstruct lf0 from generated atoms and write it to synth output.
        # recon_dict = self.get_recon_from_synth_output(synth_output)
        full_output = dict()
        for id_name, labels in synth_output.items():
            # Take lf0 and vuv from network output.
            lf0 = labels[:, 0]
            vuv = labels[:, 1]

            vuv[vuv < 0.5] = 0.0
            vuv[vuv >= 0.5] = 1.0

            # Get mgc, vuv and bap data either through a trained acoustic model or from data extracted from the audio.
            if hparams.synth_acoustic_model_path is None:
                world_dir = hparams.world_dir if hasattr(hparams, "world_dir") and hparams.world_dir is not None\
                                              else os.path.realpath(os.path.join(hparams.out_dir, self.dir_extracted_acoustic_features))
                full_sample: np.ndarray = WorldFeatLabelGen.load_sample(
                    id_name,
                    world_dir,
                    add_deltas=False,
                    num_coded_sps=hparams.num_coded_sps
                )  # Load extracted data.
                len_diff = len(full_sample) - len(lf0)
                trim_front = len_diff // 2
                trim_end = len_diff - trim_front
                full_sample = WorldFeatLabelGen.trim_end_sample(
                    full_sample, trim_end)
                full_sample = WorldFeatLabelGen.trim_end_sample(full_sample,
                                                                trim_front,
                                                                reverse=True)
            else:
                raise NotImplementedError()

            # Overwrite lf0 and vuv by network output.
            full_sample[:, hparams.num_coded_sps] = lf0
            full_sample[:, hparams.num_coded_sps + 1] = vuv
            # Fill a dictionary with the samples.
            full_output[id_name + "_E2E_Phrase"] = full_sample

        # Run the vocoder.
        ModelTrainer.synthesize(self, id_list, full_output, hparams)
    def synthesize(self, id_list, synth_output, hparams):
        """
        Depending on hparams override the network output with the extracted features,
        then continue with normal synthesis pipeline.
        """

        if hparams.synth_load_org_sp\
                or hparams.synth_load_org_lf0\
                or hparams.synth_load_org_vuv\
                or hparams.synth_load_org_bap:
            for id_name in id_list:

                world_dir = hparams.world_dir if hasattr(hparams, "world_dir") and hparams.world_dir is not None\
                                              else os.path.join(self.OutputGen.dir_labels,
                                                                self.dir_extracted_acoustic_features)
                labels = WorldFeatLabelGen.load_sample(
                    id_name, world_dir, num_coded_sps=hparams.num_coded_sps)
                len_diff = len(labels) - len(synth_output[id_name])
                if len_diff > 0:
                    labels = WorldFeatLabelGen.trim_end_sample(labels,
                                                               int(len_diff /
                                                                   2),
                                                               reverse=True)
                    labels = WorldFeatLabelGen.trim_end_sample(
                        labels, len_diff - int(len_diff / 2))

                if hparams.synth_load_org_sp:
                    synth_output[
                        id_name][:len(labels), :self.OutputGen.
                                 num_coded_sps] = labels[:, :self.OutputGen.
                                                         num_coded_sps]

                if hparams.synth_load_org_lf0:
                    synth_output[id_name][:len(labels), -3] = labels[:, -3]

                if hparams.synth_load_org_vuv:
                    synth_output[id_name][:len(labels), -2] = labels[:, -2]

                if hparams.synth_load_org_bap:
                    synth_output[id_name][:len(labels), -1] = labels[:, -1]

        # Run the vocoder.
        ModelTrainer.synthesize(self, id_list, synth_output, hparams)
Beispiel #13
0
    def load_extracted_audio_features(self, synth_output, hparams):
        """Load the audio features extracted from audio."""
        self.logger.info("Load extracted mgc, lf0, vuv, bap data.")

        org_output = dict()
        for id_name in synth_output.keys():
            world_dir = hparams.world_dir if hasattr(hparams, "world_dir") and hparams.world_dir is not None\
                                          else os.path.realpath(os.path.join(self.OutputGen.dir_labels, self.dir_extracted_acoustic_features))
            org_output[id_name] = WorldFeatLabelGen.load_sample(
                id_name,
                world_dir,
                add_deltas=False,
                num_coded_sps=hparams.num_coded_sps)  # Load extracted data.

        return org_output
Beispiel #14
0
 def get_output_dict(self, id_list, hparams, chunk_size=1):
     assert hparams.has_value("world_dir"), \
         "hparams.world_dir must be set for this operation."
     dict_original_post = dict()
     for id_name in id_list:
         sample = WorldFeatLabelGen.load_sample(
             id_name,
             dir_out=hparams.world_dir,
             add_deltas=hparams.add_deltas,
             num_coded_sps=hparams.num_coded_sps,
             sp_type=hparams.sp_type,
             num_bap=hparams.num_bap,
             load_sp=hparams.load_sp,
             load_lf0=hparams.load_lf0,
             load_vuv=hparams.load_vuv,
             load_bap=hparams.load_bap)
         if chunk_size > 1:
             sample = WorldFeatLabelGen.pad(None,
                                            sample,
                                            _get_padding_sizes(
                                                sample, chunk_size),
                                            pad_mode='constant')
         dict_original_post[id_name] = sample
     return dict_original_post
Beispiel #15
0
def main():
    from idiaptts.src.model_trainers.vtln.VTLNSpeakerAdaptionModelTrainer import VTLNSpeakerAdaptionModelTrainer
    hparams = VTLNSpeakerAdaptionModelTrainer.create_hparams()
    hparams.use_gpu = False
    hparams.voice = "English"
    hparams.model_name = "WarpingLayerTest.nn"
    hparams.add_deltas = True
    hparams.num_coded_sps = 30
    # hparams.num_questions = 505
    hparams.num_questions = 425
    hparams.out_dir = "experiments/" + hparams.voice + "/VTLNArtificiallyWarped/"
    hparams.data_dir = os.path.realpath("database")
    hparams.model_name = "warping_layer_test"
    hparams.synth_dir = hparams.out_dir
    batch_size = 2
    dir_world_labels = os.path.join("experiments", hparams.voice, "WORLD")

    from idiaptts.src.data_preparation.world.WorldFeatLabelGen import WorldFeatLabelGen
    gen_in = WorldFeatLabelGen(dir_world_labels,
                               add_deltas=hparams.add_deltas,
                               num_coded_sps=hparams.num_coded_sps)
    gen_in.get_normalisation_params(gen_in.dir_labels)

    from idiaptts.src.model_trainers.AcousticModelTrainer import AcousticModelTrainer
    trainer = AcousticModelTrainer(
        "experiments/" + hparams.voice + "/WORLD",
        "experiments/" + hparams.voice + "/questions", "ignored",
        hparams.num_questions, hparams)

    sp_mean = gen_in.norm_params[0][:hparams.num_coded_sps *
                                    (3 if hparams.add_deltas else 1)]
    sp_std_dev = gen_in.norm_params[1][:hparams.num_coded_sps *
                                       (3 if hparams.add_deltas else 1)]
    wl = WarpingLayer((hparams.num_coded_sps, ), (hparams.num_coded_sps, ),
                      hparams)
    wl.set_norm_params(sp_mean, sp_std_dev)

    # id_list = ["dorian/doriangray_16_00199"]
    id_list = ["p225/p225_051"]
    hparams.num_speakers = 1

    t_benchmark = 0
    for id_name in id_list:
        for idx, alpha in enumerate(np.arange(-0.15, 0.2, 0.05)):
            out_dir = hparams.out_dir + "alpha_{0:0.2f}/".format(alpha)
            makedirs_safe(out_dir)

            sample = WorldFeatLabelGen.load_sample(
                id_name,
                os.path.join("experiments", hparams.voice, "WORLD"),
                add_deltas=True,
                num_coded_sps=hparams.num_coded_sps)
            sample_pre = gen_in.preprocess_sample(sample)
            coded_sps = sample_pre[:, :hparams.num_coded_sps *
                                   (3 if hparams.add_deltas else 1)]

            alpha_vec = np.ones((coded_sps.shape[0], 1)) * alpha

            coded_sps = coded_sps[:len(alpha_vec), None, ...].repeat(
                batch_size, 1)  # Copy data in batch dimension.
            alpha_vec = alpha_vec[:, None, None].repeat(
                batch_size, 1)  # Copy data in batch dimension.

            t_start = timer()
            mfcc_warped, (_, nn_alpha) = wl(torch.from_numpy(coded_sps),
                                            None, (len(coded_sps), ),
                                            (len(coded_sps), ),
                                            alphas=torch.from_numpy(alpha_vec))
            mfcc_warped.sum().backward()
            t_benchmark += timer() - t_start
            assert ((mfcc_warped[:, 0] == mfcc_warped[:, 1]).all()
                    )  # Compare results for cloned coded_sps within batch.
            if alpha == 0:
                assert ((mfcc_warped == coded_sps).all()
                        )  # Compare results for no warping.
            sample_pre[:len(mfcc_warped), :hparams.num_coded_sps * (
                3 if hparams.add_deltas else 1)] = mfcc_warped[:, 0].detach()

            sample_post = gen_in.postprocess_sample(sample_pre)
            # Manually create samples without normalisation but with deltas.
            sample_pre = (sample_pre * gen_in.norm_params[1] +
                          gen_in.norm_params[0]).astype(np.float32)

            if np.isnan(sample_pre).any():
                raise ValueError(
                    "Detected nan values in output features for {}.".format(
                        id_name))
            # Save warped features.
            makedirs_safe(os.path.dirname(os.path.join(out_dir, id_name)))
            sample_pre.tofile(
                os.path.join(out_dir, id_name + WorldFeatLabelGen.ext_deltas))

            hparams.synth_dir = out_dir
            Synthesiser.run_world_synth({id_name: sample_post}, hparams)

    print("Process time for {} runs: {}".format(
        len(id_list) * idx, timedelta(seconds=t_benchmark)))
Beispiel #16
0
    def process_file(self,
                     file,
                     dir_audio,
                     dir_out,
                     silence_threshold_db=-50,
                     hop_size_ms=None):
        # sound = AudioSegment.from_file(os.path.join(dir_audio, file), format=audio_format)
        # trim_start = self._detect_leading_silence(sound, silence_threshold_db, chunk_size_ms)
        # trim_end = self._detect_leading_silence(sound.reverse(), silence_threshold_db, chunk_size_ms)

        raw, fs = soundfile.read(os.path.join(dir_audio, file))

        frame_length = WorldFeatLabelGen.fs_to_frame_length(fs)
        if hop_size_ms is None:
            hop_size_ms = min(self.min_silence_ms, 32)

        _, indices = librosa.effects.trim(raw,
                                          top_db=abs(silence_threshold_db),
                                          frame_length=frame_length,
                                          hop_length=int(fs / 1000 *
                                                         hop_size_ms))
        trim_start = indices[0] / fs * 1000
        trim_end = (len(raw) - indices[1]) / fs * 1000

        # Add silence to the front if audio starts to early.
        if trim_start < self.min_silence_ms:
            # TODO: Find a robust way to create silence so that HTK alignment still works (maybe concat mirrored segments).
            logging.warning(
                "File {} has only {} ms of silence in the beginning.".format(
                    file, trim_start))
            # AudioSegment.silent(duration=self.min_silence_ms-trim_start)
            # if trim_start > 0:
            #     silence = (sound[:trim_start] * (math.ceil(self.min_silence_ms / trim_start) - 1))[:self.min_silence_ms-trim_start]
            #     sound = silence + sound
            # elif trim_end > 0:
            #     silence = (sound[-trim_end:] * (math.ceil(self.min_silence_ms / trim_end) - 1))[:self.min_silence_ms-trim_end]
            #     sound = silence + sound
            # else:
            #     self.logger.warning("Cannot append silence to the front of " + file + ". No silence exists at front or end which can be copied.")
            trim_start = 0
        else:
            trim_start -= self.min_silence_ms

        # Append silence if audio ends too late.
        if trim_end < self.min_silence_ms:
            logging.warning(
                "File {} has only {} ms of silence in the end.".format(
                    file, trim_end))
            # silence = AudioSegment.silent(duration=self.min_silence_ms-trim_end)
            # if trim_end > 0:
            #     silence = (sound[-trim_end:] * (math.ceil(self.min_silence_ms / trim_end) - 1))[:self.min_silence_ms-trim_end]
            #     sound = sound + silence
            # elif trim_start > 0:
            #     silence = (sound[:trim_start] * (math.ceil(self.min_silence_ms / trim_start) - 1))[:self.min_silence_ms-trim_start]
            #     sound = sound + silence
            # else:
            #     self.logger.warning("Cannot append silence to the end of " + file + ". No silence exists at front or end which can be copied.")
            trim_end = 0
        else:
            trim_end -= self.min_silence_ms

        # Trim the sound.
        trimmed_raw = raw[int(trim_start * fs /
                              1000):int(-trim_end * fs / 1000 - 1)]
        # trimmed_sound = sound[trim_start:-trim_end-1]

        # Save trimmed sound to file.
        out_file = os.path.join(dir_out, file)
        makedirs_safe(os.path.dirname(out_file))
        soundfile.write(out_file, trimmed_raw, samplerate=fs)

        return trimmed_raw
Beispiel #17
0
def main():
    """Create samples with artificial alpha for each phoneme."""
    from idiaptts.src.model_trainers.vtln.VTLNSpeakerAdaptionModelTrainer import VTLNSpeakerAdaptionModelTrainer
    hparams = VTLNSpeakerAdaptionModelTrainer.create_hparams()
    hparams.use_gpu = False
    hparams.voice = sys.argv[1]
    hparams.model_name = "WarpingLayerTest.nn"
    hparams.add_deltas = True
    hparams.num_coded_sps = 30
    alpha_range = 0.2
    num_phonemes = 70

    num_random_alphas = 7
    # num_random_alphas = 53

    # Randomly pick alphas for each phoneme.
    np.random.seed(42)
    # phonemes_to_alpha_tensor = ((np.random.choice(np.random.rand(num_random_alphas), num_phonemes) - 0.5) * 2 * alpha_range)
    phonemes_to_alpha_tensor = ((np.random.rand(num_phonemes) - 0.5) * 2 *
                                alpha_range)

    # hparams.num_questions = 505
    hparams.num_questions = 609
    # hparams.num_questions = 425

    hparams.out_dir = os.path.join("experiments", hparams.voice,
                                   "WORLD_artificially_warped")
    hparams.data_dir = os.path.realpath("database")
    hparams.model_name = "warping_layer_test"
    hparams.synth_dir = hparams.out_dir
    dir_world_labels = os.path.join("experiments", hparams.voice, "WORLD")

    print(
        "Create artificially warped MGCs for {} in {} for {} questions, {} random alphas, and an alpha range of {}."
        .format(hparams.voice, hparams.out_dir, hparams.num_questions,
                len(np.unique(phonemes_to_alpha_tensor)), alpha_range))

    from idiaptts.src.data_preparation.world.WorldFeatLabelGen import WorldFeatLabelGen
    gen_in = WorldFeatLabelGen(dir_world_labels,
                               add_deltas=hparams.add_deltas,
                               num_coded_sps=hparams.num_coded_sps)
    gen_in.get_normalisation_params(gen_in.dir_labels)

    from idiaptts.src.model_trainers.AcousticModelTrainer import AcousticModelTrainer
    trainer = AcousticModelTrainer(
        os.path.join("experiments", hparams.voice, "WORLD"),
        os.path.join("experiments", hparams.voice, "questions"), "ignored",
        hparams.num_questions, hparams)

    hparams.num_speakers = 1
    speaker = "p276"
    num_synth_files = 5  # Number of files to synthesise to check warping manually.

    sp_mean = gen_in.norm_params[0][:hparams.num_coded_sps *
                                    (3 if hparams.add_deltas else 1)]
    sp_std_dev = gen_in.norm_params[1][:hparams.num_coded_sps *
                                       (3 if hparams.add_deltas else 1)]
    wl = WarpingLayer((hparams.num_coded_sps, ), (hparams.num_coded_sps, ),
                      hparams)
    wl.set_norm_params(sp_mean, sp_std_dev)

    def _question_to_phoneme_index(questions):
        """Helper function to convert questions to their current phoneme index."""
        if questions.shape[-1] == 505:  # German question set.
            indices = np.arange(86, 347, 5, dtype=np.int)
        elif questions.shape[-1] == 425:  # English radio question set.
            indices = np.arange(58, 107, dtype=np.int)
        elif questions.shape[-1] == 609:  # English unilex question set.
            indices = np.arange(92, 162, dtype=np.int)
        else:
            raise NotImplementedError(
                "Unknown question set with {} questions.".format(
                    questions.shape[-1]))
        return QuestionLabelGen.questions_to_phoneme_indices(
            questions, indices)

    # with open(os.path.join(hparams.data_dir, "file_id_list_{}_train.txt".format(hparams.voice))) as f:
    with open(
            os.path.join(hparams.data_dir, "file_id_list_{}_adapt.txt".format(
                hparams.voice))) as f:
        id_list = f.readlines()
    id_list[:] = [s.strip(' \t\n\r') for s in id_list
                  if speaker in s]  # Trim line endings in-place.

    out_dir = hparams.out_dir
    makedirs_safe(out_dir)
    makedirs_safe(os.path.join(out_dir,
                               "cmp_mgc" + str(hparams.num_coded_sps)))
    t_benchmark = 0
    org_to_warped_mcd = 0.0
    for idx, id_name in enumerate(id_list):

        sample = WorldFeatLabelGen.load_sample(
            id_name,
            os.path.join("experiments", hparams.voice, "WORLD"),
            add_deltas=True,
            num_coded_sps=hparams.num_coded_sps)
        sample_pre = gen_in.preprocess_sample(sample)
        coded_sps = sample_pre[:, :hparams.num_coded_sps *
                               (3 if hparams.add_deltas else 1)]

        questions = QuestionLabelGen.load_sample(
            id_name,
            os.path.join("experiments", hparams.voice, "questions"),
            num_questions=hparams.num_questions)
        questions = questions[:len(coded_sps)]
        phoneme_indices = _question_to_phoneme_index(questions)
        alpha_vec = phonemes_to_alpha_tensor[phoneme_indices %
                                             len(phonemes_to_alpha_tensor),
                                             None]

        coded_sps = coded_sps[:len(alpha_vec), None,
                              ...]  # Create a batch dimension.
        alpha_vec = alpha_vec[:, None,
                              None]  # Create a batch and feature dimension.

        t_start = timer()
        mfcc_warped, (_, nn_alpha) = wl(torch.from_numpy(coded_sps),
                                        None, (len(coded_sps), ),
                                        (len(coded_sps), ),
                                        alphas=torch.from_numpy(alpha_vec))
        t_benchmark += timer() - t_start
        sample_pre[:len(mfcc_warped), :hparams.num_coded_sps *
                   (3 if hparams.add_deltas else 1)] = mfcc_warped[:,
                                                                   0].detach()

        sample_post = gen_in.postprocess_sample(sample_pre)
        # Manually create samples without normalisation but with deltas.
        sample_pre = (sample_pre * gen_in.norm_params[1] +
                      gen_in.norm_params[0]).astype(np.float32)

        if np.isnan(sample_pre).any():
            raise ValueError(
                "Detected nan values in output features for {}.".format(
                    id_name))

        # Compute error between warped version and original one.
        org_to_warped_mcd += metrics.melcd(
            sample[:, 0:hparams.num_coded_sps],
            sample_pre[:, 0:hparams.num_coded_sps])

        # Save warped features.
        sample_pre.tofile(
            os.path.join(
                out_dir, "cmp_mgc" + str(hparams.num_coded_sps),
                os.path.basename(id_name + WorldFeatLabelGen.ext_deltas)))

        hparams.synth_dir = out_dir
        if idx < num_synth_files:  # Only synthesize a few of samples.
            trainer.run_world_synth({id_name: sample_post}, hparams)

    print("Process time for {} warpings: {}. MCD caused by warping: {:.2f}".
          format(len(id_list), timedelta(seconds=t_benchmark),
                 org_to_warped_mcd / len(id_list)))

    # Copy normalisation files which are necessary for training.
    for feature in ["_bap", "_lf0", "_mgc{}".format(hparams.num_coded_sps)]:
        shutil.copyfile(
            os.path.join(
                gen_in.dir_labels, gen_in.dir_deltas,
                MeanCovarianceExtractor.file_name_appendix + feature + ".bin"),
            os.path.join(
                out_dir, "cmp_mgc" + str(hparams.num_coded_sps),
                MeanCovarianceExtractor.file_name_appendix + feature + ".bin"))
    def gen_figure_from_output(self, id_name, label, hidden, hparams):
        _, alphas = hidden
        labels_post = self.OutputGen.postprocess_sample(label)
        coded_sp, lf0, vuv, bap = WorldFeatLabelGen.convert_to_world_features(
            labels_post,
            contains_deltas=False,
            num_coded_sps=hparams.num_coded_sps)
        sp = WorldFeatLabelGen.mcep_to_amp_sp(coded_sp, hparams.synth_fs)
        lf0, _ = interpolate_lin(lf0)

        # Load original LF0.
        org_labels_post = WorldFeatLabelGen.load_sample(
            id_name,
            dir_out=self.OutputGen.dir_labels,
            add_deltas=self.OutputGen.add_deltas,
            num_coded_sps=hparams.num_coded_sps)
        original_mgc, original_lf0, original_vuv, *_ = WorldFeatLabelGen.convert_to_world_features(
            sample=org_labels_post,
            contains_deltas=self.OutputGen.add_deltas,
            num_coded_sps=hparams.num_coded_sps)
        original_lf0, _ = interpolate_lin(original_lf0)

        sp = sp[:, :150]  # Zoom into spectral features.

        # Get a data plotter.
        grid_idx = -1
        plotter = DataPlotter()
        net_name = os.path.basename(hparams.model_name)
        filename = str(os.path.join(hparams.out_dir, id_name + '.' + net_name))
        plotter.set_title(id_name + ' - ' + net_name)
        plotter.set_num_colors(3)
        # plotter.set_lim(grid_idx=0, ymin=math.log(60), ymax=math.log(250))

        # # Plot LF0
        # grid_idx += 1
        # graphs.append((original_lf0, 'Original LF0'))
        # graphs.append((lf0, 'NN LF0'))
        # plotter.set_data_list(grid_idx=grid_idx, data_list=graphs)
        # plotter.set_area_list(grid_idx=grid_idx, area_list=[(np.invert(vuv.astype(bool)), '0.8', 1.0),
        #                                                     (np.invert(original_vuv.astype(bool)), 'red', 0.2)])
        # plotter.set_label(grid_idx=grid_idx, xlabel='frames [{}] ms'.format(hparams.frame_length), ylabel='log(f0)')

        # Reverse the warping.
        wl = self._get_dummy_warping_layer(hparams)
        norm_params_no_deltas = (
            self.OutputGen.norm_params[0][:hparams.num_coded_sps],
            self.OutputGen.norm_params[1][:hparams.num_coded_sps])
        pre_net_output, _ = wl.forward_sample(label, -alphas)

        # Postprocess sample manually.
        pre_net_output = pre_net_output.detach().cpu().numpy()
        pre_net_mgc = pre_net_output[:, 0, :hparams.
                                     num_coded_sps] * norm_params_no_deltas[
                                         1] + norm_params_no_deltas[0]

        # Plot spectral features predicted by pre-network.
        grid_idx += 1
        plotter.set_label(grid_idx=grid_idx,
                          xlabel='frames [{}] ms'.format(
                              hparams.frame_size_ms),
                          ylabel='Pre-network')
        plotter.set_specshow(grid_idx=grid_idx,
                             spec=np.abs(
                                 WorldFeatLabelGen.mcep_to_amp_sp(
                                     pre_net_mgc,
                                     hparams.synth_fs)[:, :sp.shape[1]]))

        # Plot final predicted spectral features.
        grid_idx += 1
        plotter.set_label(grid_idx=grid_idx,
                          xlabel='frames [{}] ms'.format(
                              hparams.frame_size_ms),
                          ylabel='VTLN')
        plotter.set_specshow(grid_idx=grid_idx, spec=np.abs(sp))

        # Plot predicted alpha value and V/UV flag.
        grid_idx += 1
        plotter.set_label(grid_idx=grid_idx,
                          xlabel='frames [{}] ms'.format(
                              hparams.frame_size_ms),
                          ylabel='alpha')
        graphs = list()
        graphs.append((alphas, 'NN alpha'))
        plotter.set_data_list(grid_idx=grid_idx, data_list=graphs)
        plotter.set_area_list(grid_idx=grid_idx,
                              area_list=[(np.invert(vuv.astype(bool)), '0.8',
                                          1.0),
                                         (np.invert(original_vuv.astype(bool)),
                                          'red', 0.2)])

        # Add phoneme annotations if given.
        if hasattr(hparams, "phoneme_indices") and hparams.phoneme_indices is not None \
           and hasattr(hparams, "question_file") and hparams.question_file is not None:
            questions = QuestionLabelGen.load_sample(
                id_name,
                os.path.join("experiments", hparams.voice, "questions"),
                num_questions=hparams.num_questions)[:len(lf0)]
            np_phonemes = QuestionLabelGen.questions_to_phonemes(
                questions, hparams.phoneme_indices, hparams.question_file)
            plotter.set_annotations(grid_idx, np_phonemes)

        # Plot reference spectral features.
        grid_idx += 1
        plotter.set_label(grid_idx=grid_idx,
                          xlabel='frames [{}] ms'.format(
                              hparams.frame_size_ms),
                          ylabel='Original spectrogram')
        plotter.set_specshow(grid_idx=grid_idx,
                             spec=np.abs(
                                 WorldFeatLabelGen.mcep_to_amp_sp(
                                     original_mgc,
                                     hparams.synth_fs)[:, :sp.shape[1]]))

        plotter.gen_plot()
        plotter.save_to_file(filename + '.VTLN' + hparams.gen_figure_ext)
    def compute_score(self, dict_outputs_post, dict_hiddens, hparams):
        mcd, f0_rmse, vuv_error_rate, bap_mcd = super().compute_score(
            dict_outputs_post, dict_hiddens, hparams)

        # Get data for comparision.
        dict_original_post = dict()
        for id_name in dict_outputs_post.keys():
            dict_original_post[id_name] = WorldFeatLabelGen.load_sample(
                id_name,
                dir_out=self.OutputGen.dir_labels,
                add_deltas=True,
                num_coded_sps=hparams.num_coded_sps)

        # Create a warping layer for manual warping.
        wl = self._get_dummy_warping_layer(hparams)
        norm_params_no_deltas = (
            self.OutputGen.norm_params[0][:hparams.num_coded_sps],
            self.OutputGen.norm_params[1][:hparams.num_coded_sps])

        # Compute MCD for different set of coefficients.
        batch_size = len(dict_outputs_post)
        for cep_coef_start in [1]:
            for cep_coef_end in itertools.chain(range(10, 19), [-1]):
                org_to_output_mcd = 0.0
                org_to_pre_net_output_mcd = 0.0

                for id_name, labels in dict_outputs_post.items():
                    # Split NN output.
                    _, output_alphas = dict_hiddens[id_name]
                    output_mgc_post, *_ = self.OutputGen.convert_to_world_features(
                        labels, False, num_coded_sps=hparams.num_coded_sps)
                    # Reverse the warping.
                    pre_net_output, _ = wl.forward_sample(
                        labels, -output_alphas)
                    # Postprocess sample manually.
                    pre_net_output = pre_net_output.detach().cpu().numpy()
                    pre_net_mgc = pre_net_output[:, 0, :hparams.
                                                 num_coded_sps] * norm_params_no_deltas[
                                                     1] + norm_params_no_deltas[
                                                         0]
                    # Load the original warped sample.
                    org_mgc_post = dict_original_post[
                        id_name][:len(output_mgc_post), :hparams.num_coded_sps]

                    # Compute mcd difference.
                    org_to_output_mcd += metrics.melcd(
                        org_mgc_post[:, cep_coef_start:cep_coef_end],
                        output_mgc_post[:, cep_coef_start:cep_coef_end])
                    org_to_pre_net_output_mcd += metrics.melcd(
                        org_mgc_post[:, cep_coef_start:cep_coef_end],
                        pre_net_mgc[:, cep_coef_start:cep_coef_end])

                org_to_pre_net_output_mcd /= batch_size
                org_to_output_mcd /= batch_size

                self.logger.info("MCep from {} to {}:".format(
                    cep_coef_start, cep_coef_end))
                self.logger.info(
                    "Original mgc to pre-net mgc error: {:4.2f}dB".format(
                        org_to_pre_net_output_mcd))
                self.logger.info(
                    "Original mgc to nn mgc error: {:4.2f}dB".format(
                        org_to_output_mcd))

        return mcd, f0_rmse, vuv_error_rate, bap_mcd
Beispiel #20
0
    def gen_figure_from_output(self, id_name, label, hidden, hparams):
        _, alphas = hidden
        labels_post = self.OutputGen.postprocess_sample(label)
        coded_sp, lf0, vuv, bap = WorldFeatLabelGen.convert_to_world_features(
            labels_post,
            contains_deltas=False,
            num_coded_sps=hparams.num_coded_sps)
        sp = WorldFeatLabelGen.mcep_to_amp_sp(coded_sp, hparams.synth_fs)
        lf0, _ = interpolate_lin(lf0)

        # Load original lf0.
        org_labels_post = WorldFeatLabelGen.load_sample(
            id_name,
            self.OutputGen.dir_labels,
            add_deltas=self.OutputGen.add_deltas,
            num_coded_sps=hparams.num_coded_sps)
        original_mgc, original_lf0, original_vuv, *_ = WorldFeatLabelGen.convert_to_world_features(
            org_labels_post,
            contains_deltas=self.OutputGen.add_deltas,
            num_coded_sps=hparams.num_coded_sps)
        original_lf0, _ = interpolate_lin(original_lf0)

        questions = QuestionLabelGen.load_sample(
            id_name,
            os.path.join("experiments", hparams.voice, "questions"),
            num_questions=hparams.num_questions)[:len(alphas)]
        phoneme_indices = QuestionLabelGen.questions_to_phoneme_indices(
            questions, hparams.phoneme_indices)
        alpha_vec = self.phonemes_to_alpha_tensor[phoneme_indices % len(
            self.phonemes_to_alpha_tensor)]

        # Get a data plotter.
        grid_idx = 0
        plotter = DataPlotter()
        net_name = os.path.basename(hparams.model_name)
        filename = str(os.path.join(hparams.out_dir, id_name + '.' + net_name))
        plotter.set_title(id_name + ' - ' + net_name)
        plotter.set_num_colors(3)
        # plotter.set_lim(grid_idx=0, ymin=math.log(60), ymax=math.log(250))
        plotter.set_label(grid_idx=grid_idx,
                          xlabel='frames [' + str(hparams.frame_size_ms) +
                          ' ms]',
                          ylabel='log(f0)')

        graphs = list()
        graphs.append((original_lf0, 'Original LF0'))
        graphs.append((lf0, 'NN LF0'))
        plotter.set_data_list(grid_idx=grid_idx, data_list=graphs)
        plotter.set_area_list(grid_idx=grid_idx,
                              area_list=[(np.invert(vuv.astype(bool)), '0.8',
                                          1.0),
                                         (np.invert(original_vuv.astype(bool)),
                                          'red', 0.2)])

        # grid_idx += 1
        # plotter.set_label(grid_idx=grid_idx, xlabel='frames [' + str(hparams.frame_size_ms) + ' ms]', ylabel='Original spectrogram')
        # plotter.set_specshow(grid_idx=grid_idx, spec=WorldFeatLabelGen.mgc_to_sp(original_mgc, hparams.synth_fs))
        #
        # grid_idx += 1
        # plotter.set_label(grid_idx=grid_idx, xlabel='frames [' + str(hparams.frame_size_ms) + ' ms]', ylabel='NN spectrogram')
        # plotter.set_specshow(grid_idx=grid_idx, spec=sp)

        grid_idx += 1
        plotter.set_label(grid_idx=grid_idx,
                          xlabel='frames [' + str(hparams.frame_size_ms) +
                          ' ms]',
                          ylabel='alpha')
        graphs = list()
        graphs.append((alpha_vec, 'Original alpha'))
        graphs.append((alphas, 'NN alpha'))
        plotter.set_data_list(grid_idx=grid_idx, data_list=graphs)
        plotter.set_area_list(grid_idx=grid_idx,
                              area_list=[(np.invert(vuv.astype(bool)), '0.8',
                                          1.0),
                                         (np.invert(original_vuv.astype(bool)),
                                          'red', 0.2)])
        if hasattr(hparams, "phoneme_indices") and hparams.phoneme_indices is not None \
           and hasattr(hparams, "question_file") and hparams.question_file is not None:
            questions = QuestionLabelGen.load_sample(
                id_name,
                os.path.join("experiments", hparams.voice, "questions"),
                num_questions=hparams.num_questions)[:len(lf0)]
            np_phonemes = QuestionLabelGen.questions_to_phonemes(
                questions, hparams.phoneme_indices, hparams.question_file)
            plotter.set_annotations(grid_idx, np_phonemes)

        plotter.gen_plot()
        plotter.save_to_file(filename + '.VTLN' + hparams.gen_figure_ext)
Beispiel #21
0
    def compute_score(self, dict_outputs_post, dict_hiddens, hparams):
        mcd, f0_rmse, vuv_error_rate, bap_mcd = super().compute_score(
            dict_outputs_post, dict_hiddens, hparams)

        # Get data for comparision.
        dict_original_post = dict()
        for id_name in dict_outputs_post.keys():
            dict_original_post[id_name] = WorldFeatLabelGen.load_sample(
                id_name,
                self.OutputGen.dir_labels,
                True,
                num_coded_sps=hparams.num_coded_sps)

        # Create a warping layer for manual warping.
        wl = WarpingLayer((hparams.num_coded_sps, ), (hparams.num_coded_sps, ),
                          hparams)
        if hparams.use_gpu:
            wl = wl.cuda()
        wl.set_norm_params(*self.OutputGen.norm_params)
        batch_size = len(dict_outputs_post)

        for cep_coef_start in [0, 1]:
            for cep_coef_end in (range(10, 19)
                                 if cep_coef_start == 1 else [-1]):
                alphas_rmse = 0.0
                org_to_warped_mcd = 0.0
                org_to_nn_warping_mcd = 0.0
                output_to_warped_mcd = 0.0

                for id_name, labels in dict_outputs_post.items():
                    # Split NN output.
                    _, output_alphas = dict_hiddens[id_name]
                    output_mgc_post, *_ = self.OutputGen.convert_to_world_features(
                        labels, False, num_coded_sps=hparams.num_coded_sps)

                    # Load the original sample without warping.
                    org_output = self.OutputGen.load_sample(
                        id_name,
                        os.path.join("experiments", hparams.voice, "WORLD"),
                        add_deltas=True,
                        num_coded_sps=hparams.num_coded_sps)
                    org_output = org_output[:len(output_mgc_post)]
                    org_mgc_post = org_output[:, :hparams.num_coded_sps]
                    org_output_pre = self.OutputGen.preprocess_sample(
                        org_output)  # Preprocess the sample.
                    org_mgc_pre = org_output_pre[:, :hparams.num_coded_sps * (
                        3 if hparams.add_deltas else 1)]

                    # Load the original warped sample.
                    org_mgc_warped_post = dict_original_post[
                        id_name][:len(output_mgc_post), :hparams.num_coded_sps]
                    # org_mgc_warped_post = self.OutputGen.load_sample(
                    #                                         id_name,
                    #                                         os.path.join("experiments",
                    #                                                      hparams.voice,
                    #                                                      "vtln_speaker_static",
                    #                                                      "alpha_1.10"),
                    #                                         add_deltas=True,
                    #                                         num_coded_sps=hparams.num_coded_sps)[:len(output_mgc_post), :hparams.num_coded_sps]

                    # Compute error between warped version and NN output.
                    output_to_warped_mcd += metrics.melcd(
                        org_mgc_warped_post[:, cep_coef_start:cep_coef_end],
                        output_mgc_post[:, cep_coef_start:cep_coef_end])
                    # Compute error between warped version and original one.
                    org_to_warped_mcd += metrics.melcd(
                        org_mgc_warped_post[:, cep_coef_start:cep_coef_end],
                        org_mgc_post[:, cep_coef_start:cep_coef_end])

                    # Get original alphas from phonemes.
                    questions = QuestionLabelGen.load_sample(
                        id_name,
                        os.path.join("experiments", hparams.voice,
                                     "questions"),
                        num_questions=hparams.num_questions)[:len(output_alphas
                                                                  )]
                    phoneme_indices = QuestionLabelGen.questions_to_phoneme_indices(
                        questions, hparams.phoneme_indices)
                    org_alphas = self.phonemes_to_alpha_tensor[
                        phoneme_indices % len(self.phonemes_to_alpha_tensor),
                        None]

                    # Compute RMSE of alphas.
                    alphas_rmse += math.sqrt(
                        ((org_alphas - output_alphas)**2).sum())

                    # Warp the original mgcs with the alpha predicted by the network.
                    org_mgc_nn_warped, _ = wl.forward_sample(
                        org_mgc_pre, output_alphas)  # Warp with the NN alphas.
                    org_output_pre[:, :hparams.num_coded_sps * (3 if hparams.add_deltas else 1)]\
                        = org_mgc_nn_warped[:, 0, ...].detach()  # Write warped mgcs back.
                    org_mgc_nn_warped_post = self.OutputGen.postprocess_sample(
                        org_output_pre,
                        apply_mlpg=False)[:, :hparams.num_coded_sps]

                    # Compute error between correctly warped version and original mgcs warped with NN alpha.
                    org_to_nn_warping_mcd += metrics.melcd(
                        org_mgc_warped_post[:, cep_coef_start:cep_coef_end],
                        org_mgc_nn_warped_post[:, cep_coef_start:cep_coef_end])

                alphas_rmse /= batch_size
                output_to_warped_mcd /= batch_size
                org_to_warped_mcd /= batch_size
                org_to_nn_warping_mcd /= batch_size

                self.logger.info("MCep from {} to {}:".format(
                    cep_coef_start, cep_coef_end))
                self.logger.info("RMSE alphas: {:4.2f}".format(alphas_rmse))
                self.logger.info(
                    "Original mgc to warped mgc error: {:4.2f}dB".format(
                        org_to_warped_mcd))
                self.logger.info(
                    "Original mgc warped by network alpha to warped mgc error: {:4.2f}dB ({:2.2f}%)"
                    .format(org_to_nn_warping_mcd,
                            (1 - org_to_nn_warping_mcd / org_to_warped_mcd) *
                            100))
                self.logger.info(
                    "Network output to original warped mgc error: {:4.2f}dB".
                    format(output_to_warped_mcd))

        return mcd, f0_rmse, vuv_error_rate, bap_mcd
    def __init__(self, dir_world_features, id_list, hparams=None):

        if hparams is None:
            hparams = self.create_hparams()
            hparams.out_dir = os.path.curdir

        # Write missing default parameters.
        if hparams.variable_sequence_length_train is None:
            hparams.variable_sequence_length_train = hparams.batch_size_train > 1
        if hparams.variable_sequence_length_test is None:
            hparams.variable_sequence_length_test = hparams.batch_size_test > 1
        if hparams.synth_dir is None:
            hparams.synth_dir = os.path.join(hparams.out_dir, "synth")

        super().__init__(id_list, hparams)

        in_to_out_multiplier = int(hparams.frame_rate_output_Hz /
                                   (1000.0 / hparams.frame_size_ms))
        max_frames_input_trainset = int(
            1000.0 / hparams.frame_size_ms * hparams.max_input_train_sec
        ) * in_to_out_multiplier  # Multiply by number of seconds.
        max_frames_input_testset = int(
            1000.0 / hparams.frame_size_ms * hparams.max_input_test_sec
        ) * in_to_out_multiplier  # Ensure that test takes all frames. NOTE: Had to limit it because of memory constraints.

        self.InputGen = WorldFeatLabelGen(
            dir_world_features,
            add_deltas=False,
            sampling_fn=partial(sample_linearly,
                                in_to_out_multiplier=in_to_out_multiplier,
                                dtype=np.float32),
            num_coded_sps=hparams.num_coded_sps,
            sp_type=hparams.sp_type,
            load_sp=hparams.load_sp,
            load_lf0=hparams.load_lf0,
            load_vuv=hparams.load_vuv,
            load_bap=hparams.load_bap)
        self.InputGen.get_normalisation_params(
            dir_world_features, hparams.input_norm_params_file_prefix)

        self.OutputGen = RawWaveformLabelGen(
            frame_rate_output_Hz=hparams.frame_rate_output_Hz,
            frame_size_ms=hparams.frame_size_ms,
            mu=hparams.mu if hparams.input_type == "mulaw-quantize" else None,
            silence_threshold_quantized=hparams.silence_threshold_quantized)
        # No normalisation parameters required.

        self.dataset_train = LabelGensDataset(
            self.id_list_train,
            self.InputGen,
            self.OutputGen,
            hparams,
            random_select=True,
            max_frames_input=max_frames_input_trainset)
        self.dataset_val = LabelGensDataset(
            self.id_list_val,
            self.InputGen,
            self.OutputGen,
            hparams,
            random_select=True,
            max_frames_input=max_frames_input_testset)

        if self.loss_function is None:
            if hparams.input_type == "mulaw-quantize":
                self.loss_function = OneHotCrossEntropyLoss(reduction='none',
                                                            shift=1)
            else:
                self.loss_function = DiscretizedMixturelogisticLoss(
                    hparams.quantize_channels,
                    hparams.log_scale_min,
                    reduction='none',
                    hinge_loss=hparams.hinge_regularizer)

        if hparams.scheduler_type == "default":
            hparams.scheduler_type = "Noam"
            # hparams.scheduler_args['exponential_gamma'] = 0.99
            hparams.scheduler_args['wormup_steps'] = 4000

        # Override the collate and decollate methods of batches.
        self.batch_collate_fn = partial(self.prepare_batch,
                                        use_cond=hparams.use_cond,
                                        one_hot_target=True)
        self.batch_decollate_fn = self.decollate_network_output
Beispiel #23
0
    def gen_figure_phrase(self, hparams, ids_input):
        id_list = ModelTrainer._input_to_str_list(ids_input)
        model_output, model_output_post = self._forward_batched(
            hparams,
            id_list,
            hparams.batch_size_gen_figure,
            synth=False,
            benchmark=False,
            gen_figure=False)

        for id_name, outputs_post in model_output_post.items():

            if outputs_post.ndim < 2:
                outputs_post = np.expand_dims(outputs_post, axis=1)

            lf0 = outputs_post[:, 0]
            output_lf0, _ = interpolate_lin(lf0)
            output_vuv = outputs_post[:, 1]
            output_vuv[output_vuv < 0.5] = 0.0
            output_vuv[output_vuv >= 0.5] = 1.0
            output_vuv = output_vuv.astype(np.bool)

            # Load original lf0 and vuv.
            world_dir = hparams.world_dir if hasattr(hparams, "world_dir") and hparams.world_dir is not None\
                                          else os.path.join(hparams.out_dir, self.dir_extracted_acoustic_features)
            org_labels = WorldFeatLabelGen.load_sample(
                id_name,
                world_dir,
                num_coded_sps=hparams.num_coded_sps,
                num_bap=hparams.num_bap)[:len(output_lf0)]
            _, original_lf0, original_vuv, _ = WorldFeatLabelGen.convert_to_world_features(
                org_labels,
                num_coded_sps=hparams.num_coded_sps,
                num_bap=hparams.num_bap)
            original_lf0, _ = interpolate_lin(original_lf0)
            original_vuv = original_vuv.astype(np.bool)

            phrase_curve = np.fromfile(os.path.join(
                self.flat_trainer.atom_trainer.OutputGen.dir_labels,
                id_name + self.OutputGen.ext_phrase),
                                       dtype=np.float32).reshape(
                                           -1, 1)[:len(original_lf0)]

            f0_mse = (np.exp(original_lf0.squeeze(-1)) -
                      np.exp(phrase_curve.squeeze(-1)))**2
            f0_rmse = math.sqrt(
                (f0_mse * original_vuv[:len(output_lf0)]).sum() /
                original_vuv[:len(output_lf0)].sum())
            self.logger.info("RMSE of {} phrase curve: {} Hz.".format(
                id_name, f0_rmse))

            len_diff = len(original_lf0) - len(lf0)
            original_lf0 = WorldFeatLabelGen.trim_end_sample(
                original_lf0, int(len_diff / 2.0))
            original_lf0 = WorldFeatLabelGen.trim_end_sample(
                original_lf0, int(len_diff / 2.0) + 1, reverse=True)

            # Get a data plotter.
            net_name = os.path.basename(hparams.model_name)
            filename = str(
                os.path.join(hparams.out_dir, id_name + '.' + net_name))
            plotter = DataPlotter()
            # plotter.set_title(id_name + " - " + net_name)

            grid_idx = 0
            graphs_lf0 = list()
            graphs_lf0.append((original_lf0, "Original"))
            graphs_lf0.append((phrase_curve, "Predicted"))
            plotter.set_data_list(grid_idx=grid_idx, data_list=graphs_lf0)
            plotter.set_area_list(grid_idx=grid_idx,
                                  area_list=[(np.invert(original_vuv), '0.8',
                                              1.0, 'Reference unvoiced')])
            plotter.set_label(grid_idx=grid_idx,
                              xlabel='frames [' + str(hparams.frame_size_ms) +
                              ' ms]',
                              ylabel='LF0')
            # amp_lim = max(np.max(np.abs(wcad_lf0)), np.max(np.abs(output_lf0))) * 1.1
            # plotter.set_lim(grid_idx=grid_idx, ymin=-amp_lim, ymax=amp_lim)
            plotter.set_lim(grid_idx=grid_idx, ymin=4.2, ymax=5.4)
            # plotter.set_linestyles(grid_idx=grid_idx, linestyles=[':', '--', '-'])

            # plotter.set_lim(xmin=300, xmax=1100)
            plotter.gen_plot()
            plotter.save_to_file(filename + ".PHRASE" + hparams.gen_figure_ext)
Beispiel #24
0
def main():
    from idiaptts.src.model_trainers.vtln.VTLNSpeakerAdaptionModelTrainer import VTLNSpeakerAdaptionModelTrainer
    hparams = VTLNSpeakerAdaptionModelTrainer.create_hparams()
    hparams.use_gpu = False
    hparams.voice = "English"
    hparams.model_name = "AllPassWarpModelTest.nn"
    hparams.add_deltas = True
    hparams.num_coded_sps = 30
    # hparams.num_questions = 505
    hparams.num_questions = 425
    hparams.out_dir = os.path.join("experiments", hparams.voice,
                                   "VTLNArtificiallyWarped")
    hparams.data_dir = os.path.realpath("database")
    hparams.model_name = "all_pass_warp_test"
    hparams.synth_dir = hparams.out_dir
    batch_size = 2
    dir_world_labels = os.path.join("experiments", hparams.voice, "WORLD")

    # hparams.add_hparam("warp_matrix_size", hparams.num_coded_sps)
    hparams.alpha_ranges = [
        0.2,
    ]

    from idiaptts.src.data_preparation.world.WorldFeatLabelGen import WorldFeatLabelGen
    gen_in = WorldFeatLabelGen(dir_world_labels,
                               add_deltas=hparams.add_deltas,
                               num_coded_sps=hparams.num_coded_sps,
                               num_bap=hparams.num_bap)
    gen_in.get_normalisation_params(gen_in.dir_labels)

    from idiaptts.src.model_trainers.AcousticModelTrainer import AcousticModelTrainer
    trainer = AcousticModelTrainer(
        "experiments/" + hparams.voice + "/WORLD",
        "experiments/" + hparams.voice + "/questions", "ignored",
        hparams.num_questions, hparams)

    sp_mean = gen_in.norm_params[0][:hparams.num_coded_sps *
                                    (3 if hparams.add_deltas else 1)]
    sp_std_dev = gen_in.norm_params[1][:hparams.num_coded_sps *
                                       (3 if hparams.add_deltas else 1)]
    all_pass_warp_model = AllPassWarpModel((hparams.num_coded_sps, ),
                                           (hparams.num_coded_sps, ), hparams)
    all_pass_warp_model.set_norm_params(sp_mean, sp_std_dev)

    # id_list = ["dorian/doriangray_16_00199"]
    # id_list = ["p225/p225_051", "p277/p277_012", "p278/p278_012", "p279/p279_012"]
    id_list = ["p225/p225_051"]

    t_benchmark = 0
    for id_name in id_list:
        sample = WorldFeatLabelGen.load_sample(
            id_name,
            os.path.join("experiments", hparams.voice, "WORLD"),
            add_deltas=True,
            num_coded_sps=hparams.num_coded_sps,
            num_bap=hparams.num_bap,
            sp_type=hparams.sp_type)
        sample_pre = gen_in.preprocess_sample(sample)
        coded_sps = sample_pre[:, :hparams.num_coded_sps *
                               (3 if hparams.add_deltas else 1)].copy()
        coded_sps = coded_sps[:, None,
                              ...].repeat(batch_size,
                                          1)  # Copy data in batch dimension.

        for idx, alpha in enumerate(np.arange(-0.2, 0.21, 0.05)):
            out_dir = os.path.join(hparams.out_dir,
                                   "alpha_{0:0.2f}".format(alpha))
            makedirs_safe(out_dir)

            alpha_vec = np.ones((coded_sps.shape[0], 1)) * alpha
            alpha_vec = alpha_vec[:, None].repeat(
                batch_size, 1)  # Copy data in batch dimension.

            t_start = timer()
            sp_warped, (_, nn_alpha) = all_pass_warp_model(
                torch.from_numpy(coded_sps.copy()),
                None, (len(coded_sps), ), (len(coded_sps), ),
                alphas=torch.tensor(alpha_vec, requires_grad=True))
            sp_warped.sum().backward()
            t_benchmark += timer() - t_start
            # assert((mfcc_warped[:, 0] == mfcc_warped[:, 1]).all())  # Compare results for cloned coded_sps within batch.
            if np.isclose(alpha, 0):
                assert np.isclose(
                    sp_warped.detach().cpu().numpy(),
                    coded_sps).all()  # Compare no warping results.
            sample_pre[:len(sp_warped), :hparams.num_coded_sps * (
                3 if hparams.add_deltas else 1)] = sp_warped[:, 0].detach()

            sample_post = gen_in.postprocess_sample(sample_pre,
                                                    apply_mlpg=False)
            # Manually create samples without normalisation but with deltas.
            sample_pre_with_deltas = (sample_pre * gen_in.norm_params[1] +
                                      gen_in.norm_params[0]).astype(np.float32)

            if np.isnan(sample_pre_with_deltas).any():
                raise ValueError(
                    "Detected nan values in output features for {}.".format(
                        id_name))
            # Save warped features.
            makedirs_safe(os.path.dirname(os.path.join(out_dir, id_name)))
            sample_pre_with_deltas.tofile(
                os.path.join(out_dir,
                             id_name + "." + WorldFeatLabelGen.ext_deltas))

            hparams.synth_dir = out_dir
            # sample_no_deltas = WorldFeatLabelGen.convert_from_world_features(*WorldFeatLabelGen.convert_to_world_features(sample, contains_deltas=hparams.add_deltas, num_coded_sps=hparams.num_coded_sps, num_bap=hparams.num_bap))
            Synthesiser.run_world_synth({id_name: sample_post}, hparams)

    print("Process time for {} runs: {}, average: {}".format(
        len(id_list) * idx, timedelta(seconds=t_benchmark),
        timedelta(seconds=t_benchmark) / (len(id_list) * idx)))
class AcousticModelTrainer(ModelTrainer):
    """
    Implementation of a ModelTrainer for the generation of acoustic data.

    Use question labels as input and WORLD features w/o deltas/double deltas (specified in hparams.add_deltas) as output.
    Synthesize audio from model output with MLPG smoothing.
    """
    logger = logging.getLogger(__name__)

    #########################
    # Default constructor
    #
    def __init__(self,
                 dir_world_features,
                 dir_question_labels,
                 id_list,
                 num_questions,
                 hparams=None):
        """Default constructor.

        :param dir_world_features:      Path to the directory containing the world features.
        :param dir_question_labels:     Path to the directory containing the question labels.
        :param id_list:                 List of ids, can contain a speaker directory.
        :param num_questions:           Number of questions in question file.
        :param hparams:                 Set of hyper parameters.
        """
        if hparams is None:
            hparams = self.create_hparams()
            hparams.out_dir = os.path.curdir

        # Write missing default parameters.
        if hparams.variable_sequence_length_train is None:
            hparams.variable_sequence_length_train = hparams.batch_size_train > 1
        if hparams.variable_sequence_length_test is None:
            hparams.variable_sequence_length_test = hparams.batch_size_test > 1
        if hparams.synth_dir is None:
            hparams.synth_dir = os.path.join(hparams.out_dir, "synth")

        super(AcousticModelTrainer, self).__init__(id_list, hparams)

        self.InputGen = QuestionLabelGen(dir_question_labels, num_questions)
        self.InputGen.get_normalisation_params(
            dir_question_labels, hparams.input_norm_params_file_prefix)

        self.OutputGen = WorldFeatLabelGen(dir_world_features,
                                           add_deltas=hparams.add_deltas,
                                           num_coded_sps=hparams.num_coded_sps,
                                           sp_type=hparams.sp_type)
        self.OutputGen.get_normalisation_params(
            dir_world_features, hparams.output_norm_params_file_prefix)

        self.dataset_train = LabelGensDataset(self.id_list_train,
                                              self.InputGen,
                                              self.OutputGen,
                                              hparams,
                                              match_lengths=True)
        self.dataset_val = LabelGensDataset(self.id_list_val,
                                            self.InputGen,
                                            self.OutputGen,
                                            hparams,
                                            match_lengths=True)

        if self.loss_function is None:
            self.loss_function = torch.nn.MSELoss(reduction='none')

        if hparams.scheduler_type == "default":
            hparams.scheduler_type = "Plateau"
            hparams.add_hparams(plateau_verbose=True)

    @staticmethod
    def create_hparams(hparams_string=None, verbose=False):
        """Create model hyper parameter container. Parse non default from given string."""
        hparams = ModelTrainer.create_hparams(hparams_string, verbose=False)

        hparams.add_hparams(
            num_questions=None,
            question_file=None,  # Used to add labels in plot.
            num_coded_sps=60,
            sp_type="mcep",
            add_deltas=True,
            synth_load_org_sp=False,
            synth_load_org_lf0=False,
            synth_load_org_vuv=False,
            synth_load_org_bap=False)

        if verbose:
            logging.info(hparams.get_debug_string())

        return hparams

    def gen_figure_from_output(self, id_name, label, hidden, hparams):

        labels_post = self.OutputGen.postprocess_sample(label)
        coded_sp, lf0, vuv, bap = WorldFeatLabelGen.convert_to_world_features(
            labels_post,
            contains_deltas=False,
            num_coded_sps=hparams.num_coded_sps)
        lf0, _ = interpolate_lin(lf0)

        # Load original lf0.
        org_labels_post = WorldFeatLabelGen.load_sample(
            id_name,
            self.OutputGen.dir_labels,
            add_deltas=self.OutputGen.add_deltas,
            num_coded_sps=hparams.num_coded_sps)
        original_mgc, original_lf0, original_vuv, *_ = WorldFeatLabelGen.convert_to_world_features(
            org_labels_post,
            contains_deltas=self.OutputGen.add_deltas,
            num_coded_sps=hparams.num_coded_sps)
        original_lf0, _ = interpolate_lin(original_lf0)

        # Get a data plotter.
        grid_idx = 0
        plotter = DataPlotter()
        net_name = os.path.basename(hparams.model_name)
        filename = str(os.path.join(hparams.out_dir, id_name + '.' + net_name))
        plotter.set_title(id_name + ' - ' + net_name)
        plotter.set_num_colors(3)
        # plotter.set_lim(grid_idx=0, ymin=math.log(60), ymax=math.log(250))
        plotter.set_label(grid_idx=grid_idx,
                          xlabel='frames [' + str(hparams.frame_size_ms) +
                          ' ms]',
                          ylabel='log(f0)')

        graphs = list()
        graphs.append((original_lf0, 'Original lf0'))
        graphs.append((lf0, 'PyTorch lf0'))
        plotter.set_data_list(grid_idx=grid_idx, data_list=graphs)
        plotter.set_area_list(grid_idx=grid_idx,
                              area_list=[(np.invert(vuv.astype(bool)), '0.8',
                                          1.0),
                                         (np.invert(original_vuv.astype(bool)),
                                          'red', 0.2)])

        grid_idx += 1
        import librosa
        plotter.set_label(grid_idx=grid_idx,
                          xlabel='frames [' + str(hparams.frame_size_ms) +
                          ' ms]',
                          ylabel='Original spectrogram')
        plotter.set_specshow(grid_idx=grid_idx,
                             spec=librosa.amplitude_to_db(np.absolute(
                                 WorldFeatLabelGen.mcep_to_amp_sp(
                                     original_mgc, hparams.synth_fs)),
                                                          top_db=None))

        grid_idx += 1
        plotter.set_label(grid_idx=grid_idx,
                          xlabel='frames [' + str(hparams.frame_size_ms) +
                          ' ms]',
                          ylabel='NN spectrogram')
        plotter.set_specshow(grid_idx=grid_idx,
                             spec=librosa.amplitude_to_db(np.absolute(
                                 WorldFeatLabelGen.mcep_to_amp_sp(
                                     coded_sp, hparams.synth_fs)),
                                                          top_db=None))

        if hasattr(hparams, "phoneme_indices") and hparams.phoneme_indices is not None \
           and hasattr(hparams, "question_file") and hparams.question_file is not None:
            questions = QuestionLabelGen.load_sample(
                id_name,
                "experiments/" + hparams.voice + "/questions/",
                num_questions=hparams.num_questions)[:len(lf0)]
            np_phonemes = QuestionLabelGen.questions_to_phonemes(
                questions, hparams.phoneme_indices, hparams.question_file)
            plotter.set_annotations(grid_idx, np_phonemes)

        plotter.gen_plot()
        plotter.save_to_file(filename + '.Org-PyTorch' +
                             hparams.gen_figure_ext)

    def compute_score(self, dict_outputs_post, dict_hiddens, hparams):

        # Get data for comparision.
        dict_original_post = dict()
        for id_name in dict_outputs_post.keys():
            dict_original_post[id_name] = WorldFeatLabelGen.load_sample(
                id_name,
                dir_out=self.OutputGen.dir_labels,
                add_deltas=True,
                num_coded_sps=hparams.num_coded_sps)

        f0_rmse = 0.0
        f0_rmse_max_id = "None"
        f0_rmse_max = 0.0
        all_rmse = []
        vuv_error_rate = 0.0
        vuv_error_max_id = "None"
        vuv_error_max = 0.0
        all_vuv = []
        mcd = 0.0
        mcd_max_id = "None"
        mcd_max = 0.0
        all_mcd = []
        bap_error = 0.0
        bap_error_max_id = "None"
        bap_error_max = 0.0
        all_bap_error = []

        for id_name, labels in dict_outputs_post.items():
            output_coded_sp, output_lf0, output_vuv, output_bap = self.OutputGen.convert_to_world_features(
                sample=labels,
                contains_deltas=False,
                num_coded_sps=hparams.num_coded_sps)
            output_vuv = output_vuv.astype(bool)

            # Get data for comparision.
            org_coded_sp, org_lf0, org_vuv, org_bap = self.OutputGen.convert_to_world_features(
                sample=dict_original_post[id_name],
                contains_deltas=self.OutputGen.add_deltas,
                num_coded_sps=hparams.num_coded_sps)

            # Compute f0 from lf0.
            org_f0 = np.exp(org_lf0.squeeze())[:len(
                output_lf0)]  # Fix minor negligible length mismatch.
            output_f0 = np.exp(output_lf0)

            # Compute MCD.
            org_coded_sp = org_coded_sp[:len(output_coded_sp)]
            current_mcd = metrics.melcd(
                output_coded_sp[:, 1:],
                org_coded_sp[:, 1:])  # TODO: Use aligned mcd.
            if current_mcd > mcd_max:
                mcd_max_id = id_name
                mcd_max = current_mcd
            mcd += current_mcd
            all_mcd.append(current_mcd)

            # Compute RMSE.
            f0_mse = (org_f0 - output_f0)**2
            current_f0_rmse = math.sqrt(
                (f0_mse * org_vuv[:len(output_lf0)]).sum() /
                org_vuv[:len(output_lf0)].sum())
            if current_f0_rmse != current_f0_rmse:
                logging.error(
                    "Computed NaN for F0 RMSE for {}.".format(id_name))
            else:
                if current_f0_rmse > f0_rmse_max:
                    f0_rmse_max_id = id_name
                    f0_rmse_max = current_f0_rmse
                f0_rmse += current_f0_rmse
                all_rmse.append(current_f0_rmse)

            # Compute error of VUV in percentage.
            num_errors = (org_vuv[:len(output_lf0)] != output_vuv)
            vuv_error_rate_tmp = float(num_errors.sum()) / len(output_lf0)
            if vuv_error_rate_tmp > vuv_error_max:
                vuv_error_max_id = id_name
                vuv_error_max = vuv_error_rate_tmp
            vuv_error_rate += vuv_error_rate_tmp
            all_vuv.append(vuv_error_rate_tmp)

            # Compute aperiodicity distortion.
            org_bap = org_bap[:len(output_bap)]
            if len(output_bap.shape) > 1 and output_bap.shape[1] > 1:
                current_bap_error = metrics.melcd(
                    output_bap, org_bap)  # TODO: Use aligned mcd?
            else:
                current_bap_error = math.sqrt(
                    ((org_bap - output_bap)**
                     2).mean()) * (10.0 / np.log(10) * np.sqrt(2.0))
            if current_bap_error > bap_error_max:
                bap_error_max_id = id_name
                bap_error_max = current_bap_error
            bap_error += current_bap_error
            all_bap_error.append(current_bap_error)

        f0_rmse /= len(dict_outputs_post)
        vuv_error_rate /= len(dict_outputs_post)
        mcd /= len(dict_original_post)
        bap_error /= len(dict_original_post)

        self.logger.info("Worst MCD: {} {:4.2f}dB".format(mcd_max_id, mcd_max))
        self.logger.info("Worst F0 RMSE: {} {:4.2f}Hz".format(
            f0_rmse_max_id, f0_rmse_max))
        self.logger.info("Worst VUV error: {} {:2.2f}%".format(
            vuv_error_max_id, vuv_error_max * 100))
        self.logger.info("Worst BAP error: {} {:4.2f}db".format(
            bap_error_max_id, bap_error_max))
        self.logger.info(
            "Benchmark score: MCD {:4.2f}dB, F0 RMSE {:4.2f}Hz, VUV {:2.2f}%, BAP error {:4.2f}db"
            .format(mcd, f0_rmse, vuv_error_rate * 100, bap_error))

        return mcd, f0_rmse, vuv_error_rate, bap_error

    def synthesize(self, id_list, synth_output, hparams):
        """
        Depending on hparams override the network output with the extracted features,
        then continue with normal synthesis pipeline.
        """

        if hparams.synth_load_org_sp\
                or hparams.synth_load_org_lf0\
                or hparams.synth_load_org_vuv\
                or hparams.synth_load_org_bap:
            for id_name in id_list:

                world_dir = hparams.world_dir if hasattr(hparams, "world_dir") and hparams.world_dir is not None\
                                              else os.path.join(self.OutputGen.dir_labels,
                                                                self.dir_extracted_acoustic_features)
                labels = WorldFeatLabelGen.load_sample(
                    id_name, world_dir, num_coded_sps=hparams.num_coded_sps)
                len_diff = len(labels) - len(synth_output[id_name])
                if len_diff > 0:
                    labels = WorldFeatLabelGen.trim_end_sample(labels,
                                                               int(len_diff /
                                                                   2),
                                                               reverse=True)
                    labels = WorldFeatLabelGen.trim_end_sample(
                        labels, len_diff - int(len_diff / 2))

                if hparams.synth_load_org_sp:
                    synth_output[
                        id_name][:len(labels), :self.OutputGen.
                                 num_coded_sps] = labels[:, :self.OutputGen.
                                                         num_coded_sps]

                if hparams.synth_load_org_lf0:
                    synth_output[id_name][:len(labels), -3] = labels[:, -3]

                if hparams.synth_load_org_vuv:
                    synth_output[id_name][:len(labels), -2] = labels[:, -2]

                if hparams.synth_load_org_bap:
                    synth_output[id_name][:len(labels), -1] = labels[:, -1]

        # Run the vocoder.
        ModelTrainer.synthesize(self, id_list, synth_output, hparams)
    def compute_score(self, dict_outputs_post, dict_hiddens, hparams):

        # Get data for comparision.
        dict_original_post = dict()
        for id_name in dict_outputs_post.keys():
            dict_original_post[id_name] = WorldFeatLabelGen.load_sample(
                id_name,
                dir_out=self.OutputGen.dir_labels,
                add_deltas=True,
                num_coded_sps=hparams.num_coded_sps)

        f0_rmse = 0.0
        f0_rmse_max_id = "None"
        f0_rmse_max = 0.0
        all_rmse = []
        vuv_error_rate = 0.0
        vuv_error_max_id = "None"
        vuv_error_max = 0.0
        all_vuv = []
        mcd = 0.0
        mcd_max_id = "None"
        mcd_max = 0.0
        all_mcd = []
        bap_error = 0.0
        bap_error_max_id = "None"
        bap_error_max = 0.0
        all_bap_error = []

        for id_name, labels in dict_outputs_post.items():
            output_coded_sp, output_lf0, output_vuv, output_bap = self.OutputGen.convert_to_world_features(
                sample=labels,
                contains_deltas=False,
                num_coded_sps=hparams.num_coded_sps)
            output_vuv = output_vuv.astype(bool)

            # Get data for comparision.
            org_coded_sp, org_lf0, org_vuv, org_bap = self.OutputGen.convert_to_world_features(
                sample=dict_original_post[id_name],
                contains_deltas=self.OutputGen.add_deltas,
                num_coded_sps=hparams.num_coded_sps)

            # Compute f0 from lf0.
            org_f0 = np.exp(org_lf0.squeeze())[:len(
                output_lf0)]  # Fix minor negligible length mismatch.
            output_f0 = np.exp(output_lf0)

            # Compute MCD.
            org_coded_sp = org_coded_sp[:len(output_coded_sp)]
            current_mcd = metrics.melcd(
                output_coded_sp[:, 1:],
                org_coded_sp[:, 1:])  # TODO: Use aligned mcd.
            if current_mcd > mcd_max:
                mcd_max_id = id_name
                mcd_max = current_mcd
            mcd += current_mcd
            all_mcd.append(current_mcd)

            # Compute RMSE.
            f0_mse = (org_f0 - output_f0)**2
            current_f0_rmse = math.sqrt(
                (f0_mse * org_vuv[:len(output_lf0)]).sum() /
                org_vuv[:len(output_lf0)].sum())
            if current_f0_rmse != current_f0_rmse:
                logging.error(
                    "Computed NaN for F0 RMSE for {}.".format(id_name))
            else:
                if current_f0_rmse > f0_rmse_max:
                    f0_rmse_max_id = id_name
                    f0_rmse_max = current_f0_rmse
                f0_rmse += current_f0_rmse
                all_rmse.append(current_f0_rmse)

            # Compute error of VUV in percentage.
            num_errors = (org_vuv[:len(output_lf0)] != output_vuv)
            vuv_error_rate_tmp = float(num_errors.sum()) / len(output_lf0)
            if vuv_error_rate_tmp > vuv_error_max:
                vuv_error_max_id = id_name
                vuv_error_max = vuv_error_rate_tmp
            vuv_error_rate += vuv_error_rate_tmp
            all_vuv.append(vuv_error_rate_tmp)

            # Compute aperiodicity distortion.
            org_bap = org_bap[:len(output_bap)]
            if len(output_bap.shape) > 1 and output_bap.shape[1] > 1:
                current_bap_error = metrics.melcd(
                    output_bap, org_bap)  # TODO: Use aligned mcd?
            else:
                current_bap_error = math.sqrt(
                    ((org_bap - output_bap)**
                     2).mean()) * (10.0 / np.log(10) * np.sqrt(2.0))
            if current_bap_error > bap_error_max:
                bap_error_max_id = id_name
                bap_error_max = current_bap_error
            bap_error += current_bap_error
            all_bap_error.append(current_bap_error)

        f0_rmse /= len(dict_outputs_post)
        vuv_error_rate /= len(dict_outputs_post)
        mcd /= len(dict_original_post)
        bap_error /= len(dict_original_post)

        self.logger.info("Worst MCD: {} {:4.2f}dB".format(mcd_max_id, mcd_max))
        self.logger.info("Worst F0 RMSE: {} {:4.2f}Hz".format(
            f0_rmse_max_id, f0_rmse_max))
        self.logger.info("Worst VUV error: {} {:2.2f}%".format(
            vuv_error_max_id, vuv_error_max * 100))
        self.logger.info("Worst BAP error: {} {:4.2f}db".format(
            bap_error_max_id, bap_error_max))
        self.logger.info(
            "Benchmark score: MCD {:4.2f}dB, F0 RMSE {:4.2f}Hz, VUV {:2.2f}%, BAP error {:4.2f}db"
            .format(mcd, f0_rmse, vuv_error_rate * 100, bap_error))

        return mcd, f0_rmse, vuv_error_rate, bap_error
Beispiel #27
0
    def gen_figure_from_output(self, id_name, labels, hidden, hparams):

        if labels.ndim < 2:
            labels = np.expand_dims(labels, axis=1)
        labels_post = self.OutputGen.postprocess_sample(labels,
                                                        identify_peaks=True,
                                                        peak_range=100)
        lf0 = self.OutputGen.labels_to_lf0(labels_post, hparams.k)
        lf0, vuv = interpolate_lin(lf0)
        vuv = vuv.astype(np.bool)

        # Load original lf0 and vuv.
        world_dir = hparams.world_dir if hasattr(hparams, "world_dir") and hparams.world_dir is not None\
                                      else os.path.join(self.OutputGen.dir_labels, self.dir_extracted_acoustic_features)
        org_labels = WorldFeatLabelGen.load_sample(
            id_name, world_dir, num_coded_sps=hparams.num_coded_sps)
        _, original_lf0, original_vuv, _ = WorldFeatLabelGen.convert_to_world_features(
            org_labels, num_coded_sps=hparams.num_coded_sps)
        original_lf0, _ = interpolate_lin(original_lf0)
        original_vuv = original_vuv.astype(np.bool)

        phrase_curve = np.fromfile(os.path.join(
            self.OutputGen.dir_labels, id_name + self.OutputGen.ext_phrase),
                                   dtype=np.float32).reshape(-1, 1)
        original_lf0 -= phrase_curve
        len_diff = len(original_lf0) - len(lf0)
        original_lf0 = WorldFeatLabelGen.trim_end_sample(
            original_lf0, int(len_diff / 2.0))
        original_lf0 = WorldFeatLabelGen.trim_end_sample(original_lf0,
                                                         int(len_diff / 2.0) +
                                                         1,
                                                         reverse=True)

        org_labels = self.OutputGen.load_sample(id_name,
                                                self.OutputGen.dir_labels,
                                                len(hparams.thetas))
        org_labels = self.OutputGen.trim_end_sample(org_labels,
                                                    int(len_diff / 2.0))
        org_labels = self.OutputGen.trim_end_sample(org_labels,
                                                    int(len_diff / 2.0) + 1,
                                                    reverse=True)
        org_atoms = self.OutputGen.labels_to_atoms(
            org_labels, k=hparams.k, frame_size=hparams.frame_size_ms)

        # Get a data plotter.
        net_name = os.path.basename(hparams.model_name)
        filename = str(os.path.join(hparams.out_dir, id_name + '.' + net_name))
        plotter = DataPlotter()
        plotter.set_title(id_name + " - " + net_name)

        graphs_output = list()
        grid_idx = 0
        for idx in reversed(range(labels.shape[1])):
            graphs_output.append(
                (labels[:, idx],
                 r'$\theta$=' + "{0:.3f}".format(hparams.thetas[idx])))
        plotter.set_label(grid_idx=grid_idx,
                          xlabel='frames [' + str(hparams.frame_size_ms) +
                          ' ms]',
                          ylabel='NN output')
        plotter.set_data_list(grid_idx=grid_idx, data_list=graphs_output)
        # plotter.set_lim(grid_idx=0, ymin=-1.8, ymax=1.8)

        grid_idx += 1
        graphs_peaks = list()
        for idx in reversed(range(labels_post.shape[1])):
            graphs_peaks.append((labels_post[:, idx, 0], ))
        plotter.set_label(grid_idx=grid_idx,
                          xlabel='frames [' + str(hparams.frame_size_ms) +
                          ' ms]',
                          ylabel='NN post-processed')
        plotter.set_data_list(grid_idx=grid_idx, data_list=graphs_peaks)
        plotter.set_area_list(grid_idx=grid_idx,
                              area_list=[(np.invert(vuv), '0.8', 1.0)])
        plotter.set_lim(grid_idx=grid_idx, ymin=-1.8, ymax=1.8)

        grid_idx += 1
        graphs_target = list()
        for idx in reversed(range(org_labels.shape[1])):
            graphs_target.append((org_labels[:, idx, 0], ))
        plotter.set_label(grid_idx=grid_idx,
                          xlabel='frames [' + str(hparams.frame_size_ms) +
                          ' ms]',
                          ylabel='target')
        plotter.set_data_list(grid_idx=grid_idx, data_list=graphs_target)
        plotter.set_area_list(grid_idx=grid_idx,
                              area_list=[(np.invert(original_vuv), '0.8', 1.0)
                                         ])
        plotter.set_lim(grid_idx=grid_idx, ymin=-1.8, ymax=1.8)

        grid_idx += 1
        output_atoms = AtomLabelGen.labels_to_atoms(
            labels_post,
            hparams.k,
            hparams.frame_size_ms,
            amp_threshold=hparams.min_atom_amp)
        wcad_lf0 = AtomLabelGen.atoms_to_lf0(org_atoms, len(labels))
        output_lf0 = AtomLabelGen.atoms_to_lf0(output_atoms, len(labels))
        graphs_lf0 = list()
        graphs_lf0.append((wcad_lf0, "wcad lf0"))
        graphs_lf0.append((original_lf0, "org lf0"))
        graphs_lf0.append((output_lf0, "predicted lf0"))
        plotter.set_data_list(grid_idx=grid_idx, data_list=graphs_lf0)
        plotter.set_area_list(grid_idx=grid_idx,
                              area_list=[(np.invert(original_vuv), '0.8', 1.0)
                                         ])
        plotter.set_label(grid_idx=grid_idx,
                          xlabel='frames [' + str(hparams.frame_size_ms) +
                          ' ms]',
                          ylabel='lf0')
        amp_lim = max(np.max(np.abs(wcad_lf0)), np.max(
            np.abs(output_lf0))) * 1.1
        plotter.set_lim(grid_idx=grid_idx, ymin=-amp_lim, ymax=amp_lim)
        plotter.set_linestyles(grid_idx=grid_idx, linestyles=[':', '--', '-'])

        # plotter.set_lim(xmin=300, xmax=1100)
        plotter.gen_plot()
        plotter.save_to_file(filename + ".BASE" + hparams.gen_figure_ext)
Beispiel #28
0
    def run_world_synth(synth_output, hparams):
        """Run the WORLD synthesize method."""

        fft_size = pyworld.get_cheaptrick_fft_size(hparams.synth_fs)

        save_dir = hparams.synth_dir if hparams.synth_dir is not None\
                                     else hparams.out_dir if hparams.out_dir is not None\
                                     else os.path.curdir
        for id_name, output in synth_output.items():
            logging.info(
                "Synthesise {} with the WORLD vocoder.".format(id_name))

            coded_sp, lf0, vuv, bap = WorldFeatLabelGen.convert_to_world_features(
                output,
                contains_deltas=False,
                num_coded_sps=hparams.num_coded_sps)
            amp_sp = WorldFeatLabelGen.decode_sp(
                coded_sp,
                hparams.sp_type,
                hparams.synth_fs,
                post_filtering=hparams.do_post_filtering).astype(np.double,
                                                                 copy=False)
            args = dict()
            for attr in "preemphasize", "f0_silence_threshold", "lf0_zero":
                if hasattr(hparams, attr):
                    args[attr] = getattr(hparams, attr)
            waveform = WorldFeatLabelGen.world_features_to_raw(
                amp_sp,
                lf0,
                vuv,
                bap,
                fs=hparams.synth_fs,
                n_fft=fft_size,
                **args)

            # f0 = np.exp(lf0, dtype=np.float64)
            # vuv[f0 < WorldFeatLabelGen.f0_silence_threshold] = 0  # WORLD throws an error for too small f0 values.
            # f0[vuv == 0] = 0.0
            # ap = pyworld.decode_aperiodicity(np.ascontiguousarray(bap.reshape(-1, 1), np.float64),
            #                                  hparams.synth_fs,
            #                                  fft_size)
            #
            # waveform = pyworld.synthesize(f0, amp_sp, ap, hparams.synth_fs)
            # waveform = waveform.astype(np.float32, copy=False)  # Does inplace conversion, if possible.

            # Always save as wav file first and convert afterwards if necessary.
            file_path = os.path.join(
                save_dir, "{}{}{}{}".format(
                    os.path.basename(id_name), "_" + hparams.model_name
                    if hparams.model_name is not None else "",
                    hparams.synth_file_suffix, "_WORLD"))
            makedirs_safe(hparams.synth_dir)
            soundfile.write(file_path + ".wav", waveform, hparams.synth_fs)

            # Use PyDub for special audio formats.
            if hparams.synth_ext.lower() != 'wav':
                as_wave = pydub.AudioSegment.from_wav(file_path + ".wav")
                file = as_wave.export(file_path + "." + hparams.synth_ext,
                                      format=hparams.synth_ext)
                file.close()
                os.remove(file_path + ".wav")
Beispiel #29
0
    def run_wavenet_vocoder(synth_output, hparams):
        # Import ModelHandlerPyTorch here to prevent circular dependencies.
        from idiaptts.src.neural_networks.pytorch.ModelHandlerPyTorch import ModelHandlerPyTorch

        assert hparams.synth_vocoder_path is not None, "Please set path to neural vocoder in hparams.synth_vocoder_path"
        # Add identifier to suffix.
        old_synth_file_suffix = hparams.synth_file_suffix
        hparams.synth_file_suffix += '_' + hparams.synth_vocoder

        if not hasattr(hparams, 'bit_depth'):
            hparams.add_hparam("bit_depth", 16)

        synth_output = copy.copy(synth_output)

        input_fs_Hz = 1000.0 / hparams.frame_size_ms
        assert hasattr(hparams, "frame_rate_output_Hz") and hparams.frame_rate_output_Hz is not None, \
            "hparams.frame_rate_output_Hz has to be set and match the trained WaveNet."
        in_to_out_multiplier = hparams.frame_rate_output_Hz / input_fs_Hz
        # # dir_world_features = os.path.join(self.OutputGen.dir_labels, self.dir_extracted_acoustic_features)
        input_gen = WorldFeatLabelGen(
            None,
            add_deltas=False,
            sampling_fn=partial(sample_linearly,
                                in_to_out_multiplier=in_to_out_multiplier,
                                dtype=np.float32))
        # Load normalisation parameters for wavenet input.
        try:
            norm_params_path = os.path.splitext(
                hparams.synth_vocoder_path)[0] + "_norm_params.npy"
            input_gen.norm_params = np.load(norm_params_path).reshape(2, -1)
        except FileNotFoundError:
            logging.error(
                "Cannot find normalisation parameters for WaveNet input at {}."
                "Please save them there with numpy.save().".format(
                    norm_params_path))
            raise

        model_handler = ModelHandlerPyTorch()
        model_handler.model, *_ = model_handler.load_model(
            hparams.synth_vocoder_path, hparams, verbose=False)

        for id_name, output in synth_output.items():
            logging.info("Synthesise {} with {} vocoder.".format(
                id_name, hparams.synth_vocoder_path))

            # Any other post-processing could be done here.

            # Normalize input.
            output = input_gen.preprocess_sample(output)

            # output (T x C) --transpose--> (C x T) --unsqueeze(0)--> (B x C x T).
            output = output.transpose()[None, ...]
            # Wavenet input has to be (B x C x T).
            output, _ = model_handler.forward(
                output, hparams, batch_seq_lengths=(output.shape[-1], ))
            # output, _ = model_handler.forward(output[:, :, :1000], hparams, batch_seq_lengths=(1000,))  # DEBUG
            output = output[0].transpose(
            )  # Remove batch dim and transpose back to (T x C).

            out_channels = output.shape[1]
            if out_channels > 1:  # Check if the output is one-hot (quantized) or 1 (raw).
                # Revert mu-law quantization.
                output = output.argmax(axis=1)
                synth_output[
                    id_name] = RawWaveformLabelGen.mu_law_companding_reversed(
                        output, out_channels)

            # Save the audio.
            wav_file_path = os.path.join(
                hparams.synth_dir, "".join(
                    (os.path.basename(id_name).rsplit('.', 1)[0], "_",
                     hparams.model_name, hparams.synth_file_suffix, ".",
                     hparams.synth_ext)))
            Synthesiser.raw_to_file(wav_file_path, synth_output[id_name],
                                    hparams.synth_fs, hparams.bit_depth)

        # Restore identifier.
        hparams.setattr_no_type_check(
            "synth_file_suffix",
            old_synth_file_suffix)  # Can be None, thus no type check.
class WaveNetVocoderTrainer(ModelTrainer):
    logger = logging.getLogger(__name__)

    #########################
    # Default constructor
    #
    def __init__(self, dir_world_features, id_list, hparams=None):

        if hparams is None:
            hparams = self.create_hparams()
            hparams.out_dir = os.path.curdir

        # Write missing default parameters.
        if hparams.variable_sequence_length_train is None:
            hparams.variable_sequence_length_train = hparams.batch_size_train > 1
        if hparams.variable_sequence_length_test is None:
            hparams.variable_sequence_length_test = hparams.batch_size_test > 1
        if hparams.synth_dir is None:
            hparams.synth_dir = os.path.join(hparams.out_dir, "synth")

        super().__init__(id_list, hparams)

        in_to_out_multiplier = int(hparams.frame_rate_output_Hz /
                                   (1000.0 / hparams.frame_size_ms))
        max_frames_input_trainset = int(
            1000.0 / hparams.frame_size_ms * hparams.max_input_train_sec
        ) * in_to_out_multiplier  # Multiply by number of seconds.
        max_frames_input_testset = int(
            1000.0 / hparams.frame_size_ms * hparams.max_input_test_sec
        ) * in_to_out_multiplier  # Ensure that test takes all frames. NOTE: Had to limit it because of memory constraints.

        self.InputGen = WorldFeatLabelGen(
            dir_world_features,
            add_deltas=False,
            sampling_fn=partial(sample_linearly,
                                in_to_out_multiplier=in_to_out_multiplier,
                                dtype=np.float32),
            num_coded_sps=hparams.num_coded_sps,
            sp_type=hparams.sp_type,
            load_sp=hparams.load_sp,
            load_lf0=hparams.load_lf0,
            load_vuv=hparams.load_vuv,
            load_bap=hparams.load_bap)
        self.InputGen.get_normalisation_params(
            dir_world_features, hparams.input_norm_params_file_prefix)

        self.OutputGen = RawWaveformLabelGen(
            frame_rate_output_Hz=hparams.frame_rate_output_Hz,
            frame_size_ms=hparams.frame_size_ms,
            mu=hparams.mu if hparams.input_type == "mulaw-quantize" else None,
            silence_threshold_quantized=hparams.silence_threshold_quantized)
        # No normalisation parameters required.

        self.dataset_train = LabelGensDataset(
            self.id_list_train,
            self.InputGen,
            self.OutputGen,
            hparams,
            random_select=True,
            max_frames_input=max_frames_input_trainset)
        self.dataset_val = LabelGensDataset(
            self.id_list_val,
            self.InputGen,
            self.OutputGen,
            hparams,
            random_select=True,
            max_frames_input=max_frames_input_testset)

        if self.loss_function is None:
            if hparams.input_type == "mulaw-quantize":
                self.loss_function = OneHotCrossEntropyLoss(reduction='none',
                                                            shift=1)
            else:
                self.loss_function = DiscretizedMixturelogisticLoss(
                    hparams.quantize_channels,
                    hparams.log_scale_min,
                    reduction='none',
                    hinge_loss=hparams.hinge_regularizer)

        if hparams.scheduler_type == "default":
            hparams.scheduler_type = "Noam"
            # hparams.scheduler_args['exponential_gamma'] = 0.99
            hparams.scheduler_args['wormup_steps'] = 4000

        # Override the collate and decollate methods of batches.
        self.batch_collate_fn = partial(self.prepare_batch,
                                        use_cond=hparams.use_cond,
                                        one_hot_target=True)
        self.batch_decollate_fn = self.decollate_network_output

    @staticmethod
    def create_hparams(hparams_string=None, verbose=False):
        """Create model hyper-parameters. Parse non-default from given string."""
        hparams = ModelTrainer.create_hparams(hparams_string, verbose=False)
        hparams.synth_vocoder = "raw"

        hparams.add_hparams(
            batch_first=True,
            frame_rate_output_Hz=16000,
            mu=255,
            bit_depth=16,
            silence_threshold_quantized=
            None,  # Beginning and end of audio below the threshold are trimmed.
            teacher_forcing_in_test=True,
            ema_decay=0.9999,

            # Model parameters.
            input_type="mulaw-quantize",
            hinge_regularizer=
            True,  # Only used in MoL prediction (input_type="raw").
            log_scale_min=float(np.log(
                1e-14)),  # Only used for mixture of logistic distributions.
            quantize_channels=256
        )  # 256 for input type mulaw-quantize, otherwise 65536
        if hparams.input_type == "mulaw-quantize":
            hparams.add_hparam("out_channels", hparams.quantize_channels)
        else:
            hparams.add_hparam("out_channels", 10 *
                               3)  # num_mixtures * 3 (pi, mean, log_scale)

        hparams.add_hparams(
            layers=24,  # 20
            stacks=4,  # 2
            residual_channels=512,
            gate_channels=512,
            skip_out_channels=256,
            dropout=1 - 0.95,
            kernel_size=3,
            weight_normalization=True,
            use_cond=True,  # Determines if conditioning is used.
            cin_channels=63,
            upsample_conditional_features=False,
            upsample_scales=[5, 4, 2])
        if hparams.upsample_conditional_features:
            hparams.len_in_out_multiplier = reduce(mul,
                                                   hparams.upsample_scales, 1)
        else:
            hparams.len_in_out_multiplier = 1

        hparams.add_hparams(freq_axis_kernel_size=3,
                            gin_channels=-1,
                            n_speakers=1,
                            use_speaker_embedding=False,
                            sp_type="mcep",
                            load_sp=True,
                            load_lf0=True,
                            load_vuv=True,
                            load_bap=True)

        if verbose:
            logging.info(hparams.get_debug_string())

        return hparams

    # Load train and test data.
    @staticmethod
    def prepare_batch(batch,
                      common_divisor=1,
                      batch_first=False,
                      use_cond=True,
                      one_hot_target=True):
        inputs, targets, seq_lengths_input, seq_lengths_output, mask, permutation = ModelHandler.prepare_batch(
            batch, common_divisor=common_divisor, batch_first=batch_first)

        if batch_first:
            # inputs: (B x T x C) --permute--> (B x C x T)
            inputs = inputs.transpose(1, 2).contiguous()
        # TODO: Handle case where batch_first=False: inputs = inputs.transpose(2, 0, 1).contiguous()?

        if targets is not None:
            if batch_first:
                # targets: (B x T x C) --permute--> (B x C x T)
                targets = targets.transpose(1, 2).contiguous()

            if not one_hot_target:
                targets = targets.max(dim=1, keepdim=True)[1].float()

        if mask is not None:
            mask = mask[:, 1:].contiguous()

        return inputs if use_cond else None, targets, seq_lengths_input, seq_lengths_output, mask, permutation

    @staticmethod
    def decollate_network_output(output,
                                 hidden,
                                 seq_lengths=None,
                                 permutation=None,
                                 batch_first=True):

        # Output of r9y9 Wavenet has batch first, thus output: B x C x T --transpose--> B x T x C
        output = np.transpose(output, (0, 2, 1))
        if not batch_first:
            # output: B x T x C --transpose--> T x B x C
            output = np.transpose(output, (1, 0, 2))
        return ModelTrainer.split_batch(output,
                                        hidden,
                                        seq_length_output=seq_lengths,
                                        permutation=permutation,
                                        batch_first=batch_first)

    def gen_figure_from_output(self, id_name, labels, hidden, hparams):

        labels_post = self.dataset_train.postprocess_sample(
            labels)  # Labels come in as T x C.
        org_raw = RawWaveformLabelGen.load_sample(
            id_name, self.OutputGen.frame_rate_output_Hz)

        # Get a data plotter.
        plotter = DataPlotter()
        net_name = os.path.basename(hparams.model_name)
        id_name = os.path.basename(id_name).rsplit('.', 1)[0]
        filename = os.path.join(hparams.out_dir, id_name + "." + net_name)
        plotter.set_title(id_name + " - " + net_name)
        grid_idx = 0

        graphs = list()
        graphs.append((org_raw, 'Org'))
        graphs.append((labels_post, 'Wavenet'))
        plotter.set_data_list(grid_idx=grid_idx, data_list=graphs)
        plotter.set_linewidth(grid_idx=grid_idx, linewidth=[0.1])
        plotter.set_colors(grid_idx=grid_idx, alpha=0.8)
        plotter.set_lim(grid_idx, ymin=-1, ymax=1)
        plotter.set_label(grid_idx=grid_idx,
                          xlabel='frames [' +
                          str(hparams.frame_rate_output_Hz) + ' Hz]',
                          ylabel='raw')

        plotter.gen_plot()
        plotter.save_to_file(filename + '.Raw' + hparams.gen_figure_ext)

    # def synthesize(self, file_id_list, synth_output, hparams):
    #     self.run_raw_synth(synth_output, hparams)

    # def synth_ref(self, hparams, file_id_list):
    #     self.logger.info("Synthesise references for [{0}].".format(", ".join([id_name for id_name in file_id_list])))  # Can be different from original by sampling frequency.
    #
    #     synth_output = dict()
    #     for id_name in file_id_list:
    #         # Use extracted data. Useful to create a reference.
    #         raw = RawWaveformLabelGen.load_sample(id_name, self.OutputGen.frame_rate_output_Hz)
    #         synth_output[id_name] = raw
    #
    #     # Add identifier to suffix.
    #     old_synth_file_suffix = hparams.synth_file_suffix
    #     hparams.synth_file_suffix += '_ref'
    #
    #     # Run the WORLD synthesiser.
    #     self.run_raw_synth(synth_output, hparams)
    #
    #     # Restore identifier.
    #     hparams.synth_file_suffix = old_synth_file_suffix

    def save_for_vocoding(self, filename):
        # Save the full model so that hyper-parameters are already set.
        self.model_handler.save_full_model(filename,
                                           self.model_handler.model,
                                           verbose=True)
        # Save an easily loadable version of the normalisation parameters on the input side used during training.
        np.save(
            os.path.splitext(filename)[0] + "_norm_params",
            np.concatenate(self.InputGen.norm_params, axis=0))