예제 #1
0
파일: app.py 프로젝트: rjdbcm/BEAGLES
def parse_args(args):
    parser = argparse.ArgumentParser()
    img_dir = Flags().imgdir
    random_img = random.choice([os.path.join(img_dir, f) for f in
                                os.listdir(img_dir) if
                                os.path.isfile(os.path.join(img_dir, f))])
    parser.add_argument('-i', '--filename', default=random_img, help="image file to open")
    parser.add_argument('-c', '--predefined_class_file', default=Flags().labels,
                        help="text file containing class names")
    parser.add_argument('-s', '--save_directory', default=None, help="save directory")
    parser.add_argument('-d', '--darkmode', default=True,
                        help='use qdarkstyle (defaults to system theme on macos)')
    return parser.parse_args(args)
예제 #2
0
 def __init__(self, flags=None):
     if flags is None:
         flags = Flags()
     clip = "--clip argument" if flags.cli else "'Clip Gradients' checkbox"
     opt = "." if flags.clip else f" or turning on gradient clipping using the {clip}."
     Exception.__init__(
         self, f"Looks like the neural net lost the gradient try restarting"
         f" from the last checkpoint with a lower learning rate{opt}")
예제 #3
0
파일: flags.py 프로젝트: rjdbcm/BEAGLES
    def __init__(self, flags=None, subprogram=False):
        self.subprogram = subprogram
        self.flags = flags if flags else Flags()
        self.logger = get_logger()
        self.shm = SharedMemory()
        if not self.subprogram and not self.shm.mounted:
            self.shm.mount()
        self.flag_path = os.path.join(self.shm.path, FLAG_FILE)

        if subprogram:
            self.read_flags()
            try:
                if self.flags.verbalise:
                    self.logger.setLevel(logging.DEBUG)
            except AttributeError:
                self.logger.setLevel(logging.DEBUG)
            try:
                f = open(self.flag_path)
                f.close()
            except FileNotFoundError:
                time.sleep(1)
예제 #4
0
 def __init__(self, parent=None, labelfile=None):
     super(FlowDialog, self).__init__(parent)
     self.labelfile = labelfile
     self.flags = Flags()
     self.setupDialog()
     self.findCkpt()
예제 #5
0
class FlowDialog(BackendDialog):
    def __init__(self, parent=None, labelfile=None):
        super(FlowDialog, self).__init__(parent)
        self.labelfile = labelfile
        self.flags = Flags()
        self.setupDialog()
        self.findCkpt()

    def findCkpt(self):
        self.loadCmb.clear()
        checkpoints = self.listFiles(self.flags.backup)
        _model = os.path.splitext(self.modelCmb.currentText())
        l = ['0']
        # a dash followed by a number or numbers followed by a dot
        _regex = re.compile("\-[0-9]+\.")
        for f in checkpoints:
            if f[:len(_model[0])] == _model[0]:
                _ckpt = re.search(_regex, f)
                start, end = _ckpt.span()
                n = f[start + 1:end - 1]
                l.append(n)
                self.buttonRun.setDisabled(False)
            # else:
            #     self.buttonOk.setDisabled(True)
        l = list(map(int, l))
        l.sort(reverse=True)
        l = list(map(str, l))
        self.loadCmb.addItems(l)

    def updateCkptFile(self):
        """write selected checkpoint and model information to checkpoint"""
        regex = re.compile('".*?"')
        model_name = os.path.splitext(self.modelCmb.currentText())[0]
        replacement = "-".join([model_name, self.loadCmb.currentText()])
        file = (os.path.join(self.flags.backup, 'checkpoint'))
        open(file, 'a').close()  # touch
        fh = open(file, 'r')
        data = fh.read()
        fh.close()
        result = regex.sub('"{}"'.format(replacement), data)
        fh = open(file, 'w')
        fh.write(result)
        fh.close()

    def trainerSelect(self):
        self.momentumSpd.setDisabled(True)
        for trainer in ("rmsprop", "momentum", "nesterov"):
            if self.trainerCmb.currentText() == trainer:
                self.momentumSpd.setDisabled(False)

    def flowSelect(self):
        if self.flowCmb.currentText() == "Predict":
            self.flowGroupBox.show()
        else:
            self.flowGroupBox.hide()

        if self.flowCmb.currentText() == "Train":
            self.trainGroupBox.show()
            self.thresholdSpd.setDisabled(True)
        else:
            self.trainGroupBox.hide()
            self.loadCmb.setCurrentIndex(0)

        if self.flowCmb.currentText() == "Annotate":
            self.thresholdSpd.setDisabled(False)

    #
    # def updateAnchors(self):
    #     pass
    #     genConfigYOLOv2()

    def assign_flags(self):
        self.flags.project_name = 'default'
        self.flags.model = os.path.join(self.flags.config,
                                        self.modelCmb.currentText())
        try:
            self.flags.load = int(self.loadCmb.currentText())
        except ValueError:
            self.flags.load = 0
            pass
        self.flags.trainer = self.trainerCmb.currentText()
        self.flags.threshold = self.thresholdSpd.value()
        self.flags.clip = bool(self.clipChb.checkState())
        self.flags.clip_norm = self.clipNorm.value()
        self.flags.clr_mode = self.learningModeCmb.currentText()
        self.flags.verbalise = bool(self.verbaliseChb.checkState())
        self.flags.momentum = self.momentumSpd.value()
        self.flags.lr = self.learningRateSpd.value()
        self.flags.max_lr = self.maxLearningRateSpd.value()
        self.flags.keep = self.keepSpb.value()
        self.flags.batch = self.batchSpb.value()
        self.flags.save = self.saveSpb.value()
        self.flags.epoch = self.epochSpb.value()
        self.flags.labels = self.labelfile  # use labelfile set by slgrSuite
        if self.jsonChb.isChecked():
            self.flags.output_type.append("json")
        if self.vocChb.isChecked():
            self.flags.output_type.append("voc")

    def accept(self):
        """set flags for darkflow and prevent startup if errors anticipated"""
        self.updateCkptFile()  # Make sure TFNet gets the correct checkpoint
        self.flags.get_defaults()  # Reset self.flags
        self.assign_flags()

        if not self.flowCmb.currentText() == "Train" and self.flags.load == 0:
            QMessageBox.warning(self, 'Error', "Invalid checkpoint",
                                QMessageBox.Ok)
            return
        if self.flowCmb.currentText() == "Predict":
            self.flowGroupBox.setDisabled(True)
            options = QFileDialog.Options()
            options = QFileDialog.ShowDirsOnly | QFileDialog.DontResolveSymlinks
            dirname = QFileDialog.getExistingDirectory(
                self, 'BEAGLES Predict - '
                'Choose Image Folder', os.getcwd(), options)
            self.flags.imgdir = dirname
            pass
        if self.flowCmb.currentText() == "Train":
            if not self.flags.save % self.flags.batch == 0:
                QMessageBox.warning(
                    self, 'Error', "The value of 'Save Every' should be "
                    "divisible by the value of 'Batch Size'", QMessageBox.Ok)
                return
            dataset = [
                f for f in os.listdir(self.flags.dataset)
                if not f.startswith('.')
            ]
            if not dataset:
                QMessageBox.warning(self, 'Error',
                                    'No frames or annotations found',
                                    QMessageBox.Ok)
                return
            else:
                self.flags.train = True
        if self.flowCmb.currentText() == "Annotate":
            formats = ['*.avi', '*.mp4', '*.wmv', '*.mkv', '*.mpeg']
            filters = "Video Files (%s)" % ' '.join(formats +
                                                    ['*%s' % LabelFile.suffix])
            options = QFileDialog.Options()
            options |= QFileDialog.DontUseNativeDialog
            filename = QFileDialog.getOpenFileNames(
                self,
                'BEAGLES Annotate - Choose Video file',
                os.getcwd(),
                filters,
                options=options)
            self.flags.video = filename[0]
        if [self.flowCmb.currentText() == "Train"]:
            proc = Popen([sys.executable, BACKEND_ENTRYPOINT],
                         stdout=PIPE,
                         shell=False)
            self.thread = BackendThread(self, proc=proc, flags=self.flags)
            self.thread.setTerminationEnabled(True)
            self.thread.finished.connect(self.onFinished)
            self.thread.connection.progressUpdate.connect(self.updateProgress)
            self.thread.start()
        self.flowPrg.setMaximum(0)
        self.buttonRun.setEnabled(False)
        self.buttonRun.hide()
        self.buttonStop.show()
        self.formGroupBox.setEnabled(False)
        self.trainGroupBox.setEnabled(False)

    def stopMessage(self, event):

        option = "close" if type(event) == QCloseEvent else "stop"
        msg = "Are you sure you want to {} this dialog? " \
              "This will terminate any running processes.".format(option)
        reply = QMessageBox.question(self, 'Message', msg, QMessageBox.Yes,
                                     QMessageBox.No)
        if reply == QMessageBox.No:
            try:
                event.ignore()
            except AttributeError:
                pass
        else:
            try:
                self.thread.stop()
            except AttributeError:
                pass
            return True

    def closeEvent(self, event):
        def acceptEvent(accepted):
            if accepted:
                self.buttonRun.setDisabled(False)
                self.buttonStop.hide()
                self.buttonRun.show()
                self.flowGroupBox.setEnabled(True)
                self.trainGroupBox.setEnabled(True)
                self.formGroupBox.setEnabled(True)
                try:
                    event.accept()
                except AttributeError:
                    pass

        try:
            thread_running = self.thread.isRunning()
        except AttributeError:
            thread_running = False
        if thread_running:
            accepted = self.stopMessage(event)
            acceptEvent(accepted)
        else:
            self.flowPrg.setMaximum(100)
            self.flowPrg.reset()
            acceptEvent(True)

    def rolloverLogs(self):
        logs = [self.thread.logfile, self.thread.tf_logfile]
        for log in logs:
            if os.stat(log.baseFilename).st_size > 0:
                log.doRollover()

    def onFinished(self):
        self.flags = self.thread.flags
        if self.flags.error:
            QMessageBox.critical(self, "Error Message", self.flags.error,
                                 QMessageBox.Ok)
            self.rolloverLogs()
        if self.flags.verbalise:
            QMessageBox.information(
                self, "Debug Message", "Process Stopped:\n" +
                "\n".join('{}: {}'.format(k, v)
                          for k, v in self.flags.items()), QMessageBox.Ok)
        self.flowGroupBox.setEnabled(True)
        self.trainGroupBox.setEnabled(True)
        self.formGroupBox.setEnabled(True)
        self.flowPrg.setMaximum(100)
        self.flowPrg.reset()
        self.buttonRun.setDisabled(False)
        self.buttonStop.hide()
        self.buttonRun.show()
        self.findCkpt()

    @pyqtSlot(int)
    def updateProgress(self, value):
        if self.flowPrg.maximum():
            self.flowPrg.setValue(value)
        else:  # stop pulsing and set value
            self.flowPrg.setMaximum(100)
            self.flowPrg.setValue(value)

    # HELPERS
    @staticmethod
    def listFiles(path):
        path = QDir(path)
        filters = ["*.cfg", "*.index"]
        path.setNameFilters(filters)
        files = path.entryList()
        return files
예제 #6
0
파일: test_qt.py 프로젝트: rjdbcm/BEAGLES
class TestMainWindow(TestCase):

    app = None
    win = None
    flags = Flags()
    labels = flags.labels

    @classmethod
    @mock.patch('argparse.ArgumentParser.parse_args',
                return_value=argparse.Namespace(
                    filename='data/sample_img/sample_dog.jpg',
                    predefined_class_file=labels,
                    save_directory='tests',
                    darkmode=True))
    def setUpClass(cls, args):
        cls.app, cls.win = get_main_app()

    def testCanvas(self):
        self.assertRaises(AssertionError, self.win.canvas.resetAllLines)

    def testToggleAdvancedMode(self):
        self.assertTrue(self.win.beginner())
        self.win.advancedMode()
        self.assertFalse(self.win.beginner())
        self.win.advancedMode(False)
        self.assertTrue(self.win.beginner())

    def testChangeFormat(self):
        self.win.changeFormat()
        self.assertTrue(self.win.usingYoloFormat)
        self.win.changeFormat()
        self.assertTrue(self.win.usingPascalVocFormat)

    def testToggleDrawMode(self):
        self.win.toggleDrawMode(True)
        self.assertTrue(self.win.canvas.editing)

    def testTrainModel(self):
        self.win.trainModel()

    def testLoadPascalXMLByFilename(self):
        self.win.loadPascalXMLByFilename('tests/resources/test.xml')

    def testFileLoadZoom(self):
        self.win.loadFile('data/sample_img/sample_dog.jpg')
        self.win.setFitWin()
        self.win.setFitWidth()
        self.win.setZoom(50)
        self.win.closeFile()

    def testImportDirImages(self):
        self.win.importDirImages('data/sample_img')
        self.win.nextImg()
        self.win.prevImg()

    def testImpVideo(self):
        FileFunctions().frameCapture(
            os.path.abspath('tests/resources/test.mp4'))
        files = glob.glob('tests/resources/test_frame_*.jpg')
        self.assertFalse(files == [])
        for file in files:
            os.remove(file)

    def tearDown(self) -> None:
        self.win.setClean()

    @classmethod
    def tearDownClass(cls) -> None:
        cls.win.close()
        cls.app.quit()

    def test_noop(self):
        pass
예제 #7
0
 def setUpClass(cls) -> None:
     cls.flags = Flags()
     cls.maxDiff = None
예제 #8
0
 def __init__(self):
     super(MachineLearningFunctions, self).__init__()
     self.tb_process = QProcess(self)
     self.tb_process.start("tensorboard",
                           ["--logdir=data/summaries", "--debugger_port=6064"])
     self.trainDialog = FlowDialog(parent=self, labelfile=Flags().labels)