Example #1
0
 def labels_to_atoms(np_labels, k=6, frame_size=5, amp_threshold=0.3):
     """
     Transform labels to GammaAtoms. Reuse super class method but
     skip over vuv information in labels.
     """
     return AtomLabelGen.labels_to_atoms(np_labels[:, 1:, :], k, frame_size,
                                         amp_threshold)
Example #2
0
    def test_save_load(self):
        dir_out = self._get_test_dir()

        theta_start = 0.01
        theta_stop = 0.055
        theta_step = 0.005
        thetas = np.arange(theta_start, theta_stop, theta_step)
        k = 6
        frame_size_ms = 5

        atom_gen = AtomLabelGen(self.dir_wcad_root, dir_out, thetas, k,
                                frame_size_ms)
        label_dict, *extracted_norm_params = atom_gen.gen_data(
            self.dir_wav, dir_out, id_list=self.id_list, return_dict=True)

        # Call this once before starting the pre-processing.
        norm_params = atom_gen.get_normalisation_params(dir_out)
        self.assertTrue((extracted_norm_params[0] == norm_params[0]).all())
        self.assertTrue((extracted_norm_params[1] == norm_params[1]).all())

        test_label = label_dict[self.id_list[1]]
        test_label_pre = atom_gen.preprocess_sample(test_label)
        self.assertTrue(
            np.isclose(test_label_pre, atom_gen[self.id_list[1]]).all())

        test_label_post = atom_gen.postprocess_sample(test_label_pre)
        # Post-precessing does peak selection, so pre and post labels
        # are not the same and we cannot check for equality here.

        self.assertTrue(np.isclose(-0.2547, test_label_post.sum(),
                                   atol=0.0001))

        os.remove(os.path.join(self.dir_database, "wcad_.txt"))
        shutil.rmtree(dir_out)
Example #3
0
    def postprocess_sample(self, sample, norm_params=None):
        """
        Identify the peaks in the position flag (remove peaks with
        absolute value lower than 0.1). Set all amplitude outputs to
        zero except the highest amplitude for positive peaks and the
        lowest amplitude for negative peaks. Then denormalise the
        amplitudes with the base class method. Set all vuv values < 0.5
        to 0 and the rest to 1. This function is used after inference of
        a network.

        :param sample:       The sample to post-process.
        :param norm_params:  Use this normalisation parameters instead
                             of self.norm_params.
        :return:             Post-processed sample.
        """

        # Remove and keep vuv and pos information so that superclass
        # postprocessing can be used.
        vuv = sample[:, 0]
        vuv[vuv < 0.5] = 0.0
        vuv[vuv >= 0.5] = 1.0
        pos = sample[:, -1]
        amps = np.copy(sample[:, 1:-1])

        # Extract atom positions.
        pos = AtomLabelGen.identify_peaks(np.expand_dims(pos, -1), 50)
        pos[abs(pos) < 0.1] = 0

        # Use sign of pos flag for selecting one amplitude.
        amps_max = np.max(amps, axis=1)
        amps_min = np.min(amps, axis=1)
        pos_flag_negative = (pos < 0).squeeze()
        amps_max[pos_flag_negative] = amps_min[
            pos_flag_negative]  # Use minimum for all negative pos_flags.
        amps_max[(pos == 0).squeeze()] = 0.0
        mask = np.not_equal(
            np.expand_dims(amps_max, axis=-1).repeat(amps.shape[1], axis=1),
            amps)
        amps[mask] = 0.0

        # Normalise amplitudes.
        amps = super().postprocess_sample(amps,
                                          norm_params,
                                          identify_peaks=False)

        # Combine vuv with amps again.
        vuv = np.repeat(vuv[:, np.newaxis, np.newaxis], 2, 2)
        vuv[:, :, 0] = -1  # Set invalid value for lf0.
        return np.concatenate((vuv, amps), axis=1)
Example #4
0
    def load_sample(id_name, dir_atoms, num_thetas, dir_world):
        """
        Load atoms from dir_atoms/id_name.atom and VUV from dir_world/vuv/id_name.vuv and returns them as tuple.
        """
        id_name = os.path.splitext(os.path.basename(id_name))[0]

        atoms = AtomLabelGen.load_sample(id_name, dir_atoms, num_thetas)
        vuv = LF0LabelGen.load_vuv(id_name, dir_world)

        # Concatenate; do trimming, if necessary.
        if len(atoms) < len(vuv):
            vuv = vuv[:len(atoms)]
        elif len(vuv) < len(atoms):
            atoms = AtomVUVDistPosLabelGen.trim_end_sample(atoms, len(atoms) - len(vuv))

        return np.concatenate((np.repeat(vuv[:, None, :], 2, axis=2), atoms), axis=1)
Example #5
0
    def synth_ref_wcad(self, file_id_list, hparams):
        synth_output = dict()
        # Load extracted atoms.
        for id_name in file_id_list:
            synth_output[id_name] = AtomLabelGen.load_sample(
                id_name, self.OutputGen.dir_labels, len(hparams.thetas))

        full_output = self.run_atom_synth(file_id_list, synth_output, hparams)

        # Add identifier to suffix.
        old_synth_file_suffix = hparams.synth_file_suffix
        hparams.synth_file_suffix += "_wcad_ref"

        # Run the WORLD synthesizer.
        Synthesiser.run_world_synth(full_output, hparams)

        # Restore identifier.
        hparams.synth_file_suffix = old_synth_file_suffix
Example #6
0
    def compute_score(self, dict_outputs_post, dict_hiddens, hparams):

        # Get data for comparision.
        dict_original_post = self.load_extracted_audio_features(
            dict_outputs_post, hparams)

        f0_rmse = 0.0
        f0_rmse_max_id = "None"
        f0_rmse_max = 0.0
        for id_name, labels in dict_outputs_post.items():
            output_lf0 = AtomLabelGen.labels_to_lf0(
                labels,
                k=hparams.k,
                frame_size=hparams.frame_size_ms,
                amp_threshold=hparams.min_atom_amp)

            # Get data for comparision.
            org_lf0 = dict_original_post[id_name][:, hparams.num_coded_sps]
            org_vuv = dict_original_post[id_name][:, hparams.num_coded_sps + 1]
            phrase_curve = self.get_phrase_curve(id_name)

            # Compute f0 from lf0.
            org_f0 = (np.exp(org_lf0.squeeze()) * org_vuv)[:len(
                output_lf0)]  # Fix minor negligible length mismatch.
            output_f0 = np.exp(output_lf0 + phrase_curve[:len(output_lf0)].
                               squeeze()) * org_vuv[:len(output_lf0)]

            # Compute RMSE, keep track of worst RMSE.
            f0_mse = (org_f0 - output_f0)**2
            current_f0_rmse = math.sqrt(f0_mse.sum() / org_vuv.sum())
            if current_f0_rmse > f0_rmse_max:
                f0_rmse_max_id = id_name
                f0_rmse_max = current_f0_rmse
            f0_rmse += current_f0_rmse

        f0_rmse /= len(dict_outputs_post)
        self.logger.info("Worst F0 RMSE: " + f0_rmse_max_id +
                         " {:4.2f}Hz".format(f0_rmse_max))
        self.logger.info("Benchmark score: F0 RMSE " +
                         "{:4.2f}Hz".format(f0_rmse))

        return f0_rmse
Example #7
0
    def load_sample(id_name, dir_atoms, num_thetas, dir_world):
        """
        Load atoms from dir_atoms/id_name.atom and VUV from
        dir_world/vuv/id_name.vuv and returns them as tuple.
        """
        id_name = os.path.splitext(os.path.basename(id_name))[0]

        atoms = AtomLabelGen.load_sample(id_name, dir_atoms, num_thetas)
        vuv = LF0LabelGen.load_vuv(id_name, dir_world)
        lf0 = LF0LabelGen.load_lf0(id_name, dir_world)

        min_length = min(len(atoms), len(vuv), len(lf0))
        vuv = vuv[:min_length]
        lf0 = lf0[:min_length]
        if len(atoms) > min_length:
            atoms = AtomVUVDistPosLabelGen.trim_end_sample(
                atoms,
                len(atoms) - min_length)

        lf0_vuv = np.concatenate((lf0, vuv), axis=1)
        return np.concatenate((lf0_vuv[:, None], atoms), axis=1)
Example #8
0
    def __init__(self,
                 wcad_root,
                 dir_atom_labels,
                 dir_question_labels,
                 id_list,
                 thetas,
                 k,
                 num_questions,
                 hparams=None):
        """Default constructor.

        :param wcad_root:               Path to main directory of wcad.
        :param dir_atom_labels:         Path to directory that contains the .atom files.
        :param dir_question_labels:     Path to directory that contains the .questions files.
        :param id_list:                 List containing all ids. Subset is taken as test set.
        :param thetas:                  List of theta values.
        :param k:                       K value of atoms.
        :param num_questions:           Expected number of questions in question labels.
        :param hparams:                 Hyper-parameter container.
        """
        if hparams is None:
            hparams = self.create_hparams()
            hparams.out_dir = os.path.curdir

        # Write missing default parameters.
        if hparams.variable_sequence_length_train is None:
            hparams.variable_sequence_length_train = hparams.batch_size_train > 1
        if hparams.variable_sequence_length_test is None:
            hparams.variable_sequence_length_test = hparams.batch_size_test > 1
        if hparams.synth_dir is None:
            hparams.synth_dir = os.path.join(hparams.out_dir, "synth")

        # If the weight for unvoiced frames is not given, compute it to get equal weights.
        non_zero_occurrence = min(0.99, 0.02 / len(thetas))
        zero_occurrence = 1 - non_zero_occurrence
        if not hasattr(hparams, "weight_zero"):
            hparams.add_hparam("weight_non_zero", 1 / non_zero_occurrence)
            hparams.add_hparam("weight_zero", 1 / zero_occurrence)
        elif hparams.weight_zero is None:
            hparams.weight_non_zero = 1 / non_zero_occurrence
            hparams.weight_zero = 1 / zero_occurrence

        super().__init__(id_list, hparams)

        self.InputGen = QuestionLabelGen(dir_question_labels, num_questions)
        self.InputGen.get_normalisation_params(
            dir_question_labels, hparams.input_norm_params_file_prefix)

        self.OutputGen = AtomLabelGen(wcad_root, dir_atom_labels, thetas, k,
                                      hparams.frame_size_ms)
        self.OutputGen.get_normalisation_params(
            dir_atom_labels, hparams.output_norm_params_file_prefix)

        self.dataset_train = PyTorchLabelGensDataset(self.id_list_train,
                                                     self.InputGen,
                                                     self.OutputGen,
                                                     hparams,
                                                     match_lengths=True)
        self.dataset_val = PyTorchLabelGensDataset(self.id_list_val,
                                                   self.InputGen,
                                                   self.OutputGen,
                                                   hparams,
                                                   match_lengths=True)

        if self.loss_function is None:
            self.loss_function = WeightedNonzeroMSELoss(
                hparams.use_gpu,
                hparams.weight_zero,
                hparams.weight_non_zero,
                size_average=False,
                reduce=False)
        if hparams.scheduler_type == "default":
            hparams.scheduler_type = "Plateau"
            hparams.add_hparams(plateau_patience=10,
                                plateau_factor=0.5,
                                plateau_verbose=True)
Example #9
0
class AtomModelTrainer(ModelTrainer):
    """
    Implementation of a ModelTrainer for the generation of acoustic data through atom prediction.
    Output labels for atoms have dimension: T x |thetas| x 2 (amp, theta).

    Use question labels as input and extracted wcad atoms as output. Synthesize audio from model
    output by generating F0 from atoms. MGC and BAP is either generated by a pre-trained acoustic
    model or loaded from the original extracted files.
    """
    logger = logging.getLogger(__name__)

    def __init__(self,
                 wcad_root,
                 dir_atom_labels,
                 dir_question_labels,
                 id_list,
                 thetas,
                 k,
                 num_questions,
                 hparams=None):
        """Default constructor.

        :param wcad_root:               Path to main directory of wcad.
        :param dir_atom_labels:         Path to directory that contains the .atom files.
        :param dir_question_labels:     Path to directory that contains the .questions files.
        :param id_list:                 List containing all ids. Subset is taken as test set.
        :param thetas:                  List of theta values.
        :param k:                       K value of atoms.
        :param num_questions:           Expected number of questions in question labels.
        :param hparams:                 Hyper-parameter container.
        """
        if hparams is None:
            hparams = self.create_hparams()
            hparams.out_dir = os.path.curdir

        # Write missing default parameters.
        if hparams.variable_sequence_length_train is None:
            hparams.variable_sequence_length_train = hparams.batch_size_train > 1
        if hparams.variable_sequence_length_test is None:
            hparams.variable_sequence_length_test = hparams.batch_size_test > 1
        if hparams.synth_dir is None:
            hparams.synth_dir = os.path.join(hparams.out_dir, "synth")

        # If the weight for unvoiced frames is not given, compute it to get equal weights.
        non_zero_occurrence = min(0.99, 0.02 / len(thetas))
        zero_occurrence = 1 - non_zero_occurrence
        if not hasattr(hparams, "weight_zero"):
            hparams.add_hparam("weight_non_zero", 1 / non_zero_occurrence)
            hparams.add_hparam("weight_zero", 1 / zero_occurrence)
        elif hparams.weight_zero is None:
            hparams.weight_non_zero = 1 / non_zero_occurrence
            hparams.weight_zero = 1 / zero_occurrence

        super().__init__(id_list, hparams)

        self.InputGen = QuestionLabelGen(dir_question_labels, num_questions)
        self.InputGen.get_normalisation_params(
            dir_question_labels, hparams.input_norm_params_file_prefix)

        self.OutputGen = AtomLabelGen(wcad_root, dir_atom_labels, thetas, k,
                                      hparams.frame_size_ms)
        self.OutputGen.get_normalisation_params(
            dir_atom_labels, hparams.output_norm_params_file_prefix)

        self.dataset_train = PyTorchLabelGensDataset(self.id_list_train,
                                                     self.InputGen,
                                                     self.OutputGen,
                                                     hparams,
                                                     match_lengths=True)
        self.dataset_val = PyTorchLabelGensDataset(self.id_list_val,
                                                   self.InputGen,
                                                   self.OutputGen,
                                                   hparams,
                                                   match_lengths=True)

        if self.loss_function is None:
            self.loss_function = WeightedNonzeroMSELoss(
                hparams.use_gpu,
                hparams.weight_zero,
                hparams.weight_non_zero,
                size_average=False,
                reduce=False)
        if hparams.scheduler_type == "default":
            hparams.scheduler_type = "Plateau"
            hparams.add_hparams(plateau_patience=10,
                                plateau_factor=0.5,
                                plateau_verbose=True)

    @staticmethod
    def create_hparams(hparams_string=None, verbose=False):
        hparams = ModelTrainer.create_hparams(hparams_string, verbose=False)

        hparams.add_hparams(thetas=None,
                            k=None,
                            min_atom_amp=0.3,
                            num_questions=None)

        if verbose:
            logging.info(hparams.get_debug_string())

        return hparams

    def gen_figure_from_output(self, id_name, labels, hidden, hparams):

        if labels.ndim < 2:
            labels = np.expand_dims(labels, axis=1)
        labels_post = self.OutputGen.postprocess_sample(labels,
                                                        identify_peaks=True,
                                                        peak_range=100)
        lf0 = self.OutputGen.labels_to_lf0(labels_post, hparams.k)
        lf0, vuv = interpolate_lin(lf0)
        vuv = vuv.astype(np.bool)

        # Load original lf0 and vuv.
        world_dir = hparams.world_dir if hasattr(hparams, "world_dir") and hparams.world_dir is not None\
                                      else os.path.join(self.OutputGen.dir_labels, self.dir_extracted_acoustic_features)
        org_labels = WorldFeatLabelGen.load_sample(
            id_name, world_dir, num_coded_sps=hparams.num_coded_sps)
        _, original_lf0, original_vuv, _ = WorldFeatLabelGen.convert_to_world_features(
            org_labels, num_coded_sps=hparams.num_coded_sps)
        original_lf0, _ = interpolate_lin(original_lf0)
        original_vuv = original_vuv.astype(np.bool)

        phrase_curve = np.fromfile(os.path.join(
            self.OutputGen.dir_labels, id_name + self.OutputGen.ext_phrase),
                                   dtype=np.float32).reshape(-1, 1)
        original_lf0 -= phrase_curve
        len_diff = len(original_lf0) - len(lf0)
        original_lf0 = WorldFeatLabelGen.trim_end_sample(
            original_lf0, int(len_diff / 2.0))
        original_lf0 = WorldFeatLabelGen.trim_end_sample(original_lf0,
                                                         int(len_diff / 2.0) +
                                                         1,
                                                         reverse=True)

        org_labels = self.OutputGen.load_sample(id_name,
                                                self.OutputGen.dir_labels,
                                                len(hparams.thetas))
        org_labels = self.OutputGen.trim_end_sample(org_labels,
                                                    int(len_diff / 2.0))
        org_labels = self.OutputGen.trim_end_sample(org_labels,
                                                    int(len_diff / 2.0) + 1,
                                                    reverse=True)
        org_atoms = self.OutputGen.labels_to_atoms(
            org_labels, k=hparams.k, frame_size=hparams.frame_size_ms)

        # Get a data plotter.
        net_name = os.path.basename(hparams.model_name)
        filename = str(os.path.join(hparams.out_dir, id_name + '.' + net_name))
        plotter = DataPlotter()
        plotter.set_title(id_name + " - " + net_name)

        graphs_output = list()
        grid_idx = 0
        for idx in reversed(range(labels.shape[1])):
            graphs_output.append(
                (labels[:, idx],
                 r'$\theta$=' + "{0:.3f}".format(hparams.thetas[idx])))
        plotter.set_label(grid_idx=grid_idx,
                          xlabel='frames [' + str(hparams.frame_size_ms) +
                          ' ms]',
                          ylabel='NN output')
        plotter.set_data_list(grid_idx=grid_idx, data_list=graphs_output)
        # plotter.set_lim(grid_idx=0, ymin=-1.8, ymax=1.8)

        grid_idx += 1
        graphs_peaks = list()
        for idx in reversed(range(labels_post.shape[1])):
            graphs_peaks.append((labels_post[:, idx, 0], ))
        plotter.set_label(grid_idx=grid_idx,
                          xlabel='frames [' + str(hparams.frame_size_ms) +
                          ' ms]',
                          ylabel='NN post-processed')
        plotter.set_data_list(grid_idx=grid_idx, data_list=graphs_peaks)
        plotter.set_area_list(grid_idx=grid_idx,
                              area_list=[(np.invert(vuv), '0.8', 1.0)])
        plotter.set_lim(grid_idx=grid_idx, ymin=-1.8, ymax=1.8)

        grid_idx += 1
        graphs_target = list()
        for idx in reversed(range(org_labels.shape[1])):
            graphs_target.append((org_labels[:, idx, 0], ))
        plotter.set_label(grid_idx=grid_idx,
                          xlabel='frames [' + str(hparams.frame_size_ms) +
                          ' ms]',
                          ylabel='target')
        plotter.set_data_list(grid_idx=grid_idx, data_list=graphs_target)
        plotter.set_area_list(grid_idx=grid_idx,
                              area_list=[(np.invert(original_vuv), '0.8', 1.0)
                                         ])
        plotter.set_lim(grid_idx=grid_idx, ymin=-1.8, ymax=1.8)

        grid_idx += 1
        output_atoms = AtomLabelGen.labels_to_atoms(
            labels_post,
            hparams.k,
            hparams.frame_size_ms,
            amp_threshold=hparams.min_atom_amp)
        wcad_lf0 = AtomLabelGen.atoms_to_lf0(org_atoms, len(labels))
        output_lf0 = AtomLabelGen.atoms_to_lf0(output_atoms, len(labels))
        graphs_lf0 = list()
        graphs_lf0.append((wcad_lf0, "wcad lf0"))
        graphs_lf0.append((original_lf0, "org lf0"))
        graphs_lf0.append((output_lf0, "predicted lf0"))
        plotter.set_data_list(grid_idx=grid_idx, data_list=graphs_lf0)
        plotter.set_area_list(grid_idx=grid_idx,
                              area_list=[(np.invert(original_vuv), '0.8', 1.0)
                                         ])
        plotter.set_label(grid_idx=grid_idx,
                          xlabel='frames [' + str(hparams.frame_size_ms) +
                          ' ms]',
                          ylabel='lf0')
        amp_lim = max(np.max(np.abs(wcad_lf0)), np.max(
            np.abs(output_lf0))) * 1.1
        plotter.set_lim(grid_idx=grid_idx, ymin=-amp_lim, ymax=amp_lim)
        plotter.set_linestyles(grid_idx=grid_idx, linestyles=[':', '--', '-'])

        # plotter.set_lim(xmin=300, xmax=1100)
        plotter.gen_plot()
        plotter.save_to_file(filename + ".BASE" + hparams.gen_figure_ext)

    def get_recon_from_synth_output(self, synth_output, hparams):
        """Reconstruct LF0 from atoms."""

        # Transform output to GammaAtoms.
        recon_dict = dict()
        for id_name, label in synth_output.items():
            if len(label.shape) == 2:
                label = np.expand_dims(label, axis=1)

            atoms = self.OutputGen.labels_to_atoms(
                label,
                k=hparams.k,
                frame_size=hparams.frame_size_ms,
                amp_threshold=hparams.min_atom_amp)
            reconstruction = self.OutputGen.atoms_to_lf0(atoms,
                                                         num_frames=len(label))

            # Add extracted phrase.
            phrase_curve = np.fromfile(os.path.join(
                self.OutputGen.dir_labels,
                id_name + self.OutputGen.ext_phrase),
                                       dtype=np.float32)[:len(reconstruction)]
            reconstruction[:len(phrase_curve)] += phrase_curve
            reconstruction[reconstruction <=
                           math.log(WorldFeatLabelGen.f0_silence_threshold
                                    )] = WorldFeatLabelGen.lf0_zero

            recon_dict[id_name] = reconstruction

        return recon_dict

    def get_phrase_curve(self, id_name):
        return np.fromfile(os.path.join(self.OutputGen.dir_labels,
                                        id_name + self.OutputGen.ext_phrase),
                           dtype=np.float32).reshape(-1, 1)

    def compute_score(self, dict_outputs_post, dict_hiddens, hparams):

        # Get data for comparision.
        dict_original_post = self.load_extracted_audio_features(
            dict_outputs_post, hparams)

        f0_rmse = 0.0
        f0_rmse_max_id = "None"
        f0_rmse_max = 0.0
        for id_name, labels in dict_outputs_post.items():
            output_lf0 = AtomLabelGen.labels_to_lf0(
                labels,
                k=hparams.k,
                frame_size=hparams.frame_size_ms,
                amp_threshold=hparams.min_atom_amp)

            # Get data for comparision.
            org_lf0 = dict_original_post[id_name][:, hparams.num_coded_sps]
            org_vuv = dict_original_post[id_name][:, hparams.num_coded_sps + 1]
            phrase_curve = self.get_phrase_curve(id_name)

            # Compute f0 from lf0.
            org_f0 = (np.exp(org_lf0.squeeze()) * org_vuv)[:len(
                output_lf0)]  # Fix minor negligible length mismatch.
            output_f0 = np.exp(output_lf0 + phrase_curve[:len(output_lf0)].
                               squeeze()) * org_vuv[:len(output_lf0)]

            # Compute RMSE, keep track of worst RMSE.
            f0_mse = (org_f0 - output_f0)**2
            current_f0_rmse = math.sqrt(f0_mse.sum() / org_vuv.sum())
            if current_f0_rmse > f0_rmse_max:
                f0_rmse_max_id = id_name
                f0_rmse_max = current_f0_rmse
            f0_rmse += current_f0_rmse

        f0_rmse /= len(dict_outputs_post)
        self.logger.info("Worst F0 RMSE: " + f0_rmse_max_id +
                         " {:4.2f}Hz".format(f0_rmse_max))
        self.logger.info("Benchmark score: F0 RMSE " +
                         "{:4.2f}Hz".format(f0_rmse))

        return f0_rmse

    def load_extracted_audio_features(self, synth_output, hparams):
        """Load the audio features extracted from audio."""
        self.logger.info("Load extracted mgc, lf0, vuv, bap data.")

        org_output = dict()
        for id_name in synth_output.keys():
            world_dir = hparams.world_dir if hasattr(hparams, "world_dir") and hparams.world_dir is not None\
                                          else os.path.realpath(os.path.join(self.OutputGen.dir_labels, self.dir_extracted_acoustic_features))
            org_output[id_name] = WorldFeatLabelGen.load_sample(
                id_name,
                world_dir,
                add_deltas=False,
                num_coded_sps=hparams.num_coded_sps)  # Load extracted data.

        return org_output

    def generate_audio_features(self, id_list,
                                hparams):  # TODO: This function is untested.
        """
        Generate mgc, vuv and bap data with an acoustic model.
        The name of the acoustic model is saved in hparams.synth_acoustic_model_path and given in the constructor.
        If the synth_acoustic_model_path is 'None' this method will not be called but the method
        load_extracted_audio_features, which reloads the original data extracted from the audio.

        If you want to generate audio directly from wcad atom extraction, uncomment the first block
        in the get_recon_from_synth_output method.

        Detailed execution process:
        This method reuses the synth method of the ModelTrainer base class. It overwrites the internal
        f_synthesize method and the OutputGen to accomplish the audio generation. Both are restored
        after finishing the generation. The base class synth method loads the acoustic model network
        by its name and forwards the question labels for each utterance in the id_list. At the
        end the method calls the f_synthesize method. Therefore the f_synthesize method is overwritten
        by the save_audio_features which saves the generate output mgc, vuv and bap files in the
        self.synth_dir folder.
        """
        self.logger.info("Generate mgc, vuv and bap with " +
                         hparams.synth_acoustic_model_path)

        acoustic_model_hparams = AcousticModelTrainer.create_hparams()
        acoustic_model_hparams.model_name = os.path.basename(
            hparams.synth_acoustic_model_path)
        acoustic_model_hparams.model_path = hparams.synth_acoustic_model_path
        acoustic_model_handler = AcousticModelTrainer(acoustic_model_hparams)

        org_model_handler = self.model_handler
        self.model_handler = acoustic_model_handler

        # Switch f_synthesize method and OutputGen for mgc, vuv and bap creation.
        # f_synthesize is called at the end of synth.
        self.f_synthesize = self.save_audio_features
        org_output_gen = self.OutputGen
        self.OutputGen = self.AudioGen

        # Explicitly synthesize with acoustic_model_name.
        # This method calls f_synthesize at the end which will save the mgc, vuv and bap.
        self.synth(hparams, id_list)

        # Switch back to atom creation.
        self.f_synthesize = self.synthesize
        self.OutputGen = org_output_gen
        self.model_handler = org_model_handler

    def synthesize(self, id_list, synth_output, hparams):
        """This method should be overwritten by sub classes."""
        # Create lf0 from atoms of output and get other acoustic features either by loading the original labels or by
        # generating them with the model at hparams.synth_acoustic_model_path.
        full_output = self.run_atom_synth(id_list, synth_output, hparams)
        # Run the WORLD synthesizer.
        Synthesiser.run_world_synth(full_output, hparams)

    def synth_ref_wcad(self, file_id_list, hparams):
        synth_output = dict()
        # Load extracted atoms.
        for id_name in file_id_list:
            synth_output[id_name] = AtomLabelGen.load_sample(
                id_name, self.OutputGen.dir_labels, len(hparams.thetas))

        full_output = self.run_atom_synth(file_id_list, synth_output, hparams)

        # Add identifier to suffix.
        old_synth_file_suffix = hparams.synth_file_suffix
        hparams.synth_file_suffix += "_wcad_ref"

        # Run the WORLD synthesizer.
        Synthesiser.run_world_synth(full_output, hparams)

        # Restore identifier.
        hparams.synth_file_suffix = old_synth_file_suffix

    def synth_phrase(self, file_id_list, hparams):
        # Create reference audio files containing only the vocoder degradation.
        self.logger.info("Synthesise phrase curve for [{0}].".format(", ".join(
            [id_name for id_name in file_id_list])))

        # Create an empty dictionary which can be filled with extracted audio features.
        synth_output = dict()
        for id_name in file_id_list:
            synth_output[id_name] = None
        # Fill dictionary with extracted audio features.
        full_output = self.load_extracted_audio_features(synth_output, hparams)

        # Override the lf0 component by the phrase curve.
        for id_name in file_id_list:
            labels = full_output[id_name]
            phrase_curve = np.fromfile(
                os.path.join(self.OutputGen.dir_labels,
                             id_name + self.OutputGen.ext_phrase),
                dtype=np.float32)[:len(full_output[id_name])]
            labels[:, -3] = phrase_curve[:len(labels)]

        # Add identifier to suffix.
        old_synth_file_suffix = hparams.synth_file_suffix
        hparams.synth_file_suffix += '_phrase'

        # Run the vocoder.
        ModelTrainer.synthesize(self, file_id_list, full_output, hparams)

        # Restore identifier.
        hparams.synth_file_suffix = old_synth_file_suffix

    def run_atom_synth(self, file_id_list, synth_output, hparams):
        """
        Reconstruct lf0, get mgc and bap data, and store all in files in self.synth_dir.
        """

        # Get mgc, vuv and bap data either through a trained acoustic model or from data extracted from the audio.
        if hparams.synth_acoustic_model_path is None:
            full_output = self.load_extracted_audio_features(
                synth_output, hparams)
        else:
            self.logger.warning("This method is untested.")
            full_output = self.generate_audio_features(file_id_list, hparams)

        # Reconstruct lf0 from generated atoms and write it to synth output.
        recon_dict = self.get_recon_from_synth_output(synth_output, hparams)
        for id_name, lf0 in recon_dict.items():
            full_sample = full_output[id_name]
            len_diff = len(full_sample) - len(lf0)
            full_sample = WorldFeatLabelGen.trim_end_sample(full_sample,
                                                            int(len_diff / 2),
                                                            reverse=True)
            full_sample = WorldFeatLabelGen.trim_end_sample(
                full_sample, len_diff - int(len_diff / 2))
            vuv = np.ones(lf0.shape)
            vuv[lf0 <= math.log(WorldFeatLabelGen.f0_silence_threshold)] = 0.0
            full_sample[:, hparams.num_coded_sps] = lf0
            full_sample[:, hparams.num_coded_sps + 1] = vuv

        return full_output
Example #10
0
    def gen_figure_from_output(self, id_name, labels, hidden, hparams):

        if labels.ndim < 2:
            labels = np.expand_dims(labels, axis=1)
        labels_post = self.OutputGen.postprocess_sample(labels,
                                                        identify_peaks=True,
                                                        peak_range=100)
        lf0 = self.OutputGen.labels_to_lf0(labels_post, hparams.k)
        lf0, vuv = interpolate_lin(lf0)
        vuv = vuv.astype(np.bool)

        # Load original lf0 and vuv.
        world_dir = hparams.world_dir if hasattr(hparams, "world_dir") and hparams.world_dir is not None\
                                      else os.path.join(self.OutputGen.dir_labels, self.dir_extracted_acoustic_features)
        org_labels = WorldFeatLabelGen.load_sample(
            id_name, world_dir, num_coded_sps=hparams.num_coded_sps)
        _, original_lf0, original_vuv, _ = WorldFeatLabelGen.convert_to_world_features(
            org_labels, num_coded_sps=hparams.num_coded_sps)
        original_lf0, _ = interpolate_lin(original_lf0)
        original_vuv = original_vuv.astype(np.bool)

        phrase_curve = np.fromfile(os.path.join(
            self.OutputGen.dir_labels, id_name + self.OutputGen.ext_phrase),
                                   dtype=np.float32).reshape(-1, 1)
        original_lf0 -= phrase_curve
        len_diff = len(original_lf0) - len(lf0)
        original_lf0 = WorldFeatLabelGen.trim_end_sample(
            original_lf0, int(len_diff / 2.0))
        original_lf0 = WorldFeatLabelGen.trim_end_sample(original_lf0,
                                                         int(len_diff / 2.0) +
                                                         1,
                                                         reverse=True)

        org_labels = self.OutputGen.load_sample(id_name,
                                                self.OutputGen.dir_labels,
                                                len(hparams.thetas))
        org_labels = self.OutputGen.trim_end_sample(org_labels,
                                                    int(len_diff / 2.0))
        org_labels = self.OutputGen.trim_end_sample(org_labels,
                                                    int(len_diff / 2.0) + 1,
                                                    reverse=True)
        org_atoms = self.OutputGen.labels_to_atoms(
            org_labels, k=hparams.k, frame_size=hparams.frame_size_ms)

        # Get a data plotter.
        net_name = os.path.basename(hparams.model_name)
        filename = str(os.path.join(hparams.out_dir, id_name + '.' + net_name))
        plotter = DataPlotter()
        plotter.set_title(id_name + " - " + net_name)

        graphs_output = list()
        grid_idx = 0
        for idx in reversed(range(labels.shape[1])):
            graphs_output.append(
                (labels[:, idx],
                 r'$\theta$=' + "{0:.3f}".format(hparams.thetas[idx])))
        plotter.set_label(grid_idx=grid_idx,
                          xlabel='frames [' + str(hparams.frame_size_ms) +
                          ' ms]',
                          ylabel='NN output')
        plotter.set_data_list(grid_idx=grid_idx, data_list=graphs_output)
        # plotter.set_lim(grid_idx=0, ymin=-1.8, ymax=1.8)

        grid_idx += 1
        graphs_peaks = list()
        for idx in reversed(range(labels_post.shape[1])):
            graphs_peaks.append((labels_post[:, idx, 0], ))
        plotter.set_label(grid_idx=grid_idx,
                          xlabel='frames [' + str(hparams.frame_size_ms) +
                          ' ms]',
                          ylabel='NN post-processed')
        plotter.set_data_list(grid_idx=grid_idx, data_list=graphs_peaks)
        plotter.set_area_list(grid_idx=grid_idx,
                              area_list=[(np.invert(vuv), '0.8', 1.0)])
        plotter.set_lim(grid_idx=grid_idx, ymin=-1.8, ymax=1.8)

        grid_idx += 1
        graphs_target = list()
        for idx in reversed(range(org_labels.shape[1])):
            graphs_target.append((org_labels[:, idx, 0], ))
        plotter.set_label(grid_idx=grid_idx,
                          xlabel='frames [' + str(hparams.frame_size_ms) +
                          ' ms]',
                          ylabel='target')
        plotter.set_data_list(grid_idx=grid_idx, data_list=graphs_target)
        plotter.set_area_list(grid_idx=grid_idx,
                              area_list=[(np.invert(original_vuv), '0.8', 1.0)
                                         ])
        plotter.set_lim(grid_idx=grid_idx, ymin=-1.8, ymax=1.8)

        grid_idx += 1
        output_atoms = AtomLabelGen.labels_to_atoms(
            labels_post,
            hparams.k,
            hparams.frame_size_ms,
            amp_threshold=hparams.min_atom_amp)
        wcad_lf0 = AtomLabelGen.atoms_to_lf0(org_atoms, len(labels))
        output_lf0 = AtomLabelGen.atoms_to_lf0(output_atoms, len(labels))
        graphs_lf0 = list()
        graphs_lf0.append((wcad_lf0, "wcad lf0"))
        graphs_lf0.append((original_lf0, "org lf0"))
        graphs_lf0.append((output_lf0, "predicted lf0"))
        plotter.set_data_list(grid_idx=grid_idx, data_list=graphs_lf0)
        plotter.set_area_list(grid_idx=grid_idx,
                              area_list=[(np.invert(original_vuv), '0.8', 1.0)
                                         ])
        plotter.set_label(grid_idx=grid_idx,
                          xlabel='frames [' + str(hparams.frame_size_ms) +
                          ' ms]',
                          ylabel='lf0')
        amp_lim = max(np.max(np.abs(wcad_lf0)), np.max(
            np.abs(output_lf0))) * 1.1
        plotter.set_lim(grid_idx=grid_idx, ymin=-amp_lim, ymax=amp_lim)
        plotter.set_linestyles(grid_idx=grid_idx, linestyles=[':', '--', '-'])

        # plotter.set_lim(xmin=300, xmax=1100)
        plotter.gen_plot()
        plotter.save_to_file(filename + ".BASE" + hparams.gen_figure_ext)
    def gen_figure_from_output(self, id_name, label, hidden, hparams):

        # Retrieve data from label.
        output_amps = label[:, 1:-1]
        output_pos = label[:, -1]
        labels_post = self.OutputGen.postprocess_sample(label)
        output_vuv = labels_post[:, 0, 1].astype(bool)
        output_atoms = self.OutputGen.labels_to_atoms(labels_post, k=hparams.k, amp_threshold=hparams.min_atom_amp)
        output_lf0 = self.OutputGen.atoms_to_lf0(output_atoms, len(label))

        # Load original lf0 and vuv.
        world_dir = hparams.world_dir if hasattr(hparams, "world_dir") and hparams.world_dir is not None\
                                      else os.path.join(self.OutputGen.dir_labels, self.dir_extracted_acoustic_features)
        org_labels = LF0LabelGen.load_sample(id_name, world_dir)
        original_lf0, _ = LF0LabelGen.convert_to_world_features(org_labels)
        original_lf0, _ = interpolate_lin(original_lf0)

        phrase_curve = np.fromfile(os.path.join(self.OutputGen.dir_labels, id_name + self.OutputGen.ext_phrase),
                                   dtype=np.float32).reshape(-1, 1)
        original_lf0[:len(phrase_curve)] -= phrase_curve[:len(original_lf0)]
        original_lf0 = original_lf0[:len(output_lf0)]

        org_labels = self.OutputGen.load_sample(id_name,
                                                self.OutputGen.dir_labels,
                                                len(hparams.thetas),
                                                self.OutputGen.dir_world_labels)
        org_vuv = org_labels[:, 0, 0].astype(bool)
        org_labels = org_labels[:, 1:]
        len_diff = len(org_labels) - len(labels_post)
        org_labels = self.OutputGen.trim_end_sample(org_labels, int(len_diff / 2.0))
        org_labels = self.OutputGen.trim_end_sample(org_labels, int(len_diff / 2.0) + 1)
        org_atoms = AtomLabelGen.labels_to_atoms(org_labels, k=hparams.k, frame_size=hparams.frame_size_ms)
        wcad_lf0 = self.OutputGen.atoms_to_lf0(org_atoms, len(org_labels))

        # Get a data plotter
        net_name = os.path.basename(hparams.model_name)
        filename = str(os.path.join(hparams.out_dir, id_name + '.' + net_name))
        plotter = DataPlotter()
        plotter.set_title(id_name + " - " + net_name)

        grid_idx = 0
        graphs_output = list()
        for idx in reversed(range(output_amps.shape[1])):
            graphs_output.append((output_amps[:, idx], r'$\theta$={0:.3f}'.format(hparams.thetas[idx])))
        plotter.set_data_list(grid_idx=grid_idx, data_list=graphs_output)
        plotter.set_label(grid_idx=grid_idx, ylabel='NN amps')
        amp_max = np.max(output_amps) * 1.1
        amp_min = np.min(output_amps) * 1.1
        plotter.set_lim(grid_idx=grid_idx, ymin=amp_min, ymax=amp_max)

        grid_idx += 1
        graphs_pos_flag = list()
        graphs_pos_flag.append((output_pos,))
        plotter.set_data_list(grid_idx=grid_idx, data_list=graphs_pos_flag)
        plotter.set_label(grid_idx=grid_idx, ylabel='NN pos')

        grid_idx += 1
        graphs_peaks = list()
        for idx in reversed(range(label.shape[1] - 2)):
            graphs_peaks.append((labels_post[:, 1 + idx, 0],))
        plotter.set_data_list(grid_idx=grid_idx, data_list=graphs_peaks)
        plotter.set_area_list(grid_idx=grid_idx, area_list=[(np.invert(output_vuv), '0.75', 1.0, 'Unvoiced')])
        plotter.set_label(grid_idx=grid_idx, ylabel='NN peaks')
        plotter.set_lim(grid_idx=grid_idx, ymin=-1.8, ymax=1.8)

        grid_idx += 1
        graphs_target = list()
        for idx in reversed(range(org_labels.shape[1])):
            graphs_target.append((org_labels[:, idx, 0],))
        plotter.set_data_list(grid_idx=grid_idx, data_list=graphs_target)
        plotter.set_hatchstyles(grid_idx=grid_idx, hatchstyles=['\\\\'])
        plotter.set_area_list(grid_idx=grid_idx, area_list=[(np.invert(org_vuv.astype(bool)), '0.75', 1.0, 'Reference unvoiced')])
        plotter.set_label(grid_idx=grid_idx, ylabel='target')
        plotter.set_lim(grid_idx=grid_idx, ymin=-1.8, ymax=1.8)

        grid_idx += 1
        graphs_lf0 = list()
        graphs_lf0.append((wcad_lf0, "wcad lf0"))
        graphs_lf0.append((original_lf0, "org lf0"))
        graphs_lf0.append((output_lf0, "predicted lf0"))
        plotter.set_data_list(grid_idx=grid_idx, data_list=graphs_lf0)
        plotter.set_area_list(grid_idx=grid_idx, area_list=[(np.invert(org_vuv.astype(bool)), '0.75', 1.0)])
        plotter.set_hatchstyles(grid_idx=grid_idx, hatchstyles=['\\\\'])
        plotter.set_label(grid_idx=grid_idx, xlabel='frames [' + str(hparams.frame_size_ms) + ' ms]', ylabel='lf0')
        amp_lim = max(np.max(np.abs(wcad_lf0)), np.max(np.abs(output_lf0))) * 1.1
        plotter.set_lim(grid_idx=grid_idx, ymin=-amp_lim, ymax=amp_lim)
        plotter.set_linestyles(grid_idx=grid_idx, linestyles=[':', '--', '-'])

        # # Compute F0 RMSE for sample and add it to title.
        # org_f0 = (np.exp(lf0.squeeze() + phrase_curve[:len(lf0)].squeeze()) * vuv)[:len(output_lf0)]  # Fix minor negligible length mismatch.
        # output_f0 = np.exp(output_lf0 + phrase_curve[:len(output_lf0)].squeeze()) * output_vuv[:len(output_lf0)]
        # f0_mse = (org_f0 - output_f0) ** 2
        # # non_zero_count = np.logical_and(vuv[:len(output_lf0)], output_vuv).sum()
        # f0_rmse = math.sqrt(f0_mse.sum() / (np.logical_and(vuv[:len(output_lf0)], output_vuv).sum()))

        # # Compute vuv error rate.
        # num_errors = (vuv[:len(output_lf0)] != output_vuv)
        # vuv_error_rate = float(num_errors.sum()) / len(output_lf0)
        # plotter.set_title(id_name + " - " + net_name + " - F0_RMSE_" + "{:4.2f}Hz".format(f0_rmse) + " - VUV_" + "{:2.2f}%".format(vuv_error_rate * 100))
        # plotter.set_lim(xmin=300, xmax=1100)g
        plotter.gen_plot(monochrome=True)
        plotter.gen_plot()
        plotter.save_to_file(filename + ".VUV_DIST_POS" + hparams.gen_figure_ext)
Example #12
0
 def test_load(self):
     sample = AtomLabelGen.load_sample(self.id_list[0],
                                       self.dir_atoms,
                                       num_thetas=5)
     self.assertEqual(1931, sample.shape[0])