示例#1
0
    def snapshot(self):
        """ Create a snapshot from the training state.

        Return
            True if the snapshot was created
            False if the snapshot could not be created
        """
        if self.getState() is State.RUNNING:
            if self.proc:
                self.last_solverstate = None
                try:
                    self.proc.send_signal(signal.SIGHUP)
                except Exception as e:
                    Log.error('Failed to take snapshot: ' + str(e),
                              self.getCallerId())
                    return False
                Log.log(
                    'Snapshot was saved for session ' +
                    self.getRunLogFileName(True) + '', self.getCallerId())
                return True
            else:
                Log.error(
                    'Could not take a session snapshot in state ' +
                    str(self.getState()), self.getCallerId())
        return False
示例#2
0
    def _promptForAutoNetImport(self, netPath):
        """Open a message box and ask the user if he/she wants to import the net definition given in netPath.

        netPath contains a net definition that has been extracted from a loaded solver definition.
        """

        # check whether the specified file does exist
        if os.path.isfile(netPath):

            # get the file content
            with open(netPath, 'r') as file:
                netPrototxt = file.read()

            msgBox = QMessageBox()
            msgBox.setWindowTitle("Barista")
            msgBox.setText(
                self.tr("Solver definition contains a network reference."))
            msgBox.setInformativeText(
                self.tr("Do you want to import the network, too?"))
            msgBox.setDetailedText("File:\n{}\n\nNet definition:\n{}".format(
                netPath, netPrototxt))
            msgBox.setStandardButtons(QMessageBox.Yes | QMessageBox.No)
            msgBox.setDefaultButton(QMessageBox.Yes)
            ret = msgBox.exec_()
            if ret == QMessageBox.Yes:
                self.openNet(netPath)
        else:
            callerId = Log.getCallerId('file_loader')
            Log.log(
                "Found network reference in loaded solver definition, but the file {} does not exist."
                .format(netPath), callerId)
示例#3
0
    def fetchParserData(self):
        if self._assertConnection() and self.transaction is not None:
            msg = {
                "key": Protocol.SESSION,
                "subkey": SessionProtocol.FETCHPARSERDATA
            }
            self.transaction.send(msg)
            ret = self.transaction.asyncRead(
                staging=True, attr=("subkey", SessionProtocol.FETCHPARSERDATA))
            if ret:
                if ret["status"]:
                    parserRows = ret["ParserRows"]
                    parserKeys = ret["ParserKeys"]
                    parserHandle = ret["ParserHandle"]
                    parserLog = ret["ParserLogs"]

                    for phase, key in parserKeys:
                        self.parser.sendParserRegisterKeys(phase, key)

                    for phase, row in parserRows:
                        self.parser.sendParserUpdate(phase, row)

                    for event, line, groups in parserHandle:
                        self.parser.sendParserHandle(event, line, groups)

                    for log, error in parserLog:
                        if error:
                            Log.error(log, self.getCallerId())
                        else:
                            Log.log(log, self.getCallerId())

                    return
                else:
                    self._handleErrors(ret["error"])
        self._handleErrors(["Failed to fetch parser data from host."])
示例#4
0
 def printLog(self, log, error=False):
     if self.log_id is not None:
         if not error:
             Log.log(log, self.log_id)
         else:
             Log.error(log, self.log_id)
     else:
         self.printLogSignl.emit(log, error)
示例#5
0
    def _addSoftmax(self):
        """Add a softmax layer to the very end of the net, but only if a SoftmaxWithLoss layer was used before."""

        # check whether the net used to contain a SoftmaxWithLoss layer
        softmaxWithLossWasUsed = False
        for id, layer in self._originalNetworkDictionary["layers"].iteritems():
            if layer["type"].name() == "SoftmaxWithLoss":
                softmaxWithLossWasUsed = True
                break

        if softmaxWithLossWasUsed:
            # ensure that the remaining deployment net has at least one layer
            if len(self._deployedNetworkDictionary["layers"]) > 0:

                softmaxLayerType = info.CaffeMetaInformation(
                ).availableLayerTypes()["Softmax"]

                # do not add another softmax, if the current deployment network already contains one
                softmaxAlreadyIncluded = False
                for id, layer in self._deployedNetworkDictionary[
                        "layers"].iteritems():
                    if layer["type"].name() == softmaxLayerType.name():
                        softmaxAlreadyIncluded = True
                        break

                if not softmaxAlreadyIncluded:
                    # get the very last layer
                    lastLayerId = self._deployedNetworkDictionary[
                        "layerOrder"][-1]
                    lastLayer = self._deployedNetworkDictionary["layers"][
                        lastLayerId]

                    # ensure that the determined last layer does have a top blob
                    if "top" in lastLayer["parameters"] and len(
                            lastLayer["parameters"]["top"]) > 0:

                        # create new softmax layer with default values and add it to the deployment net
                        name = "softmax"
                        position = len(
                            self._deployedNetworkDictionary["layers"])
                        softmaxLayer, softmaxLayerId = self._dHelper.addLayer(
                            softmaxLayerType, name, position)

                        # connect the softmax layer with the existing network
                        softmaxLayer["parameters"]["bottom"] = [
                            lastLayer["parameters"]["top"][0]
                        ]

                        # name the output
                        softmaxLayer["parameters"]["top"] = ["probabilities"]
                    else:
                        Log.log(
                            "Could not add Softmax layer as the very last layer of the deployment net does not have any "
                            "top blobs.", self._logId)
            else:
                Log.log(
                    "Could not add Softmax layer as the remaining deployment net does not have any layers.",
                    self._logId)
示例#6
0
    def proceed(self, snapshot=None):
        """ Continue training from the (last) snapshot.

        Return
            True if the process was continued
            False if the continuation failed
        """
        if self.getState() is State.PAUSED:
            self.__ensureDirectory(self.getSnapshotDirectory())
            self.__ensureDirectory(self.logs)

            if snapshot is None:
                snapshot = self.getLastSnapshot()
            self.rid += 1
            try:
                self.getParser().setLogging(True)
                self.proc = Popen([
                    caffeVersions.getVersionByName(
                        self.project.getCaffeVersion()).getBinarypath(),
                    'train', '-solver',
                    self.getSolver(), '-snapshot', snapshot
                ],
                                  stdout=PIPE,
                                  stderr=STDOUT,
                                  cwd=self.getDirectory())
                try:
                    self.tee = Popen(
                        ['tee', '-a', self.getRunLogFileName()],
                        stdin=self.proc.stdout,
                        stdout=PIPE)
                except Exception as e:
                    # continue without tee
                    Log.error('Failed to start tee: ' + str(e),
                              self.getCallerId())
                self.setState(State.RUNNING)
                Log.log(
                    'Session ' + self.getRunLogFileName(True) +
                    ' was proceeded', self.getCallerId())
                self.__startParsing()
                return True
            except Exception as e:
                # check if caffe root exists
                Log.error('Failed to continue session: ' + str(e),
                          self.getCallerId())
                if os.file.exists(
                        caffeVersions.getVersionByName(
                            self.project.getCaffeVersion()).getBinarypath()
                ) is False:
                    Log.error(
                        'CAFFE_BINARY directory does not exists: ' +
                        caffe_bin +
                        '! Please set CAFFE_BINARY to run a session.',
                        self.getCallerId())
        elif self.getState() in (State.FAILED, State.FINISHED):
            Log.error(
                'Could not continue a session in state ' +
                str(self.getState()), self.getCallerId())
        return False
示例#7
0
 def getInternalNetFile(self, log=False):
     """ Returns the original net prototxt file name.
     When the log flag is set, a message will be sent to the logger console if the file does not exist.
     """
     if log:
         if not os.path.isfile(self.__netInternalFile):
             Log.log(
                 "This sessions net file: " + self.__netInternalFile +
                 " does not exist.", self.caller_id)
     return self.__netInternalFile
示例#8
0
 def getSolver(self, log=False):
     """ Returns the solver prototxt file name.
     When the log flag is set, a message will be sent to the logger console if the file does not exist.
     """
     if log:
         if not os.path.isfile(self.__solverFile):
             Log.log(
                 "This sessions Solverfile: " + self.__solverFile +
                 " does not exist.", self.caller_id)
     return self.__solverFile
示例#9
0
 def getSession(self, SID):
     if SID in self.getValidSIDs():
         return self.__sessions[SID]
     else:
         Log.log(
             "Session " + str(SID) +
             " could not be loaded. Valid IDs are: " +
             ", ".join([str(i) for i in self.getValidSIDs()]),
             self.getCallerId())
         return None
示例#10
0
 def getActiveSID(self):
     if self.__activeSID:
         if len(self.__sessions) > 0:
             if self.__activeSID not in self.__sessions:
                 Log.log(
                     "The Active Session is no longer available. The Project seems to be broken. The active Session is set to the highest ID available.",
                     self.callerId)
                 self.setActiveSID(self.__sessions.keys()[-1])
                 Log.log("Active Session set to " + str(self.__activeSID),
                         self.callerId)
     return self.__activeSID
示例#11
0
    def pause(self):
        """ Pause the process.

        Return
            True if the process was paused
            False if the pause failed
        """
        if self.getState() is State.RUNNING:
            if self.proc:
                self.snapshot()
                # give the caffe process a second to save the state
                time.sleep(1)
                try:
                    if self.proc:
                        self.proc.kill()
                except Exception as e:
                    Log.error('Pausing session failed: ' + str(e))
                try:
                    if self.tee:
                        self.tee.kill()
                except Exception:
                    pass
                self.proc = None
                self.tee = None
                self.last_solverstate = None
                snap = self._getLastSnapshotFromSnapshotDirectory()
                if snap is not None:
                    self.last_solverstate = os.path.basename(
                        self._getLastSnapshotFromSnapshotDirectory())
                    regex_iter = re.compile(
                        'iter_([\d]+)\.solverstate[\.\w-]*$')
                    iter_match = regex_iter.search(self.last_solverstate)
                    if iter_match:
                        self.iteration = int(iter_match.group(1))
                    self.iterationChanged.emit()
                    self.setState(State.PAUSED)
                    Log.log(
                        'Session ' + self.getRunLogFileName(True) +
                        ' was paused', self.getCallerId())
                else:
                    self.setState(State.WAITING)
                    Log.log(
                        'Session ' + self.getRunLogFileName(True) +
                        ' was halted', self.getCallerId())
                self.save()
                return True
            else:
                Log.error(
                    'Could not pause a session in state ' +
                    str(self.getState()), self.getCallerId())
                self.setState(State.UNDEFINED)
        return False
示例#12
0
 def __ensureDirectory(self, directory):
     """ Creates a directory if it does not exist.
     """
     if directory == '':
         return
     if not os.path.exists(directory):
         try:
             os.makedirs(directory)
             Log.log('Created directory: ' + directory, self.getCallerId())
         except Exception as e:
             Log.error(
                 'Failed to create directory ' + directory + ')' + str(e),
                 self.getCallerId())
示例#13
0
 def _printLog(self):
     msg = self.transaction.asyncRead(attr=("subkey",
                                            SessionProtocol.PRINTLOG))
     logs = msg["log"]
     for log, error in logs:
         if not isinstance(log, list):
             log = [log]
         # if msg["isError"]:
         if error:
             self._handleErrors(log)
         else:
             for l in log:
                 Log.log(l, self.getCallerId())
示例#14
0
    def _deployAndExportUnsafe(self):
        """
        This will be triggered when the user clicks the deploy button. It validates
        the user input and displays messages and errors if additional input is
        required. If everything is validated successfully, the session is exported
        to the destination directory.
        """
        # If no snapshot exists, we show an error message and close the dialog.
        if not self._hasSnapshotsOrDisplayError():
            self.close()
            return
        # Get the destination folder from the current user input. If the input
        # is empty, we show an error message and cancel.
        destinationFolder = self._getDestinationOrDisplayError()
        if not destinationFolder:
            return
        # Check if the path already exsists. If it doesn't exist yet, let the user
        # decide whether to create all missing folders and abort otherwise.
        folderExists = os.path.isdir(destinationFolder)
        if not folderExists and not self._askDirectoryCreatePermission():
            return
        # Ensure that the full path points to a folder and not a file.
        if not self._ensurePathIsFolderOrDisplayError(destinationFolder):
            return
        # Determine the destination file paths for
        destinationPrototxtFile, caffemodelDestination = self._getDestinationFilePaths(destinationFolder)
        # Check if any of the destination files already exsist and ask the user
        # if they should be replaced. Abort if the user decides not to replace one
        # of the files.
        if not self._checkFilesDontExistOrAskReplacePermission([destinationPrototxtFile, caffemodelDestination]):
            return
        # Export files.
        session = self._selectedSession()
        snapshot = self._selectedSnapshot()
        caffemodelContents = session.readCaffemodelFile(self._replaceLast(snapshot,
                                                                          'solverstate',
                                                                          'caffemodel'))
        # Start deployment.
        deployedNet = session.readDeployedNetAsString()

        # Write prototxt file.
        with open(destinationPrototxtFile, 'w') as file:
            file.write(deployedNet)
        # Write caffemodel file.
        with open(caffemodelDestination, 'w') as caffemodelFile:
            caffemodelFile.write(caffemodelContents)

        Log.log("Deployment files have been saved successfully to {}.".format(destinationPrototxtFile), self.getCallerId())

        # Close the current dialog.
        self.close()
示例#15
0
 def save(self):
     Log.log("Saving current Session status to disk.", self.getCallerId())
     if self._assertConnection():
         msg = {"key": Protocol.SESSION, "subkey": SessionProtocol.SAVE}
         self.transaction.send(msg)
         ret = self.transaction.asyncRead(attr=("subkey",
                                                SessionProtocol.SAVE))
         if ret:
             if ret["status"]:
                 return True
             else:
                 self._handleErrors(ret["error"])
     self._handleErrors(["Could not save session."])
     return False
示例#16
0
 def changeSession(self, newSID, oldSID=None):
     """Changes the active session within one project.
     The current State of the netManager is saved to the old session. The state of the new session is loaded to the netManager."""
     self.storeSessionState(SID=oldSID, stateDict=None)
     self._project.setActiveSID(newSID)
     if self._project.getActiveSID() == newSID:
         self.loadSessionState(SID=newSID)
     else:
         Log.log(
             "New Session " + str(newSID) +
             " could not be set. Valid SIDs are: " +
             ", ".join([str(id) for id in self._project.getValidSIDs()]),
             self._viewManager.sessionController.getCallerId())
     return
示例#17
0
 def setActiveSID(self, sid):
     validSIDs = self.getValidSIDs()
     if sid in validSIDs:
         self.__activeSID = sid
         self.activeSessionChanged.emit(sid)
     else:
         Log.error(
             "Could not set active session to " + str(sid) +
             " valid Session-IDs are: " +
             ", ".join([str(s) for s in validSIDs]), self.getCallerId())
         if not self.__activeSID:
             self.__activeSID = validSIDs[-1] if len(
                 validSIDs) > 0 else None
             Log.log("Active session set to " + str(self.__activeSID),
                     self.getCallerId())
示例#18
0
    def save(self, includeProtoTxt=False):
        """Saves the current session to a json file. If includeProtoTxt is True, prototxt files are saved as well."""
        toSave = {
            "SessionState": self.state,
            "Iteration": self.iteration,
            "MaxIter": self.max_iter
        }
        toSave["ProjectID"] = self.project.projectId

        self.__ensureDirectory(self.directory)
        Log.log("Saving current Session status to disk.", self.getCallerId())
        if self.last_solverstate:
            toSave["LastSnapshot"] = self.last_solverstate
        if self.getPretrainedWeights():
            toSave["PretrainedWeights"] = self.getPretrainedWeights()
        if self.state_dictionary:
            serializedDict = copy.deepcopy(self.state_dictionary)
            if includeProtoTxt:
                if "solver" in self.state_dictionary:
                    solver = self.buildSolverPrototxt()
                    with open(self.getSolver(log=False), 'w') as f:
                        f.write(solver)
                else:
                    Log.error(
                        "Could not save a solver prototxt file, because no solver settings are defined.",
                        self.getCallerId())

            if "network" in serializedDict:
                if includeProtoTxt:
                    net = self.buildNetPrototxt(internalVersion=False)
                    with open(self.getOriginalNetFile(log=False), 'w') as f:
                        f.write(net)
                    net = self.buildNetPrototxt(internalVersion=True)
                    with open(self.getInternalNetFile(log=False), 'w') as f:
                        f.write(net)
                if "layers" in serializedDict["network"]:
                    layers = serializedDict["network"]["layers"]
                    for id in layers:
                        del layers[id]["type"]
            else:
                Log.error(
                    "Could not save the network state because no state was defined.",
                    self.getCallerId())

            toSave["NetworkState"] = serializedDict

        with open(baristaSessionFile(self.directory), "w") as f:
            json.dump(toSave, f, sort_keys=True, indent=4)
 def __saveImage(self):
     """ Gives the user the opportunity to save the image on disk.
         The user can choose from allowed extension, and is alerted if the entered extension is not valid.
         If no extension is entered by the user, the extension is set to png.
     """
     allowedExtensions = [".png", ".bmp", ".jpeg", ".jpg"]
     allowedAsString = str(reduce(lambda x, y: x + y, map(lambda x: " *" + x, allowedExtensions)))
     callerId = self.getCallerId()
     try:
         filename = ""
         # While the entered extension does not matches the allowd extensions
         while not self.__validateFilename(filename, allowedExtensions):
             fileDialog = QFileDialog()
             fileDialog.setDefaultSuffix("png")
             filenameArray = fileDialog.getSaveFileName(
                 self,
                 "Save File",
                 filename,
                 allowedAsString
             )
             filename = filenameArray[0]
             if filename != "":
                 _, extension = os.path.splitext(filename)
                 # If no extension has been entered, append .png
                 if extension == "":
                     filename += ".png"
             else:
                 # If user clicks on abort, leave the loop
                 break
             # Show an alert message, when an unknown extension has been entered
             if not  self.__validateFilename(filename, allowedExtensions):
                 msg = QMessageBox()
                 msg.setIcon(QMessageBox.Warning)
                 msg.setWindowTitle("Warning")
                 msg.setText("Please enter a valid extension:\n"
                             + allowedAsString)
                 msg.exec_()
         if filename != "":
             if self.__currentNet:
                 image = calculateConvWeights(self.__currentNet, self.__currentLayerName)
                 self.canvasWidget.saveImage(image, filename)
                 Log.log("Saved image under " + filename, callerId)
             else:
                 Log.error("Try to save image of weights without a network being loaded.", callerId)
     except Exception as e:
         Log.error("Saving the file failed. " + str(e), callerId)
示例#20
0
    def errorMsg(self, errorMsg, loggerIdString=None, addStacktrace=True):
        """Show an error message in the Logger as well as in an additional GUI popup."""
        # use the logger
        if loggerIdString is not None:
            callerId = Log.getCallerId(loggerIdString)
        else:
            callerId = None
        Log.log(errorMsg, callerId)

        # show message in the GUI
        msgBox = QMessageBox()
        msgBox.setWindowTitle("Barista - Error")
        msgBox.setText(self.tr(errorMsg))
        if addStacktrace:
            stacktrace = traceback.format_exc()
            msgBox.setDetailedText(stacktrace)
        msgBox.setStandardButtons(QMessageBox.Ok)
        _ = msgBox.exec_()
示例#21
0
    def _insertInputLayers(self):
        """Insert new input layers with fixed dimensions.

        Note that, the number of newly-added input layer might be lower than the number of previously-removed data
        layers. On the one hand, we will add only one new input layer for each unique data blob name, while multiple
        data layers might have used the same data blob name. On the other hand, input layers will only be added, if
        at least one other layer is using the provided data.
        """
        inputLayerType = info.CaffeMetaInformation().availableLayerTypes(
        )["Input"]
        inputLayerNr = 1
        for dataBlobName in self._dataBlobNames:

            # create a new input layer with default values and add it to the deployment net
            name = self._getNewInputLayerName(inputLayerNr)
            inputLayer, inputLayerId = self._dHelper.addLayer(
                inputLayerType, name, inputLayerNr - 1)

            # modify the layer template
            inputLayer["parameters"]["top"] = [dataBlobName]

            # set input_param.shape with batch size 1 and the dimensions of the first data element
            inputLayer["parameters"]["input_param"] = dict()
            inputLayer["parameters"]["input_param"]["shape"] = []
            inputLayer["parameters"]["input_param"]["shape"].append(dict())
            inputLayer["parameters"]["input_param"]["shape"][0][
                "dim"] = self._inputBlobShapes[dataBlobName]
            inputLayer["parameters"]["input_param"]["shape"][0]["dim"].insert(
                0, 1)

            # prepare next input layer
            inputLayerNr += 1

        # check whether there is any shape that could not be determined automatically and needs to be set manually
        inputShapeWarning = False
        for inputShape in self._inputBlobShapes:
            if len(inputShape) < 1:
                inputShapeWarning = True
                break
        if inputShapeWarning:
            Log.log(
                "At least one input shape could not be determined automatically. Please open the created prototxt "
                "and manually fix all input shapes which include only the batch size (1).",
                self._logId)
示例#22
0
    def storeSessionState(self, SID=None, stateDict=None):
        """Stores the stateDict to a session.
        If no SID is provided, the currently active session is used.
        If no dictionary is provided, the current netManager state is used."""
        dict = stateDict
        if not dict:
            dict = self._netManager.getStateDictionary()
        _SID = SID
        if not _SID:
            _SID = self._project.getActiveSID()
        elif _SID not in self._project.getValidSIDs():
            Log.log(
                "Could not store state to SID ", _SID, ". Valid SIDs are:" +
                ", ".join([str(id) for id in self._project.getValidSIDs()]),
                self._viewManager.sessionController.getCallerId())

        self._project.getSession(_SID).setStateDict(dict)

        return
示例#23
0
    def loadSessionState(self, SID=None):
        """Loades the state from a session to the netManager.
        If no Session is provided, the currently active session is used."""
        _SID = SID
        if not _SID:
            _SID = self._project.getActiveSID()

        if _SID in self._project.getValidSIDs():
            if self._project.getSession(_SID).hasStateDict():
                dict = self._project.getSession(_SID).state_dictionary
                net = None
                pos = None
                sel = None
                if "network" in dict:
                    net = dict["network"]
                elif ("layerOrder" in dict) and ("layers" in dict):
                    net = dict

                if "position" in dict:
                    pos = dict["position"]

                if "selection" in dict:
                    sel = dict["selection"]

                if net:
                    self._netManager.setStateDictionary(dictionary=dict,
                                                        clearHistory=True)
            else:
                Log.log(
                    "Could not load state from Session," + str(_SID) +
                    " is this an old Project?",
                    self._viewManager.sessionController.getCallerId())

        else:
            # TODO: write proper Error to log
            print("Could not load state from session", _SID)
            Log.log("Could not load state from session " + str(_SID),
                    self._viewManager.sessionController.getCallerId())

        return
示例#24
0
def ensureSolverConstraints(solverDictionary):
    """Ensure that all constraints for the given solverDictionary are valid.

    Sets static values and removes invalid values.
    """

    # The file names inside of a session are static and must not be altered by the user
    if "net" not in solverDictionary or solverDictionary[
            "net"] != backend.barista.session.session_utils.Paths.FILE_NAME_NET_INTERNAL:
        Log.log(
            "The solver property 'net' must point to the generated network file. "
            "Value has been changed from '{}' to '{}'.".format(
                solverDictionary["net"]
                if "net" in solverDictionary else "None", backend.barista.
                session.session_utils.Paths.FILE_NAME_NET_INTERNAL), logID)
        solverDictionary[
            "net"] = backend.barista.session.session_utils.Paths.FILE_NAME_NET_INTERNAL

    # An additional net definition inside of the solver would be inconsistent to the separately handled network
    if "net_param" in solverDictionary:
        Log.log(
            "The solver property 'net_param' is not supported as it would be inconsistent with the separately "
            "handled network. Property has been removed.", logID)
        del solverDictionary["net_param"]

    # a snapshot_prefix containing a path is not supported either
    if "snapshot_prefix" in solverDictionary:
        head, tail = os.path.split(solverDictionary["snapshot_prefix"])

        if len(head) > 0:
            Log.log(
                "The solver property 'snapshot_prefix' contained an unsupported path. "
                "Property was shortened from '{}' to '{}'.".format(
                    solverDictionary["snapshot_prefix"], tail), logID)
            solverDictionary["snapshot_prefix"] = tail

    return solverDictionary
示例#25
0
    def __init__(self,
                 project,
                 directory=None,
                 sid=None,
                 parse_old=False,
                 caffe_bin=None,
                 last_solverstate=None,
                 last_caffemodel=None,
                 state_dictionary=None):
        super(Session, self).__init__()
        self.caller_id = None
        self.state = State.UNDEFINED
        self.invalidErrorsList = []
        self.sid = sid
        self.directory = directory
        self.rid = 0
        self.project = project
        self.logs = None
        if self.directory is None:
            if self.sid is None:
                raise ValueError(
                    'Either directory or sid must be provided to create a session.'
                )
            self.directory = self.__createSessionDirectoryName()
            self.logs = os.path.join(self.directory, 'logs')
        else:
            self.logs = os.path.join(directory, 'logs')
            dir_sid, self.rid = self.__parseSessionId()
            if self.sid is None:
                self.sid = dir_sid
            else:
                Log.log(
                    'Provided sid and directory do not match (' +
                    str(self.sid) + ' vs. ' + str(dir_sid) + '), ' +
                    'provided sid is used.', self.getCallerId())

        self.parse_old = parse_old
        self.caffe_bin = caffe_bin  # overrides project caffe_root if necessary, i.e. if deployed to another system
        self.pretrainedWeights = None

        self.last_solverstate = last_solverstate
        self.last_caffemodel = last_caffemodel
        self.state_dictionary = state_dictionary  # state as saved from the network manager, such it can be restored

        self.start_time = self.__parseStartTime()

        self.snapshot_dir = None
        self.snapshot_prefix = None
        self.proc = None
        self.tee = None
        self.parser = None
        self.iteration = 0
        self.max_iter = 1
        self.parser_initialized = False

        self.__getSettingsFromSessionFile()

        if self.state_dictionary is not None:
            self.__parseSettings(self.state_dictionary)

        self.__solverFile = os.path.join(self.directory,
                                         Paths.FILE_NAME_SOLVER)
        self.__netInternalFile = os.path.join(self.directory,
                                              Paths.FILE_NAME_NET_INTERNAL)
        self.__netOriginalFile = os.path.join(self.directory,
                                              Paths.FILE_NAME_NET_ORIGINAL)

        self.lock = Lock()
        self.getSnapshotDirectory()
示例#26
0
    def _searchDataLayers(self):
        """Search for all data layers in self._originalNetworkDictionary and store them in self._dataLayers.

        Additionally, the names of all top blobs containing label data will be saved in self._labelBlobNames.
        """
        nLabelBlobUndetermined = 0  # check whether some label blobs couldn't be determined automatically
        nTooManyTopBlobs = 0  # check whether some data layers do have too many top blobs
        for id, layer in self._originalNetworkDictionary["layers"].iteritems():
            if layer["type"].isDataLayer():
                self._dataLayers[id] = layer

                # validate number of top blobs
                if "top" in layer["parameters"] and len(
                        layer["parameters"]["top"]) > 2:
                    nTooManyTopBlobs += 1

                # remember the associated label blob name
                labelBlobName = self._getLabelBlobName(layer)
                if labelBlobName is not None and labelBlobName not in self._labelBlobNames:
                    self._labelBlobNames.append(labelBlobName)
                elif labelBlobName is None:
                    nLabelBlobUndetermined += 1

                # remember the associated data blob name and its shape
                dataBlobName = self._getDataBlobName(layer)
                if dataBlobName is not None and dataBlobName not in self._dataBlobNames:
                    self._dataBlobNames.append(dataBlobName)

                    # calculate the shape (this assumes that data layers with the same name are using the same shape)
                    blobShape = [
                    ]  # will be left empty, if shape cannot be calculated automatically
                    if layer["type"].name() in ["Data", "HDF5Data"]:

                        if layer["type"].name() == "Data":
                            path = layer["parameters"].get("data_param",
                                                           {}).get("source")
                            type = layer["parameters"].get("data_param",
                                                           {}).get("backend")
                        elif layer["type"].name() == "HDF5Data":
                            path = layer["parameters"].get(
                                "hdf5_data_param", {}).get("source")
                            type = "HDF5TXT"

                        if path is not None and type is not None:
                            db = DatabaseObject()
                            db.openFromPath(path, type)
                            blobShapeTupel = db.getDimensions()
                            if blobShapeTupel is not None:
                                blobShapeTupel = blobShapeTupel.get(
                                    dataBlobName)
                                if blobShapeTupel is not None:
                                    blobShape = list(blobShapeTupel)

                    self._inputBlobShapes[dataBlobName] = blobShape

        # show warning, if some label blobs do not have the correct name
        if nLabelBlobUndetermined > 0:
            Log.log(
                "{} data layers might have been handled incorrectly, because their top blobs are named "
                "unconventionally. Please change the name of the blobs which provide labels to \"label\"."
                .format(nLabelBlobUndetermined), self._logId)

        # show warning, if some data layers have too many top blobs
        if nTooManyTopBlobs > 0:
            Log.log(
                "{} data layers have too many top blobs. The native caffe version does support only a maximum of 2 "
                "top blobs per data layer. Deployment result might be incorrect."
                .format(nTooManyTopBlobs), self._logId)
示例#27
0
    def save(self):
        try:
            # get user input
            netFullPath = self._textPathNet.text()
            solverFullPath = self._textPathSolver.text()
            solverDirPath = os.path.dirname(solverFullPath)
            netInSeparateFile = (self._comboNet.currentIndex() == self.COMBO_INDEX_SEPARATE_FILE)
            netPathIsValid = len(netFullPath) > 0 and os.path.exists(netFullPath) \
                             and os.path.isfile(netFullPath)

            # ensure that the input isn't empty
            if len(solverFullPath) > 0:

                # ensure that the path (except the base name) does already exist
                folderExists = os.path.isdir(solverDirPath)

                # if it doesn't exist yet, let the user decide whether to create all missing folders
                if not folderExists:
                    reply = QtWidgets.QMessageBox.question(self,
                                                           self.tr("Destination doesn't exist yet."),
                                                           self.tr("Do you want to create all non-existing folders "
                                                                   "in the given path?"),
                                                           QtWidgets.QMessageBox.Yes,
                                                           QtWidgets.QMessageBox.No)
                    if reply == QtWidgets.QMessageBox.Yes:
                        folderExists = True
                        os.makedirs(solverDirPath)

                if folderExists:
                    # ensure that the full path does point to a file and not a folder
                    fileIsNoFolder = not os.path.exists(solverFullPath) or not os.path.isdir(solverFullPath)

                    # input is valid, go ahead and start the actual export
                    if fileIsNoFolder:

                        if netInSeparateFile:

                            # network path does not need to be valid, as we are not doing anything with the referenced file
                            # anyway: let the user decide whether an invalid path is used on purpose
                            if not netPathIsValid:
                                reply = QtWidgets.QMessageBox.question(self,
                                                                       self.tr("Network path seems to be invalid."),
                                                                       "Do you want to continue anyway?",
                                                                       QtWidgets.QMessageBox.Yes,
                                                                       QtWidgets.QMessageBox.No)
                                if reply == QtWidgets.QMessageBox.Yes:
                                    netPathIsValid = True

                            if netPathIsValid:
                                # point to selected network file
                                self._solver["net"] = netFullPath

                                # remove any other references to a network definition
                                if "net_param" in self._solver:
                                    del self._solver["net_param"]
                        else:
                            # include inline definition of the network
                            self._solver["net_param"] = self._network

                            # remove any other references to a network file
                            if "net" in self._solver:
                                del self._solver["net"]

                        # finally, save solver prototxt
                        if not netInSeparateFile or netPathIsValid:
                            with open(solverFullPath, 'w') as file:
                                file.write(saver.saveSolver(self._solver))

                            callerId = Log.getCallerId('export_solver')
                            Log.log(
                                "The solver has been exported successfully to {}.".format(
                                    solverFullPath
                                ), callerId)

                            # add used file paths to the recent file list
                            if self._defaultActions is not None:
                                self._defaultActions.recentSolverData.addRecently(solverFullPath, solverFullPath)

                                if netInSeparateFile and netPathIsValid:
                                    self._defaultActions.recentNetData.addRecently(netFullPath, netFullPath)

                            self.close()
                    else:
                        QtWidgets.QMessageBox.critical(self, self.tr("Can't save solver"),
                                                       self.tr("The given path points to an existing folder instead of "
                                                               "a file."))
            else:
                QtWidgets.QMessageBox.critical(self,
                                               self.tr("Can't save solver"),
                                               self.tr("Please select a valid destination for the solver file."))
        except:
            QtWidgets.QMessageBox.critical(self,
                                           self.tr("Can't save solver"),
                                           self.tr("Unknown error."))
            self.close()
示例#28
0
def loadNet(netstring):
    """ Load the prototxt string "netstring" into a dictionary.
        The dictionary has the following form


        {
            "name": "Somenetwork",
            "input_dim": [1,2,1,1],
            "state": {
                   "phase": "TRAIN"
           },
             ...
            "layers":
            {
                "somerandomid1": {
                    "type": LayerType Instance of Pooling-Layer,
                    "parameters": {
                        "pooling_param": [
                            "kernel_size": 23,
                            "engine": "DEFAULT"
                        ]
                        ....
                        "input_param": [
                            {"shape": {"dim": [...], ....  },
                            {"shape": {"dim": [...], ....  },
                        ]
                    }
                },
              "somerandomid2": {"type": ..., "parameters": ....}
            },
           "layerOrder": ["somerandomid1", "somerandomid2", ....]
        }

    """
    from backend.caffe.path_loader import PathLoader
    proto = PathLoader().importProto()
    # Load Protoclass for parsing
    net = proto.NetParameter()

    # Get DESCRIPTION for meta infos
    descr = info.ParameterGroupDescriptor(net)
    # "Parse" the netdefinition in prototxt-format
    try:
        text_format.Merge(netstring, net)
    except ParseError as ex:
        raise ParseException(str(ex))
    params = descr.parameter().copy()  # All Parameters of the network

    # add logger output if deprecated layers have been found, to inform the user that those can't be parsed yet
    if len(net.layers) > 0:
        callerId = Log.getCallerId('protoxt-parser')
        Log.log(
            "The given network contains deprecated layer definitions which are not supported and will be dropped.",
            callerId)

    # Layers is deprecated, Layer will be handled seperatly and linked to "Layers" key
    del params["layers"]
    del params["layer"]
    if params.has_key("layerOrder"):
        raise ValueError('Key layerOrder not expected!')

    # Extract every other parameters
    res = _extract_param(net, params)

    res["layers"], res["layerOrder"] = _load_layers(net.layer)

    res = copy.deepcopy(res)
    return res
示例#29
0
    def start(self, solverstate=None, caffemodel=None):
        """ Start the process.

        Return
            True if the process was started
            False if the start failed
        """
        if self.getState() is State.WAITING:
            self.rid += 1
            # (re-)write all session files
            self.save(includeProtoTxt=True)
            # check if the session has its own caffeRoot
            caffeBin = self.caffe_bin
            if not caffeBin:
                # else take the project's caffeRoot path
                caffeBin = caffeVersions.getVersionByName(
                    self.project.getCaffeVersion()).getBinarypath()

            try:
                self.getParser().setLogging(True)

                cmd = [caffeBin, 'train', '-solver', self.getSolver()]
                if solverstate:
                    cmd.append('-snapshot')
                    cmd.append(str(solverstate))
                elif caffemodel:
                    cmd.append('-weights')
                    cmd.append(str(caffemodel))
                self.proc = Popen(cmd,
                                  stdout=PIPE,
                                  stderr=STDOUT,
                                  cwd=self.getDirectory())
                try:
                    self.tee = Popen(
                        ['tee', '-a', self.getRunLogFileName()],
                        stdin=self.proc.stdout,
                        stdout=PIPE)
                except Exception as e:
                    # continue without tee
                    Log.error('Failed to start tee: ' + str(e),
                              self.getCallerId())
                self.setState(State.RUNNING)
                Log.log(
                    'Session ' + self.getRunLogFileName(True) + ' was started',
                    self.getCallerId())
                self.__startParsing()
                return True
            except Exception as e:
                # check if caffe root exists
                Log.error('Failed to start session: ' + str(e),
                          self.getCallerId())
                if os.file.exists(
                        caffeVersions.getVersionByName(
                            self.project.getCaffeVersion()).getBinarypath()
                ) is False:
                    Log.error(
                        'CAFFE_BINARY directory does not exists: ' +
                        caffe_bin +
                        '! Please set CAFFE_BINARY to run a session.',
                        self.getCallerId())
        else:
            Log.error(
                'Could not start a session in state ' + str(self.getState()),
                self.getCallerId())
            # self.setState(State.UNDEFINED)
        return False