Exemplo n.º 1
0
    def test_save_load(self):
        dir_out = self._get_test_dir()

        theta_start = 0.01
        theta_stop = 0.055
        theta_step = 0.005
        thetas = np.arange(theta_start, theta_stop, theta_step)
        k = 6
        frame_size_ms = 5

        atom_gen = AtomLabelGen(self.dir_wcad_root, dir_out, thetas, k,
                                frame_size_ms)
        label_dict, *extracted_norm_params = atom_gen.gen_data(
            self.dir_wav, dir_out, id_list=self.id_list, return_dict=True)

        # Call this once before starting the pre-processing.
        norm_params = atom_gen.get_normalisation_params(dir_out)
        self.assertTrue((extracted_norm_params[0] == norm_params[0]).all())
        self.assertTrue((extracted_norm_params[1] == norm_params[1]).all())

        test_label = label_dict[self.id_list[1]]
        test_label_pre = atom_gen.preprocess_sample(test_label)
        self.assertTrue(
            np.isclose(test_label_pre, atom_gen[self.id_list[1]]).all())

        test_label_post = atom_gen.postprocess_sample(test_label_pre)
        # Post-precessing does peak selection, so pre and post labels
        # are not the same and we cannot check for equality here.

        self.assertTrue(np.isclose(-0.2547, test_label_post.sum(),
                                   atol=0.0001))

        os.remove(os.path.join(self.dir_database, "wcad_.txt"))
        shutil.rmtree(dir_out)
Exemplo n.º 2
0
    def __init__(self,
                 wcad_root,
                 dir_atom_labels,
                 dir_question_labels,
                 id_list,
                 thetas,
                 k,
                 num_questions,
                 hparams=None):
        """Default constructor.

        :param wcad_root:               Path to main directory of wcad.
        :param dir_atom_labels:         Path to directory that contains the .atom files.
        :param dir_question_labels:     Path to directory that contains the .questions files.
        :param id_list:                 List containing all ids. Subset is taken as test set.
        :param thetas:                  List of theta values.
        :param k:                       K value of atoms.
        :param num_questions:           Expected number of questions in question labels.
        :param hparams:                 Hyper-parameter container.
        """
        if hparams is None:
            hparams = self.create_hparams()
            hparams.out_dir = os.path.curdir

        # Write missing default parameters.
        if hparams.variable_sequence_length_train is None:
            hparams.variable_sequence_length_train = hparams.batch_size_train > 1
        if hparams.variable_sequence_length_test is None:
            hparams.variable_sequence_length_test = hparams.batch_size_test > 1
        if hparams.synth_dir is None:
            hparams.synth_dir = os.path.join(hparams.out_dir, "synth")

        # If the weight for unvoiced frames is not given, compute it to get equal weights.
        non_zero_occurrence = min(0.99, 0.02 / len(thetas))
        zero_occurrence = 1 - non_zero_occurrence
        if not hasattr(hparams, "weight_zero"):
            hparams.add_hparam("weight_non_zero", 1 / non_zero_occurrence)
            hparams.add_hparam("weight_zero", 1 / zero_occurrence)
        elif hparams.weight_zero is None:
            hparams.weight_non_zero = 1 / non_zero_occurrence
            hparams.weight_zero = 1 / zero_occurrence

        super().__init__(id_list, hparams)

        self.InputGen = QuestionLabelGen(dir_question_labels, num_questions)
        self.InputGen.get_normalisation_params(
            dir_question_labels, hparams.input_norm_params_file_prefix)

        self.OutputGen = AtomLabelGen(wcad_root, dir_atom_labels, thetas, k,
                                      hparams.frame_size_ms)
        self.OutputGen.get_normalisation_params(
            dir_atom_labels, hparams.output_norm_params_file_prefix)

        self.dataset_train = PyTorchLabelGensDataset(self.id_list_train,
                                                     self.InputGen,
                                                     self.OutputGen,
                                                     hparams,
                                                     match_lengths=True)
        self.dataset_val = PyTorchLabelGensDataset(self.id_list_val,
                                                   self.InputGen,
                                                   self.OutputGen,
                                                   hparams,
                                                   match_lengths=True)

        if self.loss_function is None:
            self.loss_function = WeightedNonzeroMSELoss(
                hparams.use_gpu,
                hparams.weight_zero,
                hparams.weight_non_zero,
                size_average=False,
                reduce=False)
        if hparams.scheduler_type == "default":
            hparams.scheduler_type = "Plateau"
            hparams.add_hparams(plateau_patience=10,
                                plateau_factor=0.5,
                                plateau_verbose=True)