def __init__(self, project):
        super().__init__()
        self.ui = Ui_Prepare()
        self.ui.setupUi(self)
        self.project = project
        self.last_hotword_index = 0
        self.main_dataset = DataSet()
        # Init hotword list
        
        self.hotwords = self.project.project_info['hotwords']
        for hotword in self.hotwords:
            self.ui.current_hotword.addItem(hotword)

        self.sample_classes = [None] + self.hotwords

        # Init directory list
        self.folder_list_layout = QtWidgets.QVBoxLayout()
        self.folder_list = [[] for i in range(len(self.hotwords) + 1)]

        self.pie_slices = [QtChart.QPieSlice("non-hotword ()", 0)]
        self.pie_slices.extend([QtChart.QPieSlice("{} ()".format(name), 0) for name in self.hotwords])
        self.display_selected_directories(0)

        # Pie chart
        self.pie_series = QtChart.QPieSeries()
        for pie_slice in self.pie_slices:
            self.pie_series.append(pie_slice)
        self.pie_series.setHoleSize(0)

        self.pie_view = QtChart.QChartView()
        self.pie_view.setRenderHint(QtGui.QPainter.Antialiasing)
        self.pie_view.chart().layout().setContentsMargins(0,0,0,0)
        self.pie_view.chart().setMargins(QtCore.QMargins(0,0,0,0))
        self.pie_view.chart().legend().setAlignment(QtCore.Qt.AlignRight)
        self.pie_view.chart().addSeries(self.pie_series)

        layout = QtWidgets.QVBoxLayout()
        layout.addWidget(self.pie_view)
        self.ui.graph_placeholder.setLayout(layout)

        #Connects
        self.ui.current_hotword.currentIndexChanged.connect(self.display_selected_directories)
        self.ui.add_folder_button.clicked.connect(self.add_folder)
        self.ui.test_subset_CB.stateChanged.connect(self.ui.test_subset_percent.setEnabled)
        self.ui.val_on_test_set_CB.stateChanged.connect(self.disable_validation_set)

        self.ui.training_percent.valueChanged.connect(self.on_trainsample_update)
        self.ui.validation_percent.valueChanged.connect(self.on_valsample_update)
        self.ui.test_percent.valueChanged.connect(self.on_testsample_update)
        self.ui.done_button.clicked.connect(self.on_done_clicked)
        self.ui.add_set_PB.clicked.connect(self.add_set)
Esempio n. 2
0
    def __init__(self, project):
        super().__init__()
        self.ui = Ui_Manage()
        self.ui.setupUi(self)
        self.project = project
        self.new_samples = DataSet()
        self.pie_view = None
        self.show_ratio(False)

        self.present_sets()
        self.init_comboBox()

        self.files = []
        #connect
        self.ui.add_PB.clicked.connect(self.add_samples)
        self.ui.browse_PB.clicked.connect(self.on_browse_clicked)
        self.ui.split_Radio.toggled.connect(self.show_ratio)
        self.ui.verify_PB.clicked.connect(self.verify)
        self.ui.addset_PB.clicked.connect(self.add_set)
        self.ui.clear_PB.clicked.connect(self.clear)
Esempio n. 3
0
    def test_on_internal(self):
        bad_results = []
        hotwords = self.project.project_info['hotwords']

        # Init outputs
        outputs = dict()
        outputs[None] = [0.0] * len(hotwords)
        for i, hw in enumerate(hotwords):
            outputs[hw] = [0.0] * len(hotwords)
            outputs[hw][i] = 1.0

        test_set = DataSet()
        test_set.load(self.project.data_info['test_set'])
        if self.ui.training_set.isChecked():
            test_set.load(self.project.data_info['train_set'])
        if self.ui.validation_set.isChecked():
            test_set.load(self.project.data_info['val_set'])

        # Extract features
        files, inputs, labels = files2features(
            [s['file_path'] for s in test_set],
            self.project.features_info,
            labels=[outputs[s['label']] for s in test_set],
            return_file=True,
            progress_callback=self.progress_display)

        # Make prediction
        self.progress_display(1, 1, "Predicting ...")
        res = self.model.predict(inputs)

        self.progress_display(1, 1, "Processing results ...")
        # Result sorting
        result_matrix = np.zeros((len(hotwords) + 1, len(hotwords) + 1))
        for f, result, label in zip(files, res, labels):
            ct = np.argmax(label) + 1 if any(label) else 0
            cp = np.squeeze(np.argwhere(result > self.threshold)) + 1
            if cp.size == 0:
                cp = [0]
            elif cp.ndim == 0:
                cp = cp[np.newaxis]
            for c in cp:
                result_matrix[ct][c] += 1
                if c != ct:
                    bad_results.append((f, ct, c))

        self.progress_display(1, 1, "Done")
        # Display Metrics
        self.display_metrics(result_matrix, len(res))
        self.display_false_results(bad_results)
Esempio n. 4
0
    def generate_charts(self):
        empty_layout(self.layout)
        self.figure = Figure()
        self.canvas = FigureCanvas(self.figure)
        self.layout.addWidget(self.canvas)
        base_text = self.ui.analyse_PB.text()
        self.ui.analyse_PB.setText("Processing ...")

        hotwords = self.project.project_info['hotwords']
        n_graph = 1 + self.ui.variance_CB.isChecked()
        dataset = DataSet()
        dataset.load(self.project.data_info['train_set'])
        for i, hotword in enumerate(hotwords):
            axes = self.figure.add_subplot(len(hotwords), n_graph,
                                           i * n_graph + 1)
            axes.clear()
            files = [
                sample['file_path']
                for sample in dataset.get_subset_by_label(hotword)
            ]
            data = files2features(files, self.tab.generate_dict())
            axes.imshow(np.mean(data, axis=0).T,
                        interpolation='nearest',
                        origin='lower')
            axes.set_title("{} (mean)".format(hotword))

            if self.ui.variance_CB.isChecked():
                axes = self.figure.add_subplot(len(hotwords), n_graph,
                                               i * n_graph + 2)
                axes.clear()

                axes.imshow(np.var(data, axis=0).T,
                            interpolation='nearest',
                            origin='lower')
                axes.set_title("{} (variance)".format(hotword))

        self.canvas.draw()
        self.ui.analyse_PB.setText(base_text)
    def generate_sets(self):
        self.project.data_info['samples_location'] = dict()
        
        sub_set_by_label = []
        for hw in self.sample_classes:
            sub_set_by_label.append(self.main_dataset.get_subset_by_label(hw))
        
        train_set = DataSet()
        val_set = DataSet()
        test_set = DataSet()

        for s in sub_set_by_label:
            tr, val, test = s.split_dataset([self.ui.training_percent.value(), self.ui.validation_percent.value(), self.ui.test_percent.value()], not self.shuffle_set)
            train_set += tr
            val_set += val
            test_set += test
        for name, s in zip(["train", "val", "test"], [train_set, val_set, test_set]):
            json = os.path.join(self.project.project_location, "{}.json".format(name))
            self.project.data_info['{}_set'.format(name)] = json
            s.write(json) 
        
        self.project.data_info['set'] = True
        self.project.update_project()
Esempio n. 6
0
    def add_samples(self):
        split = False
        n_samples = len(self.new_samples)
        n_duplicate = 0
        if self.ui.train_Radio.isChecked():
            set_name = 'train_set'
        elif self.ui.test_Radio.isChecked():
            set_name = 'test_set'
        elif self.ui.val_Radio.isChecked():
            set_name = 'val_set'
        else:
            split = True
            set_name = ['train_set', 'test_set', 'val_set']
        
        if not split:
            #Add to a single set
            dataset = DataSet()
            dataset.load(self.project.data_info[set_name])
            dataset += self.new_samples
            if self.ui.duplicateCheck.isChecked():
                n_duplicate = dataset.remove_duplicate()
            dataset.write(self.project.data_info[set_name])
        else:
            #Add to multiple sets
            ratios = [self.ui.train_ratio.value(), self.ui.test_ratio.value(), self.ui.val_ratio.value()]
            if sum(ratios) == 0:
                QtWidgets.QMessageBox.warning(self, "No Ratio specified", "You must specify ratios")
                return
            else:
                dest = [DataSet() for _ in set_name]
                splits = self.new_samples.split_dataset(ratios, split_using_attr=True)
                for d, t, s in zip(dest, set_name, splits):
                    d.load(self.project.data_info[t])
                    d += s
                    if self.ui.duplicateCheck.isChecked():
                        n_duplicate = d.remove_duplicate()
                    d.write(self.project.data_info[t])
        

        QtWidgets.QMessageBox.warning(self, "Samples added", "Added {} samples ({} duplicates ignored)".format(n_samples - n_duplicate, n_duplicate)) 
            
        self.ui.add_PB.setEnabled(False)
        self.clear()
        self.present_sets()
Esempio n. 7
0
    def present_sets(self):
        self.ui.resume_TE.clear()
        if self.pie_view is None:
            self.init_graph()
        
        self.ui.resume_TE.appendPlainText("Data distribution\n")
        self.train_set = DataSet()
        self.train_set.load(self.project.data_info['train_set'])

        self.val_set = DataSet()
        self.val_set.load(self.project.data_info['val_set'])

        self.test_set = DataSet()
        self.test_set.load(self.project.data_info['test_set'])
        
        total_c = sum([len(s) for s in [self.train_set, self.val_set, self.test_set]])
        self.ui.resume_TE.appendPlainText("Total sample: {}\n".format(total_c))
        classes = self.project.project_info['hotwords'] + [None]
        class_count = [0 for _ in classes]

        for s, set_name in zip([self.train_set, self.val_set, self.test_set], ['Train', 'Validation', 'Test']):
            self.ui.resume_TE.appendPlainText('{} set : {} samples ({:.2f}%)\n'.format(set_name, len(s), (len(s)/ total_c *100) if total_c > 0 else 0.0))
            for i, cl in enumerate(classes):
                count = len(s.get_subset_by_label(cl))
                if cl is None :
                    cl = 'non-hotword'
                class_count[i] += count
                self.ui.resume_TE.appendPlainText('\t- {} : {} ({:.2f}%)\n'.format(cl, count, (count/len(s) * 100) if len(s) > 0 else 0))

        self.ui.resume_TE.appendPlainText('Total sample:\n'.format())
        for i, cl in enumerate(classes):
            class_name = classes[i]
            if class_name is None :
                class_name = 'non-hotword'
            self.ui.resume_TE.appendPlainText('\t- {} : {}({:.2f}%)'.format(class_name, class_count[i], (class_count[i] / total_c * 100) if total_c > 0 else 0))
            self.pieSlices[i].setValue(class_count[i])
            self.pieSlices[i].setLabel('{} : {}({:.2f}%)'.format(class_name, class_count[i], (class_count[i] / total_c * 100) if total_c > 0 else 0))
Esempio n. 8
0
class Manage(QtWidgets.QWidget):
    datasets_modified = QtCore.pyqtSignal(name='datasets_modified')
    def __init__(self, project):
        super().__init__()
        self.ui = Ui_Manage()
        self.ui.setupUi(self)
        self.project = project
        self.new_samples = DataSet()
        self.pie_view = None
        self.show_ratio(False)

        self.present_sets()
        self.init_comboBox()

        self.files = []
        #connect
        self.ui.add_PB.clicked.connect(self.add_samples)
        self.ui.browse_PB.clicked.connect(self.on_browse_clicked)
        self.ui.split_Radio.toggled.connect(self.show_ratio)
        self.ui.verify_PB.clicked.connect(self.verify)
        self.ui.addset_PB.clicked.connect(self.add_set)
        self.ui.clear_PB.clicked.connect(self.clear)

    def init_graph(self):
        self.pieSlices = []
        for hw in self.project.project_info['hotwords'] + ['non-hotword']:
            self.pieSlices.append(QtChart.QPieSlice("{} ()".format(hw), 0))
        self.pie_series = QtChart.QPieSeries()

        for pie_slice in self.pieSlices:
            self.pie_series.append(pie_slice)
        self.pie_series.setHoleSize(0)

        self.pie_view = QtChart.QChartView()
        self.pie_view.setRenderHint(QtGui.QPainter.Antialiasing)
        self.pie_view.chart().layout().setContentsMargins(0,0,0,0)
        self.pie_view.chart().setMargins(QtCore.QMargins(0,0,0,0))
        self.pie_view.chart().legend().setAlignment(QtCore.Qt.AlignBottom)
        self.pie_view.chart().addSeries(self.pie_series)

        layout = QtWidgets.QVBoxLayout()
        layout.addWidget(self.pie_view)
        self.ui.graph_placeholder.setLayout(layout)

    def init_comboBox(self):
        for hw in self.project.project_info['hotwords'] + ['non-hotword']:
            self.ui.hotword_Combo.addItem(hw)

    def on_browse_clicked(self):
        res = QtWidgets.QFileDialog.getExistingDirectory(self, "Select a directory", "")
        if len(res) != 0:
            self.ui.browse_LE.setText(res)
            self.add_from_folder(res)
    
    def present_sets(self):
        self.ui.resume_TE.clear()
        if self.pie_view is None:
            self.init_graph()
        
        self.ui.resume_TE.appendPlainText("Data distribution\n")
        self.train_set = DataSet()
        self.train_set.load(self.project.data_info['train_set'])

        self.val_set = DataSet()
        self.val_set.load(self.project.data_info['val_set'])

        self.test_set = DataSet()
        self.test_set.load(self.project.data_info['test_set'])
        
        total_c = sum([len(s) for s in [self.train_set, self.val_set, self.test_set]])
        self.ui.resume_TE.appendPlainText("Total sample: {}\n".format(total_c))
        classes = self.project.project_info['hotwords'] + [None]
        class_count = [0 for _ in classes]

        for s, set_name in zip([self.train_set, self.val_set, self.test_set], ['Train', 'Validation', 'Test']):
            self.ui.resume_TE.appendPlainText('{} set : {} samples ({:.2f}%)\n'.format(set_name, len(s), (len(s)/ total_c *100) if total_c > 0 else 0.0))
            for i, cl in enumerate(classes):
                count = len(s.get_subset_by_label(cl))
                if cl is None :
                    cl = 'non-hotword'
                class_count[i] += count
                self.ui.resume_TE.appendPlainText('\t- {} : {} ({:.2f}%)\n'.format(cl, count, (count/len(s) * 100) if len(s) > 0 else 0))

        self.ui.resume_TE.appendPlainText('Total sample:\n'.format())
        for i, cl in enumerate(classes):
            class_name = classes[i]
            if class_name is None :
                class_name = 'non-hotword'
            self.ui.resume_TE.appendPlainText('\t- {} : {}({:.2f}%)'.format(class_name, class_count[i], (class_count[i] / total_c * 100) if total_c > 0 else 0))
            self.pieSlices[i].setValue(class_count[i])
            self.pieSlices[i].setLabel('{} : {}({:.2f}%)'.format(class_name, class_count[i], (class_count[i] / total_c * 100) if total_c > 0 else 0))

    def verify(self):
        missing_files = []
        for s, path in zip([self.train_set, self.val_set, self.test_set],
                           [self.project.data_info['train_set'], self.project.data_info['val_set'], self.project.data_info['test_set']]):
            missing_files.append(s.verify_and_clear())
            s.write(path)
        msgBox = QtWidgets.QMessageBox()
        if sum([len(l) for l in missing_files]) == 0:
            msgBox.setText("No missing file !")
        else:
            msgBox.setText("{} missing file (removed). \n Check terminal for details".format(sum([len(l) for l in missing_files])))
            for i, l in enumerate(missing_files):
                print("{} set : {} missing files".format(['Train', 'Validation', 'Test'][i], len(l)))
                for f in l:
                    print(f['file_path'])
        msgBox.exec()

    def clear(self):
        self.new_samples.clear()
        self.update_sample_preview()
        self.ui.clear_PB.setEnabled(False)
    
    def add_set(self, json_path):
        res = QtWidgets.QFileDialog.getOpenFileName(self, "Select a file", "/home/", "Json file (*.json)")[0]
        if len(res) != 0:
            dialog = SetFilePreview(res)
            dialog.add_clicked.connect(self.new_samples.add_from_manifest)
            dialog.exec()
        self.update_sample_preview()

    def add_from_folder(self, folder):
        label = self.ui.hotword_Combo.currentText()
        if label == "non-hotword":
            label = None
        self.new_samples.add_from_folder(folder, label, ext='wav')
        self.update_sample_preview()

    
    def update_sample_preview(self):
        self.ui.samples_TE.clear()
        if len(self.new_samples) == 0:
            self.ui.samples_TE.appendPlainText("No samples")
            return
        self.ui.samples_TE.appendPlainText("New samples: {}\n".format(len(self.new_samples)))
        for hw in self.project.project_info['hotwords'] + [None]:
            self.ui.samples_TE.appendPlainText("\t{}: {} samples\n".format(hw if hw is not None else 'Non-hotword',len(self.new_samples.get_subset_by_label(hw))))
        if len(self.new_samples) > 0:
            self.ui.add_PB.setEnabled(True)
            self.ui.clear_PB.setEnabled(True)
        else:
            self.ui.add_PB.setEnabled(False)
            self.ui.clear_PB.setEnabled(False)

    def add_samples(self):
        split = False
        n_samples = len(self.new_samples)
        n_duplicate = 0
        if self.ui.train_Radio.isChecked():
            set_name = 'train_set'
        elif self.ui.test_Radio.isChecked():
            set_name = 'test_set'
        elif self.ui.val_Radio.isChecked():
            set_name = 'val_set'
        else:
            split = True
            set_name = ['train_set', 'test_set', 'val_set']
        
        if not split:
            #Add to a single set
            dataset = DataSet()
            dataset.load(self.project.data_info[set_name])
            dataset += self.new_samples
            if self.ui.duplicateCheck.isChecked():
                n_duplicate = dataset.remove_duplicate()
            dataset.write(self.project.data_info[set_name])
        else:
            #Add to multiple sets
            ratios = [self.ui.train_ratio.value(), self.ui.test_ratio.value(), self.ui.val_ratio.value()]
            if sum(ratios) == 0:
                QtWidgets.QMessageBox.warning(self, "No Ratio specified", "You must specify ratios")
                return
            else:
                dest = [DataSet() for _ in set_name]
                splits = self.new_samples.split_dataset(ratios, split_using_attr=True)
                for d, t, s in zip(dest, set_name, splits):
                    d.load(self.project.data_info[t])
                    d += s
                    if self.ui.duplicateCheck.isChecked():
                        n_duplicate = d.remove_duplicate()
                    d.write(self.project.data_info[t])
        

        QtWidgets.QMessageBox.warning(self, "Samples added", "Added {} samples ({} duplicates ignored)".format(n_samples - n_duplicate, n_duplicate)) 
            
        self.ui.add_PB.setEnabled(False)
        self.clear()
        self.present_sets()

    def show_ratio(self, visible):
        self.ui.val_ratio.setVisible(visible)
        self.ui.train_ratio.setVisible(visible)
        self.ui.test_ratio.setVisible(visible)
Esempio n. 9
0
    def training(self):
        """ Train the graph using paremeters set in the GUI """
        def check_dense_sizes(n_dense, dense_sizes):
            if len(dense_sizes) != n_dense:
                raise ValueError(
                    "The number of values for dense size must be equal to the number of dense layers."
                )

        # Load or create model
        if self.project.model_info['set']:
            if self.model == None:
                try:
                    self.model = load_model(self.project.model_path)
                except:
                    print("Could not load the model file at {}".format(
                        self.project.model_path))
                    return
        else:
            try:
                check_dense_sizes(self.n_denses, self.dense_sizes)
            except ValueError as e:
                error_box = QtWidgets.QMessageBox(self)
                error_box.setIcon(QtWidgets.QMessageBox.Warning)
                error_box.setText(str(e.args))
                error_box.setWindowTitle("Error")
                error_box.setStandardButtons(QtWidgets.QMessageBox.Ok)
                res = error_box.exec()

                return
            if self.model_name == '':
                self.ui.model_name.setFocus()
                return
            else:

                if not self.model_name.endswith('.hdf5'):
                    self.model_name = self.model_name + '.hdf5'
                self.project.project_info['model_name'] = self.model_name
            self.model = create_model(
                self.model_name,
                (self.input_x, self.input_y),
                n_denses=self.n_denses,
                dense_sizes=self.dense_sizes,
                dropout=self.dropout,
                output_size=self.output_size,
                loss_fun=self.loss_bias,
                #metrics=[],
                noise_layer_derivation=self.gaussian_noise_stdder,
                unroll=self.ui.unroll_CB.isChecked())

        self.callbacks = callbacks(
            os.path.join(self.project.project_location,
                         self.project.project_info['model_name']),
            self.only_keep_best)

        self.callbacks.append(Epoch_CallBack(self.train_callback))
        self.target_epoch = self.current_epoch + self.n_epochs
        print(self.model.summary())

        # Check and vectorize samples
        if not self.vectorized or self.ui.pos_only_CB.isChecked():
            self.validation_set, self.validation_set_output = [], []
            error_files = []
            hotwords = self.project.project_info['hotwords']

            train_dataset = DataSet()
            val_dataset = DataSet()

            train_dataset.load(self.project.data_info['train_set'])
            val_dataset.load(self.project.data_info['val_set'])

            # If train on keyword only is checked
            if self.ui.pos_only_CB.isChecked():
                train_dataset = train_dataset.get_subset_by_labels(
                    self.project.project_info['hotwords'])
                val_dataset = val_dataset.get_subset_by_labels(
                    self.project.project_info['hotwords'])

            outputs = dict()
            outputs[None] = [0.0] * len(hotwords)
            for i, hw in enumerate(hotwords):
                outputs[hw] = [0.0] * len(hotwords)
                outputs[hw][i] = 1.0
            self.progress_display(0, 1, "Collecting training samples ...")

            unproc_train = [(s['file_path'],
                             outputs.get(s['label'], outputs[None]))
                            for s in train_dataset]
            unproc_val = [(s['file_path'],
                           outputs.get(s['label'], outputs[None]))
                          for s in val_dataset]

            self.progress_display(0, 1, "Extracting training features ...")
            QtWidgets.QApplication.instance().processEvents()

            self.train_set, self.train_set_output = files2features(
                [s[0] for s in unproc_train],
                self.project.features_info,
                labels=[s[1] for s in unproc_train],
                progress_callback=self.progress_display)

            self.progress_display(0, 1, "Extracting validation features ...")

            self.validation_set, self.validation_set_output = files2features(
                [s[0] for s in unproc_val],
                self.project.features_info,
                labels=[s[1] for s in unproc_val],
                progress_callback=self.progress_display)
            if not self.ui.pos_only_CB.isChecked():
                self.vectorized = True
        self.progress_display(0, 1, "Training ...")
        self.ui.stop_button.setEnabled(True)
        QtWidgets.QApplication.instance().processEvents()

        # Fits model and uses outputs to update chart
        self.model.fit([self.train_set], [self.train_set_output],
                       batch_size=self.batch_size,
                       initial_epoch=self.current_epoch,
                       epochs=self.target_epoch,
                       callbacks=self.callbacks,
                       validation_data=([self.validation_set],
                                        [self.validation_set_output]),
                       verbose=0,
                       shuffle=self.shuffle)

        self.ui.stop_button.setEnabled(False)
        self.text_output = "Training complete !"
        self.ui.progressBar.setValue(100)

        self.current_epoch += 1
        self.setup_model()
        self.write_training_log(self.logfile_name())
class Prepare(QtWidgets.QWidget):
    #TODO sample counter for validation is borked when checkbox are checked
    #TODO test validation on test and validation on test sub_set
    last_folder = "/home"
    prepare_complete = QtCore.pyqtSignal(name='prepare_complete')

    def __init__(self, project):
        super().__init__()
        self.ui = Ui_Prepare()
        self.ui.setupUi(self)
        self.project = project
        self.last_hotword_index = 0
        self.main_dataset = DataSet()
        # Init hotword list
        
        self.hotwords = self.project.project_info['hotwords']
        for hotword in self.hotwords:
            self.ui.current_hotword.addItem(hotword)

        self.sample_classes = [None] + self.hotwords

        # Init directory list
        self.folder_list_layout = QtWidgets.QVBoxLayout()
        self.folder_list = [[] for i in range(len(self.hotwords) + 1)]

        self.pie_slices = [QtChart.QPieSlice("non-hotword ()", 0)]
        self.pie_slices.extend([QtChart.QPieSlice("{} ()".format(name), 0) for name in self.hotwords])
        self.display_selected_directories(0)

        # Pie chart
        self.pie_series = QtChart.QPieSeries()
        for pie_slice in self.pie_slices:
            self.pie_series.append(pie_slice)
        self.pie_series.setHoleSize(0)

        self.pie_view = QtChart.QChartView()
        self.pie_view.setRenderHint(QtGui.QPainter.Antialiasing)
        self.pie_view.chart().layout().setContentsMargins(0,0,0,0)
        self.pie_view.chart().setMargins(QtCore.QMargins(0,0,0,0))
        self.pie_view.chart().legend().setAlignment(QtCore.Qt.AlignRight)
        self.pie_view.chart().addSeries(self.pie_series)

        layout = QtWidgets.QVBoxLayout()
        layout.addWidget(self.pie_view)
        self.ui.graph_placeholder.setLayout(layout)

        #Connects
        self.ui.current_hotword.currentIndexChanged.connect(self.display_selected_directories)
        self.ui.add_folder_button.clicked.connect(self.add_folder)
        self.ui.test_subset_CB.stateChanged.connect(self.ui.test_subset_percent.setEnabled)
        self.ui.val_on_test_set_CB.stateChanged.connect(self.disable_validation_set)

        self.ui.training_percent.valueChanged.connect(self.on_trainsample_update)
        self.ui.validation_percent.valueChanged.connect(self.on_valsample_update)
        self.ui.test_percent.valueChanged.connect(self.on_testsample_update)
        self.ui.done_button.clicked.connect(self.on_done_clicked)
        self.ui.add_set_PB.clicked.connect(self.add_set)

    def display_selected_directories(self, index):
        self.ui.folder_list.clear()
        #TODO list is not cleared
        for folder in self.folder_list[index]:
            l_item = QtWidgets.QListWidgetItem()
            l_widget = Folder_Line(folder)
            l_widget.line_removed.connect(self.remove_folder)
            l_item.setSizeHint(l_widget.sizeHint())
            self.ui.folder_list.addItem(l_item)
            self.ui.folder_list.setItemWidget(l_item, l_widget)

    def update_hotword_combo_size(self, value):
        current_size = self.ui.current_hotword.count()
        if current_size > value + 1:
            self.ui.current_hotword.removeItem(current_size - 1)
        elif current_size < value + 1:
            self.ui.current_hotword.addItem("")
        self.update_hotword_combo_names()
    

    def remove_folder(self, folder):
        index = self.ui.current_hotword.currentIndex()
        self.folder_list[index].remove(folder)
        self.display_selected_directories(index)
        self.hotword_samples[index].setValue(self.hotword_samples[index].value() - len([f for f in os.listdir(folder) if f.endswith('.wav')]))
        self.update_sample_counter()
    
    def add_folder(self):
        index = self.ui.current_hotword.currentIndex()
        res = QtWidgets.QFileDialog.getExistingDirectory(self, "Select a directory", self.last_folder)
        if len(res) != 0 and res not in self.folder_list[index]:
            self.last_folder = os.path.dirname(res)
            self.folder_list[index].append(res)
            self.display_selected_directories(index)
            self.main_dataset.add_from_folder(folder_path=res, 
                                              label = self.sample_classes[index],
                                              ext = ".wav")
            self.update_chart()

    def add_set(self):
        res = QtWidgets.QFileDialog.getOpenFileName(self, "Select a file", self.last_folder, "Json file (*.json)")[0]
        if len(res) != 0:
            dialog = SetFilePreview(res)
            dialog.add_clicked.connect(self.main_dataset.add_from_manifest)
            dialog.exec()
            self.update_chart()

    def update_chart(self):
        total_sample = len(self.main_dataset)
        for i, hotword in enumerate(self.sample_classes):
            n_s = len(self.main_dataset.get_subset_by_label(hotword))
            self.pie_slices[i].setValue(n_s)
            if hotword is None:
                hotword = 'non-hotword'
            self.pie_slices[i].setLabel("{} : {} ({:.2f}%)".format(hotword, n_s, (0 if total_sample == 0 else n_s / total_sample * 100)))
            self.pie_view.chart().setTitle("{} samples".format(total_sample))

    def disable_validation_set(self, value:bool):
        self.ui.validation_percent.setValue(0.0)
        self.ui.test_percent.setValue(100 - self.ui.training_percent.value())
        self.ui.validation_percent.setEnabled(not value)
        self.ui.test_subset_percent.setValue(100.0)
        self.ui.test_subset_CB.setEnabled(value)
        self.ui.test_subset_CB.setChecked(False)
    
    def update_sets_sample_count(self):
        total_sample_count = len(self.main_dataset)
        self.ui.n_train_samples.setValue(round(total_sample_count * self.ui.training_percent.value() / 100))
        self.ui.n_test_samples.setValue(round(total_sample_count * self.ui.test_percent.value() / 100))
        if self.ui.val_on_test_set_CB.isChecked():
            self.ui.n_val_samples.setValue(self.ui.n_test_samples.value() * self.ui.test_subset_percent.value() / 100)
        else:
            self.ui.n_val_samples.setValue(round(total_sample_count * self.ui.validation_percent.value() / 100))
    
    def on_trainsample_update(self, value):
        total = sum([value, self.ui.validation_percent.value(), self.ui.test_percent.value()])
        if total > 100:
            self.ui.training_percent.setValue(value - (total - 100))
        self.update_sets_sample_count()
    
    def on_valsample_update(self, value):
        total = sum([self.ui.training_percent.value(), value, self.ui.test_percent.value()])
        if total > 100:
            self.ui.validation_percent.setValue(value - (total - 100))
        self.update_sets_sample_count()

    def on_testsample_update(self, value):
        total = sum([self.ui.training_percent.value(), self.ui.validation_percent.value(), value])
        if total > 100:
            self.ui.test_percent.setValue(value - (total - 100))
        self.update_sets_sample_count()

    def on_done_clicked(self):
        self.generate_sets()
        self.ui.folder_list.clear()

    def generate_sets(self):
        self.project.data_info['samples_location'] = dict()
        
        sub_set_by_label = []
        for hw in self.sample_classes:
            sub_set_by_label.append(self.main_dataset.get_subset_by_label(hw))
        
        train_set = DataSet()
        val_set = DataSet()
        test_set = DataSet()

        for s in sub_set_by_label:
            tr, val, test = s.split_dataset([self.ui.training_percent.value(), self.ui.validation_percent.value(), self.ui.test_percent.value()], not self.shuffle_set)
            train_set += tr
            val_set += val
            test_set += test
        for name, s in zip(["train", "val", "test"], [train_set, val_set, test_set]):
            json = os.path.join(self.project.project_location, "{}.json".format(name))
            self.project.data_info['{}_set'.format(name)] = json
            s.write(json) 
        
        self.project.data_info['set'] = True
        self.project.update_project()


    @property
    def training_percent(self):
        return self.ui.training_percent.value()
    
    @training_percent.setter
    def training_percent(self, value):
        self.ui.training_percent.setValue(value)

    @property
    def validation_percent(self):
        return self.ui.validation_percent.value()
    
    @validation_percent.setter
    def validation_percent(self, value):
        self.ui.validation_percent.setValue(value)

    @property
    def test_percent(self):
        return self.ui.test_percent.value()
    
    @test_percent.setter
    def test_percent(self, value):
        self.ui.test_percent.setValue(value)

    @property
    def shuffle_set(self):
        return self.ui.shuffle.isChecked()