Exemplo n.º 1
0
    def test_synth_wav(self):
        num_test_files = 2

        hparams = self._get_hparams()
        hparams.out_dir = os.path.join(
            hparams.out_dir, "test_synth_wav")  # Add function name to path
        hparams.model_name = "test_model_in409_out67"
        hparams.model_path = os.path.join("integration", "fixtures",
                                          hparams.model_name,
                                          hparams.networks_dir)
        hparams.synth_fs = 16000
        hparams.frame_size_ms = 5
        hparams.synth_ext = "wav"
        hparams.synth_load_org_sp = True
        hparams.synth_load_org_lf0 = True
        hparams.synth_load_org_vuv = True
        hparams.synth_load_org_bap = True

        trainer = AcousticModelTrainer(
            **AcousticModelTrainer.legacy_support_init(
                self.dir_world_features, self.dir_question_labels,
                self.id_list, hparams.num_questions, hparams))
        trainer.init(hparams)
        hparams.synth_dir = os.path.join(hparams.out_dir, hparams.model_name)
        trainer.synth(hparams, self.id_list[:num_test_files])

        found_files = list([
            name for name in os.listdir(hparams.synth_dir)
            if os.path.isfile(os.path.join(hparams.synth_dir, name))
            and name.endswith("_WORLD." + hparams.synth_ext)
        ])
        # Check number of created files.
        self.assertEqual(
            len(self.id_list[:num_test_files]),
            len(found_files),
            msg="Number of {} files in synth_dir directory does not match.".
            format(hparams.synth_ext))

        # Check readability and length of one created file.
        raw, fs = soundfile.read(
            os.path.join(hparams.synth_dir, found_files[0]))
        self.assertEqual(
            hparams.synth_fs,
            fs,
            msg="Desired sampling frequency of output doesn't match.")
        labels = trainer.datareaders["acoustic_features"][[
            id_name for id_name in self.id_list[:num_test_files]
            if id_name in found_files[0]
        ][0]]
        expected_length = len(
            raw) / hparams.synth_fs / hparams.frame_size_ms * 1000
        self.assertTrue(
            abs(expected_length - len(labels["acoustic_features"])) < 10,
            msg=
            "Saved raw audio file length does not roughly match length of labels."
        )

        shutil.rmtree(hparams.out_dir)
Exemplo n.º 2
0
    def run_DM_AM(hparams, input_strings):
        """
        A function for TTS with a pre-trained duration and acoustic model.

        :param hparams:            Hyper-parameter container. The following parameters are used:
                                   front_end:                    Full path to the makeLabels.sh script in scripts/tts_frontend, depends on the language.
                                   festival_dir:                 Full path to the directory with the festival bin/ folder.
                                   front_end_accent (optional):  Give an accent to the front_end, used in tts_frontend.
                                   duration_labels_dir:          Full path to the folder containing the normalisation parameters used to train the duration model.
                                   file_symbol_dict:             A file containing all the used phonemes (has been used to train the duration model, usually mono_phone.list).
                                   duration_model:               Full path to the pre-trained duration model.
                                   num_phoneme_states:           Number of states per phoneme, for each a duration is predicted by the duration model.
                                   question_file:               Full path to question file used to train the acoustic model.
                                   question_labels_norm_file:    Full path to normalisation file of questions used to train the acoustic model.
                                   num_questions:                Number of questions which form the input dimension to the acoustic model.
                                   acoustic_model:               Full path to acoustic model.
        :param input_strings:
        :return:
        """
        # Create a temporary directory to store all files.
        with tempfile.TemporaryDirectory() as tmp_dir_name:
            # tmp_dir_name = os.path.realpath("TMP")
            # makedirs_safe(tmp_dir_name)
            hparams.out_dir = tmp_dir_name
            print("Created temporary directory", tmp_dir_name)
            id_list = ["synth" + str(idx) for idx in range(len(input_strings))]

            # Write the text to synthesise into a single synth.txt file with ids.
            utts_file = os.path.join(tmp_dir_name, "synth.txt")
            with open(utts_file, "w") as text_file:
                for idx, text in enumerate(input_strings):
                    text_file.write("synth{}\t{}\n".format(
                        idx, text))  # TODO: Remove parenthesis etc.

            # Call the front end on the synth.txt file.
            front_end_arguments = [
                hparams.front_end, hparams.festival_dir, utts_file
            ]
            if hasattr(hparams, "front_end_accent"
                       ) and hparams.front_end_accent is not None:
                front_end_arguments.append(hparams.front_end_accent)
            front_end_arguments.append(tmp_dir_name)
            subprocess.check_call(front_end_arguments)

            # Remove durations from mono labels.
            dir_mono_no_align = os.path.join(tmp_dir_name, "mono_no_align")
            dir_mono = os.path.join(tmp_dir_name, "labels", "mono")

            if os.path.isdir(dir_mono_no_align):
                shutil.rmtree(dir_mono_no_align)
            os.rename(dir_mono, dir_mono_no_align)
            for id_name in id_list:
                with open(os.path.join(dir_mono_no_align, id_name + ".lab"),
                          "r") as f:
                    old = f.read()
                    monophones = old.split()[2::3]
                with open(os.path.join(dir_mono_no_align, id_name + ".lab"),
                          "w") as f:
                    f.write("\n".join(monophones))

            # Run duration model.
            hparams.batch_size_test = len(input_strings)
            hparams.test_set_perc = 0.0
            hparams.val_set_perc = 0.0
            hparams.phoneme_label_type = "mono_no_align"
            hparams.output_norm_params_file_prefix = hparams.duration_norm_file_name if hasattr(
                hparams, "duration_norm_file_name") else None
            duration_model_trainer = DurationModelTrainer(
                os.path.join(tmp_dir_name,
                             "mono_no_align"), hparams.duration_labels_dir,
                id_list, hparams.file_symbol_dict, hparams)
            assert hparams.duration_model is not None, "Path to duration model in hparams.duration_model is needed."
            hparams.model_path = hparams.duration_model
            hparams.model_name = os.path.basename(hparams.duration_model)

            # Predict durations. Durations are already converted to multiples of hparams.min_phoneme_length.
            hparams.load_from_checkpoint = True
            duration_model_trainer.init(hparams)
            _, output_dict_post = duration_model_trainer.forward(
                hparams, id_list)
            hparams.output_norm_params_file_prefix = None  # Reset again.

            # Write duration to full labels.
            dir_full = os.path.join(tmp_dir_name, "labels", "full")
            dir_label_state_align = os.path.join(tmp_dir_name, "labels",
                                                 "label_state_align")
            makedirs_safe(dir_label_state_align)
            for id_name in id_list:
                with open(os.path.join(dir_full, id_name + ".lab"), "r") as f:
                    full = f.read().split()[2::3]
                with open(
                        os.path.join(dir_label_state_align, id_name + ".lab"),
                        "w") as f:
                    current_time = 0
                    timings = output_dict_post[id_name]
                    for idx, monophone in enumerate(full):
                        for state in range(hparams.num_phoneme_states):
                            next_time = current_time + int(timings[idx, state])
                            f.write("{}\t{}\t{}[{}]\n".format(
                                current_time, next_time, monophone, state + 2))
                            current_time = next_time

            # Generate questions from HTK full labels.
            QuestionLabelGen.gen_data(dir_label_state_align,
                                      hparams.question_file,
                                      dir_out=tmp_dir_name,
                                      file_id_list="synth",
                                      id_list=id_list,
                                      return_dict=False)

            # Run acoustic model and synthesise.
            shutil.copy2(hparams.question_labels_norm_file,
                         tmp_dir_name + "/min-max.bin"
                         )  # Get normalisation parameters in same directory.
            acoustic_model_trainer = AcousticModelTrainer(
                hparams.world_features_dir, tmp_dir_name, id_list,
                hparams.num_questions, hparams)
            assert hparams.acoustic_model is not None, "Path to acoustic model in hparams.acoustic_model is needed."
            hparams.model_path = hparams.acoustic_model
            hparams.model_name = os.path.basename(hparams.acoustic_model)
            hparams.load_from_checkpoint = True
            acoustic_model_trainer.init(hparams)
            hparams.model_name = ""  # No suffix in synthesised files.
            _, output_dict_post = acoustic_model_trainer.synth(
                hparams, id_list)

            logging.info("Synthesized files are in {}.".format(
                hparams.synth_dir))

        return 0