def gen_figure_from_output(self, id_name, label, hidden, hparams):

        labels_post = self.OutputGen.postprocess_sample(label)
        coded_sp, lf0, vuv, bap = WorldFeatLabelGen.convert_to_world_features(labels_post, contains_deltas=False, num_coded_sps=hparams.num_coded_sps)
        lf0, _ = interpolate_lin(lf0)

        # Load original lf0.
        org_labels_post = WorldFeatLabelGen.load_sample(id_name, self.OutputGen.dir_labels, add_deltas=self.OutputGen.add_deltas, num_coded_sps=hparams.num_coded_sps)
        original_mgc, original_lf0, original_vuv, *_ = WorldFeatLabelGen.convert_to_world_features(org_labels_post, contains_deltas=self.OutputGen.add_deltas, num_coded_sps=hparams.num_coded_sps)
        original_lf0, _ = interpolate_lin(original_lf0)

        # # Load merlin lf0.
        # merlin_data_dir = os.path.join(os.path.dirname(self.InputGen.label_dir), "wav_merlin/")
        # with open(os.path.join(merlin_data_dir, id_name + ".lf0"), 'rb') as f:
        #     merlin_lf0 = np.fromfile(f, dtype=np.float32)
        #     merlin_lf0 = np.reshape(merlin_lf0, merlin_lf0.shape[0] / original_lf0.shape[1], original_lf0.shape[1])
        #     merlin_lf0[merlin_lf0 < 0] = 0.0
        #     merlin_lf0, _ = interpolate_lin(merlin_lf0)

        # Get a data plotter.
        grid_idx = 0
        plotter = DataPlotter()
        net_name = os.path.basename(hparams.model_name)
        filename = str(os.path.join(hparams.out_dir, id_name + '.' + net_name))
        plotter.set_title(id_name + ' - ' + net_name)
        plotter.set_num_colors(3)
        # plotter.set_lim(grid_idx=0, ymin=math.log(60), ymax=math.log(250))
        plotter.set_label(grid_idx=grid_idx, xlabel='frames [' + str(hparams.frame_size_ms) + ' ms]', ylabel='log(f0)')

        graphs = list()
        # graphs.append((merlin_lf0, 'Merlin lf0'))
        graphs.append((original_lf0, 'Original lf0'))
        graphs.append((lf0, 'PyTorch lf0'))
        plotter.set_data_list(grid_idx=grid_idx, data_list=graphs)
        plotter.set_area_list(grid_idx=grid_idx, area_list=[(np.invert(vuv.astype(bool)), '0.8', 1.0),(np.invert(original_vuv.astype(bool)), 'red', 0.2)])

        grid_idx += 1
        plotter.set_label(grid_idx=grid_idx, xlabel='frames [' + str(hparams.frame_size_ms) + ' ms]', ylabel='Original spectrogram')
        plotter.set_specshow(grid_idx=grid_idx, spec=WorldFeatLabelGen.mgc_to_sp(original_mgc, hparams.synth_fs))

        grid_idx += 1
        plotter.set_label(grid_idx=grid_idx, xlabel='frames [' + str(hparams.frame_size_ms) + ' ms]', ylabel='NN spectrogram')
        plotter.set_specshow(grid_idx=grid_idx, spec=WorldFeatLabelGen.mgc_to_sp(coded_sp, hparams.synth_fs))

        if hasattr(hparams, "phoneme_indices") and hparams.phoneme_indices is not None \
           and hasattr(hparams, "question_file") and hparams.question_file is not None:
            questions = QuestionLabelGen.load_sample(id_name,
                                                     "experiments/" + hparams.voice + "/questions/",
                                                     num_questions=hparams.num_questions)[:len(lf0)]
            np_phonemes = QuestionLabelGen.questions_to_phonemes(questions, hparams.phoneme_indices, hparams.question_file)
            plotter.set_annotations(grid_idx, np_phonemes)

        plotter.gen_plot()
        plotter.save_to_file(filename + '.Org-PyTorch' + hparams.gen_figure_ext)
Пример #2
0
    def __init__(self, wcad_root, dir_atom_labels, dir_question_labels, id_list, thetas, k,
                 num_questions, hparams=None):
        """Default constructor.

        :param wcad_root:               Path to main directory of wcad.
        :param dir_atom_labels:         Path to directory that contains the .atom files.
        :param dir_question_labels:     Path to directory that contains the .questions files.
        :param id_list:                 List containing all ids. Subset is taken as test set.
        :param thetas:                  List of theta values.
        :param k:                       K value of atoms.
        :param num_questions:           Expected number of questions in question labels.
        :param hparams:                 Hyper-parameter container.
        """
        if hparams is None:
            hparams = self.create_hparams()
            hparams.out_dir = os.path.curdir

        # Write missing default parameters.
        if hparams.variable_sequence_length_train is None:
            hparams.variable_sequence_length_train = hparams.batch_size_train > 1
        if hparams.variable_sequence_length_test is None:
            hparams.variable_sequence_length_test = hparams.batch_size_test > 1
        if hparams.synth_dir is None:
            hparams.synth_dir = os.path.join(hparams.out_dir, "synth")

        # If the weight for unvoiced frames is not given, compute it to get equal weights.
        if not hasattr(hparams, "weight_zero") or hparams.weight_zero is None:
            non_zero_occurrence = min(0.99, 0.02 / len(thetas))
            zero_occurrence = 1 - non_zero_occurrence
            hparams.weight_non_zero = 1 / non_zero_occurrence
            hparams.weight_zero = 1 / zero_occurrence

        super(AtomModelTrainer, self).__init__(id_list, hparams)

        self.InputGen = QuestionLabelGen(dir_question_labels, num_questions)
        self.InputGen.get_normalisation_params(dir_question_labels)

        self.OutputGen = AtomLabelGen(wcad_root, dir_atom_labels, thetas, k, hparams.frame_size_ms)
        self.OutputGen.get_normalisation_params(dir_atom_labels)

        self.dataset_train = PyTorchLabelGensDataset(self.id_list_train, self.InputGen, self.OutputGen, hparams)
        self.dataset_val = PyTorchLabelGensDataset(self.id_list_val, self.InputGen, self.OutputGen, hparams)

        if self.loss_function is None:
            self.loss_function = WeightedNonzeroMSELoss(hparams.use_gpu, hparams.weight_zero, hparams.weight_non_zero,
                                                        size_average=False, reduce=False)
        if hparams.scheduler_type == "default":
            hparams.scheduler_type = "Plateau"
            hparams.plateau_patience = 10
            hparams.plateau_factor = 0.5
            hparams.plateau_verbos = True
Пример #3
0
    def gen_figure_from_output(self, id_name, labels, hidden, hparams):

        labels_post = self.OutputGen.postprocess_sample(labels)
        mfcc, lf0, vuv, bap = WorldFeatLabelGen.convert_to_world_features(labels_post, num_coded_sps=self.OutputGen.num_coded_sps)
        lf0, _ = interpolate_lin(lf0)

        # Load original lf0.
        org_labels_post = WorldFeatLabelGen.load_sample(id_name, self.OutputGen.dir_labels, num_coded_sps=hparams.num_coded_sps)
        _, original_lf0, *_ = WorldFeatLabelGen.convert_to_world_features(org_labels_post, num_coded_sps=hparams.num_coded_sps)
        original_lf0, _ = interpolate_lin(original_lf0)

        # Get a data plotter.
        grid_idx = 0
        plotter = DataPlotter()
        net_name = os.path.basename(hparams.model_name)
        filename = str(os.path.join(hparams.out_dir, id_name + '.' + net_name))
        plotter.set_title(id_name + ' - ' + net_name)
        plotter.set_num_colors(3)
        # plotter.set_lim(grid_idx=grid_idx, ymin=math.log(60), ymax=math.log(250))
        plotter.set_label(grid_idx=grid_idx, xlabel='frames [' + str(hparams.frame_size_ms) + ' ms]', ylabel='log(f0)')
        grid_idx += 1

        graphs = list()
        # graphs.append((merlin_lf0, 'Merlin lf0'))
        graphs.append((original_lf0, 'Original lf0'))
        graphs.append((lf0, 'PyTorch lf0'))
        plotter.set_data_list(grid_idx=grid_idx, data_list=graphs)
        plotter.set_area_list(grid_idx=grid_idx, area_list=[(np.invert(vuv.astype(bool)), '0.8', 1.0)])

        if hasattr(hparams, "phoneme_indices") and hparams.phoneme_indices is not None \
           and hasattr(hparams, "question_file") and hparams.question_file is not None:
            questions = QuestionLabelGen.load_sample(id_name,
                                                     "experiments/" + hparams.voice + "/questions/",
                                                     num_questions=hparams.num_questions)[:len(lf0)]
            np_phonemes = QuestionLabelGen.questions_to_phonemes(questions, hparams.phoneme_indices, hparams.question_file)
            plotter.set_annotations(grid_idx, np_phonemes)

        plotter.gen_plot()
        plotter.save_to_file(filename + '.Org-PyTorch' + hparams.gen_figure_ext)
Пример #4
0
    def __init__(self, dir_world_features, dir_question_labels, id_list, num_questions, hparams=None):
        """Default constructor.

        :param dir_world_features:      Path to the directory containing the world features.
        :param dir_question_labels:     Path to the directory containing the question labels.
        :param id_list:                 List containing all ids. Subset is taken as test set.
        :param num_questions:           Expected number of questions in question labels.
        :param hparams:                 Hyper-parameter container.
        """
        if hparams is None:
            hparams = self.create_hparams()
            hparams.out_dir = os.path.curdir

        # Write missing default parameters.
        if hparams.variable_sequence_length_train is None:
            hparams.variable_sequence_length_train = hparams.batch_size_train > 1
        if hparams.variable_sequence_length_test is None:
            hparams.variable_sequence_length_test = hparams.batch_size_test > 1
        if hparams.synth_dir is None:
            hparams.synth_dir = os.path.join(hparams.out_dir, "synth")

        super().__init__(id_list, hparams)

        self.InputGen = QuestionLabelGen(dir_question_labels, num_questions)
        self.InputGen.get_normalisation_params(dir_question_labels)

        self.OutputGen = WorldFeatLabelGen(dir_world_features, add_deltas=False, num_coded_sps=hparams.num_coded_sps)
        self.OutputGen.get_normalisation_params(dir_world_features)

        self.dataset_train = PyTorchLabelGensDataset(self.id_list_train, self.InputGen, self.OutputGen, hparams)
        self.dataset_val = PyTorchLabelGensDataset(self.id_list_val, self.InputGen, self.OutputGen, hparams)

        if self.loss_function is None:
            self.loss_function = torch.nn.MSELoss(reduction='none')

        if hparams.scheduler_type == "default":
            hparams.scheduler_type = "Plateau"
            hparams.plateau_verbose = True
Пример #5
0
class AcousticModelTrainer(ModelTrainer):
    """
    Implementation of a ModelTrainer for the generation of acoustic features.

    Use question labels as input and WORLD features as output. Synthesize audio from model output.
    """
    logger = logging.getLogger(__name__)

    #########################
    # Default constructor
    #
    def __init__(self, dir_world_features, dir_question_labels, id_list, num_questions, hparams=None):
        """Default constructor.

        :param dir_world_features:      Path to the directory containing the world features.
        :param dir_question_labels:     Path to the directory containing the question labels.
        :param id_list:                 List containing all ids. Subset is taken as test set.
        :param num_questions:           Expected number of questions in question labels.
        :param hparams:                 Hyper-parameter container.
        """
        if hparams is None:
            hparams = self.create_hparams()
            hparams.out_dir = os.path.curdir

        # Write missing default parameters.
        if hparams.variable_sequence_length_train is None:
            hparams.variable_sequence_length_train = hparams.batch_size_train > 1
        if hparams.variable_sequence_length_test is None:
            hparams.variable_sequence_length_test = hparams.batch_size_test > 1
        if hparams.synth_dir is None:
            hparams.synth_dir = os.path.join(hparams.out_dir, "synth")

        super().__init__(id_list, hparams)

        self.InputGen = QuestionLabelGen(dir_question_labels, num_questions)
        self.InputGen.get_normalisation_params(dir_question_labels)

        self.OutputGen = WorldFeatLabelGen(dir_world_features, add_deltas=False, num_coded_sps=hparams.num_coded_sps)
        self.OutputGen.get_normalisation_params(dir_world_features)

        self.dataset_train = PyTorchLabelGensDataset(self.id_list_train, self.InputGen, self.OutputGen, hparams)
        self.dataset_val = PyTorchLabelGensDataset(self.id_list_val, self.InputGen, self.OutputGen, hparams)

        if self.loss_function is None:
            self.loss_function = torch.nn.MSELoss(reduction='none')

        if hparams.scheduler_type == "default":
            hparams.scheduler_type = "Plateau"
            hparams.plateau_verbose = True

    @staticmethod
    def create_hparams(hparams_string=None, verbose=False):
        hparams = ModelTrainer.create_hparams(hparams_string, verbose=False)
        hparams.num_coded_sps = 60

        if verbose:
            logging.info('Final parsed hparams: %s', hparams.values())

        return hparams

    def gen_figure_from_output(self, id_name, labels, hidden, hparams):

        labels_post = self.OutputGen.postprocess_sample(labels)
        mfcc, lf0, vuv, bap = WorldFeatLabelGen.convert_to_world_features(labels_post, num_coded_sps=self.OutputGen.num_coded_sps)
        lf0, _ = interpolate_lin(lf0)

        # Load original lf0.
        org_labels_post = WorldFeatLabelGen.load_sample(id_name, self.OutputGen.dir_labels, num_coded_sps=hparams.num_coded_sps)
        _, original_lf0, *_ = WorldFeatLabelGen.convert_to_world_features(org_labels_post, num_coded_sps=hparams.num_coded_sps)
        original_lf0, _ = interpolate_lin(original_lf0)

        # Get a data plotter.
        grid_idx = 0
        plotter = DataPlotter()
        net_name = os.path.basename(hparams.model_name)
        filename = str(os.path.join(hparams.out_dir, id_name + '.' + net_name))
        plotter.set_title(id_name + ' - ' + net_name)
        plotter.set_num_colors(3)
        # plotter.set_lim(grid_idx=grid_idx, ymin=math.log(60), ymax=math.log(250))
        plotter.set_label(grid_idx=grid_idx, xlabel='frames [' + str(hparams.frame_size_ms) + ' ms]', ylabel='log(f0)')
        grid_idx += 1

        graphs = list()
        # graphs.append((merlin_lf0, 'Merlin lf0'))
        graphs.append((original_lf0, 'Original lf0'))
        graphs.append((lf0, 'PyTorch lf0'))
        plotter.set_data_list(grid_idx=grid_idx, data_list=graphs)
        plotter.set_area_list(grid_idx=grid_idx, area_list=[(np.invert(vuv.astype(bool)), '0.8', 1.0)])

        if hasattr(hparams, "phoneme_indices") and hparams.phoneme_indices is not None \
           and hasattr(hparams, "question_file") and hparams.question_file is not None:
            questions = QuestionLabelGen.load_sample(id_name,
                                                     "experiments/" + hparams.voice + "/questions/",
                                                     num_questions=hparams.num_questions)[:len(lf0)]
            np_phonemes = QuestionLabelGen.questions_to_phonemes(questions, hparams.phoneme_indices, hparams.question_file)
            plotter.set_annotations(grid_idx, np_phonemes)

        plotter.gen_plot()
        plotter.save_to_file(filename + '.Org-PyTorch' + hparams.gen_figure_ext)
Пример #6
0
    def __init__(self,
                 wcad_root,
                 dir_atom_labels,
                 dir_lf0_labels,
                 dir_question_labels,
                 id_list,
                 thetas,
                 k,
                 num_questions,
                 dist_window_size=51,
                 hparams=None):
        """Default constructor.

        :param wcad_root:               Path to main directory of wcad.
        :param dir_atom_labels:         Path to directory that contains the .wav files.
        :param dir_lf0_labels:          Path to directory that contains the .lf0 files.
        :param dir_question_labels:     Path to directory that contains the .lab files.
        :param id_list:                 List containing all ids. Subset is taken as test set.
        :param thetas:                  List of theta values of atoms.
        :param k:                       K-value of atoms.
        :param num_questions:           Expected number of questions in question labels.
        :param dist_window_size:        Width of the distribution surrounding each atom spike
                                        The window is only used for amps. Thetas are surrounded by a window of 5.
        :param hparams:                 Hyper-parameter container.
        """
        if hparams is None:
            hparams = self.create_hparams()
            hparams.out_dir = os.path.curdir

        # Write missing default parameters.
        if hparams.variable_sequence_length_train is None:
            hparams.variable_sequence_length_train = hparams.batch_size_train > 1
        if hparams.variable_sequence_length_test is None:
            hparams.variable_sequence_length_test = hparams.batch_size_test > 1
        if hparams.synth_dir is None:
            hparams.synth_dir = os.path.join(hparams.out_dir, "synth")

        # If the weight for unvoiced frames is not given, compute it to get equal weights.
        if not hasattr(hparams, "weight_zero") or hparams.weight_zero is None:
            non_zero_occurrence = min(0.99, 0.015 / len(thetas))
            zero_occurrence = 1 - non_zero_occurrence
            hparams.weight_non_zero = 1 / non_zero_occurrence
            hparams.weight_zero = 1 / zero_occurrence
        if not hasattr(hparams, "weight_vuv") or hparams.weight_vuv is None:
            hparams.weight_vuv = 0.5
        if not hasattr(hparams,
                       "atom_loss_theta") or hparams.atom_loss_theta is None:
            hparams.atom_loss_theta = 0.01

        # Explicitly call only the constructor of the baseclass of AtomModelTrainer.
        super(AtomModelTrainer, self).__init__(id_list, hparams)

        if hparams.dist_window_size % 2 == 0:
            hparams.dist_window_size += 1
            self.logger.warning(
                "hparams.dist_window_size should be odd, changed it to " +
                str(hparams.dist_window_size))

        self.InputGen = QuestionLabelGen(dir_question_labels, num_questions)
        self.InputGen.get_normalisation_params(dir_question_labels)

        # Overwrite OutputGen by the one with beta distribution.
        self.OutputGen = AtomVUVDistPosLabelGen(wcad_root,
                                                dir_atom_labels,
                                                dir_lf0_labels,
                                                thetas,
                                                k,
                                                hparams.frame_size_ms,
                                                window_size=dist_window_size)
        self.OutputGen.get_normalisation_params(dir_atom_labels)

        self.dataset_train = PyTorchLabelGensDataset(self.id_list_train,
                                                     self.InputGen,
                                                     self.OutputGen,
                                                     hparams,
                                                     length_check=True)
        self.dataset_val = PyTorchLabelGensDataset(self.id_list_val,
                                                   self.InputGen,
                                                   self.OutputGen,
                                                   hparams,
                                                   length_check=True)

        if self.loss_function is None:
            self.loss_function = WeightedNonzeroWMSEAtomLoss(
                hparams.use_gpu,
                hparams.atom_loss_theta,
                hparams.weight_vuv,
                hparams.weight_zero,
                hparams.weight_non_zero,
                reduce=False)

        if hparams.scheduler_type == "default":
            hparams.scheduler_type = "None"
Пример #7
0
class AtomVUVDistPosModelTrainer(AtomModelTrainer):
    """
    Subclass of AtomModelTrainer, which uses one amplitude per theta plus position flag,
    format is T x (|thetas| + 1). Each amplitude in the target labels is surrounded by a distribution.
    Positions of atoms are identified by finding the peaks of the position flag prediction. For positive peaks
    the theta with the highest amplitude is used, for negative peaks the theta with the lowest amplitude.
    Acoustic data is generated from these atoms. MGC and BAP is either generated by a pre-trained acoustic model
    or loaded from the original extracted files. Question labels are used as input.
    """
    logger = logging.getLogger(__name__)

    def __init__(self,
                 wcad_root,
                 dir_atom_labels,
                 dir_lf0_labels,
                 dir_question_labels,
                 id_list,
                 thetas,
                 k,
                 num_questions,
                 dist_window_size=51,
                 hparams=None):
        """Default constructor.

        :param wcad_root:               Path to main directory of wcad.
        :param dir_atom_labels:         Path to directory that contains the .wav files.
        :param dir_lf0_labels:          Path to directory that contains the .lf0 files.
        :param dir_question_labels:     Path to directory that contains the .lab files.
        :param id_list:                 List containing all ids. Subset is taken as test set.
        :param thetas:                  List of theta values of atoms.
        :param k:                       K-value of atoms.
        :param num_questions:           Expected number of questions in question labels.
        :param dist_window_size:        Width of the distribution surrounding each atom spike
                                        The window is only used for amps. Thetas are surrounded by a window of 5.
        :param hparams:                 Hyper-parameter container.
        """
        if hparams is None:
            hparams = self.create_hparams()
            hparams.out_dir = os.path.curdir

        # Write missing default parameters.
        if hparams.variable_sequence_length_train is None:
            hparams.variable_sequence_length_train = hparams.batch_size_train > 1
        if hparams.variable_sequence_length_test is None:
            hparams.variable_sequence_length_test = hparams.batch_size_test > 1
        if hparams.synth_dir is None:
            hparams.synth_dir = os.path.join(hparams.out_dir, "synth")

        # If the weight for unvoiced frames is not given, compute it to get equal weights.
        if not hasattr(hparams, "weight_zero") or hparams.weight_zero is None:
            non_zero_occurrence = min(0.99, 0.015 / len(thetas))
            zero_occurrence = 1 - non_zero_occurrence
            hparams.weight_non_zero = 1 / non_zero_occurrence
            hparams.weight_zero = 1 / zero_occurrence
        if not hasattr(hparams, "weight_vuv") or hparams.weight_vuv is None:
            hparams.weight_vuv = 0.5
        if not hasattr(hparams,
                       "atom_loss_theta") or hparams.atom_loss_theta is None:
            hparams.atom_loss_theta = 0.01

        # Explicitly call only the constructor of the baseclass of AtomModelTrainer.
        super(AtomModelTrainer, self).__init__(id_list, hparams)

        if hparams.dist_window_size % 2 == 0:
            hparams.dist_window_size += 1
            self.logger.warning(
                "hparams.dist_window_size should be odd, changed it to " +
                str(hparams.dist_window_size))

        self.InputGen = QuestionLabelGen(dir_question_labels, num_questions)
        self.InputGen.get_normalisation_params(dir_question_labels)

        # Overwrite OutputGen by the one with beta distribution.
        self.OutputGen = AtomVUVDistPosLabelGen(wcad_root,
                                                dir_atom_labels,
                                                dir_lf0_labels,
                                                thetas,
                                                k,
                                                hparams.frame_size_ms,
                                                window_size=dist_window_size)
        self.OutputGen.get_normalisation_params(dir_atom_labels)

        self.dataset_train = PyTorchLabelGensDataset(self.id_list_train,
                                                     self.InputGen,
                                                     self.OutputGen,
                                                     hparams,
                                                     length_check=True)
        self.dataset_val = PyTorchLabelGensDataset(self.id_list_val,
                                                   self.InputGen,
                                                   self.OutputGen,
                                                   hparams,
                                                   length_check=True)

        if self.loss_function is None:
            self.loss_function = WeightedNonzeroWMSEAtomLoss(
                hparams.use_gpu,
                hparams.atom_loss_theta,
                hparams.weight_vuv,
                hparams.weight_zero,
                hparams.weight_non_zero,
                reduce=False)

        if hparams.scheduler_type == "default":
            hparams.scheduler_type = "None"

    @staticmethod
    def create_hparams(hparams_string=None, verbose=False):
        hparams = AtomModelTrainer.create_hparams(hparams_string,
                                                  verbose=False)
        hparams.dist_window_size = 51

        if verbose:
            logging.info('Final parsed hparams: %s', hparams.values())

        return hparams

    def gen_figure_from_output(self, id_name, label, hidden, hparams):

        # Retrieve data from label.
        output_amps = label[:, 1:-1]
        output_pos = label[:, -1]
        labels_post = self.OutputGen.postprocess_sample(label)
        output_vuv = labels_post[:, 0, 1].astype(bool)
        output_atoms = self.OutputGen.labels_to_atoms(
            labels_post, k=hparams.k, amp_threshold=hparams.min_atom_amp)
        output_lf0 = self.OutputGen.atoms_to_lf0(output_atoms, len(label))

        # Load original lf0 and vuv.
        org_labels = LF0LabelGen.load_sample(
            id_name,
            os.path.join(hparams.out_dir,
                         self.dir_extracted_acoustic_features))
        original_lf0, _ = LF0LabelGen.convert_to_world_features(org_labels)
        original_lf0, _ = interpolate_lin(original_lf0)

        phrase_curve = np.fromfile(os.path.join(
            self.OutputGen.dir_labels, id_name + self.OutputGen.ext_phrase),
                                   dtype=np.float32).reshape(-1, 1)
        original_lf0[:len(phrase_curve)] -= phrase_curve[:len(original_lf0)]
        original_lf0 = original_lf0[:len(output_lf0)]

        org_labels = self.OutputGen.load_sample(
            id_name, self.OutputGen.dir_labels, len(hparams.thetas),
            self.OutputGen.dir_world_labels)
        org_vuv = org_labels[:, 0, 0].astype(bool)
        org_labels = org_labels[:, 1:]
        len_diff = len(org_labels) - len(labels_post)
        org_labels = self.OutputGen.trim_end_sample(org_labels,
                                                    int(len_diff / 2.0))
        org_labels = self.OutputGen.trim_end_sample(org_labels,
                                                    int(len_diff / 2.0) + 1)
        org_atoms = AtomLabelGen.labels_to_atoms(
            org_labels, k=hparams.k, frame_size=hparams.frame_size_ms)
        wcad_lf0 = self.OutputGen.atoms_to_lf0(org_atoms, len(org_labels))

        # Get a data plotter
        net_name = os.path.basename(hparams.model_name)
        filename = str(os.path.join(hparams.out_dir, id_name + '.' + net_name))
        plotter = DataPlotter()
        plotter.set_title(id_name + " - " + net_name)

        grid_idx = 0
        graphs_output = list()
        for idx in reversed(range(output_amps.shape[1])):
            graphs_output.append(
                (output_amps[:, idx],
                 r'$\theta$={0:.3f}'.format(hparams.thetas[idx])))
        plotter.set_data_list(grid_idx=grid_idx, data_list=graphs_output)
        plotter.set_label(grid_idx=grid_idx, ylabel='NN amps')
        amp_max = np.max(output_amps) * 1.1
        amp_min = np.min(output_amps) * 1.1
        plotter.set_lim(grid_idx=grid_idx, ymin=amp_min, ymax=amp_max)

        grid_idx += 1
        graphs_pos_flag = list()
        graphs_pos_flag.append((output_pos, ))
        plotter.set_data_list(grid_idx=grid_idx, data_list=graphs_pos_flag)
        plotter.set_label(grid_idx=grid_idx, ylabel='NN pos')

        grid_idx += 1
        graphs_peaks = list()
        for idx in reversed(range(label.shape[1] - 2)):
            graphs_peaks.append((labels_post[:, 1 + idx, 0], ))
        plotter.set_data_list(grid_idx=grid_idx, data_list=graphs_peaks)
        plotter.set_area_list(grid_idx=grid_idx,
                              area_list=[(np.invert(output_vuv), '0.75', 1.0,
                                          'Unvoiced')])
        plotter.set_label(grid_idx=grid_idx, ylabel='NN peaks')
        plotter.set_lim(grid_idx=grid_idx, ymin=-1.8, ymax=1.8)

        grid_idx += 1
        graphs_target = list()
        for idx in reversed(range(org_labels.shape[1])):
            graphs_target.append((org_labels[:, idx, 0], ))
        plotter.set_data_list(grid_idx=grid_idx, data_list=graphs_target)
        plotter.set_hatchstyles(grid_idx=grid_idx, hatchstyles=['\\\\'])
        plotter.set_area_list(grid_idx=grid_idx,
                              area_list=[(np.invert(org_vuv.astype(bool)),
                                          '0.75', 1.0, 'Reference unvoiced')])
        plotter.set_label(grid_idx=grid_idx, ylabel='target')
        plotter.set_lim(grid_idx=grid_idx, ymin=-1.8, ymax=1.8)

        grid_idx += 1
        graphs_lf0 = list()
        graphs_lf0.append((wcad_lf0, "wcad lf0"))
        graphs_lf0.append((original_lf0, "org lf0"))
        graphs_lf0.append((output_lf0, "predicted lf0"))
        plotter.set_data_list(grid_idx=grid_idx, data_list=graphs_lf0)
        plotter.set_area_list(grid_idx=grid_idx,
                              area_list=[(np.invert(org_vuv.astype(bool)),
                                          '0.75', 1.0)])
        plotter.set_hatchstyles(grid_idx=grid_idx, hatchstyles=['\\\\'])
        plotter.set_label(grid_idx=grid_idx,
                          xlabel='frames [' + str(hparams.frame_size_ms) +
                          ' ms]',
                          ylabel='lf0')
        amp_lim = max(np.max(np.abs(wcad_lf0)), np.max(
            np.abs(output_lf0))) * 1.1
        plotter.set_lim(grid_idx=grid_idx, ymin=-amp_lim, ymax=amp_lim)
        plotter.set_linestyles(grid_idx=grid_idx, linestyles=[':', '--', '-'])

        # # Compute F0 RMSE for sample and add it to title.
        # org_f0 = (np.exp(lf0.squeeze() + phrase_curve[:len(lf0)].squeeze()) * vuv)[:len(output_lf0)]  # Fix minor negligible length mismatch.
        # output_f0 = np.exp(output_lf0 + phrase_curve[:len(output_lf0)].squeeze()) * output_vuv[:len(output_lf0)]
        # f0_mse = (org_f0 - output_f0) ** 2
        # # non_zero_count = np.logical_and(vuv[:len(output_lf0)], output_vuv).sum()
        # f0_rmse = math.sqrt(f0_mse.sum() / (np.logical_and(vuv[:len(output_lf0)], output_vuv).sum()))

        # # Compute vuv error rate.
        # num_errors = (vuv[:len(output_lf0)] != output_vuv)
        # vuv_error_rate = float(num_errors.sum()) / len(output_lf0)
        # plotter.set_title(id_name + " - " + net_name + " - F0_RMSE_" + "{:4.2f}Hz".format(f0_rmse) + " - VUV_" + "{:2.2f}%".format(vuv_error_rate * 100))
        # plotter.set_lim(xmin=300, xmax=1100)g
        plotter.gen_plot(monochrome=True)
        plotter.gen_plot()
        plotter.save_to_file(filename + ".VUV_DIST_POS.png")

    def compute_score(self, dict_outputs_post, dict_hiddens, hparams):
        """Compute the score of a dictionary with post-processes labels."""

        # Get data for comparision.
        dict_original_post = self.load_extracted_audio_features(
            dict_outputs_post, hparams)

        f0_rmse = 0.0
        vuv_error_rate = 0.0
        f0_rmse_max_id = "None"
        f0_rmse_max = 0.0
        vuv_error_max_id = "None"
        vuv_error_max = 0.0
        for id_name, label in dict_outputs_post.items():

            output_vuv = label[:, 0, 1].astype(bool)
            output_atom_labels = label[:, 1:]
            output_lf0 = self.OutputGen.labels_to_lf0(
                output_atom_labels,
                k=hparams.k,
                frame_size=hparams.frame_size_ms,
                amp_threshold=hparams.min_atom_amp)

            # Get data for comparision.
            org_lf0 = dict_original_post[id_name][:, 60]
            org_vuv = dict_original_post[id_name][:, 61]
            phrase_curve = self.get_phrase_curve(id_name)

            # Compute f0 from lf0.
            org_f0 = np.exp(org_lf0.squeeze())[:len(
                output_lf0)]  # Fix minor negligible length mismatch.
            output_f0 = np.exp(output_lf0 +
                               phrase_curve[:len(output_lf0)].squeeze())

            # Compute RMSE, keep track of worst RMSE.
            f0_mse = (org_f0 - output_f0)**2
            current_f0_rmse = math.sqrt(
                (f0_mse * org_vuv[:len(output_lf0)]).sum() /
                org_vuv[:len(output_lf0)].sum())
            if current_f0_rmse > f0_rmse_max:
                f0_rmse_max_id = id_name
                f0_rmse_max = current_f0_rmse
            f0_rmse += current_f0_rmse

            # Compute vuv error rate.
            num_errors = (org_vuv[:len(output_lf0)] != output_vuv)
            vuv_error_rate_tmp = float(num_errors.sum()) / len(output_lf0)
            if vuv_error_rate_tmp > vuv_error_max:
                vuv_error_max_id = id_name
                vuv_error_max = vuv_error_rate_tmp
            vuv_error_rate += vuv_error_rate_tmp

        f0_rmse /= len(dict_outputs_post)
        vuv_error_rate /= len(dict_outputs_post)

        self.logger.info("Worst F0 RMSE: " + f0_rmse_max_id +
                         " {:4.2f}Hz".format(f0_rmse_max))
        self.logger.info("Worst VUV error: " + vuv_error_max_id +
                         " {:2.2f}%".format(vuv_error_max * 100))
        self.logger.info("Benchmark score: F0 RMSE " +
                         "{:4.2f}Hz".format(f0_rmse) + ", VUV " +
                         "{:2.2f}%".format(vuv_error_rate * 100))

        return f0_rmse, vuv_error_rate

    def synthesize(self, id_list, synth_output, hparams):
        """This method should be overwritten by sub classes."""
        full_output = self.run_atom_synth(id_list, synth_output, hparams)

        for id_name, labels in full_output.items():
            lf0 = labels[:, -3]
            lf0 = interpolate_lin(lf0)
            vuv = synth_output[id_name][:, 0, 1]
            len_diff = len(labels) - len(vuv)
            labels = WorldFeatLabelGen.trim_end_sample(labels,
                                                       int(len_diff / 2),
                                                       reverse=True)
            labels = WorldFeatLabelGen.trim_end_sample(
                labels, len_diff - int(len_diff / 2))
            labels[:, -2] = vuv

        # Run the vocoder.
        ModelTrainer.synthesize(self, id_list, full_output, hparams)
class AcousticDeltasModelTrainer(ModelTrainer):
    """
    Implementation of a ModelTrainer for the generation of acoustic data.

    Use question labels as input and WORLD features with deltas/double deltas as output.
    Synthesize audio from model output with MLPG smoothing.
    """
    logger = logging.getLogger(__name__)

    #########################
    # Default constructor
    #
    def __init__(self, dir_world_features, dir_question_labels, id_list, num_questions, hparams=None):
        """Default constructor.

        :param dir_world_features:      Path to the directory containing the world features.
        :param dir_question_labels:     Path to the directory containing the question labels.
        :param id_list:                 List of ids, can contain a speaker directory.
        :param num_questions:           Number of questions in question file.
        :param hparams:                 Set of hyper parameters.
        """
        if hparams is None:
            hparams = self.create_hparams()
            hparams.out_dir = os.path.curdir

        # Write missing default parameters.
        if hparams.variable_sequence_length_train is None:
            hparams.variable_sequence_length_train = hparams.batch_size_train > 1
        if hparams.variable_sequence_length_test is None:
            hparams.variable_sequence_length_test = hparams.batch_size_test > 1
        if hparams.synth_dir is None:
            hparams.synth_dir = os.path.join(hparams.out_dir, "synth")

        super(AcousticDeltasModelTrainer, self).__init__(id_list, hparams)

        self.InputGen = QuestionLabelGen(dir_question_labels, num_questions)
        self.InputGen.get_normalisation_params(dir_question_labels)

        self.OutputGen = WorldFeatLabelGen(dir_world_features, add_deltas=True, num_coded_sps=hparams.num_coded_sps)
        self.OutputGen.get_normalisation_params(dir_world_features)

        self.dataset_train = LabelGensDataset(self.id_list_train, self.InputGen, self.OutputGen, hparams)
        self.dataset_val = LabelGensDataset(self.id_list_val, self.InputGen, self.OutputGen, hparams)

        if self.loss_function is None:
            self.loss_function = torch.nn.MSELoss(reduction='none')

        if hparams.scheduler_type == "default":
            hparams.scheduler_type = "Plateau"
            hparams.plateau_verbose = True

    @staticmethod
    def create_hparams(hparams_string=None, verbose=False):
        """Create model hyperparameters. Parse nondefault from given string."""
        hparams = ModelTrainer.create_hparams(hparams_string, verbose=False)

        hparams.synth_load_org_sp = False
        hparams.synth_load_org_lf0 = False
        hparams.synth_load_org_vuv = False
        hparams.synth_load_org_bap = False

        if verbose:
            logging.info('Final parsed hparams: %s', hparams.values())

        return hparams

    # def set_train_params(self, learning_rate):
    #     """Overwrite baseclass method to change non_zero_occurrence value."""
    #
    #     # Defaults for criterion, optimiser_type and scheduler_type.
    #     # If not None it means that the model was loaded from a file.
    #     if self.model_handler.loss_function is not None:
    #         loss_function = self.model_handler.loss_function
    #     else:
    #         loss_function = torch.nn.MSELoss(reduction='none')
    #
    #     if self.model_handler.optimiser is not None:
    #         optimiser = self.model_handler.optimiser
    #         if learning_rate is not None:
    #             for g in optimiser.param_groups:
    #                 g['lr'] = learning_rate
    #     else:
    #         optimiser = torch.optim.Adam(self.model_handler.model.parameters(), lr=learning_rate)
    #
    #     if self.model_handler.scheduler is not None:
    #         scheduler = self.model_handler.scheduler
    #     else:
    #         scheduler = ReduceLROnPlateau(optimiser, verbose=True)
    #
    #     self.model_handler.set_train_params(loss_function, optimiser, scheduler, batch_loss=True)

    def gen_figure_from_output(self, id_name, label, hidden, hparams):

        labels_post = self.OutputGen.postprocess_sample(label)
        coded_sp, lf0, vuv, bap = WorldFeatLabelGen.convert_to_world_features(labels_post, contains_deltas=False, num_coded_sps=hparams.num_coded_sps)
        lf0, _ = interpolate_lin(lf0)

        # Load original lf0.
        org_labels_post = WorldFeatLabelGen.load_sample(id_name, self.OutputGen.dir_labels, add_deltas=self.OutputGen.add_deltas, num_coded_sps=hparams.num_coded_sps)
        original_mgc, original_lf0, original_vuv, *_ = WorldFeatLabelGen.convert_to_world_features(org_labels_post, contains_deltas=self.OutputGen.add_deltas, num_coded_sps=hparams.num_coded_sps)
        original_lf0, _ = interpolate_lin(original_lf0)

        # # Load merlin lf0.
        # merlin_data_dir = os.path.join(os.path.dirname(self.InputGen.label_dir), "wav_merlin/")
        # with open(os.path.join(merlin_data_dir, id_name + ".lf0"), 'rb') as f:
        #     merlin_lf0 = np.fromfile(f, dtype=np.float32)
        #     merlin_lf0 = np.reshape(merlin_lf0, merlin_lf0.shape[0] / original_lf0.shape[1], original_lf0.shape[1])
        #     merlin_lf0[merlin_lf0 < 0] = 0.0
        #     merlin_lf0, _ = interpolate_lin(merlin_lf0)

        # Get a data plotter.
        grid_idx = 0
        plotter = DataPlotter()
        net_name = os.path.basename(hparams.model_name)
        filename = str(os.path.join(hparams.out_dir, id_name + '.' + net_name))
        plotter.set_title(id_name + ' - ' + net_name)
        plotter.set_num_colors(3)
        # plotter.set_lim(grid_idx=0, ymin=math.log(60), ymax=math.log(250))
        plotter.set_label(grid_idx=grid_idx, xlabel='frames [' + str(hparams.frame_size_ms) + ' ms]', ylabel='log(f0)')

        graphs = list()
        # graphs.append((merlin_lf0, 'Merlin lf0'))
        graphs.append((original_lf0, 'Original lf0'))
        graphs.append((lf0, 'PyTorch lf0'))
        plotter.set_data_list(grid_idx=grid_idx, data_list=graphs)
        plotter.set_area_list(grid_idx=grid_idx, area_list=[(np.invert(vuv.astype(bool)), '0.8', 1.0),(np.invert(original_vuv.astype(bool)), 'red', 0.2)])

        grid_idx += 1
        plotter.set_label(grid_idx=grid_idx, xlabel='frames [' + str(hparams.frame_size_ms) + ' ms]', ylabel='Original spectrogram')
        plotter.set_specshow(grid_idx=grid_idx, spec=WorldFeatLabelGen.mgc_to_sp(original_mgc, hparams.synth_fs))

        grid_idx += 1
        plotter.set_label(grid_idx=grid_idx, xlabel='frames [' + str(hparams.frame_size_ms) + ' ms]', ylabel='NN spectrogram')
        plotter.set_specshow(grid_idx=grid_idx, spec=WorldFeatLabelGen.mgc_to_sp(coded_sp, hparams.synth_fs))

        if hasattr(hparams, "phoneme_indices") and hparams.phoneme_indices is not None \
           and hasattr(hparams, "question_file") and hparams.question_file is not None:
            questions = QuestionLabelGen.load_sample(id_name,
                                                     "experiments/" + hparams.voice + "/questions/",
                                                     num_questions=hparams.num_questions)[:len(lf0)]
            np_phonemes = QuestionLabelGen.questions_to_phonemes(questions, hparams.phoneme_indices, hparams.question_file)
            plotter.set_annotations(grid_idx, np_phonemes)

        plotter.gen_plot()
        plotter.save_to_file(filename + '.Org-PyTorch' + hparams.gen_figure_ext)

    def compute_score(self, dict_outputs_post, dict_hiddens, hparams):

        # Get data for comparision.
        dict_original_post = dict()
        for id_name in dict_outputs_post.keys():
            dict_original_post[id_name] = WorldFeatLabelGen.load_sample(id_name, self.OutputGen.dir_labels, True, num_coded_sps=hparams.num_coded_sps)

        f0_rmse = 0.0
        f0_rmse_max_id = "None"
        f0_rmse_max = 0.0
        all_rmse = []
        vuv_error_rate = 0.0
        vuv_error_max_id = "None"
        vuv_error_max = 0.0
        all_vuv = []
        mcd = 0.0
        mcd_max_id = "None"
        mcd_max = 0.0
        all_mcd = []
        bap_error = 0.0
        bap_error_max_id = "None"
        bap_error_max = 0.0
        all_bap_error = []

        for id_name, labels in dict_outputs_post.items():
            output_coded_sp, output_lf0, output_vuv, output_bap = self.OutputGen.convert_to_world_features(labels, False, num_coded_sps=hparams.num_coded_sps)
            output_vuv = output_vuv.astype(bool)

            # Get data for comparision.
            org_coded_sp, org_lf0, org_vuv, org_bap = self.OutputGen.convert_to_world_features(dict_original_post[id_name],
                                                                                               self.OutputGen.add_deltas,
                                                                                               num_coded_sps=hparams.num_coded_sps)

            # Compute f0 from lf0.
            org_f0 = np.exp(org_lf0.squeeze())[:len(output_lf0)]  # Fix minor negligible length mismatch.
            output_f0 = np.exp(output_lf0)

            # Compute MCD.
            org_coded_sp = org_coded_sp[:len(output_coded_sp)]
            current_mcd = metrics.melcd(output_coded_sp, org_coded_sp)  # TODO: Use aligned mcd.
            if current_mcd > mcd_max:
                mcd_max_id = id_name
                mcd_max = current_mcd
            mcd += current_mcd
            all_mcd.append(current_mcd)

            # Compute RMSE.
            f0_mse = (org_f0 - output_f0) ** 2
            current_f0_rmse = math.sqrt((f0_mse * org_vuv[:len(output_lf0)]).sum() / org_vuv[:len(output_lf0)].sum())
            if current_f0_rmse != current_f0_rmse:
                logging.error("Computed NaN for F0 RMSE for {}.".format(id_name))
            else:
                if current_f0_rmse > f0_rmse_max:
                    f0_rmse_max_id = id_name
                    f0_rmse_max = current_f0_rmse
                f0_rmse += current_f0_rmse
                all_rmse.append(current_f0_rmse)

            # Compute error of VUV in percentage.
            num_errors = (org_vuv[:len(output_lf0)] != output_vuv)
            vuv_error_rate_tmp = float(num_errors.sum()) / len(output_lf0)
            if vuv_error_rate_tmp > vuv_error_max:
                vuv_error_max_id = id_name
                vuv_error_max = vuv_error_rate_tmp
            vuv_error_rate += vuv_error_rate_tmp
            all_vuv.append(vuv_error_rate_tmp)

            # Compute aperiodicity distortion.
            org_bap = org_bap[:len(output_bap)]
            if len(output_bap.shape) > 1 and output_bap.shape[1] > 1:
                current_bap_error = metrics.melcd(output_bap, org_bap)  # TODO: Use aligned mcd?
            else:
                current_bap_error = math.sqrt(((org_bap - output_bap) ** 2).mean()) * (10.0 / np.log(10) * np.sqrt(2.0))
            if current_bap_error > bap_error_max:
                bap_error_max_id = id_name
                bap_error_max = current_bap_error
            bap_error += current_bap_error
            all_bap_error.append(current_bap_error)

        f0_rmse /= len(dict_outputs_post)
        vuv_error_rate /= len(dict_outputs_post)
        mcd /= len(dict_original_post)
        bap_error /= len(dict_original_post)

        self.logger.info("Worst MCD: {} {:4.2f}dB".format(mcd_max_id, mcd_max))
        self.logger.info("Worst F0 RMSE: {} {:4.2f}Hz".format(f0_rmse_max_id, f0_rmse_max))
        self.logger.info("Worst VUV error: {} {:2.2f}%".format(vuv_error_max_id, vuv_error_max * 100))
        self.logger.info("Worst BAP error: {} {:4.2f}db".format(bap_error_max_id, bap_error_max))
        self.logger.info("Benchmark score: MCD {:4.2f}dB, F0 RMSE {:4.2f}Hz, VUV {:2.2f}%, BAP error {:4.2f}db".format(mcd, f0_rmse, vuv_error_rate * 100, bap_error))

        return mcd, f0_rmse, vuv_error_rate, bap_error

    def synthesize(self, id_list, synth_output, hparams):
        """Depending on hparams override the network output with the extracted features, then continue with normal synthesis pipeline."""

        if hparams.synth_load_org_sp or hparams.synth_load_org_lf0 or hparams.synth_load_org_vuv or hparams.synth_load_org_bap:
            for id_name in id_list:

                labels = WorldFeatLabelGen.load_sample(id_name, os.path.join(self.OutputGen.dir_labels, self.dir_extracted_acoustic_features), num_coded_sps=hparams.num_coded_sps)
                len_diff = len(labels) - len(synth_output[id_name])
                if len_diff > 0:
                    labels = WorldFeatLabelGen.trim_end_sample(labels, int(len_diff / 2), reverse=True)
                    labels = WorldFeatLabelGen.trim_end_sample(labels, len_diff - int(len_diff / 2))

                if hparams.synth_load_org_sp:
                    synth_output[id_name][:len(labels), :self.OutputGen.num_coded_sps] = labels[:, :self.OutputGen.num_coded_sps]

                if hparams.synth_load_org_lf0:
                    synth_output[id_name][:len(labels), -3] = labels[:, -3]

                if hparams.synth_load_org_vuv:
                    synth_output[id_name][:len(labels), -2] = labels[:, -2]

                if hparams.synth_load_org_bap:
                    synth_output[id_name][:len(labels), -1] = labels[:, -1]

        # Run the vocoder.
        ModelTrainer.synthesize(self, id_list, synth_output, hparams)
Пример #9
0
class AtomModelTrainer(ModelTrainer):
    """
    Implementation of a ModelTrainer for the generation of acoustic data through atom prediction.
    Output labels for atoms have dimension: T x |thetas| x 2 (amp, theta).

    Use question labels as input and extracted wcad atoms as output. Synthesize audio from model
    output by generating F0 from atoms. MGC and BAP is either generated by a pre-trained acoustic
    model or loaded from the original extracted files.
    """
    logger = logging.getLogger(__name__)

    def __init__(self, wcad_root, dir_atom_labels, dir_question_labels, id_list, thetas, k,
                 num_questions, hparams=None):
        """Default constructor.

        :param wcad_root:               Path to main directory of wcad.
        :param dir_atom_labels:         Path to directory that contains the .atom files.
        :param dir_question_labels:     Path to directory that contains the .questions files.
        :param id_list:                 List containing all ids. Subset is taken as test set.
        :param thetas:                  List of theta values.
        :param k:                       K value of atoms.
        :param num_questions:           Expected number of questions in question labels.
        :param hparams:                 Hyper-parameter container.
        """
        if hparams is None:
            hparams = self.create_hparams()
            hparams.out_dir = os.path.curdir

        # Write missing default parameters.
        if hparams.variable_sequence_length_train is None:
            hparams.variable_sequence_length_train = hparams.batch_size_train > 1
        if hparams.variable_sequence_length_test is None:
            hparams.variable_sequence_length_test = hparams.batch_size_test > 1
        if hparams.synth_dir is None:
            hparams.synth_dir = os.path.join(hparams.out_dir, "synth")

        # If the weight for unvoiced frames is not given, compute it to get equal weights.
        if not hasattr(hparams, "weight_zero") or hparams.weight_zero is None:
            non_zero_occurrence = min(0.99, 0.02 / len(thetas))
            zero_occurrence = 1 - non_zero_occurrence
            hparams.weight_non_zero = 1 / non_zero_occurrence
            hparams.weight_zero = 1 / zero_occurrence

        super(AtomModelTrainer, self).__init__(id_list, hparams)

        self.InputGen = QuestionLabelGen(dir_question_labels, num_questions)
        self.InputGen.get_normalisation_params(dir_question_labels)

        self.OutputGen = AtomLabelGen(wcad_root, dir_atom_labels, thetas, k, hparams.frame_size_ms)
        self.OutputGen.get_normalisation_params(dir_atom_labels)

        self.dataset_train = PyTorchLabelGensDataset(self.id_list_train, self.InputGen, self.OutputGen, hparams)
        self.dataset_val = PyTorchLabelGensDataset(self.id_list_val, self.InputGen, self.OutputGen, hparams)

        if self.loss_function is None:
            self.loss_function = WeightedNonzeroMSELoss(hparams.use_gpu, hparams.weight_zero, hparams.weight_non_zero,
                                                        size_average=False, reduce=False)
        if hparams.scheduler_type == "default":
            hparams.scheduler_type = "Plateau"
            hparams.plateau_patience = 10
            hparams.plateau_factor = 0.5
            hparams.plateau_verbos = True

    def gen_figure_from_output(self, id_name, labels, hidden, hparams):

        if labels.ndim < 2:
            labels = np.expand_dims(labels, axis=1)
        labels_post = self.OutputGen.postprocess_sample(labels)
        lf0 = self.OutputGen.labels_to_lf0(labels_post, hparams.k)
        lf0, vuv = interpolate_lin(lf0)
        vuv = vuv.astype(np.bool)

        # Load original lf0 and vuv.
        org_labels = WorldFeatLabelGen.load_sample(id_name, os.path.join(hparams.out_dir, self.dir_extracted_acoustic_features), num_coded_sps=hparams.num_coded_sps)
        _, original_lf0, original_vuv, _ = WorldFeatLabelGen.convert_to_world_features(org_labels, num_coded_sps=hparams.num_coded_sps)
        original_lf0, _ = interpolate_lin(original_lf0)
        original_vuv = original_vuv.astype(np.bool)

        phrase_curve = np.fromfile(os.path.join(self.OutputGen.dir_labels, id_name + self.OutputGen.ext_phrase),
                                   dtype=np.float32).reshape(-1, 1)
        original_lf0 -= phrase_curve
        len_diff = len(original_lf0) - len(lf0)
        original_lf0 = WorldFeatLabelGen.trim_end_sample(original_lf0, int(len_diff / 2.0))
        original_lf0 = WorldFeatLabelGen.trim_end_sample(original_lf0, int(len_diff / 2.0) + 1, reverse=True)

        org_labels = self.OutputGen.load_sample(id_name, self.OutputGen.dir_labels, len(hparams.thetas))
        org_labels = self.OutputGen.trim_end_sample(org_labels, int(len_diff / 2.0))
        org_labels = self.OutputGen.trim_end_sample(org_labels, int(len_diff / 2.0) + 1, reverse=True)
        org_atoms = self.OutputGen.labels_to_atoms(org_labels, k=hparams.k, frame_size=hparams.frame_size_ms)

        # Get a data plotter.
        net_name = os.path.basename(hparams.model_name)
        filename = str(os.path.join(hparams.out_dir, id_name + '.' + net_name))
        plotter = DataPlotter()
        plotter.set_title(id_name + " - " + net_name)

        graphs_output = list()
        grid_idx = 0
        for idx in reversed(range(labels.shape[1])):
            graphs_output.append((labels[:, idx], r'$\theta$=' + "{0:.3f}".format(hparams.thetas[idx])))
        plotter.set_label(grid_idx=grid_idx, xlabel='frames [' + str(hparams.frame_size_ms) + ' ms]', ylabel='NN output')
        plotter.set_data_list(grid_idx=grid_idx, data_list=graphs_output)
        # plotter.set_lim(grid_idx=0, ymin=-1.8, ymax=1.8)

        grid_idx += 1
        graphs_peaks = list()
        for idx in reversed(range(labels_post.shape[1])):
            graphs_peaks.append((labels_post[:, idx, 0],))
        plotter.set_label(grid_idx=grid_idx, xlabel='frames [' + str(hparams.frame_size_ms) + ' ms]', ylabel='NN post-processed')
        plotter.set_data_list(grid_idx=grid_idx, data_list=graphs_peaks)
        plotter.set_area_list(grid_idx=grid_idx, area_list=[(np.invert(vuv), '0.8', 1.0)])
        plotter.set_lim(grid_idx=grid_idx, ymin=-1.8, ymax=1.8)

        grid_idx += 1
        graphs_target = list()
        for idx in reversed(range(org_labels.shape[1])):
            graphs_target.append((org_labels[:, idx, 0],))
        plotter.set_label(grid_idx=grid_idx, xlabel='frames [' + str(hparams.frame_size_ms) + ' ms]', ylabel='target')
        plotter.set_data_list(grid_idx=grid_idx, data_list=graphs_target)
        plotter.set_area_list(grid_idx=grid_idx, area_list=[(np.invert(original_vuv), '0.8', 1.0)])
        plotter.set_lim(grid_idx=grid_idx, ymin=-1.8, ymax=1.8)

        grid_idx += 1
        output_atoms = AtomLabelGen.labels_to_atoms(labels_post, hparams.k, hparams.frame_size_ms, amp_threshold=hparams.min_atom_amp)
        wcad_lf0 = AtomLabelGen.atoms_to_lf0(org_atoms, len(labels))
        output_lf0 = AtomLabelGen.atoms_to_lf0(output_atoms, len(labels))
        graphs_lf0 = list()
        graphs_lf0.append((wcad_lf0, "wcad lf0"))
        graphs_lf0.append((original_lf0, "org lf0"))
        graphs_lf0.append((output_lf0, "predicted lf0"))
        plotter.set_data_list(grid_idx=grid_idx, data_list=graphs_lf0)
        plotter.set_area_list(grid_idx=grid_idx, area_list=[(np.invert(original_vuv), '0.8', 1.0)])
        plotter.set_label(grid_idx=grid_idx, xlabel='frames [' + str(hparams.frame_size_ms) + ' ms]', ylabel='lf0')
        amp_lim = max(np.max(np.abs(wcad_lf0)), np.max(np.abs(output_lf0))) * 1.1
        plotter.set_lim(grid_idx=grid_idx, ymin=-amp_lim, ymax=amp_lim)
        plotter.set_linestyles(grid_idx=grid_idx, linestyles=[':', '--', '-'])

        # plotter.set_lim(xmin=300, xmax=1100)
        plotter.gen_plot()
        plotter.save_to_file(filename + ".BASE.png")

    def get_recon_from_synth_output(self, synth_output, hparams):
        """Reconstruct LF0 from atoms."""

        # Transform output to GammaAtoms.
        recon_dict = dict()
        for id_name, label in synth_output.items():
            if len(label.shape) == 2:
                label = np.expand_dims(label, axis=1)

            atoms = self.OutputGen.__class__.labels_to_atoms(label, k=hparams.k, frame_size=hparams.frame_size_ms, amp_threshold=hparams.min_atom_amp)
            reconstruction = self.OutputGen.__class__.atoms_to_lf0(atoms, num_frames=len(label))

            # Add extracted phrase.
            phrase_curve = np.fromfile(os.path.join(self.OutputGen.dir_labels, id_name + self.OutputGen.ext_phrase),
                                       dtype=np.float32)[:len(reconstruction)]
            reconstruction[:len(phrase_curve)] += phrase_curve
            reconstruction[reconstruction <= math.log(WorldFeatLabelGen.f0_silence_threshold)] = WorldFeatLabelGen.lf0_zero

            recon_dict[id_name] = reconstruction

        return recon_dict

    def get_phrase_curve(self, id_name):
        return np.fromfile(os.path.join(self.OutputGen.dir_labels, id_name + self.OutputGen.ext_phrase),
                           dtype=np.float32).reshape(-1, 1)

    def compute_score(self, dict_outputs_post, dict_hiddens, hparams):

        # Get data for comparision.
        dict_original_post = self.load_extracted_audio_features(dict_outputs_post, hparams)

        f0_rmse = 0.0
        f0_rmse_max_id = "None"
        f0_rmse_max = 0.0
        for id_name, labels in dict_outputs_post.items():
            output_lf0 = AtomLabelGen.labels_to_lf0(labels,
                                                    k=hparams.k,
                                                    frame_size=hparams.frame_size_ms,
                                                    amp_threshold=hparams.max_atom_amp)

            # Get data for comparision.
            org_lf0 = dict_original_post[id_name][:, hparams.num_coded_sps]
            org_vuv = dict_original_post[id_name][:, hparams.num_coded_sps + 1]
            phrase_curve = self.get_phrase_curve(id_name)

            # Compute f0 from lf0.
            org_f0 = (np.exp(org_lf0.squeeze()) * org_vuv)[:len(output_lf0)]  # Fix minor negligible length mismatch.
            output_f0 = np.exp(output_lf0 + phrase_curve[:len(output_lf0)].squeeze()) * org_vuv[:len(output_lf0)]

            # Compute RMSE, keep track of worst RMSE.
            f0_mse = (org_f0 - output_f0) ** 2
            current_f0_rmse = math.sqrt(f0_mse.sum() / org_vuv.sum())
            if current_f0_rmse > f0_rmse_max:
                f0_rmse_max_id = id_name
                f0_rmse_max = current_f0_rmse
            f0_rmse += current_f0_rmse

        f0_rmse /= len(dict_outputs_post)
        self.logger.info("Worst F0 RMSE: " + f0_rmse_max_id + " {:4.2f}Hz".format(f0_rmse_max))
        self.logger.info("Benchmark score: F0 RMSE " + "{:4.2f}Hz".format(f0_rmse))

        return f0_rmse

    def load_extracted_audio_features(self, synth_output, hparams):
        """Load the audio features extracted from audio."""
        self.logger.info("Load extracted mgc, lf0, vuv, bap data.")

        org_output = dict()
        for id_name in synth_output.keys():
            path = os.path.realpath(os.path.join(hparams.out_dir, self.dir_extracted_acoustic_features))
            org_output[id_name] = WorldFeatLabelGen.load_sample(id_name, path, add_deltas=False, num_coded_sps=hparams.num_coded_sps)  # Load extracted data.

        return org_output

    def generate_audio_features(self, id_list, hparams):  # TODO: Test
        """
        Generate mgc, vuv and bap data with an acoustic model.
        The name of the acoustic model is saved in hparams.synth_acoustic_model_path and given in the constructor.
        If the synth_acoustic_model_path is 'None' this method will not be called but the method
        load_extracted_audio_features, which reloads the original data extracted from the audio.

        If you want to generate audio directly from wcad atom extraction, uncomment the first block
        in the get_recon_from_synth_output method.

        Detailed execution process:
        This method reuses the synth method of the ModelTrainer base class. It overwrites the internal
        f_synthesize method and the OutputGen to accomplish the audio generation. Both are restored
        after finishing the generation. The base class synth method loads the acoustic model network
        by its name and forwards the question labels for each utterance in the id_list. At the
        end the method calls the f_synthesize method. Therefore the f_synthesize method is overwritten
        by the save_audio_features which saves the generate output mgc, vuv and bap files in the
        self.synth_dir folder.
        """
        self.logger.info("Generate mgc, vuv and bap with " + hparams.synth_acoustic_model_path)

        acoustic_model_hparams = AcousticModelTrainer.create_hparams()
        acoustic_model_hparams.model_dir = os.path.dirname(hparams.synth_acoustic_model_path)
        acoustic_model_hparams.model_name = os.path.basename(hparams.synth_acoustic_model_path)
        acoustic_model_handler = AcousticModelTrainer(acoustic_model_hparams)

        org_model_handler = self.model_handler
        self.model_handler = acoustic_model_handler

        # Switch f_synthesize method and OutputGen for mgc, vuv and bap creation.
        # f_synthesize is called at the end of synth.
        self.f_synthesize = self.save_audio_features
        org_output_gen = self.OutputGen
        self.OutputGen = self.AudioGen

        # Explicitly synthesize with acoustic_model_name.
        # This method calls f_synthesize at the end which will save the mgc, vuv and bap.
        self.synth(hparams, id_list)

        # Switch back to atom creation.
        self.f_synthesize = self.synthesize
        self.OutputGen = org_output_gen
        self.model_handler = org_model_handler

    def synthesize(self, id_list, synth_output, hparams):
        """This method should be overwritten by sub classes."""
        # Create lf0 from atoms of output and get other acoustic features either by loading the original labels or by
        # generating them with the model at hparams.synth_acoustic_model_path.
        full_output = self.run_atom_synth(id_list, synth_output, hparams)
        # Run the WORLD synthesizer.
        self.run_world_synth(full_output, hparams)

    def synth_ref_wcad(self, file_id_list, hparams):
        synth_output = dict()
        # Load extracted atoms.
        for id_name in file_id_list:
            synth_output[id_name] = AtomLabelGen.load_sample(id_name, self.OutputGen.dir_labels,
                                                             len(hparams.thetas))

        full_output = self.run_atom_synth(file_id_list, synth_output, hparams)

        # Add identifier to suffix.
        old_synth_file_suffix = hparams.synth_file_suffix
        hparams.synth_file_suffix += "_wcad_ref"

        # Run the WORLD synthesizer.
        self.run_world_synth(full_output, hparams)

        # Restore identifier.
        hparams.synth_file_suffix = old_synth_file_suffix

    def synth_phrase(self, file_id_list, hparams):
        # Create reference audio files containing only the vocoder degradation.
        self.logger.info("Synthesise phrase curve for [{0}].".format(", ".join([id_name for id_name in file_id_list])))

        # Create an empty dictionary which can be filled with extracted audio features.
        synth_output = dict()
        for id_name in file_id_list:
            synth_output[id_name] = None
        # Fill dictionary with extracted audio features.
        full_output = self.load_extracted_audio_features(synth_output, hparams)

        # Override the lf0 component by the phrase curve.
        for id_name in file_id_list:
            labels = full_output[id_name]
            phrase_curve = np.fromfile(os.path.join(self.OutputGen.dir_labels, id_name + self.OutputGen.ext_phrase),
                                       dtype=np.float32)[:len(full_output[id_name])]
            labels[:, -3] = phrase_curve[:len(labels)]

        # Add identifier to suffix.
        old_synth_file_suffix = hparams.synth_file_suffix
        hparams.synth_file_suffix += '_phrase'

        # Run the vocoder.
        ModelTrainer.synthesize(self, file_id_list, full_output, hparams)

        # Restore identifier.
        hparams.synth_file_suffix = old_synth_file_suffix

    def run_atom_synth(self, file_id_list, synth_output, hparams):
        """
        Reconstruct lf0, get mgc and bap data, and store all in files in self.synth_dir.
        """

        # Get mgc, vuv and bap data either through a trained acoustic model or from data extracted from the audio.
        if hparams.synth_acoustic_model_path is None:
            full_output = self.load_extracted_audio_features(synth_output, hparams)
        else:
            full_output = self.generate_audio_features(file_id_list, hparams)

        # Reconstruct lf0 from generated atoms and write it to synth output.
        recon_dict = self.get_recon_from_synth_output(synth_output, hparams)
        for id_name, lf0 in recon_dict.items():
            full_sample = full_output[id_name]
            len_diff = len(full_sample) - len(lf0)
            full_sample = WorldFeatLabelGen.trim_end_sample(full_sample, int(len_diff / 2), reverse=True)
            full_sample = WorldFeatLabelGen.trim_end_sample(full_sample, len_diff - int(len_diff / 2))
            vuv = np.ones(lf0.shape)
            vuv[lf0 <= math.log(WorldFeatLabelGen.f0_silence_threshold)] = 0.0
            full_sample[:, hparams.num_coded_sps] = lf0
            full_sample[:, hparams.num_coded_sps + 1] = vuv

        return full_output
Пример #10
0
    def __init__(self,
                 wcad_root,
                 dir_audio,
                 dir_atom_labels,
                 dir_lf0_labels,
                 dir_question_labels,
                 id_list,
                 thetas,
                 k,
                 num_questions,
                 dist_window_size=51,
                 hparams=None):
        """Default constructor.

        :param wcad_root:               Path to main directory of wcad.
        :param dir_audio:               Path to directory that contains the .wav files.
        :param dir_atom_labels:         Path to directory that contains the .atoms files.
        :param dir_lf0_labels:          Path to directory that contains the .lf0 files.
        :param dir_question_labels:     Path to directory that contains the .lab files.
        :param id_list:                 List containing all ids. Subset is taken as test set.
        :param thetas:                  List of theta values of the used atoms.
        :param k:                       K-value of atoms.
        :param num_questions:           Expected number of questions in question labels.
        :param dist_window_size:        Size of distribution around atom amplitudes when training the atom model.
        :param hparams:                 Hyper-parameter container.
        """

        if hparams is None:
            hparams = self.create_hparams()
            hparams.out_dir = os.path.curdir

        hparams_atom = hparams.hparams_atom

        if hparams_atom is None:
            hparams_atom = copy.deepcopy(hparams)
            hparams_atom.synth_gen_figure = False
            hparams_atom.synth_acoustic_model_path = None
            hparams.atom_model_path = os.path.join(
                hparams.out_dir, hparams.networks_dir,
                hparams.model_name + "_atoms")

        # Write missing default parameters.
        if hparams.variable_sequence_length_train is None:
            hparams.variable_sequence_length_train = hparams.batch_size_train > 1
        if hparams.variable_sequence_length_test is None:
            hparams.variable_sequence_length_test = hparams.batch_size_test > 1
        if hparams.synth_dir is None:
            hparams.synth_dir = os.path.join(hparams.out_dir, "synth")

        super().__init__(id_list, hparams)

        self.InputGen = QuestionLabelGen(dir_question_labels, num_questions)
        self.InputGen.get_normalisation_params(dir_question_labels)

        self.OutputGen = FlatLF0LabelGen(dir_lf0_labels, dir_atom_labels)
        self.OutputGen.get_normalisation_params(dir_atom_labels)

        self.dataset_train = PyTorchLabelGensDataset(self.id_list_train,
                                                     self.InputGen,
                                                     self.OutputGen, hparams)
        self.dataset_val = PyTorchLabelGensDataset(self.id_list_val,
                                                   self.InputGen,
                                                   self.OutputGen, hparams)

        self.atom_trainer = AtomVUVDistPosModelTrainer(
            wcad_root, dir_atom_labels, dir_lf0_labels, dir_question_labels,
            id_list, thetas, k, num_questions, dist_window_size, hparams_atom)

        if self.loss_function is None:
            self.loss_function = L1WeightedVUVMSELoss(
                weight=hparams.vuv_weight,
                vuv_loss_weight=hparams.vuv_loss_weight,
                L1_weight=hparams.L1_loss_weight,
                reduce=False)
        if hparams.scheduler_type == "default":
            hparams.scheduler_type = "None"

        # Override the collate and decollate methods of batches.
        self.batch_collate_fn = self.prepare_batch
        self.batch_decollate_fn = self.decollate_network_output
Пример #11
0
class AtomNeuralFilterModelTrainer(ModelTrainer):
    """
    Implementation of a ModelTrainer for the generation of intonation curves with an end-to-end system.
    The first part of the architecture runs atom position prediction, and the output layer contains neural filters.
    Output curves have dimension: T x 2 (amp, theta).

    Use question labels as input and extracted lf0 as output.
    """
    logger = logging.getLogger(__name__)
    dir_extracted_acoustic_features = "../WORLD/"  # TODO: Move to hparams.

    def __init__(self,
                 wcad_root,
                 dir_audio,
                 dir_atom_labels,
                 dir_lf0_labels,
                 dir_question_labels,
                 id_list,
                 thetas,
                 k,
                 num_questions,
                 dist_window_size=51,
                 hparams=None):
        """Default constructor.

        :param wcad_root:               Path to main directory of wcad.
        :param dir_audio:               Path to directory that contains the .wav files.
        :param dir_atom_labels:         Path to directory that contains the .atoms files.
        :param dir_lf0_labels:          Path to directory that contains the .lf0 files.
        :param dir_question_labels:     Path to directory that contains the .lab files.
        :param id_list:                 List containing all ids. Subset is taken as test set.
        :param thetas:                  List of theta values of the used atoms.
        :param k:                       K-value of atoms.
        :param num_questions:           Expected number of questions in question labels.
        :param dist_window_size:        Size of distribution around atom amplitudes when training the atom model.
        :param hparams:                 Hyper-parameter container.
        """

        if hparams is None:
            hparams = self.create_hparams()
            hparams.out_dir = os.path.curdir

        hparams_atom = hparams.hparams_atom

        if hparams_atom is None:
            hparams_atom = copy.deepcopy(hparams)
            hparams_atom.synth_gen_figure = False
            hparams_atom.synth_acoustic_model_path = None
            hparams.atom_model_path = os.path.join(
                hparams.out_dir, hparams.networks_dir,
                hparams.model_name + "_atoms")

        # Write missing default parameters.
        if hparams.variable_sequence_length_train is None:
            hparams.variable_sequence_length_train = hparams.batch_size_train > 1
        if hparams.variable_sequence_length_test is None:
            hparams.variable_sequence_length_test = hparams.batch_size_test > 1
        if hparams.synth_dir is None:
            hparams.synth_dir = os.path.join(hparams.out_dir, "synth")

        super().__init__(id_list, hparams)

        self.InputGen = QuestionLabelGen(dir_question_labels, num_questions)
        self.InputGen.get_normalisation_params(dir_question_labels)

        self.OutputGen = FlatLF0LabelGen(dir_lf0_labels, dir_atom_labels)
        self.OutputGen.get_normalisation_params(dir_atom_labels)

        self.dataset_train = PyTorchLabelGensDataset(self.id_list_train,
                                                     self.InputGen,
                                                     self.OutputGen, hparams)
        self.dataset_val = PyTorchLabelGensDataset(self.id_list_val,
                                                   self.InputGen,
                                                   self.OutputGen, hparams)

        self.atom_trainer = AtomVUVDistPosModelTrainer(
            wcad_root, dir_atom_labels, dir_lf0_labels, dir_question_labels,
            id_list, thetas, k, num_questions, dist_window_size, hparams_atom)

        if self.loss_function is None:
            self.loss_function = L1WeightedVUVMSELoss(
                weight=hparams.vuv_weight,
                vuv_loss_weight=hparams.vuv_loss_weight,
                L1_weight=hparams.L1_loss_weight,
                reduce=False)
        if hparams.scheduler_type == "default":
            hparams.scheduler_type = "None"

        # Override the collate and decollate methods of batches.
        self.batch_collate_fn = self.prepare_batch
        self.batch_decollate_fn = self.decollate_network_output

    @staticmethod
    def create_hparams(hparams_string=None, verbose=False):
        hparams = ModelTrainer.create_hparams(hparams_string, verbose=False)

        hparams.synth_gen_figure = False
        hparams.complex_poles = True
        hparams.phase_init = 0.0

        hparams.vuv_loss_weight = 1
        hparams.L1_loss_weight = 1
        hparams.vuv_weight = 0.5

        if verbose:
            logging.info('Final parsed hparams: %s', hparams.values())

        return hparams

    @staticmethod
    def prepare_batch(batch, common_divisor=1, batch_first=False):
        inputs, targets, seq_lengths_input, seq_lengths_output, mask, permutation = ModelHandler.prepare_batch(
            batch, common_divisor=common_divisor, batch_first=batch_first)

        if mask is None:
            mask = torch.ones((seq_lengths_output[0], 1, 1))
        mask = mask.expand(*mask.shape[:2], 2)
        # mask: T x B x 2 (lf0, vuv), add L1 error dimension.
        mask = torch.cat((mask, mask[..., -1:]), dim=-1).contiguous()

        # TODO: This is a dirty hack, it works but only for VUV weight of 0 (it completes the loss function WMSELoss).
        mask[..., 0] = mask[..., 0] * seq_lengths_output.float()
        ################################################

        return inputs, targets, seq_lengths_input, seq_lengths_output, mask, permutation

    @staticmethod
    def decollate_network_output(output,
                                 _,
                                 seq_lengths=None,
                                 permutation=None,
                                 batch_first=True):
        """Split output into LF0, V/UV and command signals. Return command signals as hidden state."""

        # Split pre-net output (command signals).
        intern_amps, _ = ModelTrainer.split_batch(output[:, :, 2:], None,
                                                  seq_lengths, permutation,
                                                  batch_first)
        # Split final LF0, V/UV.
        output, _ = ModelTrainer.split_batch(output[:, :, :2], None,
                                             seq_lengths, permutation,
                                             batch_first)

        return output, intern_amps

    def init_atom(self, hparams_atom):
        """
        Initialize the atom model.
        If the model_type_filters is None, the old model will be loaded, which already contains the atom model.

        :param hparams_atom:    Hyper-parameter container of atom trainier.
        :return:                Nothing
        """
        if hparams_atom.model_type is None:
            if hparams_atom.epochs != 0:
                logging.warning(
                    "When hparams_atom.model_type=None the old model is loaded."
                    "This means that training the atom model by hparams_atom.epochs="
                    + str(hparams_atom.epochs) + " has no effect.")
                hparams_atom.epochs = 0

        self.logger.info("Create atom model.")
        self.atom_trainer.init(hparams_atom)

    def init(self, hparams):
        self.logger.info("Create E2E model.")
        super().init(hparams)

    def train_atom(self, hparams_atom):
        output = self.atom_trainer.train(hparams_atom)
        if hparams_atom.epochs > 0:
            self.atom_trainer.benchmark(hparams_atom)
        return output

    def filters_forward(self,
                        in_tensor,
                        hparams,
                        batch_seq_lengths=None,
                        max_seq_length=None):
        """Get output of each filter without their superposition."""
        self.model_handler.model.eval()

        # If input is numpy array convert it to torch tensor.
        if isinstance(in_tensor, np.ndarray):
            in_tensor = torch.from_numpy(in_tensor)

        if hparams.use_gpu:
            in_tensor = in_tensor.cuda()

        if batch_seq_lengths is None:
            batch_seq_lengths = (len(in_tensor), )

        if max_seq_length is None:
            max_seq_length = max(batch_seq_lengths)

        hidden = self.model_handler.model.init_hidden(len(batch_seq_lengths))
        output = self.model_handler.model.filters_forward(
            in_tensor, hidden, batch_seq_lengths, max_seq_length)

        return output.detach().cpu().numpy()

    # FIXME
    # def gen_animation(self, id_name, labels=None):
    #
    #     if labels is None:
    #         input_labels = self.InputGen.__getitem__(id_name)[:, None, :]
    #         labels, _ = self.model_handler.forward(input_labels)
    #
    #     # Retrieve data from label.
    #     labels_post = self.OutputGen.postprocess_sample(labels)
    #     output_vuv = labels_post[:, 1]
    #     output_vuv[output_vuv < 0.5] = 0.0
    #     output_vuv[output_vuv >= 0.5] = 1.0
    #
    #     output_lf0 = labels_post[:, 0]
    #
    #     # Load original lf0 and vuv.
    #     org_labels = self.OutputGen.load_sample(id_name, self.OutputGen.dir_labels)
    #     original_lf0, _ = self.OutputGen.convert_to_world_features(org_labels)
    #     # original_lf0, _ = interpolate_lin(original_lf0)
    #
    #     phrase_curve = self.OutputGen.get_phrase_curve(id_name)
    #     original_lf0 -= phrase_curve[:len(original_lf0)]
    #     original_lf0 = original_lf0[:len(output_lf0)]
    #
    #     org_labels = self.atom_trainer.OutputGen.load_sample(id_name,
    #                                                          self.atom_trainer.OutputGen.dir_labels,
    #                                                          len(self.atom_trainer.OutputGen.theta_interval),
    #                                                          self.atom_trainer.OutputGen.dir_world_labels)
    #
    #     org_labels = org_labels[:, 1:]
    #     len_diff = len(org_labels) - len(labels_post)
    #     org_labels = self.atom_trainer.OutputGen.trim_end_sample(org_labels, int(len_diff / 2.0))
    #     org_labels = self.atom_trainer.OutputGen.trim_end_sample(org_labels, int(len_diff / 2.0) + 1)
    #     org_atoms = AtomLabelGen.labels_to_atoms(org_labels, k=self.atom_trainer.OutputGen.k, frame_size=self.atom_trainer.OutputGen.frame_size)
    #     wcad_lf0 = self.atom_trainer.OutputGen.atoms_to_lf0(org_atoms, len(org_labels))
    #
    #     phrase_curve = self.OutputGen.get_phrase_curve(id_name)[:len(wcad_lf0)]
    #     original_lf0 = original_lf0[:len(wcad_lf0)] + phrase_curve.squeeze()
    #
    #     for index in range(len(org_atoms)+1):
    #         plotter = DataPlotter()
    #         plot_id = 0
    #         wcad_lf0 = self.atom_trainer.OutputGen.atoms_to_lf0(org_atoms[:index], len(org_labels))
    #         reconstruction = phrase_curve + wcad_lf0
    #
    #         graphs_lf0 = list()
    #         graphs_lf0.append((original_lf0, "Original"))
    #         graphs_lf0.append((reconstruction, "Reconstruction"))
    #         plotter.set_data_list(grid_idx=plot_id, data_list=graphs_lf0)
    #         plotter.set_label(grid_idx=plot_id, xlabel='frames [' + str(self.atom_trainer.OutputGen.frame_size) + ' ms]',
    #                           ylabel='lf0')
    #         plotter.set_lim(grid_idx=plot_id, ymin=4)
    #         plotter.set_linestyles(grid_idx=plot_id, linestyles=['-.', '-','-'])
    #         plotter.set_colors(grid_idx=plot_id, colors=['C3', 'C2'], alpha=1)
    #         plot_id += 1
    #
    #         graphs_atoms = list()
    #         # graphs_atoms.append((phrase_curve[:len(original_lf0)], ))
    #         plotter.set_data_list(grid_idx=plot_id, data_list=graphs_atoms)
    #         plotter.set_atom_list(grid_idx=plot_id, atom_list=org_atoms[:index])
    #         plotter.set_label(grid_idx=plot_id, xlabel='frames [' + str(self.atom_trainer.OutputGen.frame_size) + ' ms]',
    #                           ylabel='Atoms')
    #         plotter.set_lim(grid_idx=plot_id, ymin=-0.5, ymax=0.3)
    #         plotter.set_colors(grid_idx=plot_id, colors=['C1',], alpha=1)
    #
    #         plotter.gen_plot(sharex=True)

    def gen_figure_from_output(self,
                               id_name,
                               output,
                               intern_amps,
                               hparams,
                               clustering=None,
                               filters_out=None):

        if output is None or filters_out is None:
            input_labels = self.InputGen[id_name][:, None, ...]
            output_full = self.model_handler.forward(input_labels, hparams)[0]
            output = output_full[:, 0, :2]
            intern_amps = output_full[:, 0, 2:]
            filters_out = self.filters_forward(input_labels, hparams)[:, 0,
                                                                      ...]

        # Retrieve data from label.
        labels_post = self.OutputGen.postprocess_sample(output)
        output_vuv = labels_post[:, 1]
        output_vuv[output_vuv < 0.5] = 0.0
        output_vuv[output_vuv >= 0.5] = 1.0
        output_vuv = output_vuv.astype(bool)

        output_lf0 = labels_post[:, 0]

        # Load original lf0 and vuv.
        org_labels = self.OutputGen.load_sample(id_name,
                                                self.OutputGen.dir_labels)
        original_lf0, _ = self.OutputGen.convert_to_world_features(org_labels)
        # original_lf0, _ = interpolate_lin(original_lf0)

        phrase_curve = self.OutputGen.get_phrase_curve(id_name)
        original_lf0 -= phrase_curve[:len(original_lf0)]
        original_lf0 = original_lf0[:len(output_lf0)]

        org_labels = self.atom_trainer.OutputGen.load_sample(
            id_name, self.atom_trainer.OutputGen.dir_labels,
            len(self.atom_trainer.OutputGen.theta_interval),
            self.atom_trainer.OutputGen.dir_world_labels)
        org_vuv = org_labels[:, 0, 0]
        org_vuv = org_vuv.astype(bool)

        thetas = self.model_handler.model.thetas_approx()

        # Get a data plotter
        net_name = os.path.basename(hparams.model_name)
        filename = str(os.path.join(hparams.out_dir, id_name + '.' + net_name))
        plotter = DataPlotter()

        plot_id = 0

        graphs_intern = list()

        for idx in reversed(range(intern_amps.shape[1])):
            graphs_intern.append(
                (intern_amps[:, idx], r'$\theta$={0:.3f}'.format(thetas[idx])))
        plotter.set_data_list(grid_idx=plot_id, data_list=graphs_intern)
        plotter.set_area_list(grid_idx=plot_id,
                              area_list=[(np.invert(output_vuv), '0.75', 1.0)])
        plotter.set_label(grid_idx=plot_id, ylabel='command')
        amp_max = 0.04
        amp_min = -amp_max
        plotter.set_lim(grid_idx=plot_id, ymin=amp_min, ymax=amp_max)
        plot_id += 1

        graphs_filters = list()
        for idx in reversed(range(filters_out.shape[1])):
            graphs_filters.append((filters_out[:, idx], ))
        plotter.set_data_list(grid_idx=plot_id, data_list=graphs_filters)
        plotter.set_area_list(grid_idx=plot_id,
                              area_list=[(np.invert(output_vuv), '0.75', 1.0,
                                          'Unvoiced')])
        plotter.set_label(grid_idx=plot_id, ylabel='filtered')
        amp_max = 0.1
        amp_min = -amp_max
        plotter.set_lim(grid_idx=plot_id, ymin=amp_min, ymax=amp_max)
        plot_id += 1

        graphs_lf0 = list()
        graphs_lf0.append((original_lf0, "Original"))
        graphs_lf0.append((output_lf0, "Predicted"))
        plotter.set_data_list(grid_idx=plot_id, data_list=graphs_lf0)
        plotter.set_hatchstyles(grid_idx=plot_id, hatchstyles=['\\\\'])
        plotter.set_area_list(grid_idx=plot_id,
                              area_list=[(np.invert(org_vuv.astype(bool)),
                                          '0.75', 1.0, 'Reference unvoiced')])
        plotter.set_label(grid_idx=plot_id,
                          xlabel='frames [' + str(hparams.frame_size_ms) +
                          ' ms]',
                          ylabel='lf0')
        amp_lim = 1
        plotter.set_lim(grid_idx=plot_id, ymin=-amp_lim, ymax=amp_lim)
        plotter.set_linestyles(grid_idx=plot_id, linestyles=['-.', '-'])
        plotter.set_colors(grid_idx=plot_id,
                           colors=['C3', 'C2', 'C0'],
                           alpha=1)

        plotter.gen_plot()
        # plotter.gen_plot(True)
        plotter.save_to_file(filename + ".FILTERS.png")

        plotter.plt.show()

        if clustering is None:
            return

        plotter = DataPlotter()

        def cluster(array, mean=False):
            if mean:
                return np.array([
                    np.take(array, i, axis=-1).mean() for i in clustering
                ]).transpose()
            return np.array([
                np.take(array, i, axis=-1).sum(-1) for i in clustering
            ]).transpose()

        clustered_amps = cluster(intern_amps)
        clustered_thetas = cluster(thetas, True)
        clustered_filters = cluster(filters_out)

        plot_id = 0
        graphs_intern = list()

        for idx in reversed(range(clustered_amps.shape[1])):
            graphs_intern.append(
                (clustered_amps[:, idx],
                 r'$\theta$={0:.3f}'.format(clustered_thetas[idx])))
        plotter.set_data_list(grid_idx=plot_id, data_list=graphs_intern)
        plotter.set_area_list(grid_idx=plot_id,
                              area_list=[(np.invert(output_vuv), '0.75', 1.0,
                                          'Unvoiced')])
        plotter.set_label(grid_idx=plot_id, ylabel='cluster command')
        amp_max = 0.04
        amp_min = -amp_max
        plotter.set_lim(grid_idx=plot_id, ymin=amp_min, ymax=amp_max)
        plot_id += 1

        graphs_filters = list()
        for idx in reversed(range(clustered_filters.shape[1])):
            graphs_filters.append((clustered_filters[:, idx], ))
        plotter.set_data_list(grid_idx=plot_id, data_list=graphs_filters)
        plotter.set_area_list(grid_idx=plot_id,
                              area_list=[(np.invert(output_vuv), '0.75', 1.0)])
        plotter.set_label(grid_idx=plot_id, ylabel='filtered')
        amp_max = 0.175
        amp_min = -amp_max
        plotter.set_lim(grid_idx=plot_id, ymin=amp_min, ymax=amp_max)
        plot_id += 1

        graphs_lf0 = list()
        graphs_lf0.append((original_lf0, "Original"))
        graphs_lf0.append((output_lf0, "Predicted"))
        plotter.set_data_list(grid_idx=plot_id, data_list=graphs_lf0)
        plotter.set_hatchstyles(grid_idx=plot_id, hatchstyles=['\\\\'])
        plotter.set_area_list(grid_idx=plot_id,
                              area_list=[(np.invert(org_vuv.astype(bool)),
                                          '0.75', 1.0, 'Reference unvoiced')])
        plotter.set_label(grid_idx=plot_id,
                          xlabel='frames [' +
                          str(self.atom_trainer.OutputGen.frame_size) + ' ms]',
                          ylabel='lf0')
        # amp_lim = max(np.max(np.abs(wcad_lf0)), np.max(np.abs(output_lf0))) * 1.1
        amp_lim = 1
        plotter.set_lim(grid_idx=plot_id, ymin=-amp_lim, ymax=amp_lim)
        plotter.set_linestyles(grid_idx=plot_id, linestyles=['-.', '-'])
        plotter.set_colors(grid_idx=plot_id,
                           colors=['C3', 'C2', 'C0'],
                           alpha=1)

        plotter.gen_plot()
        # plotter.gen_plot(True)
        plotter.save_to_file(filename + ".CLUSTERS.png")

        plotter.plt.show()

    def gen_figure_atoms(self, hparams, ids_input):
        self.atom_trainer.gen_figure(hparams, ids_input)

    def synthesize(self, id_list, synth_output, hparams):
        """Save output of model to .lf0 and (.vuv) files and call Merlin synth which reads those files."""

        # Reconstruct lf0 from generated atoms and write it to synth output.
        # recon_dict = self.get_recon_from_synth_output(synth_output)
        full_output = dict()
        for id_name, labels in synth_output.items():
            # Take lf0 and vuv from network output.
            lf0 = labels[:, 0]
            vuv = labels[:, 1]

            phrase_curve = self.OutputGen.get_phrase_curve(id_name)
            lf0 = lf0 + phrase_curve[:len(lf0)].squeeze()

            vuv[vuv < 0.5] = 0.0
            vuv[vuv >= 0.5] = 1.0

            # Get mgc, vuv and bap data either through a trained acoustic model or from data extracted from the audio.
            if hparams.synth_acoustic_model_path is None:
                path = os.path.realpath(
                    os.path.join(hparams.out_dir,
                                 self.dir_extracted_acoustic_features))
                full_sample: np.ndarray = WorldFeatLabelGen.load_sample(
                    id_name,
                    path,
                    add_deltas=False,
                    num_coded_sps=hparams.num_coded_sps
                )  # Load extracted data.
                len_diff = len(full_sample) - len(lf0)
                trim_front = len_diff // 2
                trim_end = len_diff - trim_front
                full_sample = WorldFeatLabelGen.trim_end_sample(
                    full_sample, trim_end)
                full_sample = WorldFeatLabelGen.trim_end_sample(full_sample,
                                                                trim_front,
                                                                reverse=True)
            else:
                raise NotImplementedError()

            # Overwrite lf0 and vuv by network output.
            full_sample[:, hparams.num_coded_sps] = lf0
            full_sample[:, hparams.num_coded_sps + 1] = vuv
            # Fill a dictionary with the samples.
            full_output[id_name + "_E2E"] = full_sample

        # Run the merlin synthesizer
        self.run_world_synth(full_output, hparams)

    def compute_score(self, dict_outputs_post, dict_hiddens, hparams):

        # Get data for comparision.
        dict_original_post = dict()
        for id_name in dict_outputs_post.keys():
            dict_original_post[id_name] = self.OutputGen.load_sample(
                id_name, self.OutputGen.dir_labels)

        f0_rmse = 0.0
        vuv_error_rate = 0.0
        f0_rmse_max_id = "None"
        f0_rmse_max = 0.0
        vuv_error_max_id = "None"
        vuv_error_max = 0.0

        all_rmse = []
        all_vuv = []

        for id_name, labels in dict_outputs_post.items():
            output_lf0 = labels[:, 0]
            output_vuv = labels[:, 1]
            output_vuv[output_vuv < 0.5] = 0.0
            output_vuv[output_vuv >= 0.5] = 1.0
            output_vuv = output_vuv.astype(bool)

            # Get data for comparision.
            org_lf0 = dict_original_post[id_name][:, 0]
            org_vuv = dict_original_post[id_name][:, 1]
            phrase_curve = self.OutputGen.get_phrase_curve(id_name)

            # Compute f0 from lf0.
            org_f0 = np.exp(org_lf0.squeeze())[:len(
                output_lf0)]  # Fix minor negligible length mismatch.
            output_f0 = np.exp(output_lf0 +
                               phrase_curve[:len(output_lf0)].squeeze())

            # Compute RMSE, keep track of worst RMSE.
            f0_mse = (org_f0 - output_f0)**2
            current_f0_rmse = math.sqrt(
                (f0_mse * org_vuv[:len(output_lf0)]).sum() /
                org_vuv[:len(output_lf0)].sum())
            if current_f0_rmse > f0_rmse_max:
                f0_rmse_max_id = id_name
                f0_rmse_max = current_f0_rmse
            f0_rmse += current_f0_rmse
            all_rmse.append(current_f0_rmse)

            num_errors = (org_vuv[:len(output_lf0)] != output_vuv)
            vuv_error_rate_tmp = float(num_errors.sum()) / len(output_lf0)
            if vuv_error_rate_tmp > vuv_error_max:
                vuv_error_max_id = id_name
                vuv_error_max = vuv_error_rate_tmp
            vuv_error_rate += vuv_error_rate_tmp
            all_vuv.append(vuv_error_rate_tmp)

        f0_rmse /= len(dict_outputs_post)
        vuv_error_rate /= len(dict_outputs_post)

        self.logger.info("Worst F0 RMSE: " + f0_rmse_max_id +
                         " {:4.2f}Hz".format(f0_rmse_max))
        self.logger.info("Worst VUV error: " + vuv_error_max_id +
                         " {:2.2f}%".format(vuv_error_max * 100))
        self.logger.info("Benchmark score: F0 RMSE " +
                         "{:4.2f}Hz".format(f0_rmse) + ", VUV " +
                         "{:2.2f}%".format(vuv_error_rate * 100))

        return f0_rmse, vuv_error_rate
Пример #12
0
    def __init__(self,
                 wcad_root,
                 dir_audio,
                 dir_atom_labels,
                 dir_lf0_labels,
                 dir_question_labels,
                 id_list,
                 thetas,
                 k,
                 num_questions,
                 dist_window_size=51,
                 hparams_phrase=None):
        """Default constructor.

        :param wcad_root:               Path to main directory of wcad.
        :param dir_audio:               Path to directory that contains the .wav files.
        :param dir_lf0_labels:          Path to directory that contains the .lf0 files.
        :param dir_atom_labels:         Path to directory that contains the .atoms files.
        :param dir_question_labels:     Path to directory that contains the .lab files.
        :param id_list:                 List containing all ids. Subset is taken as test set.
        :param thetas:                  List of used theta values.
        :param k:                       k-order of each each atom.
        :param num_questions:           Expected number of questions in question labels.
        :param dist_window_size:        Width of the distribution surrounding each atom spike
                                        The window is only used for amps. Thetas are surrounded by a window of 5.
        :param hparams_phrase:          Hyper-parameter container.
        """
        if hparams_phrase is None:
            hparams_phrase = self.create_hparams()
            hparams_phrase.out_dir = os.path.curdir

        hparams_flat = hparams_phrase.hparams_flat
        if hparams_flat is None:
            hparams_flat = copy.deepcopy(hparams_phrase)

        # Write missing default parameters.
        if hparams_phrase.variable_sequence_length_train is None:
            hparams_phrase.variable_sequence_length_train = hparams_phrase.batch_size_train > 1
        if hparams_phrase.variable_sequence_length_test is None:
            hparams_phrase.variable_sequence_length_test = hparams_phrase.batch_size_test > 1
        if hparams_phrase.synth_dir is None:
            hparams_phrase.synth_dir = os.path.join(hparams_phrase.out_dir,
                                                    "synth")

        super().__init__(id_list, hparams_phrase)

        self.InputGen = QuestionLabelGen(dir_question_labels, num_questions)
        self.InputGen.get_normalisation_params(dir_question_labels)

        self.OutputGen = FlatLF0LabelGen(dir_lf0_labels,
                                         dir_atom_labels,
                                         remove_phrase=False)
        self.OutputGen.get_normalisation_params(dir_atom_labels)

        self.dataset_train = PyTorchLabelGensDataset(self.id_list_train,
                                                     self.InputGen,
                                                     self.OutputGen,
                                                     hparams_phrase)
        self.dataset_val = PyTorchLabelGensDataset(self.id_list_val,
                                                   self.InputGen,
                                                   self.OutputGen,
                                                   hparams_phrase)

        self.flat_trainer = AtomNeuralFilterModelTrainer(
            wcad_root, dir_audio, dir_atom_labels, dir_lf0_labels,
            dir_question_labels, id_list, thetas, k, num_questions,
            dist_window_size, hparams_flat)

        if self.loss_function is None:
            self.loss_function = L1WeightedVUVMSELoss(
                weight=hparams_phrase.vuv_weight,
                vuv_loss_weight=hparams_phrase.vuv_loss_weight,
                L1_weight=hparams_phrase.L1_loss_weight,
                reduce=False)
        if hparams_phrase.scheduler_type == "default":
            hparams_phrase.scheduler_type = "None"

        # Override the collate and decollate methods of batches.
        self.batch_collate_fn = self.prepare_batch
        self.batch_decollate_fn = self.decollate_network_output