def compute_score(self, data, output, hparams):

        dict_original_post = self.get_output_dict(data.keys(),
                                                  hparams,
                                                  chunk_size=hparams.get_value(
                                                      "n_frames_per_step",
                                                      default=1))

        metric_dict = {}
        for label_name in next(iter(data.values())).keys():
            metric = Metrics(hparams.metrics)
            for id_name, labels in data.items():
                labels = labels[label_name]
                output = WorldFeatLabelGen.convert_to_world_features(
                    sample=labels,
                    contains_deltas=False,
                    num_coded_sps=hparams.num_coded_sps,
                    num_bap=hparams.num_bap)

                org = WorldFeatLabelGen.convert_to_world_features(
                    sample=dict_original_post[id_name],
                    contains_deltas=hparams.add_deltas,
                    num_coded_sps=hparams.num_coded_sps,
                    num_bap=hparams.num_bap)

                current_metrics = metric.get_metrics(hparams.metrics, *org,
                                                     *output)
                metric.accumulate(id_name, current_metrics)

            metric.log()
            metric_dict[label_name] = metric.get_cum_values()

        return metric_dict
Exemple #2
0
    def synth_ref(hparams, file_id_list, feature_dir=None):
        # Create reference audio files containing only the vocoder degradation.
        logging.info("Synthesise references with {} for [{}].".format(
            hparams.synth_vocoder,
            ", ".join([id_name for id_name in file_id_list])))

        synth_dict = dict()
        old_synth_file_suffix = hparams.synth_file_suffix
        hparams.synth_file_suffix = '_ref'
        if hparams.synth_vocoder == "WORLD":
            for id_name in file_id_list:
                # Load reference audio features.
                try:
                    output = WorldFeatLabelGen.load_sample(
                        id_name,
                        feature_dir,
                        num_coded_sps=hparams.num_coded_sps)
                except FileNotFoundError as e1:
                    try:
                        output = WorldFeatLabelGen.load_sample(
                            id_name,
                            feature_dir,
                            add_deltas=True,
                            num_coded_sps=hparams.num_coded_sps)
                        coded_sp, lf0, vuv, bap = WorldFeatLabelGen.convert_to_world_features(
                            output,
                            contains_deltas=True,
                            num_coded_sps=hparams.num_coded_sps)
                        length = len(output)
                        lf0 = lf0.reshape(length, 1)
                        vuv = vuv.reshape(length, 1)
                        bap = bap.reshape(length, 1)
                        output = np.concatenate((coded_sp, lf0, vuv, bap),
                                                axis=1)
                    except FileNotFoundError as e2:
                        logging.error(
                            "Cannot find extracted WORLD features with or without deltas in {}."
                            .format(feature_dir))
                        raise Exception([e1, e2])
                synth_dict[id_name] = output

            # Add identifier to suffix.
            old_synth_file_suffix = hparams.synth_file_suffix
            hparams.synth_file_suffix += str(hparams.num_coded_sps) + 'sp'
            Synthesiser.run_world_synth(synth_dict, hparams)
        elif hparams.synth_vocoder == "raw":
            for id_name in file_id_list:
                # Use extracted data. Useful to create a reference.
                raw = RawWaveformLabelGen.load_sample(
                    id_name, hparams.frame_rate_output_Hz)
                synth_dict[id_name] = raw
            Synthesiser.run_raw_synth(synth_dict, hparams)
        else:
            raise NotImplementedError("Unknown vocoder type {}.".format(
                hparams.synth_vocoder))

        # Restore identifier.
        hparams.synth_file_suffix = old_synth_file_suffix
Exemple #3
0
    def run_world_synth(synth_output: Dict[str, np.ndarray],
                        hparams: ExtendedHParams,
                        epoch: int = None,
                        step: int = None,
                        use_model_name: bool = True,
                        has_deltas: bool = False) -> None:
        """Run the WORLD synthesize method."""

        fft_size = pyworld.get_cheaptrick_fft_size(hparams.synth_fs)

        save_dir = Synthesiser._get_synth_dir(hparams,
                                              use_model_name,
                                              epoch=epoch,
                                              step=step)

        for id_name, output in synth_output.items():
            logging.info(
                "Synthesise {} with the WORLD vocoder.".format(id_name))

            coded_sp, lf0, vuv, bap = WorldFeatLabelGen.convert_to_world_features(
                output,
                contains_deltas=has_deltas,
                num_coded_sps=hparams.num_coded_sps,
                num_bap=hparams.num_bap)
            amp_sp = AudioProcessing.decode_sp(
                coded_sp,
                hparams.sp_type,
                hparams.synth_fs,
                post_filtering=hparams.do_post_filtering).astype(np.double,
                                                                 copy=False)
            args = dict()
            for attr in "preemphasis", "f0_silence_threshold", "lf0_zero":
                if hasattr(hparams, attr):
                    args[attr] = getattr(hparams, attr)
            waveform = WorldFeatLabelGen.world_features_to_raw(
                amp_sp,
                lf0,
                vuv,
                bap,
                fs=hparams.synth_fs,
                n_fft=fft_size,
                **args)

            # Always save as wav file first and convert afterwards if necessary.
            file_name = (os.path.basename(id_name) +
                         hparams.synth_file_suffix + '_' +
                         str(hparams.num_coded_sps) + hparams.sp_type +
                         "_WORLD")
            file_path = os.path.join(save_dir, file_name)
            soundfile.write(file_path + ".wav", waveform, hparams.synth_fs)

            # Use PyDub for special audio formats.
            if hparams.synth_ext.lower() != 'wav':
                as_wave = pydub.AudioSegment.from_wav(file_path + ".wav")
                file = as_wave.export(file_path + "." + hparams.synth_ext,
                                      format=hparams.synth_ext)
                file.close()
                os.remove(file_path + ".wav")
Exemple #4
0
    def run_r9y9wavenet_mulaw_world_feats_synth(synth_output, hparams):

        # If no path is given, use pre-trained model.
        if not hasattr(
                hparams,
                "synth_vocoder_path") or hparams.synth_vocoder_path is None:
            parent_dirs = os.path.realpath(__file__).split(os.sep)
            dir_root = str.join(
                os.sep, parent_dirs[:parent_dirs.index("IdiapTTS") + 1])
            hparams.synth_vocoder_path = os.path.join(
                dir_root, "idiaptts", "misc", "pretrained",
                "r9y9wavenet_quantized_16k_world_feats_English.nn")

        # Default quantization is with mu=255.
        if not hasattr(hparams, "mu") or hparams.mu is None:
            hparams.add_hparam("mu", 255)

        if hasattr(hparams, 'frame_rate_output_Hz'):
            org_frame_rate_output_Hz = hparams.frame_rate_output_Hz
            hparams.frame_rate_output_Hz = 16000
        else:
            org_frame_rate_output_Hz = None
            hparams.add_hparam("frame_rate_output_Hz", 16000)

        synth_output = copy.copy(synth_output)

        if hparams.do_post_filtering:
            for id_name, output in synth_output.items():
                coded_sp, lf0, vuv, bap = WorldFeatLabelGen.convert_to_world_features(
                    output,
                    contains_deltas=False,
                    num_coded_sps=hparams.num_coded_sps)
                coded_sp = merlin_post_filter(
                    coded_sp,
                    WorldFeatLabelGen.fs_to_mgc_alpha(hparams.synth_fs))
                synth_output[
                    id_name] = WorldFeatLabelGen.convert_from_world_features(
                        coded_sp, lf0, vuv, bap)

        if hasattr(hparams, 'bit_depth'):
            org_bit_depth = hparams.bit_depth
            hparams.bit_depth = 16
        else:
            org_bit_depth = None
            hparams.add_hparam("bit_depth", 16)

        Synthesiser.run_wavenet_vocoder(synth_output, hparams)

        # Restore identifier.
        hparams.setattr_no_type_check(
            "bit_depth", org_bit_depth)  # Can be None, thus no type check.
        hparams.setattr_no_type_check("frame_rate_output_Hz",
                                      org_frame_rate_output_Hz)  # Can be None.
    def gen_figure_from_output(self, id_name, labels, hidden, hparams):

        if labels.ndim < 2:
            labels = np.expand_dims(labels, axis=1)
        labels_post = self.OutputGen.postprocess_sample(labels,
                                                        identify_peaks=True,
                                                        peak_range=100)
        lf0 = self.OutputGen.labels_to_lf0(labels_post, hparams.k)
        lf0, vuv = interpolate_lin(lf0)
        vuv = vuv.astype(np.bool)

        # Load original lf0 and vuv.
        world_dir = hparams.world_dir if hasattr(hparams, "world_dir") and hparams.world_dir is not None\
                                      else os.path.join(self.OutputGen.dir_labels, self.dir_extracted_acoustic_features)
        org_labels = WorldFeatLabelGen.load_sample(
            id_name, world_dir, num_coded_sps=hparams.num_coded_sps)
        _, original_lf0, original_vuv, _ = WorldFeatLabelGen.convert_to_world_features(
            org_labels, num_coded_sps=hparams.num_coded_sps)
        original_lf0, _ = interpolate_lin(original_lf0)
        original_vuv = original_vuv.astype(np.bool)

        phrase_curve = np.fromfile(os.path.join(
            self.OutputGen.dir_labels, id_name + self.OutputGen.ext_phrase),
                                   dtype=np.float32).reshape(-1, 1)
        original_lf0 -= phrase_curve
        len_diff = len(original_lf0) - len(lf0)
        original_lf0 = WorldFeatLabelGen.trim_end_sample(
            original_lf0, int(len_diff / 2.0))
        original_lf0 = WorldFeatLabelGen.trim_end_sample(original_lf0,
                                                         int(len_diff / 2.0) +
                                                         1,
                                                         reverse=True)

        org_labels = self.OutputGen.load_sample(id_name,
                                                self.OutputGen.dir_labels,
                                                len(hparams.thetas))
        org_labels = self.OutputGen.trim_end_sample(org_labels,
                                                    int(len_diff / 2.0))
        org_labels = self.OutputGen.trim_end_sample(org_labels,
                                                    int(len_diff / 2.0) + 1,
                                                    reverse=True)
        org_atoms = self.OutputGen.labels_to_atoms(
            org_labels, k=hparams.k, frame_size=hparams.frame_size_ms)

        # Get a data plotter.
        net_name = os.path.basename(hparams.model_name)
        filename = str(os.path.join(hparams.out_dir, id_name + '.' + net_name))
        plotter = DataPlotter()
        plotter.set_title(id_name + " - " + net_name)

        graphs_output = list()
        grid_idx = 0
        for idx in reversed(range(labels.shape[1])):
            graphs_output.append(
                (labels[:, idx],
                 r'$\theta$=' + "{0:.3f}".format(hparams.thetas[idx])))
        plotter.set_label(grid_idx=grid_idx,
                          xlabel='frames [' + str(hparams.frame_size_ms) +
                          ' ms]',
                          ylabel='NN output')
        plotter.set_data_list(grid_idx=grid_idx, data_list=graphs_output)
        # plotter.set_lim(grid_idx=0, ymin=-1.8, ymax=1.8)

        grid_idx += 1
        graphs_peaks = list()
        for idx in reversed(range(labels_post.shape[1])):
            graphs_peaks.append((labels_post[:, idx, 0], ))
        plotter.set_label(grid_idx=grid_idx,
                          xlabel='frames [' + str(hparams.frame_size_ms) +
                          ' ms]',
                          ylabel='NN post-processed')
        plotter.set_data_list(grid_idx=grid_idx, data_list=graphs_peaks)
        plotter.set_area_list(grid_idx=grid_idx,
                              area_list=[(np.invert(vuv), '0.8', 1.0)])
        plotter.set_lim(grid_idx=grid_idx, ymin=-1.8, ymax=1.8)

        grid_idx += 1
        graphs_target = list()
        for idx in reversed(range(org_labels.shape[1])):
            graphs_target.append((org_labels[:, idx, 0], ))
        plotter.set_label(grid_idx=grid_idx,
                          xlabel='frames [' + str(hparams.frame_size_ms) +
                          ' ms]',
                          ylabel='target')
        plotter.set_data_list(grid_idx=grid_idx, data_list=graphs_target)
        plotter.set_area_list(grid_idx=grid_idx,
                              area_list=[(np.invert(original_vuv), '0.8', 1.0)
                                         ])
        plotter.set_lim(grid_idx=grid_idx, ymin=-1.8, ymax=1.8)

        grid_idx += 1
        output_atoms = AtomLabelGen.labels_to_atoms(
            labels_post,
            hparams.k,
            hparams.frame_size_ms,
            amp_threshold=hparams.min_atom_amp)
        wcad_lf0 = AtomLabelGen.atoms_to_lf0(org_atoms, len(labels))
        output_lf0 = AtomLabelGen.atoms_to_lf0(output_atoms, len(labels))
        graphs_lf0 = list()
        graphs_lf0.append((wcad_lf0, "wcad lf0"))
        graphs_lf0.append((original_lf0, "org lf0"))
        graphs_lf0.append((output_lf0, "predicted lf0"))
        plotter.set_data_list(grid_idx=grid_idx, data_list=graphs_lf0)
        plotter.set_area_list(grid_idx=grid_idx,
                              area_list=[(np.invert(original_vuv), '0.8', 1.0)
                                         ])
        plotter.set_label(grid_idx=grid_idx,
                          xlabel='frames [' + str(hparams.frame_size_ms) +
                          ' ms]',
                          ylabel='lf0')
        amp_lim = max(np.max(np.abs(wcad_lf0)), np.max(
            np.abs(output_lf0))) * 1.1
        plotter.set_lim(grid_idx=grid_idx, ymin=-amp_lim, ymax=amp_lim)
        plotter.set_linestyles(grid_idx=grid_idx, linestyles=[':', '--', '-'])

        # plotter.set_lim(xmin=300, xmax=1100)
        plotter.gen_plot()
        plotter.save_to_file(filename + ".BASE" + hparams.gen_figure_ext)
Exemple #6
0
    def gen_figure_phrase(self, hparams, ids_input):
        id_list = ModelTrainer._input_to_str_list(ids_input)
        model_output, model_output_post = self._forward_batched(
            hparams,
            id_list,
            hparams.batch_size_gen_figure,
            synth=False,
            benchmark=False,
            gen_figure=False)

        for id_name, outputs_post in model_output_post.items():

            if outputs_post.ndim < 2:
                outputs_post = np.expand_dims(outputs_post, axis=1)

            lf0 = outputs_post[:, 0]
            output_lf0, _ = interpolate_lin(lf0)
            output_vuv = outputs_post[:, 1]
            output_vuv[output_vuv < 0.5] = 0.0
            output_vuv[output_vuv >= 0.5] = 1.0
            output_vuv = output_vuv.astype(np.bool)

            # Load original lf0 and vuv.
            world_dir = hparams.world_dir if hasattr(hparams, "world_dir") and hparams.world_dir is not None\
                                          else os.path.join(hparams.out_dir, self.dir_extracted_acoustic_features)
            org_labels = WorldFeatLabelGen.load_sample(
                id_name,
                world_dir,
                num_coded_sps=hparams.num_coded_sps,
                num_bap=hparams.num_bap)[:len(output_lf0)]
            _, original_lf0, original_vuv, _ = WorldFeatLabelGen.convert_to_world_features(
                org_labels,
                num_coded_sps=hparams.num_coded_sps,
                num_bap=hparams.num_bap)
            original_lf0, _ = interpolate_lin(original_lf0)
            original_vuv = original_vuv.astype(np.bool)

            phrase_curve = np.fromfile(os.path.join(
                self.flat_trainer.atom_trainer.OutputGen.dir_labels,
                id_name + self.OutputGen.ext_phrase),
                                       dtype=np.float32).reshape(
                                           -1, 1)[:len(original_lf0)]

            f0_mse = (np.exp(original_lf0.squeeze(-1)) -
                      np.exp(phrase_curve.squeeze(-1)))**2
            f0_rmse = math.sqrt(
                (f0_mse * original_vuv[:len(output_lf0)]).sum() /
                original_vuv[:len(output_lf0)].sum())
            self.logger.info("RMSE of {} phrase curve: {} Hz.".format(
                id_name, f0_rmse))

            len_diff = len(original_lf0) - len(lf0)
            original_lf0 = WorldFeatLabelGen.trim_end_sample(
                original_lf0, int(len_diff / 2.0))
            original_lf0 = WorldFeatLabelGen.trim_end_sample(
                original_lf0, int(len_diff / 2.0) + 1, reverse=True)

            # Get a data plotter.
            net_name = os.path.basename(hparams.model_name)
            filename = str(
                os.path.join(hparams.out_dir, id_name + '.' + net_name))
            plotter = DataPlotter()
            # plotter.set_title(id_name + " - " + net_name)

            grid_idx = 0
            graphs_lf0 = list()
            graphs_lf0.append((original_lf0, "Original"))
            graphs_lf0.append((phrase_curve, "Predicted"))
            plotter.set_data_list(grid_idx=grid_idx, data_list=graphs_lf0)
            plotter.set_area_list(grid_idx=grid_idx,
                                  area_list=[(np.invert(original_vuv), '0.8',
                                              1.0, 'Reference unvoiced')])
            plotter.set_label(grid_idx=grid_idx,
                              xlabel='frames [' + str(hparams.frame_size_ms) +
                              ' ms]',
                              ylabel='LF0')
            # amp_lim = max(np.max(np.abs(wcad_lf0)), np.max(np.abs(output_lf0))) * 1.1
            # plotter.set_lim(grid_idx=grid_idx, ymin=-amp_lim, ymax=amp_lim)
            plotter.set_lim(grid_idx=grid_idx, ymin=4.2, ymax=5.4)
            # plotter.set_linestyles(grid_idx=grid_idx, linestyles=[':', '--', '-'])

            # plotter.set_lim(xmin=300, xmax=1100)
            plotter.gen_plot()
            plotter.save_to_file(filename + ".PHRASE" + hparams.gen_figure_ext)
    def gen_figure_from_output(self, id_name, label, hidden, hparams):
        _, alphas = hidden
        labels_post = self.OutputGen.postprocess_sample(label)
        coded_sp, lf0, vuv, bap = WorldFeatLabelGen.convert_to_world_features(
            labels_post,
            contains_deltas=False,
            num_coded_sps=hparams.num_coded_sps)
        sp = WorldFeatLabelGen.mcep_to_amp_sp(coded_sp, hparams.synth_fs)
        lf0, _ = interpolate_lin(lf0)

        # Load original LF0.
        org_labels_post = WorldFeatLabelGen.load_sample(
            id_name,
            dir_out=self.OutputGen.dir_labels,
            add_deltas=self.OutputGen.add_deltas,
            num_coded_sps=hparams.num_coded_sps)
        original_mgc, original_lf0, original_vuv, *_ = WorldFeatLabelGen.convert_to_world_features(
            sample=org_labels_post,
            contains_deltas=self.OutputGen.add_deltas,
            num_coded_sps=hparams.num_coded_sps)
        original_lf0, _ = interpolate_lin(original_lf0)

        sp = sp[:, :150]  # Zoom into spectral features.

        # Get a data plotter.
        grid_idx = -1
        plotter = DataPlotter()
        net_name = os.path.basename(hparams.model_name)
        filename = str(os.path.join(hparams.out_dir, id_name + '.' + net_name))
        plotter.set_title(id_name + ' - ' + net_name)
        plotter.set_num_colors(3)
        # plotter.set_lim(grid_idx=0, ymin=math.log(60), ymax=math.log(250))

        # # Plot LF0
        # grid_idx += 1
        # graphs.append((original_lf0, 'Original LF0'))
        # graphs.append((lf0, 'NN LF0'))
        # plotter.set_data_list(grid_idx=grid_idx, data_list=graphs)
        # plotter.set_area_list(grid_idx=grid_idx, area_list=[(np.invert(vuv.astype(bool)), '0.8', 1.0),
        #                                                     (np.invert(original_vuv.astype(bool)), 'red', 0.2)])
        # plotter.set_label(grid_idx=grid_idx, xlabel='frames [{}] ms'.format(hparams.frame_length), ylabel='log(f0)')

        # Reverse the warping.
        wl = self._get_dummy_warping_layer(hparams)
        norm_params_no_deltas = (
            self.OutputGen.norm_params[0][:hparams.num_coded_sps],
            self.OutputGen.norm_params[1][:hparams.num_coded_sps])
        pre_net_output, _ = wl.forward_sample(label, -alphas)

        # Postprocess sample manually.
        pre_net_output = pre_net_output.detach().cpu().numpy()
        pre_net_mgc = pre_net_output[:, 0, :hparams.
                                     num_coded_sps] * norm_params_no_deltas[
                                         1] + norm_params_no_deltas[0]

        # Plot spectral features predicted by pre-network.
        grid_idx += 1
        plotter.set_label(grid_idx=grid_idx,
                          xlabel='frames [{}] ms'.format(
                              hparams.frame_size_ms),
                          ylabel='Pre-network')
        plotter.set_specshow(grid_idx=grid_idx,
                             spec=np.abs(
                                 WorldFeatLabelGen.mcep_to_amp_sp(
                                     pre_net_mgc,
                                     hparams.synth_fs)[:, :sp.shape[1]]))

        # Plot final predicted spectral features.
        grid_idx += 1
        plotter.set_label(grid_idx=grid_idx,
                          xlabel='frames [{}] ms'.format(
                              hparams.frame_size_ms),
                          ylabel='VTLN')
        plotter.set_specshow(grid_idx=grid_idx, spec=np.abs(sp))

        # Plot predicted alpha value and V/UV flag.
        grid_idx += 1
        plotter.set_label(grid_idx=grid_idx,
                          xlabel='frames [{}] ms'.format(
                              hparams.frame_size_ms),
                          ylabel='alpha')
        graphs = list()
        graphs.append((alphas, 'NN alpha'))
        plotter.set_data_list(grid_idx=grid_idx, data_list=graphs)
        plotter.set_area_list(grid_idx=grid_idx,
                              area_list=[(np.invert(vuv.astype(bool)), '0.8',
                                          1.0),
                                         (np.invert(original_vuv.astype(bool)),
                                          'red', 0.2)])

        # Add phoneme annotations if given.
        if hasattr(hparams, "phoneme_indices") and hparams.phoneme_indices is not None \
           and hasattr(hparams, "question_file") and hparams.question_file is not None:
            questions = QuestionLabelGen.load_sample(
                id_name,
                os.path.join("experiments", hparams.voice, "questions"),
                num_questions=hparams.num_questions)[:len(lf0)]
            np_phonemes = QuestionLabelGen.questions_to_phonemes(
                questions, hparams.phoneme_indices, hparams.question_file)
            plotter.set_annotations(grid_idx, np_phonemes)

        # Plot reference spectral features.
        grid_idx += 1
        plotter.set_label(grid_idx=grid_idx,
                          xlabel='frames [{}] ms'.format(
                              hparams.frame_size_ms),
                          ylabel='Original spectrogram')
        plotter.set_specshow(grid_idx=grid_idx,
                             spec=np.abs(
                                 WorldFeatLabelGen.mcep_to_amp_sp(
                                     original_mgc,
                                     hparams.synth_fs)[:, :sp.shape[1]]))

        plotter.gen_plot()
        plotter.save_to_file(filename + '.VTLN' + hparams.gen_figure_ext)
Exemple #8
0
    def gen_figure_from_output(self, id_name, label, hidden, hparams):
        _, alphas = hidden
        labels_post = self.OutputGen.postprocess_sample(label)
        coded_sp, lf0, vuv, bap = WorldFeatLabelGen.convert_to_world_features(
            labels_post,
            contains_deltas=False,
            num_coded_sps=hparams.num_coded_sps)
        sp = WorldFeatLabelGen.mcep_to_amp_sp(coded_sp, hparams.synth_fs)
        lf0, _ = interpolate_lin(lf0)

        # Load original lf0.
        org_labels_post = WorldFeatLabelGen.load_sample(
            id_name,
            self.OutputGen.dir_labels,
            add_deltas=self.OutputGen.add_deltas,
            num_coded_sps=hparams.num_coded_sps)
        original_mgc, original_lf0, original_vuv, *_ = WorldFeatLabelGen.convert_to_world_features(
            org_labels_post,
            contains_deltas=self.OutputGen.add_deltas,
            num_coded_sps=hparams.num_coded_sps)
        original_lf0, _ = interpolate_lin(original_lf0)

        questions = QuestionLabelGen.load_sample(
            id_name,
            os.path.join("experiments", hparams.voice, "questions"),
            num_questions=hparams.num_questions)[:len(alphas)]
        phoneme_indices = QuestionLabelGen.questions_to_phoneme_indices(
            questions, hparams.phoneme_indices)
        alpha_vec = self.phonemes_to_alpha_tensor[phoneme_indices % len(
            self.phonemes_to_alpha_tensor)]

        # Get a data plotter.
        grid_idx = 0
        plotter = DataPlotter()
        net_name = os.path.basename(hparams.model_name)
        filename = str(os.path.join(hparams.out_dir, id_name + '.' + net_name))
        plotter.set_title(id_name + ' - ' + net_name)
        plotter.set_num_colors(3)
        # plotter.set_lim(grid_idx=0, ymin=math.log(60), ymax=math.log(250))
        plotter.set_label(grid_idx=grid_idx,
                          xlabel='frames [' + str(hparams.frame_size_ms) +
                          ' ms]',
                          ylabel='log(f0)')

        graphs = list()
        graphs.append((original_lf0, 'Original LF0'))
        graphs.append((lf0, 'NN LF0'))
        plotter.set_data_list(grid_idx=grid_idx, data_list=graphs)
        plotter.set_area_list(grid_idx=grid_idx,
                              area_list=[(np.invert(vuv.astype(bool)), '0.8',
                                          1.0),
                                         (np.invert(original_vuv.astype(bool)),
                                          'red', 0.2)])

        # grid_idx += 1
        # plotter.set_label(grid_idx=grid_idx, xlabel='frames [' + str(hparams.frame_size_ms) + ' ms]', ylabel='Original spectrogram')
        # plotter.set_specshow(grid_idx=grid_idx, spec=WorldFeatLabelGen.mgc_to_sp(original_mgc, hparams.synth_fs))
        #
        # grid_idx += 1
        # plotter.set_label(grid_idx=grid_idx, xlabel='frames [' + str(hparams.frame_size_ms) + ' ms]', ylabel='NN spectrogram')
        # plotter.set_specshow(grid_idx=grid_idx, spec=sp)

        grid_idx += 1
        plotter.set_label(grid_idx=grid_idx,
                          xlabel='frames [' + str(hparams.frame_size_ms) +
                          ' ms]',
                          ylabel='alpha')
        graphs = list()
        graphs.append((alpha_vec, 'Original alpha'))
        graphs.append((alphas, 'NN alpha'))
        plotter.set_data_list(grid_idx=grid_idx, data_list=graphs)
        plotter.set_area_list(grid_idx=grid_idx,
                              area_list=[(np.invert(vuv.astype(bool)), '0.8',
                                          1.0),
                                         (np.invert(original_vuv.astype(bool)),
                                          'red', 0.2)])
        if hasattr(hparams, "phoneme_indices") and hparams.phoneme_indices is not None \
           and hasattr(hparams, "question_file") and hparams.question_file is not None:
            questions = QuestionLabelGen.load_sample(
                id_name,
                os.path.join("experiments", hparams.voice, "questions"),
                num_questions=hparams.num_questions)[:len(lf0)]
            np_phonemes = QuestionLabelGen.questions_to_phonemes(
                questions, hparams.phoneme_indices, hparams.question_file)
            plotter.set_annotations(grid_idx, np_phonemes)

        plotter.gen_plot()
        plotter.save_to_file(filename + '.VTLN' + hparams.gen_figure_ext)
Exemple #9
0
    def run_world_synth(synth_output, hparams):
        """Run the WORLD synthesize method."""

        fft_size = pyworld.get_cheaptrick_fft_size(hparams.synth_fs)

        save_dir = hparams.synth_dir if hparams.synth_dir is not None\
                                     else hparams.out_dir if hparams.out_dir is not None\
                                     else os.path.curdir
        for id_name, output in synth_output.items():
            logging.info(
                "Synthesise {} with the WORLD vocoder.".format(id_name))

            coded_sp, lf0, vuv, bap = WorldFeatLabelGen.convert_to_world_features(
                output,
                contains_deltas=False,
                num_coded_sps=hparams.num_coded_sps)
            amp_sp = WorldFeatLabelGen.decode_sp(
                coded_sp,
                hparams.sp_type,
                hparams.synth_fs,
                post_filtering=hparams.do_post_filtering).astype(np.double,
                                                                 copy=False)
            args = dict()
            for attr in "preemphasize", "f0_silence_threshold", "lf0_zero":
                if hasattr(hparams, attr):
                    args[attr] = getattr(hparams, attr)
            waveform = WorldFeatLabelGen.world_features_to_raw(
                amp_sp,
                lf0,
                vuv,
                bap,
                fs=hparams.synth_fs,
                n_fft=fft_size,
                **args)

            # f0 = np.exp(lf0, dtype=np.float64)
            # vuv[f0 < WorldFeatLabelGen.f0_silence_threshold] = 0  # WORLD throws an error for too small f0 values.
            # f0[vuv == 0] = 0.0
            # ap = pyworld.decode_aperiodicity(np.ascontiguousarray(bap.reshape(-1, 1), np.float64),
            #                                  hparams.synth_fs,
            #                                  fft_size)
            #
            # waveform = pyworld.synthesize(f0, amp_sp, ap, hparams.synth_fs)
            # waveform = waveform.astype(np.float32, copy=False)  # Does inplace conversion, if possible.

            # Always save as wav file first and convert afterwards if necessary.
            file_path = os.path.join(
                save_dir, "{}{}{}{}".format(
                    os.path.basename(id_name), "_" + hparams.model_name
                    if hparams.model_name is not None else "",
                    hparams.synth_file_suffix, "_WORLD"))
            makedirs_safe(hparams.synth_dir)
            soundfile.write(file_path + ".wav", waveform, hparams.synth_fs)

            # Use PyDub for special audio formats.
            if hparams.synth_ext.lower() != 'wav':
                as_wave = pydub.AudioSegment.from_wav(file_path + ".wav")
                file = as_wave.export(file_path + "." + hparams.synth_ext,
                                      format=hparams.synth_ext)
                file.close()
                os.remove(file_path + ".wav")
class AcousticModelTrainer(ModelTrainer):
    """
    Implementation of a ModelTrainer for the generation of acoustic data.

    Use question labels as input and WORLD features w/o deltas/double deltas (specified in hparams.add_deltas) as output.
    Synthesize audio from model output with MLPG smoothing.
    """
    logger = logging.getLogger(__name__)

    #########################
    # Default constructor
    #
    def __init__(self,
                 dir_world_features,
                 dir_question_labels,
                 id_list,
                 num_questions,
                 hparams=None):
        """Default constructor.

        :param dir_world_features:      Path to the directory containing the world features.
        :param dir_question_labels:     Path to the directory containing the question labels.
        :param id_list:                 List of ids, can contain a speaker directory.
        :param num_questions:           Number of questions in question file.
        :param hparams:                 Set of hyper parameters.
        """
        if hparams is None:
            hparams = self.create_hparams()
            hparams.out_dir = os.path.curdir

        # Write missing default parameters.
        if hparams.variable_sequence_length_train is None:
            hparams.variable_sequence_length_train = hparams.batch_size_train > 1
        if hparams.variable_sequence_length_test is None:
            hparams.variable_sequence_length_test = hparams.batch_size_test > 1
        if hparams.synth_dir is None:
            hparams.synth_dir = os.path.join(hparams.out_dir, "synth")

        super(AcousticModelTrainer, self).__init__(id_list, hparams)

        self.InputGen = QuestionLabelGen(dir_question_labels, num_questions)
        self.InputGen.get_normalisation_params(
            dir_question_labels, hparams.input_norm_params_file_prefix)

        self.OutputGen = WorldFeatLabelGen(dir_world_features,
                                           add_deltas=hparams.add_deltas,
                                           num_coded_sps=hparams.num_coded_sps,
                                           sp_type=hparams.sp_type)
        self.OutputGen.get_normalisation_params(
            dir_world_features, hparams.output_norm_params_file_prefix)

        self.dataset_train = LabelGensDataset(self.id_list_train,
                                              self.InputGen,
                                              self.OutputGen,
                                              hparams,
                                              match_lengths=True)
        self.dataset_val = LabelGensDataset(self.id_list_val,
                                            self.InputGen,
                                            self.OutputGen,
                                            hparams,
                                            match_lengths=True)

        if self.loss_function is None:
            self.loss_function = torch.nn.MSELoss(reduction='none')

        if hparams.scheduler_type == "default":
            hparams.scheduler_type = "Plateau"
            hparams.add_hparams(plateau_verbose=True)

    @staticmethod
    def create_hparams(hparams_string=None, verbose=False):
        """Create model hyper parameter container. Parse non default from given string."""
        hparams = ModelTrainer.create_hparams(hparams_string, verbose=False)

        hparams.add_hparams(
            num_questions=None,
            question_file=None,  # Used to add labels in plot.
            num_coded_sps=60,
            sp_type="mcep",
            add_deltas=True,
            synth_load_org_sp=False,
            synth_load_org_lf0=False,
            synth_load_org_vuv=False,
            synth_load_org_bap=False)

        if verbose:
            logging.info(hparams.get_debug_string())

        return hparams

    def gen_figure_from_output(self, id_name, label, hidden, hparams):

        labels_post = self.OutputGen.postprocess_sample(label)
        coded_sp, lf0, vuv, bap = WorldFeatLabelGen.convert_to_world_features(
            labels_post,
            contains_deltas=False,
            num_coded_sps=hparams.num_coded_sps)
        lf0, _ = interpolate_lin(lf0)

        # Load original lf0.
        org_labels_post = WorldFeatLabelGen.load_sample(
            id_name,
            self.OutputGen.dir_labels,
            add_deltas=self.OutputGen.add_deltas,
            num_coded_sps=hparams.num_coded_sps)
        original_mgc, original_lf0, original_vuv, *_ = WorldFeatLabelGen.convert_to_world_features(
            org_labels_post,
            contains_deltas=self.OutputGen.add_deltas,
            num_coded_sps=hparams.num_coded_sps)
        original_lf0, _ = interpolate_lin(original_lf0)

        # Get a data plotter.
        grid_idx = 0
        plotter = DataPlotter()
        net_name = os.path.basename(hparams.model_name)
        filename = str(os.path.join(hparams.out_dir, id_name + '.' + net_name))
        plotter.set_title(id_name + ' - ' + net_name)
        plotter.set_num_colors(3)
        # plotter.set_lim(grid_idx=0, ymin=math.log(60), ymax=math.log(250))
        plotter.set_label(grid_idx=grid_idx,
                          xlabel='frames [' + str(hparams.frame_size_ms) +
                          ' ms]',
                          ylabel='log(f0)')

        graphs = list()
        graphs.append((original_lf0, 'Original lf0'))
        graphs.append((lf0, 'PyTorch lf0'))
        plotter.set_data_list(grid_idx=grid_idx, data_list=graphs)
        plotter.set_area_list(grid_idx=grid_idx,
                              area_list=[(np.invert(vuv.astype(bool)), '0.8',
                                          1.0),
                                         (np.invert(original_vuv.astype(bool)),
                                          'red', 0.2)])

        grid_idx += 1
        import librosa
        plotter.set_label(grid_idx=grid_idx,
                          xlabel='frames [' + str(hparams.frame_size_ms) +
                          ' ms]',
                          ylabel='Original spectrogram')
        plotter.set_specshow(grid_idx=grid_idx,
                             spec=librosa.amplitude_to_db(np.absolute(
                                 WorldFeatLabelGen.mcep_to_amp_sp(
                                     original_mgc, hparams.synth_fs)),
                                                          top_db=None))

        grid_idx += 1
        plotter.set_label(grid_idx=grid_idx,
                          xlabel='frames [' + str(hparams.frame_size_ms) +
                          ' ms]',
                          ylabel='NN spectrogram')
        plotter.set_specshow(grid_idx=grid_idx,
                             spec=librosa.amplitude_to_db(np.absolute(
                                 WorldFeatLabelGen.mcep_to_amp_sp(
                                     coded_sp, hparams.synth_fs)),
                                                          top_db=None))

        if hasattr(hparams, "phoneme_indices") and hparams.phoneme_indices is not None \
           and hasattr(hparams, "question_file") and hparams.question_file is not None:
            questions = QuestionLabelGen.load_sample(
                id_name,
                "experiments/" + hparams.voice + "/questions/",
                num_questions=hparams.num_questions)[:len(lf0)]
            np_phonemes = QuestionLabelGen.questions_to_phonemes(
                questions, hparams.phoneme_indices, hparams.question_file)
            plotter.set_annotations(grid_idx, np_phonemes)

        plotter.gen_plot()
        plotter.save_to_file(filename + '.Org-PyTorch' +
                             hparams.gen_figure_ext)

    def compute_score(self, dict_outputs_post, dict_hiddens, hparams):

        # Get data for comparision.
        dict_original_post = dict()
        for id_name in dict_outputs_post.keys():
            dict_original_post[id_name] = WorldFeatLabelGen.load_sample(
                id_name,
                dir_out=self.OutputGen.dir_labels,
                add_deltas=True,
                num_coded_sps=hparams.num_coded_sps)

        f0_rmse = 0.0
        f0_rmse_max_id = "None"
        f0_rmse_max = 0.0
        all_rmse = []
        vuv_error_rate = 0.0
        vuv_error_max_id = "None"
        vuv_error_max = 0.0
        all_vuv = []
        mcd = 0.0
        mcd_max_id = "None"
        mcd_max = 0.0
        all_mcd = []
        bap_error = 0.0
        bap_error_max_id = "None"
        bap_error_max = 0.0
        all_bap_error = []

        for id_name, labels in dict_outputs_post.items():
            output_coded_sp, output_lf0, output_vuv, output_bap = self.OutputGen.convert_to_world_features(
                sample=labels,
                contains_deltas=False,
                num_coded_sps=hparams.num_coded_sps)
            output_vuv = output_vuv.astype(bool)

            # Get data for comparision.
            org_coded_sp, org_lf0, org_vuv, org_bap = self.OutputGen.convert_to_world_features(
                sample=dict_original_post[id_name],
                contains_deltas=self.OutputGen.add_deltas,
                num_coded_sps=hparams.num_coded_sps)

            # Compute f0 from lf0.
            org_f0 = np.exp(org_lf0.squeeze())[:len(
                output_lf0)]  # Fix minor negligible length mismatch.
            output_f0 = np.exp(output_lf0)

            # Compute MCD.
            org_coded_sp = org_coded_sp[:len(output_coded_sp)]
            current_mcd = metrics.melcd(
                output_coded_sp[:, 1:],
                org_coded_sp[:, 1:])  # TODO: Use aligned mcd.
            if current_mcd > mcd_max:
                mcd_max_id = id_name
                mcd_max = current_mcd
            mcd += current_mcd
            all_mcd.append(current_mcd)

            # Compute RMSE.
            f0_mse = (org_f0 - output_f0)**2
            current_f0_rmse = math.sqrt(
                (f0_mse * org_vuv[:len(output_lf0)]).sum() /
                org_vuv[:len(output_lf0)].sum())
            if current_f0_rmse != current_f0_rmse:
                logging.error(
                    "Computed NaN for F0 RMSE for {}.".format(id_name))
            else:
                if current_f0_rmse > f0_rmse_max:
                    f0_rmse_max_id = id_name
                    f0_rmse_max = current_f0_rmse
                f0_rmse += current_f0_rmse
                all_rmse.append(current_f0_rmse)

            # Compute error of VUV in percentage.
            num_errors = (org_vuv[:len(output_lf0)] != output_vuv)
            vuv_error_rate_tmp = float(num_errors.sum()) / len(output_lf0)
            if vuv_error_rate_tmp > vuv_error_max:
                vuv_error_max_id = id_name
                vuv_error_max = vuv_error_rate_tmp
            vuv_error_rate += vuv_error_rate_tmp
            all_vuv.append(vuv_error_rate_tmp)

            # Compute aperiodicity distortion.
            org_bap = org_bap[:len(output_bap)]
            if len(output_bap.shape) > 1 and output_bap.shape[1] > 1:
                current_bap_error = metrics.melcd(
                    output_bap, org_bap)  # TODO: Use aligned mcd?
            else:
                current_bap_error = math.sqrt(
                    ((org_bap - output_bap)**
                     2).mean()) * (10.0 / np.log(10) * np.sqrt(2.0))
            if current_bap_error > bap_error_max:
                bap_error_max_id = id_name
                bap_error_max = current_bap_error
            bap_error += current_bap_error
            all_bap_error.append(current_bap_error)

        f0_rmse /= len(dict_outputs_post)
        vuv_error_rate /= len(dict_outputs_post)
        mcd /= len(dict_original_post)
        bap_error /= len(dict_original_post)

        self.logger.info("Worst MCD: {} {:4.2f}dB".format(mcd_max_id, mcd_max))
        self.logger.info("Worst F0 RMSE: {} {:4.2f}Hz".format(
            f0_rmse_max_id, f0_rmse_max))
        self.logger.info("Worst VUV error: {} {:2.2f}%".format(
            vuv_error_max_id, vuv_error_max * 100))
        self.logger.info("Worst BAP error: {} {:4.2f}db".format(
            bap_error_max_id, bap_error_max))
        self.logger.info(
            "Benchmark score: MCD {:4.2f}dB, F0 RMSE {:4.2f}Hz, VUV {:2.2f}%, BAP error {:4.2f}db"
            .format(mcd, f0_rmse, vuv_error_rate * 100, bap_error))

        return mcd, f0_rmse, vuv_error_rate, bap_error

    def synthesize(self, id_list, synth_output, hparams):
        """
        Depending on hparams override the network output with the extracted features,
        then continue with normal synthesis pipeline.
        """

        if hparams.synth_load_org_sp\
                or hparams.synth_load_org_lf0\
                or hparams.synth_load_org_vuv\
                or hparams.synth_load_org_bap:
            for id_name in id_list:

                world_dir = hparams.world_dir if hasattr(hparams, "world_dir") and hparams.world_dir is not None\
                                              else os.path.join(self.OutputGen.dir_labels,
                                                                self.dir_extracted_acoustic_features)
                labels = WorldFeatLabelGen.load_sample(
                    id_name, world_dir, num_coded_sps=hparams.num_coded_sps)
                len_diff = len(labels) - len(synth_output[id_name])
                if len_diff > 0:
                    labels = WorldFeatLabelGen.trim_end_sample(labels,
                                                               int(len_diff /
                                                                   2),
                                                               reverse=True)
                    labels = WorldFeatLabelGen.trim_end_sample(
                        labels, len_diff - int(len_diff / 2))

                if hparams.synth_load_org_sp:
                    synth_output[
                        id_name][:len(labels), :self.OutputGen.
                                 num_coded_sps] = labels[:, :self.OutputGen.
                                                         num_coded_sps]

                if hparams.synth_load_org_lf0:
                    synth_output[id_name][:len(labels), -3] = labels[:, -3]

                if hparams.synth_load_org_vuv:
                    synth_output[id_name][:len(labels), -2] = labels[:, -2]

                if hparams.synth_load_org_bap:
                    synth_output[id_name][:len(labels), -1] = labels[:, -1]

        # Run the vocoder.
        ModelTrainer.synthesize(self, id_list, synth_output, hparams)
    def plot_world_features(plotter: DataPlotter,
                            plotter_config: DataPlotter.Config,
                            grid_indices: List[int],
                            id_name: str,
                            features: np.ndarray,
                            contains_deltas: bool,
                            num_coded_sps: int,
                            num_bap: int,
                            hparams: ExtendedHParams,
                            plot_mgc: bool = True,
                            mgc_label: str = "Original spectrogram",
                            plot_lf0: bool = True,
                            sps_slices: slice = None,
                            lf0_label: str = "Original LF0",
                            plot_vuv: bool = True,
                            vuv_colour_alpha: List[Union[str, float]] = ('red',
                                                                         0.2),
                            *args,
                            **kwargs):

        mgc, lf0, vuv, _ = WorldFeatLabelGen.convert_to_world_features(
            features,
            contains_deltas=contains_deltas,
            num_coded_sps=num_coded_sps,
            num_bap=num_bap)
        lf0, _ = interpolate_lin(lf0)

        if grid_indices is None:
            grid_idx = plotter.get_next_free_grid_idx()
        else:
            grid_idx = grid_indices[0]

        if plot_lf0:

            plotter.set_label(grid_idx=grid_idx,
                              xlabel='frames [{} ms]'.format(
                                  hparams.frame_size_ms),
                              ylabel='log(f0)')
            # plotter.set_lim(grid_idx=0, ymin=math.log(60), ymax=math.log(250))
            plotter.set_data_list(grid_idx=grid_idx,
                                  data_list=[(lf0, lf0_label)])

        if plot_vuv:
            plotter.set_area_list(grid_idx=grid_idx,
                                  area_list=[(np.invert(vuv.astype(bool)),
                                              *vuv_colour_alpha)])

        if plot_mgc:
            if grid_indices is None:
                grid_idx = plotter.get_next_free_grid_idx()
            else:
                grid_idx = grid_indices[1]

            AcousticModelTrainer.plot_mgc(plotter=plotter,
                                          plotter_config=plotter_config,
                                          grid_indices=[grid_idx],
                                          id_name=id_name,
                                          features=mgc,
                                          synth_fs=hparams.synth_fs,
                                          spec_slice=sps_slices,
                                          labels=('frames [{} ms]'.format(
                                              hparams.frame_size_ms),
                                                  mgc_label),
                                          *args,
                                          **kwargs)
    def compute_score(self, data, output, hparams):
        metrics_dict = super().compute_score(data, output, hparams)

        dict_original_post = self.get_output_dict(
            data.keys(),
            hparams,
            chunk_size=hparams.n_frames_per_step
            if hparams.has_value("n_frames_per_step") else 1)

        # Create a warping layer for manual warping.
        wl = self._get_dummy_warping_layer(hparams)
        # norm_params_no_deltas = (self.OutputGen.norm_params[0][:hparams.num_coded_sps],
        #                          self.OutputGen.norm_params[1][:hparams.num_coded_sps])

        # Compute MCD for different set of coefficients.
        batch_size = len(data.keys())
        for cep_coef_start in [1]:
            for cep_coef_end in itertools.chain(range(10, 19), [-1]):
                org_to_output_mcd = 0.0
                org_to_pre_net_output_mcd = 0.0
                for label_name in next(iter(data.values())).keys():
                    for id_name, labels in data.items():
                        labels = labels[label_name]
                        # alphas = output[id_name]['alphas']
                        alphas = [
                            output[id_name][key]
                            for key in output[id_name].keys()
                            if "alphas" in key
                        ]
                        # batch_size = len(dict_outputs_post)
                        # for cep_coef_start in [1]:
                        #     for cep_coef_end in itertools.chain(range(10, 19), [-1]):

                        #         for id_name, labels in dict_outputs_post.items():
                        #             # Split NN output.
                        #             _, (output_alphas,) = dict_hiddens[id_name]
                        output_mgc_post, *_ = WorldFeatLabelGen.convert_to_world_features(
                            sample=labels,
                            contains_deltas=False,
                            num_coded_sps=hparams.num_coded_sps,
                            num_bap=hparams.num_bap)
                        # Reverse the warping.
                        pre_net_output, _ = wl.forward_sample(
                            labels[:, :hparams.num_coded_sps],
                            [-alpha for alpha in alphas])
                        # pre_net_output, _ = wl.forward_sample(labels, -output_alphas)
                        pre_net_output = pre_net_output.detach().cpu().numpy()
                        pre_net_mgc = pre_net_output[
                            0, :, :hparams.
                            num_coded_sps]  # * norm_params_no_deltas[1] + norm_params_no_deltas[0]
                        # Load the original warped sample.
                        org_mgc_post = dict_original_post[id_name][:len(
                            output_mgc_post), :hparams.num_coded_sps]

                        # Compute mcd difference.
                        org_to_output_mcd += Metrics.mcd_k(
                            org_mgc_post,
                            output_mgc_post,
                            k=cep_coef_end,
                            start_bin=cep_coef_start)
                        org_to_pre_net_output_mcd += Metrics.mcd_k(
                            org_mgc_post,
                            pre_net_mgc,
                            k=cep_coef_end,
                            start_bin=cep_coef_start)
                org_to_pre_net_output_mcd /= batch_size
                org_to_output_mcd /= batch_size

                self.logger.info("MCep from {} to {}:".format(
                    cep_coef_start, cep_coef_end))
                self.logger.info(
                    "Original mgc to pre-net mgc error: {:4.2f}dB".format(
                        org_to_pre_net_output_mcd))
                self.logger.info(
                    "Original mgc to nn mgc error: {:4.2f}dB".format(
                        org_to_output_mcd))

        return metrics_dict