Exemple #1
0
 def export_TensorflowJS(self):
     import tensorflowjs as tfjs
     dir_path = QtWidgets.QFileDialog.getExistingDirectory(
         self, "Select a target directory", "/home/")
     if dir_path:
         model = load_model(self.project.model_path)
         tfjs.converters.save_keras_model(model, dir_path)
Exemple #2
0
    def export_TensorflowLite(self):
        model = load_model(self.project.model_path)
        converter = tf.lite.TFLiteConverter.from_keras_model(model)

        try:
            tflite_model = converter.convert()
        except Exception as e:
            error_box = QtWidgets.QMessageBox(self)
            error_box.setIcon(QtWidgets.QMessageBox.Warning)
            error_box.setText(
                "Unable to create the tensorflowLite model, see logs")
            error_box.setWindowTitle("Error")
            error_box.setStandardButtons(QtWidgets.QMessageBox.Ok)
            res = error_box.exec()
            print(str(e))
        else:
            gen_name = self.project.project_info.get('project_name',
                                                     'model') + ".tflite"
            file_path = QtWidgets.QFileDialog.getSaveFileName(
                self, "Save model parameter file",
                os.path.join(self.project.project_location, gen_name),
                "TFLite File (*.tflite)")[0]
            if file_path is None or len(file_path) == 0:
                return
            with open(file_path, "wb") as f:
                f.write(tflite_model)
Exemple #3
0
 def export_TF_PB(self):
     model = load_model(self.project.model_path)
     export_folder = os.path.join(
         self.project.project_location,
         "output/{}/{}/".format(self.project.project_info['project_name'],
                                self.ui.version_SB.value()))
     tf.saved_model.save(model, export_folder)
     QtWidgets.QMessageBox.information(
         self, "Model exported",
         "Saved model exported at {}".format(export_folder))
Exemple #4
0
 def on_test_clicked(self):
     self.setEnabled(False)
     self.reset_table()
     self.reset_metrics()
     self.model = load_model(self.project.model_path)
     print(self.model.summary())
     if self.ui.external_CB.isChecked():
         self.test_on_external()
     else:
         self.test_on_internal()
     self.setEnabled(True)
Exemple #5
0
 def __init__(self, project):
     super().__init__()
     self.project = project
     if project.model_info['set']:
         model_path = self.project.model_path
         from scripts.keras_functions import load_model
         from tensorflow.keras.utils import plot_model
         model_dir = os.path.dirname(model_path)
         graph_path = os.path.join(model_dir, 'graph.png')
         plot_model(load_model(model_path), to_file=graph_path, show_shapes=True)
         pixmap = QtGui.QPixmap(graph_path)
         self.setPixmap(pixmap)
     else:
         self.setText("Model haven't been created yet")
 def _load_keras_model(self, model_path: str):
     self.model = load_model(model_path)
     self.model._make_predict_function()
     return self.model.predict
Exemple #7
0
    def training(self):
        """ Train the graph using paremeters set in the GUI """
        def check_dense_sizes(n_dense, dense_sizes):
            if len(dense_sizes) != n_dense:
                raise ValueError(
                    "The number of values for dense size must be equal to the number of dense layers."
                )

        # Load or create model
        if self.project.model_info['set']:
            if self.model == None:
                try:
                    self.model = load_model(self.project.model_path)
                except:
                    print("Could not load the model file at {}".format(
                        self.project.model_path))
                    return
        else:
            try:
                check_dense_sizes(self.n_denses, self.dense_sizes)
            except ValueError as e:
                error_box = QtWidgets.QMessageBox(self)
                error_box.setIcon(QtWidgets.QMessageBox.Warning)
                error_box.setText(str(e.args))
                error_box.setWindowTitle("Error")
                error_box.setStandardButtons(QtWidgets.QMessageBox.Ok)
                res = error_box.exec()

                return
            if self.model_name == '':
                self.ui.model_name.setFocus()
                return
            else:

                if not self.model_name.endswith('.hdf5'):
                    self.model_name = self.model_name + '.hdf5'
                self.project.project_info['model_name'] = self.model_name
            self.model = create_model(
                self.model_name,
                (self.input_x, self.input_y),
                n_denses=self.n_denses,
                dense_sizes=self.dense_sizes,
                dropout=self.dropout,
                output_size=self.output_size,
                loss_fun=self.loss_bias,
                #metrics=[],
                noise_layer_derivation=self.gaussian_noise_stdder,
                unroll=self.ui.unroll_CB.isChecked())

        self.callbacks = callbacks(
            os.path.join(self.project.project_location,
                         self.project.project_info['model_name']),
            self.only_keep_best)

        self.callbacks.append(Epoch_CallBack(self.train_callback))
        self.target_epoch = self.current_epoch + self.n_epochs
        print(self.model.summary())

        # Check and vectorize samples
        if not self.vectorized or self.ui.pos_only_CB.isChecked():
            self.validation_set, self.validation_set_output = [], []
            error_files = []
            hotwords = self.project.project_info['hotwords']

            train_dataset = DataSet()
            val_dataset = DataSet()

            train_dataset.load(self.project.data_info['train_set'])
            val_dataset.load(self.project.data_info['val_set'])

            # If train on keyword only is checked
            if self.ui.pos_only_CB.isChecked():
                train_dataset = train_dataset.get_subset_by_labels(
                    self.project.project_info['hotwords'])
                val_dataset = val_dataset.get_subset_by_labels(
                    self.project.project_info['hotwords'])

            outputs = dict()
            outputs[None] = [0.0] * len(hotwords)
            for i, hw in enumerate(hotwords):
                outputs[hw] = [0.0] * len(hotwords)
                outputs[hw][i] = 1.0
            self.progress_display(0, 1, "Collecting training samples ...")

            unproc_train = [(s['file_path'],
                             outputs.get(s['label'], outputs[None]))
                            for s in train_dataset]
            unproc_val = [(s['file_path'],
                           outputs.get(s['label'], outputs[None]))
                          for s in val_dataset]

            self.progress_display(0, 1, "Extracting training features ...")
            QtWidgets.QApplication.instance().processEvents()

            self.train_set, self.train_set_output = files2features(
                [s[0] for s in unproc_train],
                self.project.features_info,
                labels=[s[1] for s in unproc_train],
                progress_callback=self.progress_display)

            self.progress_display(0, 1, "Extracting validation features ...")

            self.validation_set, self.validation_set_output = files2features(
                [s[0] for s in unproc_val],
                self.project.features_info,
                labels=[s[1] for s in unproc_val],
                progress_callback=self.progress_display)
            if not self.ui.pos_only_CB.isChecked():
                self.vectorized = True
        self.progress_display(0, 1, "Training ...")
        self.ui.stop_button.setEnabled(True)
        QtWidgets.QApplication.instance().processEvents()

        # Fits model and uses outputs to update chart
        self.model.fit([self.train_set], [self.train_set_output],
                       batch_size=self.batch_size,
                       initial_epoch=self.current_epoch,
                       epochs=self.target_epoch,
                       callbacks=self.callbacks,
                       validation_data=([self.validation_set],
                                        [self.validation_set_output]),
                       verbose=0,
                       shuffle=self.shuffle)

        self.ui.stop_button.setEnabled(False)
        self.text_output = "Training complete !"
        self.ui.progressBar.setValue(100)

        self.current_epoch += 1
        self.setup_model()
        self.write_training_log(self.logfile_name())