Exemplo n.º 1
0
    def __init__(self,
                 dir_phoneme_labels,
                 dir_durations,
                 id_list,
                 file_symbol_dict,
                 hparams=None):
        """Default constructor.

        :param dir_phoneme_labels:      Path to the directory containing the label files with monophones.
        :param dir_durations:           Path to the directory containing the durations.
        :param id_list:                 List containing all ids. Subset is taken as test set.
        :param file_symbol_dict:        List of all used monophones.
        """
        if hparams is None:
            hparams = self.create_hparams()
            hparams.out_dir = os.path.curdir

        # Write missing default parameters.
        if hparams.variable_sequence_length_train is None:
            hparams.variable_sequence_length_train = hparams.batch_size_train > 1
        if hparams.variable_sequence_length_test is None:
            hparams.variable_sequence_length_test = hparams.batch_size_test > 1
        if not hasattr(hparams, "synth_dir") or hparams.synth_dir is None:
            hparams.synth_dir = os.path.join(hparams.out_dir, "synth")

        super().__init__(id_list, hparams)

        self.InputGen = PhonemeLabelGen(dir_phoneme_labels,
                                        file_symbol_dict,
                                        hparams.phoneme_label_type,
                                        one_hot=True)
        self.OutputGen = PhonemeDurationLabelGen(dir_durations)
        self.OutputGen.get_normalisation_params(
            dir_durations, hparams.output_norm_params_file_prefix)

        self.dataset_train = PyTorchLabelGensDataset(self.id_list_train,
                                                     self.InputGen,
                                                     self.OutputGen,
                                                     hparams,
                                                     match_lengths=False)
        self.dataset_val = PyTorchLabelGensDataset(self.id_list_val,
                                                   self.InputGen,
                                                   self.OutputGen,
                                                   hparams,
                                                   match_lengths=False)

        if self.loss_function is None:
            self.loss_function = torch.nn.MSELoss(reduction='none')

        if hparams.scheduler_type == "default":
            hparams.scheduler_type = "Plateau"
            hparams.add_hparams(plateau_verbose=True)
Exemplo n.º 2
0
    def __init__(self,
                 wcad_root,
                 dir_atom_labels,
                 dir_question_labels,
                 id_list,
                 thetas,
                 k,
                 num_questions,
                 hparams=None):
        """Default constructor.

        :param wcad_root:               Path to main directory of wcad.
        :param dir_atom_labels:         Path to directory that contains the .atom files.
        :param dir_question_labels:     Path to directory that contains the .questions files.
        :param id_list:                 List containing all ids. Subset is taken as test set.
        :param thetas:                  List of theta values.
        :param k:                       K value of atoms.
        :param num_questions:           Expected number of questions in question labels.
        :param hparams:                 Hyper-parameter container.
        """
        if hparams is None:
            hparams = self.create_hparams()
            hparams.out_dir = os.path.curdir

        # Write missing default parameters.
        if hparams.variable_sequence_length_train is None:
            hparams.variable_sequence_length_train = hparams.batch_size_train > 1
        if hparams.variable_sequence_length_test is None:
            hparams.variable_sequence_length_test = hparams.batch_size_test > 1
        if hparams.synth_dir is None:
            hparams.synth_dir = os.path.join(hparams.out_dir, "synth")

        # If the weight for unvoiced frames is not given, compute it to get equal weights.
        non_zero_occurrence = min(0.99, 0.02 / len(thetas))
        zero_occurrence = 1 - non_zero_occurrence
        if not hasattr(hparams, "weight_zero"):
            hparams.add_hparam("weight_non_zero", 1 / non_zero_occurrence)
            hparams.add_hparam("weight_zero", 1 / zero_occurrence)
        elif hparams.weight_zero is None:
            hparams.weight_non_zero = 1 / non_zero_occurrence
            hparams.weight_zero = 1 / zero_occurrence

        super().__init__(id_list, hparams)

        self.InputGen = QuestionLabelGen(dir_question_labels, num_questions)
        self.InputGen.get_normalisation_params(
            dir_question_labels, hparams.input_norm_params_file_prefix)

        self.OutputGen = AtomLabelGen(wcad_root, dir_atom_labels, thetas, k,
                                      hparams.frame_size_ms)
        self.OutputGen.get_normalisation_params(
            dir_atom_labels, hparams.output_norm_params_file_prefix)

        self.dataset_train = PyTorchLabelGensDataset(self.id_list_train,
                                                     self.InputGen,
                                                     self.OutputGen,
                                                     hparams,
                                                     match_lengths=True)
        self.dataset_val = PyTorchLabelGensDataset(self.id_list_val,
                                                   self.InputGen,
                                                   self.OutputGen,
                                                   hparams,
                                                   match_lengths=True)

        if self.loss_function is None:
            self.loss_function = WeightedNonzeroMSELoss(
                hparams.use_gpu,
                hparams.weight_zero,
                hparams.weight_non_zero,
                size_average=False,
                reduce=False)
        if hparams.scheduler_type == "default":
            hparams.scheduler_type = "Plateau"
            hparams.add_hparams(plateau_patience=10,
                                plateau_factor=0.5,
                                plateau_verbose=True)
Exemplo n.º 3
0
    def __init__(self,
                 wcad_root,
                 dir_audio,
                 dir_atom_labels,
                 dir_lf0_labels,
                 dir_question_labels,
                 id_list,
                 thetas,
                 k,
                 num_questions,
                 dist_window_size=51,
                 hparams_phrase=None):
        """Default constructor.

        :param wcad_root:               Path to main directory of wcad.
        :param dir_audio:               Path to directory that contains the .wav files.
        :param dir_lf0_labels:          Path to directory that contains the .lf0 files.
        :param dir_atom_labels:         Path to directory that contains the .atoms files.
        :param dir_question_labels:     Path to directory that contains the .lab files.
        :param id_list:                 List containing all ids. Subset is taken as test set.
        :param thetas:                  List of used theta values.
        :param k:                       k-order of each each atom.
        :param num_questions:           Expected number of questions in question labels.
        :param dist_window_size:        Width of the distribution surrounding each atom spike
                                        The window is only used for amps. Thetas are surrounded by a window of 5.
        :param hparams_phrase:          Hyper-parameter container.
        """
        if hparams_phrase is None:
            hparams_phrase = self.create_hparams()
            hparams_phrase.out_dir = os.path.curdir

        hparams_flat = hparams_phrase.hparams_flat
        if hparams_flat is None:
            hparams_flat = copy.deepcopy(hparams_phrase)
        # Set default paths to pre-trained models.
        if hparams_phrase.atom_model_path is None:
            hparams_phrase.atom_model_path = os.path.join(
                hparams_phrase.out_dir, hparams_phrase.networks_dir,
                hparams_phrase.model_name + "_flat_atoms")
        if hparams_phrase.flat_model_path is None:
            hparams_phrase.flat_model_path = os.path.join(
                hparams_phrase.out_dir, hparams_phrase.networks_dir,
                hparams_phrase.model_name + "_flat")

        # Write missing default parameters.
        if hparams_phrase.synth_dir is None:
            hparams_phrase.synth_dir = os.path.join(hparams_phrase.out_dir,
                                                    "synth")

        super().__init__(id_list, hparams_phrase)

        self.InputGen = QuestionLabelGen(dir_question_labels, num_questions)
        self.InputGen.get_normalisation_params(
            dir_question_labels, hparams_phrase.input_norm_params_file_prefix)

        self.OutputGen = FlatLF0LabelGen(dir_lf0_labels,
                                         dir_atom_labels,
                                         remove_phrase=False)
        self.OutputGen.get_normalisation_params(
            dir_atom_labels, hparams_phrase.output_norm_params_file_prefix)

        self.dataset_train = PyTorchLabelGensDataset(self.id_list_train,
                                                     self.InputGen,
                                                     self.OutputGen,
                                                     hparams_phrase,
                                                     match_lengths=True)
        self.dataset_val = PyTorchLabelGensDataset(self.id_list_val,
                                                   self.InputGen,
                                                   self.OutputGen,
                                                   hparams_phrase,
                                                   match_lengths=True)

        self.flat_trainer = AtomNeuralFilterModelTrainer(
            wcad_root, dir_audio, dir_atom_labels, dir_lf0_labels,
            dir_question_labels, id_list, thetas, k, num_questions,
            dist_window_size, hparams_flat)

        if self.loss_function is None:
            self.loss_function = L1WeightedVUVMSELoss(
                weight_unvoiced=hparams_phrase.weight_unvoiced,
                vuv_loss_weight=hparams_phrase.vuv_loss_weight,
                L1_loss_weight=hparams_phrase.L1_loss_weight,
                reduce=False)
        if hparams_phrase.scheduler_type == "default":
            hparams_phrase.scheduler_type = "None"

        # Override the collate and decollate methods of batches.
        self.batch_collate_fn = self.prepare_batch
        self.batch_decollate_fn = self.decollate_network_output
Exemplo n.º 4
0
    def __init__(self, dir_world_features, id_list, hparams=None):

        if hparams is None:
            hparams = self.create_hparams()
            hparams.out_dir = os.path.curdir

        # Write missing default parameters.
        if hparams.variable_sequence_length_train is None:
            hparams.variable_sequence_length_train = hparams.batch_size_train > 1
        if hparams.variable_sequence_length_test is None:
            hparams.variable_sequence_length_test = hparams.batch_size_test > 1
        if hparams.synth_dir is None:
            hparams.synth_dir = os.path.join(hparams.out_dir, "synth")

        super().__init__(id_list, hparams)

        in_to_out_multiplier = int(hparams.frame_rate_output_Hz /
                                   (1000.0 / hparams.frame_size_ms))
        max_frames_input_trainset = int(
            1000.0 / hparams.frame_size_ms * hparams.max_input_train_sec
        ) * in_to_out_multiplier  # Multiply by number of seconds.
        max_frames_input_testset = int(
            1000.0 / hparams.frame_size_ms * hparams.max_input_test_sec
        ) * in_to_out_multiplier  # Ensure that test takes all frames. NOTE: Had to limit it because of memory constraints.

        self.InputGen = WorldFeatLabelGen(
            dir_world_features,
            add_deltas=False,
            sampling_fn=partial(sample_linearly,
                                in_to_out_multiplier=in_to_out_multiplier,
                                dtype=np.float32),
            num_coded_sps=hparams.num_coded_sps,
            sp_type=hparams.sp_type,
            load_sp=hparams.load_sp,
            load_lf0=hparams.load_lf0,
            load_vuv=hparams.load_vuv,
            load_bap=hparams.load_bap)
        self.InputGen.get_normalisation_params(
            dir_world_features, hparams.input_norm_params_file_prefix)

        self.OutputGen = RawWaveformLabelGen(
            frame_rate_output_Hz=hparams.frame_rate_output_Hz,
            frame_size_ms=hparams.frame_size_ms,
            mu=hparams.mu if hparams.input_type == "mulaw-quantize" else None,
            silence_threshold_quantized=hparams.silence_threshold_quantized)
        # No normalisation parameters required.

        self.dataset_train = LabelGensDataset(
            self.id_list_train,
            self.InputGen,
            self.OutputGen,
            hparams,
            random_select=True,
            max_frames_input=max_frames_input_trainset)
        self.dataset_val = LabelGensDataset(
            self.id_list_val,
            self.InputGen,
            self.OutputGen,
            hparams,
            random_select=True,
            max_frames_input=max_frames_input_testset)

        if self.loss_function is None:
            if hparams.input_type == "mulaw-quantize":
                self.loss_function = OneHotCrossEntropyLoss(reduction='none',
                                                            shift=1)
            else:
                self.loss_function = DiscretizedMixturelogisticLoss(
                    hparams.quantize_channels,
                    hparams.log_scale_min,
                    reduction='none',
                    hinge_loss=hparams.hinge_regularizer)

        if hparams.scheduler_type == "default":
            hparams.scheduler_type = "Noam"
            # hparams.scheduler_args['exponential_gamma'] = 0.99
            hparams.scheduler_args['wormup_steps'] = 4000

        # Override the collate and decollate methods of batches.
        self.batch_collate_fn = partial(self.prepare_batch,
                                        use_cond=hparams.use_cond,
                                        one_hot_target=True)
        self.batch_decollate_fn = self.decollate_network_output
Exemplo n.º 5
0
class WaveNetVocoderTrainer(ModelTrainer):
    logger = logging.getLogger(__name__)

    #########################
    # Default constructor
    #
    def __init__(self, dir_world_features, id_list, hparams=None):

        if hparams is None:
            hparams = self.create_hparams()
            hparams.out_dir = os.path.curdir

        # Write missing default parameters.
        if hparams.variable_sequence_length_train is None:
            hparams.variable_sequence_length_train = hparams.batch_size_train > 1
        if hparams.variable_sequence_length_test is None:
            hparams.variable_sequence_length_test = hparams.batch_size_test > 1
        if hparams.synth_dir is None:
            hparams.synth_dir = os.path.join(hparams.out_dir, "synth")

        super().__init__(id_list, hparams)

        in_to_out_multiplier = int(hparams.frame_rate_output_Hz /
                                   (1000.0 / hparams.frame_size_ms))
        max_frames_input_trainset = int(
            1000.0 / hparams.frame_size_ms * hparams.max_input_train_sec
        ) * in_to_out_multiplier  # Multiply by number of seconds.
        max_frames_input_testset = int(
            1000.0 / hparams.frame_size_ms * hparams.max_input_test_sec
        ) * in_to_out_multiplier  # Ensure that test takes all frames. NOTE: Had to limit it because of memory constraints.

        self.InputGen = WorldFeatLabelGen(
            dir_world_features,
            add_deltas=False,
            sampling_fn=partial(sample_linearly,
                                in_to_out_multiplier=in_to_out_multiplier,
                                dtype=np.float32),
            num_coded_sps=hparams.num_coded_sps,
            sp_type=hparams.sp_type,
            load_sp=hparams.load_sp,
            load_lf0=hparams.load_lf0,
            load_vuv=hparams.load_vuv,
            load_bap=hparams.load_bap)
        self.InputGen.get_normalisation_params(
            dir_world_features, hparams.input_norm_params_file_prefix)

        self.OutputGen = RawWaveformLabelGen(
            frame_rate_output_Hz=hparams.frame_rate_output_Hz,
            frame_size_ms=hparams.frame_size_ms,
            mu=hparams.mu if hparams.input_type == "mulaw-quantize" else None,
            silence_threshold_quantized=hparams.silence_threshold_quantized)
        # No normalisation parameters required.

        self.dataset_train = LabelGensDataset(
            self.id_list_train,
            self.InputGen,
            self.OutputGen,
            hparams,
            random_select=True,
            max_frames_input=max_frames_input_trainset)
        self.dataset_val = LabelGensDataset(
            self.id_list_val,
            self.InputGen,
            self.OutputGen,
            hparams,
            random_select=True,
            max_frames_input=max_frames_input_testset)

        if self.loss_function is None:
            if hparams.input_type == "mulaw-quantize":
                self.loss_function = OneHotCrossEntropyLoss(reduction='none',
                                                            shift=1)
            else:
                self.loss_function = DiscretizedMixturelogisticLoss(
                    hparams.quantize_channels,
                    hparams.log_scale_min,
                    reduction='none',
                    hinge_loss=hparams.hinge_regularizer)

        if hparams.scheduler_type == "default":
            hparams.scheduler_type = "Noam"
            # hparams.scheduler_args['exponential_gamma'] = 0.99
            hparams.scheduler_args['wormup_steps'] = 4000

        # Override the collate and decollate methods of batches.
        self.batch_collate_fn = partial(self.prepare_batch,
                                        use_cond=hparams.use_cond,
                                        one_hot_target=True)
        self.batch_decollate_fn = self.decollate_network_output

    @staticmethod
    def create_hparams(hparams_string=None, verbose=False):
        """Create model hyper-parameters. Parse non-default from given string."""
        hparams = ModelTrainer.create_hparams(hparams_string, verbose=False)
        hparams.synth_vocoder = "raw"

        hparams.add_hparams(
            batch_first=True,
            frame_rate_output_Hz=16000,
            mu=255,
            bit_depth=16,
            silence_threshold_quantized=
            None,  # Beginning and end of audio below the threshold are trimmed.
            teacher_forcing_in_test=True,
            ema_decay=0.9999,

            # Model parameters.
            input_type="mulaw-quantize",
            hinge_regularizer=
            True,  # Only used in MoL prediction (input_type="raw").
            log_scale_min=float(np.log(
                1e-14)),  # Only used for mixture of logistic distributions.
            quantize_channels=256
        )  # 256 for input type mulaw-quantize, otherwise 65536
        if hparams.input_type == "mulaw-quantize":
            hparams.add_hparam("out_channels", hparams.quantize_channels)
        else:
            hparams.add_hparam("out_channels", 10 *
                               3)  # num_mixtures * 3 (pi, mean, log_scale)

        hparams.add_hparams(
            layers=24,  # 20
            stacks=4,  # 2
            residual_channels=512,
            gate_channels=512,
            skip_out_channels=256,
            dropout=1 - 0.95,
            kernel_size=3,
            weight_normalization=True,
            use_cond=True,  # Determines if conditioning is used.
            cin_channels=63,
            upsample_conditional_features=False,
            upsample_scales=[5, 4, 2])
        if hparams.upsample_conditional_features:
            hparams.len_in_out_multiplier = reduce(mul,
                                                   hparams.upsample_scales, 1)
        else:
            hparams.len_in_out_multiplier = 1

        hparams.add_hparams(freq_axis_kernel_size=3,
                            gin_channels=-1,
                            n_speakers=1,
                            use_speaker_embedding=False,
                            sp_type="mcep",
                            load_sp=True,
                            load_lf0=True,
                            load_vuv=True,
                            load_bap=True)

        if verbose:
            logging.info(hparams.get_debug_string())

        return hparams

    # Load train and test data.
    @staticmethod
    def prepare_batch(batch,
                      common_divisor=1,
                      batch_first=False,
                      use_cond=True,
                      one_hot_target=True):
        inputs, targets, seq_lengths_input, seq_lengths_output, mask, permutation = ModelHandler.prepare_batch(
            batch, common_divisor=common_divisor, batch_first=batch_first)

        if batch_first:
            # inputs: (B x T x C) --permute--> (B x C x T)
            inputs = inputs.transpose(1, 2).contiguous()
        # TODO: Handle case where batch_first=False: inputs = inputs.transpose(2, 0, 1).contiguous()?

        if targets is not None:
            if batch_first:
                # targets: (B x T x C) --permute--> (B x C x T)
                targets = targets.transpose(1, 2).contiguous()

            if not one_hot_target:
                targets = targets.max(dim=1, keepdim=True)[1].float()

        if mask is not None:
            mask = mask[:, 1:].contiguous()

        return inputs if use_cond else None, targets, seq_lengths_input, seq_lengths_output, mask, permutation

    @staticmethod
    def decollate_network_output(output,
                                 hidden,
                                 seq_lengths=None,
                                 permutation=None,
                                 batch_first=True):

        # Output of r9y9 Wavenet has batch first, thus output: B x C x T --transpose--> B x T x C
        output = np.transpose(output, (0, 2, 1))
        if not batch_first:
            # output: B x T x C --transpose--> T x B x C
            output = np.transpose(output, (1, 0, 2))
        return ModelTrainer.split_batch(output,
                                        hidden,
                                        seq_length_output=seq_lengths,
                                        permutation=permutation,
                                        batch_first=batch_first)

    def gen_figure_from_output(self, id_name, labels, hidden, hparams):

        labels_post = self.dataset_train.postprocess_sample(
            labels)  # Labels come in as T x C.
        org_raw = RawWaveformLabelGen.load_sample(
            id_name, self.OutputGen.frame_rate_output_Hz)

        # Get a data plotter.
        plotter = DataPlotter()
        net_name = os.path.basename(hparams.model_name)
        id_name = os.path.basename(id_name).rsplit('.', 1)[0]
        filename = os.path.join(hparams.out_dir, id_name + "." + net_name)
        plotter.set_title(id_name + " - " + net_name)
        grid_idx = 0

        graphs = list()
        graphs.append((org_raw, 'Org'))
        graphs.append((labels_post, 'Wavenet'))
        plotter.set_data_list(grid_idx=grid_idx, data_list=graphs)
        plotter.set_linewidth(grid_idx=grid_idx, linewidth=[0.1])
        plotter.set_colors(grid_idx=grid_idx, alpha=0.8)
        plotter.set_lim(grid_idx, ymin=-1, ymax=1)
        plotter.set_label(grid_idx=grid_idx,
                          xlabel='frames [' +
                          str(hparams.frame_rate_output_Hz) + ' Hz]',
                          ylabel='raw')

        plotter.gen_plot()
        plotter.save_to_file(filename + '.Raw' + hparams.gen_figure_ext)

    # def synthesize(self, file_id_list, synth_output, hparams):
    #     self.run_raw_synth(synth_output, hparams)

    # def synth_ref(self, hparams, file_id_list):
    #     self.logger.info("Synthesise references for [{0}].".format(", ".join([id_name for id_name in file_id_list])))  # Can be different from original by sampling frequency.
    #
    #     synth_output = dict()
    #     for id_name in file_id_list:
    #         # Use extracted data. Useful to create a reference.
    #         raw = RawWaveformLabelGen.load_sample(id_name, self.OutputGen.frame_rate_output_Hz)
    #         synth_output[id_name] = raw
    #
    #     # Add identifier to suffix.
    #     old_synth_file_suffix = hparams.synth_file_suffix
    #     hparams.synth_file_suffix += '_ref'
    #
    #     # Run the WORLD synthesiser.
    #     self.run_raw_synth(synth_output, hparams)
    #
    #     # Restore identifier.
    #     hparams.synth_file_suffix = old_synth_file_suffix

    def save_for_vocoding(self, filename):
        # Save the full model so that hyper-parameters are already set.
        self.model_handler.save_full_model(filename,
                                           self.model_handler.model,
                                           verbose=True)
        # Save an easily loadable version of the normalisation parameters on the input side used during training.
        np.save(
            os.path.splitext(filename)[0] + "_norm_params",
            np.concatenate(self.InputGen.norm_params, axis=0))
    def __init__(self, wcad_root, dir_atom_labels, dir_lf0_labels, dir_question_labels, id_list,
                 thetas, k,
                 num_questions,
                 dist_window_size=51, hparams=None):
        """Default constructor.

        :param wcad_root:               Path to main directory of wcad.
        :param dir_atom_labels:         Path to directory that contains the .wav files.
        :param dir_lf0_labels:          Path to directory that contains the .lf0 files.
        :param dir_question_labels:     Path to directory that contains the .lab files.
        :param id_list:                 List containing all ids. Subset is taken as test set.
        :param thetas:                  List of theta values of atoms.
        :param k:                       K-value of atoms.
        :param num_questions:           Expected number of questions in question labels.
        :param dist_window_size:        Width of the distribution surrounding each atom spike
                                        The window is only used for amps. Thetas are surrounded by a window of 5.
        :param hparams:                 Hyper-parameter container.
        """
        if hparams is None:
            hparams = self.create_hparams()
            hparams.out_dir = os.path.curdir

        # Write missing default parameters.
        if hparams.variable_sequence_length_train is None:
            hparams.variable_sequence_length_train = hparams.batch_size_train > 1
        if hparams.variable_sequence_length_test is None:
            hparams.variable_sequence_length_test = hparams.batch_size_test > 1
        if hparams.synth_dir is None:
            hparams.synth_dir = os.path.join(hparams.out_dir, "synth")

        # If the weight for unvoiced frames is not given, compute it to get equal weights.
        if not hasattr(hparams, "weight_zero") or hparams.weight_zero is None:
            non_zero_occurrence = min(0.99, 0.015 / len(thetas))
            zero_occurrence = 1 - non_zero_occurrence
            hparams.add_hparam("weight_non_zero", 1 / non_zero_occurrence)
            hparams.add_hparam("weight_zero", 1 / zero_occurrence)
        if not hasattr(hparams, "weight_vuv") or hparams.weight_vuv is None:
            hparams.add_hparam("weight_vuv", 0.5)
        if not hasattr(hparams, "atom_loss_theta") or hparams.atom_loss_theta is None:
            hparams.add_hparam("atom_loss_theta", 0.01)

        # Explicitly call only the constructor of the baseclass of AtomModelTrainer.
        super(AtomModelTrainer, self).__init__(id_list, hparams)

        if hparams.dist_window_size % 2 == 0:
            hparams.dist_window_size += 1
            self.logger.warning("hparams.dist_window_size should be odd, changed it to " + str(hparams.dist_window_size))

        self.InputGen = QuestionLabelGen(dir_question_labels, num_questions)
        self.InputGen.get_normalisation_params(dir_question_labels, hparams.input_norm_params_file_prefix)

        # Overwrite OutputGen by the one with beta distribution.
        self.OutputGen = AtomVUVDistPosLabelGen(wcad_root, dir_atom_labels, dir_lf0_labels, thetas, k, hparams.frame_size_ms, window_size=dist_window_size)
        self.OutputGen.get_normalisation_params(dir_atom_labels, hparams.output_norm_params_file_prefix)

        self.dataset_train = PyTorchLabelGensDataset(self.id_list_train, self.InputGen, self.OutputGen, hparams, match_lengths=True)
        self.dataset_val = PyTorchLabelGensDataset(self.id_list_val, self.InputGen, self.OutputGen, hparams, match_lengths=True)

        if self.loss_function is None:
            self.loss_function = WeightedNonzeroWMSEAtomLoss(use_gpu=hparams.use_gpu,
                                                             theta=hparams.atom_loss_theta,
                                                             weights_vuv=hparams.weight_vuv,
                                                             weights_zero=hparams.weight_zero,
                                                             weights_non_zero=hparams.weight_non_zero,
                                                             reduce=False)

        if hparams.scheduler_type == "default":
            hparams.scheduler_type = "None"
Exemplo n.º 7
0
    def __init__(self,
                 wcad_root,
                 dir_audio,
                 dir_atom_labels,
                 dir_lf0_labels,
                 dir_question_labels,
                 id_list,
                 thetas,
                 k,
                 num_questions,
                 dist_window_size=51,
                 hparams=None):
        """Default constructor.

        :param wcad_root:               Path to main directory of wcad.
        :param dir_audio:               Path to directory that contains the .wav files.
        :param dir_atom_labels:         Path to directory that contains the .atoms files.
        :param dir_lf0_labels:          Path to directory that contains the .lf0 files.
        :param dir_question_labels:     Path to directory that contains the .lab files.
        :param id_list:                 List containing all ids. Subset is taken as test set.
        :param thetas:                  List of theta values of the used atoms.
        :param k:                       K-value of atoms.
        :param num_questions:           Expected number of questions in question labels.
        :param dist_window_size:        Size of distribution around atom amplitudes when training the atom model.
        :param hparams:                 Hyper-parameter container.
        """

        if hparams is None:
            hparams = self.create_hparams()
            hparams.out_dir = os.path.curdir

        hparams_atom = hparams.hparams_atom

        if hparams_atom is None:
            hparams_atom = copy.deepcopy(hparams)
            hparams_atom.synth_gen_figure = False
            hparams_atom.synth_acoustic_model_path = None
        if hparams.atom_model_path is None:
            hparams.atom_model_path = os.path.join(
                hparams.out_dir, hparams.networks_dir,
                hparams.model_name + "_atoms")

        # Write missing default parameters.
        if hparams.synth_dir is None:
            hparams.synth_dir = os.path.join(hparams.out_dir, "synth")

        super().__init__(id_list, hparams)

        self.InputGen = QuestionLabelGen(dir_question_labels, num_questions)
        self.InputGen.get_normalisation_params(
            dir_question_labels, hparams.input_norm_params_file_prefix)

        self.OutputGen = FlatLF0LabelGen(dir_lf0_labels, dir_atom_labels)
        self.OutputGen.get_normalisation_params(
            dir_atom_labels, hparams.output_norm_params_file_prefix)

        self.dataset_train = PyTorchLabelGensDataset(self.id_list_train,
                                                     self.InputGen,
                                                     self.OutputGen,
                                                     hparams,
                                                     match_lengths=True)
        self.dataset_val = PyTorchLabelGensDataset(self.id_list_val,
                                                   self.InputGen,
                                                   self.OutputGen,
                                                   hparams,
                                                   match_lengths=True)

        self.atom_trainer = AtomVUVDistPosModelTrainer(
            wcad_root, dir_atom_labels, dir_lf0_labels, dir_question_labels,
            id_list, thetas, k, num_questions, dist_window_size, hparams_atom)

        if self.loss_function is None:
            self.loss_function = L1WeightedVUVMSELoss(
                weight_unvoiced=hparams.weight_unvoiced,
                vuv_loss_weight=hparams.vuv_loss_weight,
                L1_loss_weight=hparams.L1_loss_weight,
                reduce=False)
        if hparams.scheduler_type == "default":
            hparams.scheduler_type = "None"

        # Override the collate and decollate methods of batches.
        self.batch_collate_fn = self.prepare_batch
        self.batch_decollate_fn = self.decollate_network_output