Пример #1
0
    def __init__(self, project: Project):
        _Module.__init__(self, project)
        self.ui = Ui_Form()
        self.ui.setupUi(self)
        self.project = project
        self.project.project_updated.connect(self.updateDisplay)
        self.currentDataset = DataSet()
        if len(self.project.datasets) > 0:
            self.currentDataset = self.project.getDatasetByName(
                self.project.datasets[-1])

        self.populateDatasetCB()

        #Chart
        self.chart = DataChart(self.currentDataset.datasetValues())
        self.ui.graphPlaceHolder.setLayout(QtWidgets.QHBoxLayout())
        self.ui.graphPlaceHolder.layout().addWidget(self.chart)

        self.updateDisplay()

        # CONNECT
        self.ui.currentDataSet_CB.currentTextChanged.connect(
            self.onDataSetChanged)
        self.currentDataset.dataset_updated.connect(self.updateDisplay)

        ## buttons
        self.ui.createDataSet_PB.clicked.connect(self.onCreateDatasetClicked)
        self.ui.delete_PB.clicked.connect(self.onDeleteDatasetClicked)
        self.ui.addFromFolder_PB.clicked.connect(self.addFromFolder)
        self.ui.export_PB.clicked.connect(self.onExportClicked)
        self.ui.import_PB.clicked.connect(self.onImportClicked)
        self.ui.remove_PB.clicked.connect(self.onRemoveClicked)
Пример #2
0
 def getDatasetByName(self, name) -> DataSet:
     if name not in self.datasets:
         return None
     dataset = DataSet()
     dataSetPath = os.path.join(self.project_location, "data", name,
                                name + ".json")
     dataset.loadDataSet(dataSetPath)
     return dataset
Пример #3
0
 def addNewDataSet(self, name):
     self.datasets.append(name)
     dataSetPath = os.path.join(self.project_location, "data", name)
     os.mkdir(dataSetPath)
     os.mkdir(os.path.join(dataSetPath, "features"))
     dataset = DataSet(name, self.keywords)
     dataset.saveDataSet(os.path.join(dataSetPath, name + ".json"))
     self._write()
     self.dataset_updated.emit()
def read_data_sets(f=0.9):
    X_train, y_train = _extract_images()
    X_predict, _ = _extract_images(test=True)
    # split X_train, y_train in train and test
    perm = np.arange(X_train.shape[0])
    np.random.shuffle(perm)
    X_train = X_train[perm]
    y_train = y_train[perm]
    N = int(f * X_train.shape[0])
    train = DataSet(X_train[:N], y_train[:N])
    test = DataSet(X_train[N:], y_train[N:])
    predict = DataSet(X_predict, None)
    return DataSets(train=train, test=test, predict=predict)
Пример #5
0
 def deleteDataset(self, name: str):
     try:
         self.project.deleteDataSet(name)
     except Exception as e:
         dialog = SimpleDialog(self, "Error", str(e))
         dialog.show()
         return
     self.populateDatasetCB()
     if len(self.project.datasets) > 0:
         self.currentDataset = self.project.getDatasetByName(
             self.project.datasets[0])
         self.ui.currentDataSet_CB.setCurrentIndex(0)
     else:
         self.currentDataset = DataSet()
     self.updateDisplay()
Пример #6
0
 def __getattr__(self, attribute):
     if attribute == 'collections':
         dataset = DataSet(self.dms.collection)
         for key in self.keys():
             item = self[key]
             collection = item.collection
             dataset[collection.key] = collection
         return dataset
     elif attribute == 'documents':
         dataset = self.dms.document.new_dataset()
         DataSet(self.dms)
         for key in self.keys():
             item = self[key]
             document = item.document
             dataset[document.key] = document
         return dataset
     else:
         raise AttributeError('No such attribute %s' % attribute)
Пример #7
0
    def evaluate(self):
        # Data presentation
        evalSet = []
        
        testSet = DataSet()
        testSet.loadDataSet(self.currentProfile.testSetPath)
        evalSet.append(testSet)

        if self.ui.training_set.isChecked():
            trainSet = DataSet()
            trainSet.loadDataSet(self.currentProfile.trainSetPath)
            evalSet.append(trainSet)

        if self.ui.validation_set.isChecked():
            valSet = DataSet()
            valSet.loadDataSet(self.currentProfile.valSetPath)
            evalSet.append(valSet)

        # Feature extraction
        samples, inputs, expectedOutput = prepare_input_output(evalSet, 
                                                                self.currentProfile.features,
                                                                save_features_folder=self.currentProfile.featureFolder,
                                                                traceCallBack=self.displayState,
                                                                returnSamples=True)

        # Load model
        self.displayState("Loading model ...")
        model = loadModel(self.currentProfile.trainedModelPath)

        # Predictions
        self.displayState("Predicting ...")
        predictions = model.predict(inputs)

        # Result classification
        result_matrix = np.zeros((len(evalSet[0].labels) + 1, len(evalSet[0].labels) + 1))
        for sample, prediction, expectedResult in zip(samples, predictions, expectedOutput):
            ct = np.where(expectedResult > 0.0)[0] # Class truth
            if len(ct) == 0:
                ct = 0
            else:
                ct = ct[0] + 1
            triggered = max(prediction) > self.ui.threshold.value()
            if not triggered:
                cp = 0 
            else:
                cp = np.argmax(prediction) + 1 # Class predicted
            result_matrix[ct][cp] += 1
            if ct != cp:
                self.addFalseSample(sample, (ct, cp))

        # Display results
        self.setTableValues(result_matrix)
        self.displayMetrics(result_matrix)
Пример #8
0
    def onRemoveClicked(self):
        selectedSamples = []
        for i in range(self.ui.false_samples.count()):
            item = self.ui.false_samples.item(i)
            widget = self.ui.false_samples.itemWidget(item)
            if widget.CB.isChecked():
                selectedSamples.append(widget.sample)
        if len(selectedSamples) == 0:
            return
        testSet = DataSet()
        testSet.loadDataSet(self.currentProfile.testSetPath)

        trainSet = DataSet()
        trainSet.loadDataSet(self.currentProfile.trainSetPath)

        valSet = DataSet()
        valSet.loadDataSet(self.currentProfile.valSetPath)

        dialog = RemoveSamplesDialog(self, selectedSamples,
                                     [testSet, trainSet, valSet],
                                     self.currentProfile.dataset,
                                     outputFolder=self.currentProfile.folder)
        dialog.on_removed.connect(self.onSamplesRemoved)
        dialog.show()
Пример #9
0
    def train(self):
        # Training session
        if self.trainingSession is None:
            if self.currentTrained.hasModel:
                model = loadModel(self.currentTrained.trainedModelPath)
            else:
                self.updateState("Creating neural net")
                model = self.currentTrained.model.toKerasModel(
                    self.currentTrained.features.feature_shape,
                    len(self.project.keywords))
                saveModel(model, self.currentTrained.trainedModelPath)
                self.currentTrained.hasModel = True
            self.trainingSession = TrainingSession(model,
                                                   self.currentTrained.epoch)

        # Set charts ranges
        self.accChart.setRangeX(
            0, self.trainingSession.epoch + self.ui.epoch_SB.value())
        self.lossChart.setRangeX(
            0, self.trainingSession.epoch + self.ui.epoch_SB.value())

        # Fetch sets
        trainSet = DataSet()
        trainSet.loadDataSet(self.currentTrained.trainSetPath)

        valSet = DataSet()
        valSet.loadDataSet(self.currentTrained.valSetPath)

        # prepare inputs / outputs
        train_input, train_output = prepare_input_output(
            [trainSet],
            self.currentTrained.features,
            traceCallBack=self.updateState,
            save_features_folder=self.currentTrained.featureFolder)
        val_input, val_output = prepare_input_output(
            [valSet],
            self.currentTrained.features,
            traceCallBack=self.updateState,
            save_features_folder=self.currentTrained.featureFolder)

        # Set callbacks
        callbacks = callbacksDef(self.currentTrained.trainedModelPath,
                                 self.train_callback)

        # training
        self.trainingSession.model.fit(
            train_input,
            train_output,
            batch_size=self.ui.batch_SB.value(),
            initial_epoch=self.trainingSession.epoch,
            epochs=self.trainingSession.epoch + self.ui.epoch_SB.value() + 1,
            callbacks=callbacks,
            validation_data=(val_input, val_output),
            verbose=0,
            shuffle=self.ui.shuffle_CB.isChecked())

        self.currentTrained.isTrained = True
        self.currentTrained.writeTrained()
        self.project.trained_updated.emit()
        self.currentTrained.epoch = self.trainingSession.epoch
        self.currentTrained.writeTrained()

        # Write training logs
        self.writeLogs()
Пример #10
0
 def listSample(self, path):
     self.ui.browse_LE.setText(path)
     self.files = DataSet.listFolder(path, self.ui.recurs_CB.isChecked())
     self.displaydataInfo(len(self.files))
Пример #11
0
class Data(_Module):
    moduleTitle = "Data"
    iconName = "data.png"
    shortDescription = ''' Manage your project data '''
    category = "prep"
    moduleHelp = '''
                 The data module allow you to add audio samples to your project.
                 '''

    def __init__(self, project: Project):
        _Module.__init__(self, project)
        self.ui = Ui_Form()
        self.ui.setupUi(self)
        self.project = project
        self.project.project_updated.connect(self.updateDisplay)
        self.currentDataset = DataSet()
        if len(self.project.datasets) > 0:
            self.currentDataset = self.project.getDatasetByName(
                self.project.datasets[-1])

        self.populateDatasetCB()

        #Chart
        self.chart = DataChart(self.currentDataset.datasetValues())
        self.ui.graphPlaceHolder.setLayout(QtWidgets.QHBoxLayout())
        self.ui.graphPlaceHolder.layout().addWidget(self.chart)

        self.updateDisplay()

        # CONNECT
        self.ui.currentDataSet_CB.currentTextChanged.connect(
            self.onDataSetChanged)
        self.currentDataset.dataset_updated.connect(self.updateDisplay)

        ## buttons
        self.ui.createDataSet_PB.clicked.connect(self.onCreateDatasetClicked)
        self.ui.delete_PB.clicked.connect(self.onDeleteDatasetClicked)
        self.ui.addFromFolder_PB.clicked.connect(self.addFromFolder)
        self.ui.export_PB.clicked.connect(self.onExportClicked)
        self.ui.import_PB.clicked.connect(self.onImportClicked)
        self.ui.remove_PB.clicked.connect(self.onRemoveClicked)

    ########################################################################
    ##### UI LOGIC
    ########################################################################

    def onCreateDatasetClicked(self):
        dialog = CreateDialog(self, self.project.datasets, "Create Dataset",
                              "Dataset name:")
        dialog.on_create.connect(self.createNewDataSet)
        dialog.show()

    def onDataSetChanged(self, name: str):
        if name is not None and name != '':
            self.currentDataset = self.project.getDatasetByName(name)
            self.updateDisplay()

    def onDeleteDatasetClicked(self):
        dialog = ConfirmDelete(self, "Delete Dataset", "Do you want to delete",
                               self.ui.currentDataSet_CB.currentText())
        dialog.on_delete.connect(self.deleteDataset)
        dialog.show()

    def onExportClicked(self):
        dialog = ExportDatasetDialog(self, self.currentDataset)
        dialog.show()

    def onRemoveClicked(self):
        dialog = RemoveFolderSamplesDialog(
            self, self.currentDataset.getSamplesFolders())
        dialog.on_removed.connect(self.onRemoveSamples)
        dialog.show()

    def populateDatasetCB(self):
        self.ui.currentDataSet_CB.clear()
        for ds in self.project.datasets:
            self.ui.currentDataSet_CB.addItem(ds, userData=ds)

    def updateDisplay(self):
        dataset_existing = len(self.project.datasets) > 0
        self.ui.overView_GB.setEnabled(dataset_existing)
        self.ui.add_GB.setEnabled(dataset_existing)
        self.ui.addFromMan_PB.setEnabled(False)  # TODO: implement
        self.ui.export_PB.setEnabled(dataset_existing)
        self.ui.remove_PB.setEnabled(dataset_existing)
        self.ui.overview_TE.clear()
        self.ui.overview_TE.appendPlainText(self.currentDataset.datasetInfo())
        self.chart.updateChart(self.currentDataset.datasetValues())

    ########################################################################
    ##### PROCCESSING
    ########################################################################

    def addFromFolder(self):
        dialog = AddFolderDialog(self, self.currentDataset)
        dialog.addSamples.connect(self.onAddSample)
        dialog.show()

    def createNewDataSet(self, name: str):
        self.project.addNewDataSet(name)
        self.populateDatasetCB()
        self.ui.currentDataSet_CB.setCurrentText(name)
        self.updateDisplay()

    def deleteDataset(self, name: str):
        try:
            self.project.deleteDataSet(name)
        except Exception as e:
            dialog = SimpleDialog(self, "Error", str(e))
            dialog.show()
            return
        self.populateDatasetCB()
        if len(self.project.datasets) > 0:
            self.currentDataset = self.project.getDatasetByName(
                self.project.datasets[0])
            self.ui.currentDataSet_CB.setCurrentIndex(0)
        else:
            self.currentDataset = DataSet()
        self.updateDisplay()

    def onAddSample(self, label: str, files: list):
        self.currentDataset.addSampleFiles(label, files)
        self.updateDisplay()

    def onRemoveSamples(self, folders):
        self.currentDataset.removeFromFolders(folders)

    def onImportClicked(self):
        def datasetLabels(manifest):
            target_labels = set()
            for s in manifest:
                target_labels.add(s["label"])

            return target_labels

        def isMatchingFormat(manifest):
            if type(manifest) != list and type(manifest[0]) != dict:
                return False
            if "label" in manifest[0].keys() and "file" in manifest[0].keys():
                return True
            else:
                return False

        def matchAllLabels(target_labels, labels):
            """ The ideal case where all labels in the imported match the project labels"""
            for l in labels + ['']:
                if not l in target_labels:
                    return False
            for l in target_labels:
                if l not in labels + ['']:
                    return False
            return True

        res = QtWidgets.QFileDialog.getOpenFileName(
            self, "Select dataset json file.", filter="json(*.json)")[0]
        if not res:
            return

        datasetName = os.path.basename(res).split('.')[0]
        with open(res, 'r') as f:
            manifest = json.load(f)

        if not isMatchingFormat(manifest):
            dialog = SimpleDialog(
                self, "Format mismatch",
                "Could not import from dataset manifest.\nTry using create a new dataset and use import from manifest."
            )
            dialog.show()
            return

        target_labels = datasetLabels(manifest)

        if matchAllLabels(target_labels, self.currentDataset.labels):
            if datasetName in self.project.datasets:
                datasetName += "_imported"
            self.project.addNewDataSet(datasetName)
            dataset = self.project.getDatasetByName(datasetName)
            dataset.importDataSet(res)
            self.populateDatasetCB()
            self.ui.currentDataSet_CB.setCurrentText(datasetName)
            self.updateDisplay()
        else:
            pass